diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..09c9a597 --- /dev/null +++ b/.gitignore @@ -0,0 +1,191 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +*.swp +*.pyc +*.py~ +*.bak +.pytest_cache +.DS_Store +.idea +.vscode +.coverage +.coverage.* +__pycache__/ +_build/ +*.npz +*.pth +.pytype/ +git_rewrite_commit_history.sh + +# Setuptools distribution and build folders. +/dist/ +/build +keys/ + +# Virtualenv +/env +/venv + + +*.sublime-project +*.sublime-workspace + +.idea + +logs/ + +.ipynb_checkpoints +ghostdriver.log + +htmlcov + +junk +src + +*.egg-info +.cache +*.lprof +*.prof + +configs/arch_gym_configs.py + +sims/Sniper/docker/SniperLink +sims/Sniper/CPU2017/intspeed +sims/Sniper/CPU2017/fpspeed +sims/Sniper/config + +*.code-workspace + +saves/ + +# Astra-sim and results +sims/AstraSim/astra-sim +sims/AstraSim/results diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..9d7547a5 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ + +[submodule "sims/AstraSim/astra-sim"] + path = sims/AstraSim/astra-sim + url = https://github.com/astra-sim/astra-sim.git diff --git a/Project_FARSI/.gitignore b/Project_FARSI/.gitignore new file mode 100644 index 00000000..8f5b1848 --- /dev/null +++ b/Project_FARSI/.gitignore @@ -0,0 +1,6 @@ +.idea/ +data_collection/data +data_collection/.DS_Store +*pyc +__pycache__ +specs/database_data/hardcoded diff --git a/Project_FARSI/BUCKCONFIG_FILE_EXISTS.md b/Project_FARSI/BUCKCONFIG_FILE_EXISTS.md new file mode 100644 index 00000000..e69de29b diff --git a/Project_FARSI/CODE_OF_CONDUCT.md b/Project_FARSI/CODE_OF_CONDUCT.md new file mode 100644 index 00000000..83f431e8 --- /dev/null +++ b/Project_FARSI/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic +address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/Project_FARSI/CONTRIBUTING.md b/Project_FARSI/CONTRIBUTING.md new file mode 100644 index 00000000..4261e030 --- /dev/null +++ b/Project_FARSI/CONTRIBUTING.md @@ -0,0 +1,35 @@ +# Contributing to Project_FARSI +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## Coding Style +* 2 spaces for indentation rather than tabs +* 80 character line length +* ... + +## License +By contributing to Project_FARSI , you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/Project_FARSI/DOCUSAURUS_ENABLED.md b/Project_FARSI/DOCUSAURUS_ENABLED.md new file mode 100644 index 00000000..e69de29b diff --git a/Project_FARSI/DSE_utils/__init__.py b/Project_FARSI/DSE_utils/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/Project_FARSI/DSE_utils/__init__.py @@ -0,0 +1 @@ + diff --git a/Project_FARSI/DSE_utils/design_space_exploration_handler.py b/Project_FARSI/DSE_utils/design_space_exploration_handler.py new file mode 100644 index 00000000..0b94b6bb --- /dev/null +++ b/Project_FARSI/DSE_utils/design_space_exploration_handler.py @@ -0,0 +1,630 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. +from zipfile import ZipFile +from os.path import basename +from design_utils.design import * +from DSE_utils import hill_climbing +from specs.data_base import * +from visualization_utils import vis_hardware, vis_stats,plot +import csv +import dill +import pickle +import matplotlib.pyplot as plt +if config.simulation_method == "power_knobs": + from specs import database_input_powerKnobs as database_input +elif config.simulation_method == "performance": + from specs import database_input +else: + raise NameError("Simulation method unavailable") +import random +import pickle + +# class used for deign handling. +# This class uses an exploration algorithm (such as hill climbing to explore the design space) +# specify the exploration algorithm in the config file. +class DSEHandler: + def __init__(self, result_dir=os.getcwd()): + self.check_pointed_best_sim_dps = [] # list of check pointed simulated designs + self.check_pointed_best_ex_dps = [] # list of check pointed example designs + self.dse = None # design space exploration algorithm + self.database = None # data base (contains hw and sw database for mapping/allocation of hw/sw) + self.IP_library = [] + self.result_dir = result_dir + self.check_point_folder_name = "check_points" + self.check_point_ctr = 0 + return None + + # --------------- + # Functionality: + # set up an exploration dse. + # specify the explorer type in the config file + # --------------- + def explore_exhaustively(self, db_input, hw_sampling, system_workers): + mapping_process_id = system_workers[1] + FARSI_gen_process_id = system_workers[3] + + if config.dse_type == "exhaustive": + self.database = DataBase(db_input, hw_sampling) + self.dse = hill_climbing.HillClimbing(self.database, self.result_dir) + + # generate light systems + start = time.time() + all_light_systems = self.dse.dh.light_system_gen_exhaustively(system_workers, self.database) + print("light system generation time: " + str(time.time() - start)) + print("----- all light system generated for process: " + str(mapping_process_id) + "_" + str(FARSI_gen_process_id)) + + # generate FARSI systems + start = time.time() + all_exs = self.dse.dh.FARSI_system_gen_exhaustively(all_light_systems, system_workers) + print("FARSI system generation time: " + str(time.time() - start)) + print("----- all FARSI system generated for process: " + str(mapping_process_id) + "_" + str(FARSI_gen_process_id)) + + # simulate them + start = time.time() + all_sims = [] + for ex_dp in all_exs: + sim_dp = self.dse.eval_design(ex_dp, self.database) + if config.RUN_VERIFICATION_PER_GEN or config.RUN_VERIFICATION_PER_NEW_CONFIG or config.RUN_VERIFICATION_PER_IMPROVMENT: + self.dse.gen_verification_data(sim_dp, ex_dp) + all_sims.append(sim_dp) + + print("simulation time: " + str(time.time() - start)) + print("----- all FARSI system simulated process: " + str(mapping_process_id) + "_" + str(FARSI_gen_process_id)) + + # collect data + latency = [sim.get_dp_stats().get_system_complex_metric("latency") for sim in all_sims] + power = [sim.get_dp_stats().get_system_complex_metric("power") for sim in all_sims] + area = [sim.get_dp_stats().get_system_complex_metric("area") for sim in all_sims] + energy = [sim.get_dp_stats().get_system_complex_metric("energy") for sim in all_sims] + x = range(0, len(latency)) + + # write into a file + base_dir = os.getcwd() + result_dir = os.path.join(base_dir, config.exhaustive_result_dir ) + if not os.path.exists(result_dir): + os.makedirs(result_dir) + + result_in_list_file_addr = os.path.join(result_dir, + config.exhaustive_output_file_prefix + + str(mapping_process_id) +"_" + str(FARSI_gen_process_id) + '.txt') + with open(result_in_list_file_addr, 'w') as f: + f.write('itr_count: ') + for listitem in x: + f.write('%s,' % str(listitem)) + + result_in_pdf_file_addr = os.path.join(result_dir, + 'exhaustive_for_pid' + str(mapping_process_id) + "_" + str( + FARSI_gen_process_id) + '.txt') + with open(result_in_pdf_file_addr, + 'a') as f: + f.write('\n') + f.write('latency: ') + for listitem in latency: + f.write('%s,' % str(listitem)) + + f.write('\n') + f.write('power: ') + for listitem in power: + f.write('%s,' % str(listitem)) + + f.write('\n') + f.write('energy: ') + for listitem in energy: + f.write('%s,' % str(listitem)) + + f.write('\n') + f.write('area: ') + for listitem in area: + f.write('%s,' % str(listitem)) + + # plot + for metric in ["latency", "area", "power", "energy"]: + fig, ax = plt.subplots() + if metric == "latency": + y = [max(list(el.values())) for el in vars()[metric]] + else: + y = vars()[metric] + ax.scatter(x, y, marker="x") + ax.set_xlabel("iteration count") + ax.set_ylabel(metric) + fig.savefig("exhaustive_"+metric+"_for_pid_"+ str(mapping_process_id) +"_" + str(FARSI_gen_process_id) +".pdf") + + print("done") + else: + print("can not explore exhaustively with dse_type:" + config.dse_type) + exit(0) + + # --------------- + # Functionality: + # set up an exploration dse. + # specify the explorer type in the config file + # --------------- + def setup_an_explorer(self, db_input, hw_sampling): + # body + if config.dse_type == "hill_climbing" or config.dse_type == "moos" or config.dse_type == "simple_greedy_one_sample": + exploration_start_time = time.time() # time hooks (for data collection) + self.database = DataBase(db_input, hw_sampling) # initialize the database + # initializes the design space exploration of certain type + self.dse = hill_climbing.HillClimbing(self.database, self.result_dir) + elif config.dse_type == "exhaustive": + print("this main is not suitable for exhaustive search") + exit(0) + #TODO: fix the following. The following commented code is there + # to guide writing the code + """ + self.database = DataBase(database_input.tasksL, database_input.blocksL, + database_input.pe_mapsL, database_input.pe_schedulesL, database_input.SOCsL) + + # change this later + self.dse = hill_climbing.HillClimbing(self.database) + all_exs = self.dse.dh.gen_des_exhaustively() + all_sims = [] + for ex_dp in all_exs: + all_sims.append(self.dse.eval_design(ex_dp, self.database)) + + latency = [sim.get_dp_stats().get_system_complex_metric("latency") for sim in all_sims] + plot.scatter_plot(range(0, len(latency)), latency, latency, self.database) + self.dse = Exhaustive(self.database) + """ + + # populate the IP library from an external source + # mode: {"python", "csv"} + def populate_IP_library(self, mode="python"): + if (mode == "python"): + for task_name,blocksL in self.database.get_mappable_blocksL_to_tasks().items(): + for task in self.database.get_tasks(): + if task.name == task_name: + for blockL_ in blocksL: + IP_library_element = IPLibraryElement() + IP_library_element.set_task(task) + IP_library_element.set_blockL(blockL_) + IP_library_element.generate() + self.IP_library.append(IP_library_element) + + # latency + for metric in ["latency", "energy", "area", "power"]: + IP_library_dict = defaultdict(dict) + for IP_library_element in self.IP_library: + IP_library_dict[IP_library_element.blockL.block_instance_name][IP_library_element.get_task().name] = \ + IP_library_element.get_PPAC()[metric] + + all_task_names = [task.name for task in self.database.get_tasks()] + IP_library_dict_ordered = defaultdict(dict) + for IP_library_key in IP_library_dict.keys(): + for task_name in all_task_names: + if task_name in IP_library_dict[IP_library_key].keys(): # check if exists + IP_library_dict_ordered[IP_library_key][task_name] = IP_library_dict[IP_library_key][task_name] + else: + IP_library_dict_ordered[IP_library_key][task_name] = "NA" + + # writing into a file + fields = ["tasks"] + all_task_names + with open("IP_library_"+metric+".csv", "w") as f: + w = csv.DictWriter(f, fields) + w.writeheader() + for k in IP_library_dict_ordered: + w.writerow({field: IP_library_dict_ordered[k].get(field) or k for field in fields}) + + # --------------- + # Functionality: + # prepare the exploration by either generating an initial design (mode ="from scratch") + # or using a check-pointed design + # Variables: + # init_des_point: design point to boost trap the exploration with + # boost_SOC: choose a better SOC (This is for multiple SOC design. Not activated yet) + # mode: whether to bootstrap exploration from scratch or from an already existing design. + # --------------- + def prepare_for_exploration(self, boost_SOC, starting_exploration_mode="from_scratch", init_des = ""): + # either generate an initial design point(dh.gen_init_des()) or use a check_pointed one + self.dse.gen_init_ex_dp(starting_exploration_mode, init_des) + self.dse.dh.boos_SOC = boost_SOC + + # --------------- + # Functionality: + # explore the design space + # --------------- + def explore(self): + exploration_start_time = time.time() # time hook (data collection) + if config.heuristic_type in ["FARSI", "SA"]: + self.dse.explore_ds() + if config.heuristic_type == "moos": + self.dse.explore_ds_with_moos() + if config.heuristic_type == "simple_greedy_one_sample": + self.dse.explore_simple_greedy_one_sample(self.dse.init_ex_dp) + + # --------------- + # Functionality: + # explore the one design. Basically simulate the design and profile + # --------------- + def explore_one_design(self): + exploration_start_time = time.time() # time hook (data collection) + self.dse.explore_one_design() + + # copy the DSE results to the result dir + def copy_DSE_data(self, result_dir): + # result_dir_specific = os.path.join(result_dirresult_summary") + os.system("cp " + config.latest_visualization + "/*" + " " + result_dir) + + # ------------------------------ + # Functionality: + # write the results into a file + # Variables: + # sim_dp: design point simulation + # result_dir: result directory + # unique_number: a number to differentiate between designs + # file_name: output file name + # ------------------------------ + def write_one_results(self, sim_dp, dse, reason_to_terminate, case_study, result_dir_specific, unique_number, file_name): + """ + def convert_dict_to_parsable_csv(dict_): + list = [] + for k,v in dict_.items(): + list.append(str(k)+"="+str(v)) + return list + """ + + def convert_tuple_list_to_parsable_csv(list_): + result = "" + for k, v in list_: + result += str(k) + "=" + str(v) + "___" + return result + + def convert_dictionary_to_parsable_csv_with_semi_column(dict_): + result = "" + for k, v in dict_.items(): + result += str(k) + "=" + str(v) + ";" + return result + + if not os.path.isdir(result_dir_specific): + os.makedirs(result_dir_specific) + + compute_system_attrs = sim_dp.dp_stats.get_compute_system_attr() + bus_system_attrs = sim_dp.dp_stats.get_bus_system_attr() + memory_system_attrs = sim_dp.dp_stats.get_memory_system_attr() + speedup_dict, speedup_attrs = sim_dp.dp_stats.get_speedup_analysis(dse) + + output_file_minimal = os.path.join(result_dir_specific, file_name + ".csv") + + base_budget_scaling = sim_dp.database.db_input.sw_hw_database_population["misc_knobs"]["base_budget_scaling"] + + # minimal output + if os.path.exists(output_file_minimal): + output_fh_minimal = open(output_file_minimal, "a") + else: + output_fh_minimal = open(output_file_minimal, "w") + for metric in config.all_metrics: + output_fh_minimal.write(metric + ",") + if metric in sim_dp.database.db_input.get_budget_dict("glass").keys(): + output_fh_minimal.write(metric + "_budget" + ",") + output_fh_minimal.write("sampling_mode,") + output_fh_minimal.write("sampling_reduction" + ",") + for metric, accuracy_percentage in sim_dp.database.hw_sampling["accuracy_percentage"]["ip"].items(): + output_fh_minimal.write( + metric + "_accuracy" + ",") # for now only write the latency accuracy as the other + for block_type, porting_effort in sim_dp.database.db_input.porting_effort.items(): + output_fh_minimal.write( + block_type + "_effort" + ",") # for now only write the latency accuracy as the other + + output_fh_minimal.write( + "output_design_status" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("case_study" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("heuristic_type" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("unique_number" + ",") # for now only write the latency accuracy as the other + + output_fh_minimal.write("SA_total_depth,") + output_fh_minimal.write("reason_to_terminate" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write( + "population generation cnt" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("iteration cnt" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("workload_set" + ",") # for now only write the latency accuracy as the other + # output_fh_minimal.write("iterationxdepth number" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("simulation time" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write( + "move generation time" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write( + "kernel selection time" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write( + "block selection time" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write( + "transformation selection time" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write( + "transformation_selection_mode" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("dist_to_goal_all" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write( + "dist_to_goal_non_cost" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("system block count" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("system PE count" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("system bus count" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("system memory count" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("routing complexity" + ",") # for now only write the latency accuracy as the other + # output_fh_minimal.write("area_breakdown_subtype" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("block_impact_sorted" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write( + "kernel_impact_sorted" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write( + "metric_impact_sorted" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("move_metric" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write( + "move_transformation_name" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("move_kernel" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("move_block_name" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("move_block_type" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("move_dir" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("comm_comp" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write( + "high_level_optimization" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write( + "architectural_principle" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("area_dram" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("area_non_dram" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("channel_cnt" + ",") # for now only write the latency accuracy as the other + for key, val in compute_system_attrs.items(): + output_fh_minimal.write(str(key) + ",") + for key, val in bus_system_attrs.items(): + output_fh_minimal.write(str(key) + ",") + for key, val in memory_system_attrs.items(): + output_fh_minimal.write(str(key) + ",") + + for key, val in speedup_attrs.items(): + output_fh_minimal.write(str(key) + ",") + + for key, val in speedup_dict.items(): + output_fh_minimal.write(str(key) + "_speedup_analysis" + ",") + + for key, val in base_budget_scaling.items(): + output_fh_minimal.write("budget_scaling_" + str(key) + ",") + + output_fh_minimal.write("\n") + for metric in config.all_metrics: + data_ = sim_dp.dp_stats.get_system_complex_metric(metric) + if isinstance(data_, dict): + data__ = convert_dictionary_to_parsable_csv_with_semi_column(data_) + else: + data__ = data_ + + output_fh_minimal.write(str(data__) + ",") + + if metric in sim_dp.database.db_input.get_budget_dict("glass").keys(): + data_ = sim_dp.database.db_input.get_budget_dict("glass")[metric] + if isinstance(data_, dict): + data__ = convert_dictionary_to_parsable_csv_with_semi_column(data_) + else: + data__ = data_ + output_fh_minimal.write(str(data__) + ",") + + output_fh_minimal.write(sim_dp.database.hw_sampling["mode"] + ",") + output_fh_minimal.write(sim_dp.database.hw_sampling["reduction"] + ",") + for metric, accuracy_percentage in sim_dp.database.hw_sampling["accuracy_percentage"]["ip"].items(): + output_fh_minimal.write( + str(accuracy_percentage) + ",") # for now only write the latency accuracy as the other + for block_type, porting_effort in sim_dp.database.db_input.porting_effort.items(): + output_fh_minimal.write(str(porting_effort) + ",") # for now only write the latency accuracy as the other + + if sim_dp.dp_stats.fits_budget(1): + output_fh_minimal.write("budget_met" + ",") # for now only write the latency accuracy as the other + else: + output_fh_minimal.write("budget_not_met" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(case_study + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(config.heuristic_type+ ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str(unique_number) + ",") # for now only write the latency accuracy as the other + + output_fh_minimal.write(str(config.SA_depth) + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str(reason_to_terminate) + ",") # for now only write the latency accuracy as the other + + ma = sim_dp.get_move_applied() # move applied + if not ma == None: + sorted_metrics = convert_tuple_list_to_parsable_csv([(el, val) for el, val in ma.sorted_metrics.items()]) + metric = ma.get_metric() + transformation_name = ma.get_transformation_name() + task_name = ma.get_kernel_ref().get_task_name() + block_type = ma.get_block_ref().type + dir = ma.get_dir() + generation_time = ma.get_generation_time() + sorted_blocks = convert_tuple_list_to_parsable_csv( + [(el.get_generic_instance_name(), val) for el, val in ma.sorted_blocks]) + sorted_kernels = convert_tuple_list_to_parsable_csv( + [(el.get_task_name(), val) for el, val in ma.sorted_kernels.items()]) + blk_instance_name = ma.get_block_ref().get_generic_instance_name() + blk_type = ma.get_block_ref().type + + comm_comp = (ma.get_system_improvement_log())["comm_comp"] + high_level_optimization = (ma.get_system_improvement_log())["high_level_optimization"] + exact_optimization = (ma.get_system_improvement_log())["exact_optimization"] + architectural_variable_to_improve = (ma.get_system_improvement_log())["architectural_principle"] + block_selection_time = ma.get_logs("block_selection_time") + kernel_selection_time = ma.get_logs("kernel_selection_time") + transformation_selection_time = ma.get_logs("transformation_selection_time") + else: # happens at the very fist iteration + sorted_metrics = "" + metric = "" + transformation_name = "" + task_name = "" + block_type = "" + dir = "" + generation_time = '' + sorted_blocks = '' + sorted_kernels = {} + blk_instance_name = '' + blk_type = '' + comm_comp = "" + high_level_optimization = "" + architectural_variable_to_improve = "" + block_selection_time = "" + kernel_selection_time = "" + transformation_selection_time = "" + + routing_complexity = sim_dp.dp_rep.get_hardware_graph().get_routing_complexity() + simple_topology = sim_dp.dp_rep.get_hardware_graph().get_simplified_topology_code() + blk_cnt = sum([int(el) for el in simple_topology.split("_")]) + bus_cnt = [int(el) for el in simple_topology.split("_")][0] + mem_cnt = [int(el) for el in simple_topology.split("_")][1] + pe_cnt = [int(el) for el in simple_topology.split("_")][2] + # itr_depth_multiplied = sim_dp.dp_rep.get_iteration_number()*config.SA_depth + sim_dp.dp_rep.get_depth_number() + + output_fh_minimal.write(str( + sim_dp.dp_rep.get_population_generation_cnt()) + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write( + str(dse.get_total_iteration_cnt()) + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write('_'.join(sim_dp.database.db_input.workload_tasks.keys()) + ",") + # output_fh_minimal.write(str(itr_depth_multiplied)+ ",") # for now only write the latency accuracy as the other + output_fh_minimal.write( + str(sim_dp.dp_rep.get_simulation_time()) + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str(generation_time) + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write( + str(kernel_selection_time) + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str(block_selection_time) + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write( + str(transformation_selection_time) + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write( + str(config.transformation_selection_mode) + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str( + sim_dp.dp_stats.dist_to_goal(metrics_to_look_into=["area", "latency", "power", "cost"], + mode="eliminate")) + ",") + output_fh_minimal.write(str( + sim_dp.dp_stats.dist_to_goal(metrics_to_look_into=["area", "latency", "power"], mode="eliminate")) + ",") + output_fh_minimal.write(str(blk_cnt) + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str(pe_cnt) + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str(bus_cnt) + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str(mem_cnt) + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str(routing_complexity) + ",") # for now only write the latency accuracy as the other + # output_fh_minimal.write(convert_dictionary_to_parsable_csv_with_semi_column(sim_dp.dp_stats.SOC_area_subtype_dict.keys()) + ",") + output_fh_minimal.write(str(sorted_blocks) + ",") + output_fh_minimal.write(str(sorted_kernels) + ",") + output_fh_minimal.write(str(sorted_metrics) + ",") + output_fh_minimal.write(str(metric) + ",") + output_fh_minimal.write(transformation_name + ",") + output_fh_minimal.write(task_name + ",") + output_fh_minimal.write(blk_instance_name + ",") + output_fh_minimal.write(blk_type + ",") + output_fh_minimal.write(str(dir) + ",") + output_fh_minimal.write(str(comm_comp) + ",") + output_fh_minimal.write(str(high_level_optimization) + ",") + output_fh_minimal.write(str(architectural_variable_to_improve) + ",") + output_fh_minimal.write(str(sim_dp.dp_stats.get_system_complex_area_stacked_dram()["dram"]) + ",") + output_fh_minimal.write(str(sim_dp.dp_stats.get_system_complex_area_stacked_dram()["non_dram"]) + ",") + output_fh_minimal.write(str(sim_dp.dp_rep.get_hardware_graph().get_number_of_channels()) + ",") + for key, val in compute_system_attrs.items(): + output_fh_minimal.write(str(val) + ",") + for key, val in bus_system_attrs.items(): + output_fh_minimal.write(str(val) + ",") + for key, val in memory_system_attrs.items(): + output_fh_minimal.write(str(val) + ",") + + for key, val in speedup_attrs.items(): + output_fh_minimal.write(str(val) + ",") + + for key, val in speedup_dict.items(): + output_fh_minimal.write(convert_dictionary_to_parsable_csv_with_semi_column(val) + ",") + + for key, val in base_budget_scaling.items(): + output_fh_minimal.write(str(val) + ",") + + output_fh_minimal.close() + + def write_data(self, unique_number, result_folder, case_study, current_process_id, total_process_cnt, ctr): + # write the results in the general folder + result_dir_specific = os.path.join(result_folder, "result_summary") + self.write_one_results(self.dse.so_far_best_sim_dp, self.dse, self.dse.reason_to_terminate, case_study, + result_dir_specific, unique_number, + config.FARSI_simple_run_prefix + "_" + str(current_process_id) + "_" + str(total_process_cnt)) + self.dse.write_data_log(list(self.dse.get_log_data()), self.dse.reason_to_terminate, case_study, result_dir_specific, unique_number, + config.FARSI_simple_run_prefix + "_" + str(current_process_id) + "_" + str(total_process_cnt)) + + # write the results in the specific folder + result_folder_modified = result_folder+ "/runs/" + str(ctr) + "/" + os.system("mkdir -p " + result_folder_modified) + self.copy_DSE_data(result_folder_modified) + self.write_one_results(self.dse.so_far_best_sim_dp, self.dse, self.dse.reason_to_terminate, case_study, result_folder_modified, unique_number, + config.FARSI_simple_run_prefix + "_" + str(current_process_id) + "_" + str(total_process_cnt)) + + os.system("cp " + config.home_dir+"/settings/config.py"+ " "+ result_folder) + + # --------------- + # Functionality: + # check point the best design. Check pointing allows to iteratively improve the design by + # using the best of the last iteration design. + # --------------- + def check_point_best_design(self, unique_number): + # deactivate check point to prevent running out of memory + if not config.check_pointing_allowed: + return + + # pickle the results for (out of run) verifications. + # make a directory according to the data/time + date_time = datetime.now().strftime('%m-%d_%H-%M_%S') + result_folder = os.path.join(self.result_dir, self.check_point_folder_name) + #date_time + "_" + str(unique_number)) + if not os.path.exists(result_folder): + os.makedirs(result_folder) + # pickle the results in it + if "ex" in config.check_point_list: + zip_file_name = 'ex_dp_pickled.zip' + zip_file_addr = os.path.join(result_folder, zip_file_name) + pickle_file_name = "ex_dp_pickled"+".txt" + pickle_file_addr = os.path.join(result_folder,pickle_file_name) + ex_dp_pickled_file = open(pickle_file_addr, "wb") + dill.dump(self.dse.so_far_best_ex_dp, ex_dp_pickled_file) + ex_dp_pickled_file.close() + + # remove the old zip file + if os.path.isfile(zip_file_addr): + os.remove(zip_file_addr) + + zipObj = ZipFile(zip_file_addr, 'w') + # Add multiple files to the zip + zipObj.write(pickle_file_addr, basename(pickle_file_addr)) + # close the Zip File + zipObj.close() + + # remove the pickle file + os.remove(pickle_file_addr) + + if "db" in config.check_point_list: + #database_pickled_file = open(os.path.join(result_folder, "database_pickled"+".txt"), "wb") + #dill.dump(self.database, database_pickled_file) + #database_pickled_file.close() + zip_file_name = 'database_pickled.zip' + zip_file_addr = os.path.join(result_folder, zip_file_name) + pickle_file_name = "database_pickled"+".txt" + pickle_file_addr = os.path.join(result_folder,pickle_file_name) + database_pickled_file = open(pickle_file_addr, "wb") + dill.dump(self.database, database_pickled_file) + #dill.dump(self.dse.so_far_best_ex_dp, ex_dp_pickled_file) + database_pickled_file.close() + + # remove the old zip file + if os.path.isfile(zip_file_addr): + os.remove(zip_file_addr) + + zipObj = ZipFile(zip_file_addr, 'w') + # Add multiple files to the zip + zipObj.write(pickle_file_addr, basename(pickle_file_addr)) + # close the Zip File + zipObj.close() + + # remove the pickle file + os.remove(pickle_file_addr) + + if "sim" in config.check_point_list: + sim_dp_pickled_file = open(os.path.join(result_folder, "sim_dp_pickled"+".txt"), "wb") + dill.dump(self.dse.so_far_best_sim_dp, sim_dp_pickled_file) + sim_dp_pickled_file.close() + vis_hardware.vis_hardware(self.dse.so_far_best_ex_dp, config.hw_graphing_mode, result_folder) + + if "counters" in config.check_point_list: + counters_pickled_file = open(os.path.join(result_folder, "counters_pickled" + ".txt"), "wb") + dill.dump(self.dse.counters, counters_pickled_file) + counters_pickled_file.close() + #vis_hardware.vis_hardware(self.dse.so_far_best_ex_dp, config.hw_graphing_mode, result_folder) + + for key, val in self.dse.so_far_best_sim_dp.dp_stats.SOC_metric_dict["latency"]["glass"][0].items(): + print("lat is {} for {}".format(val, key)) + burst_size = config.default_burst_size + queue_size = config.default_data_queue_size + print("burst size is {}".format(burst_size)) + print("queue size is {}".format(queue_size)) + + #self.dse.write_data_log(list(self.dse.get_log_data()), self.dse.reason_to_terminate, "", result_folder, self.check_point_ctr, + # config.FARSI_simple_run_prefix) + self.check_point_ctr +=1 diff --git a/Project_FARSI/DSE_utils/exhaustive_DSE.py b/Project_FARSI/DSE_utils/exhaustive_DSE.py new file mode 100644 index 00000000..0f9dc4ab --- /dev/null +++ b/Project_FARSI/DSE_utils/exhaustive_DSE.py @@ -0,0 +1,1190 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +from sympy.functions.combinatorial import numbers +from sympy import factorial +import math +import itertools +from copy import * +import time +import numpy as np +import operator +import collections + +# ------------------------------ +# Functionality: +# calculate stirling values, ie., the number of ways to partition a set. For mathematical understanding refer to +# https://en.wikipedia.org/wiki/Stirling_numbers_of_the_second_kind +# Variables: +# n, k are both stirling inputs. refer to: +# https://en.wikipedia.org/wiki/Stirling_numbers_of_the_second_kind +# ------------------------------ +# n: balls, k: bins +def calc_stirling(n, k): + # multiply by k! if you want the boxes to at least contain 1 value + return numbers.stirling(n, k, d=None, kind=2, signed=False) + + +# ------------------------------ +# Functionality: +# calculate the migration cardinality +# where we can migrate such that tasks can be distributed accross n+1 blocks when we introduce +# n new blocks (with the restriction that each block needs to have at least one task on it) +# Variables: +# num_tasks: number of tasks within the workload. +# num_blcks_to_split_to: number of blocks to maps the tasks to. +# ------------------------------ +def calc_mig_comb_cnt(num_task, num_blcks_to_split_to): + return factorial(num_blcks_to_split_to) * calc_stirling(num_task, num_blcks_to_split_to) + + +def calc_mig_comb_idenitical_blocks_cnt(num_task, num_blcks_to_split_to): + return calc_stirling(num_task, num_blcks_to_split_to) + +# ------------------------------ +# Functionality: +# calculates the upper bound combination associated with the reduced contention options +# Variables: +# num_tasks: number of tasks within the workload. +# num_blcks_to_split_to: number of blocks to maps the tasks to. +# bocks_to_choose_from: total number of blocks that we can choose our num_blcks_to_split_to set from. +# ------------------------------ +def calc_red_cont_up_comb_cnt(num_tasks, num_blcks_to_split_to, blocks_to_choose_from): + allocation_cardinality = (num_blcks_to_split_to - 1)**blocks_to_choose_from + # the reason that this is upper bound is because if a hardware configuration + # uses two blocks of the same kind, then migration can results in a setup that has been already seen. + return allocation_cardinality * calc_mig_comb_cnt(num_tasks, num_blcks_to_split_to) + +# ------------------------------ +# Functionality: +# give tne number of draws from a population, what is the statistical expected coverage value. +# Variables: +# population_size: statistical population cardinality. +# num_of_draws_threshold: bounding the number of samples drawn from the population. +# ------------------------------ +def calc_coverage_exepctation_value(population_size, num_of_draws): + expected_value_of_coverage = population_size * (1 - ((population_size-1)/population_size)**num_of_draws) + return float(expected_value_of_coverage) + +# ------------------------------ +# Functionality: +# Find how many samples we need to draw from the population to achieve a certain coverage. Note that +# we use the expected value of the coverage, that is, on AVERAGE, what the coverage is if we make certain +# num of draws. +# Variables: +# population_size: statistical population cardinality. +# desired_coverage: which percentage of the population size we'd like to cover. +# num_of_draws_threshold: bounding the number of samples drawn from the population. +# num_of_draw_incr: incrementally increase num_of-draws_threshold to meet the desired coverage. +# ------------------------------ +def find_num_draws_to_satisfy_coverage(population_size, desired_coverage, num_of_draws_threshold, num_of_draw_incr): + coverage_satisfied = False + for num_of_draws in range(0, num_of_draws_threshold, num_of_draw_incr): + coverage_expectation = calc_coverage_exepctation_value(population_size, num_of_draws) + if coverage_expectation >= desired_coverage: + coverage_satisfied = True + break + return coverage_satisfied, coverage_expectation, num_of_draws + +# ------------------------------ +# Functionality: +# simple test for sanity check. +# ------------------------------ +def simple_test(): + pop_size = calc_red_cont_up_comb_cnt(30, 2, 5) + coverage_percentage = 0.50 + x, y, z = find_num_draws_to_satisfy_coverage(pop_size, coverage_percentage*pop_size, 4000, 500) + print(x) + +# we don't consider the infeasibility of allocating +# extra blocks despite of no-parallelism +def system_variation_count_1(gen_config): + MAX_TASK_CNT = gen_config["MAX_TASK_CNT"] + DB_MAX_PE_CNT = gen_config["DB_MAX_PE_CNT"] + DB_MAX_MEM_CNT = gen_config["DB_MAX_MEM_CNT"] + DB_MAX_BUS_CNT = gen_config["DB_MAX_BUS_CNT"] + + # assuming that we have 5 different tasks and hence (can have up to 5 different blocks). + # we'd like to know how many different migration/allocation combinations are out there. + # Assumptions: + # PE's are identical. + # Buses are identical + # memory is ignored for now + + # topologies + MAX_PE_CNT = MAX_TASK_CNT + MAX_BUS_CNT = MAX_TASK_CNT + task_cnt = MAX_TASK_CNT + system_variations = 0 + for bus_cnt in range(1, MAX_BUS_CNT+1): + for pe_cnt in range(bus_cnt, MAX_PE_CNT+1): + for mem_cnt in range(bus_cnt, pe_cnt+ 1): + # first calculate the topological variations (split) + topo_incr_1 = calc_stirling(pe_cnt, bus_cnt)*factorial(bus_cnt) # at least one pe per bus + # factorial is used because we + # assume distinct buses because the + # the relative positioning of the buses + # impacts the topology + topo_incr_2 = calc_stirling(mem_cnt, bus_cnt)*factorial(bus_cnt) # at least one memory for + # each bus. Note that this estimation + # is a bit conservative as + # scenarios where number of mems + # exceeds the number of pes connected + # to a bus are not really reasonable + # however, they are considered here. + + # then calculate mapping (migrate) + mapping = calc_stirling(task_cnt, pe_cnt)*factorial(pe_cnt) + # then calculate customization (swap) + swap = math.pow(2, (DB_MAX_BUS_CNT + DB_MAX_PE_CNT + DB_MAX_MEM_CNT)) + system_variations += topo_incr_1*topo_incr_2*mapping*swap + + #print("number of system variations: " + str(float(system_variations))) + #print("exhaustive simulation time (hours): " + str(float(system_variations)/(20*3600))) + return system_variations + #print("{:e}".format(system_variations)) + + +class SYSTEM_(): + def __init__(self): + self.bus_cnt = 0 + self.pe_cnt = 0 + self.mem_cnt = 0 + self.similar_system_cnt = 0 + self.task_cnt = 0 + self.par_task_cnt = 0 + self.pe_set = [] # lay buses as the reference and decorate with mem + self.mem_set = [] + self.mapping_variation_cnt = 0 + self.PE_list = [] + self.MEM_list = [] + self.BUS_list = [] + self.BUS_PE_list = {} # dictionary of bus index and the PEs hanging from them. buses are indexed based on BUS_list + self.BUS_MEM_list = {} # dictionary of bus index and the MEMs hanging from them. buses are indexed based on BUS_list + + def set_bus_cnt(self, bus_cnt): + self.bus_cnt = bus_cnt + + def set_pe_cnt(self, pe_cnt): + self.pe_cnt = pe_cnt + + def set_mem_cnt(self, mem_cnt): + self.mem_cnt = mem_cnt + + def get_bus_cnt(self): + return self.bus_cnt + + def get_pe_cnt(self): + return self.pe_cnt + + def get_mem_cnt(self): + return self.mem_cnt + + def append_pe_set(self, pe_cnt): + self.pe_set.append(pe_cnt) + + def set_pe_set(self, pe_set): + self.pe_set = pe_set + self.pe_cnt = sum(self.get_pe_set()) + + def get_mem_set(self): + return self.mem_set + + def set_mem_set(self, mem_set): + self.mem_set = mem_set + self.mem_cnt = sum(self.get_mem_set()) + + def get_pe_set(self): + return self.pe_set + + def set_task_cnt(self, task_cnt): + self.task_cnt = task_cnt + + def set_par_task_cnt(self, par_task_cnt): + self.par_task_cnt = par_task_cnt + + def parallelism_check(self, gen_config): + MAX_TASK_CNT = gen_config["DB_MAX_TASK_CNT"] + DB_MAX_PE_CNT = gen_config["DB_MAX_PE_CNT"] + + par_task_cnt = self.par_task_cnt + for pe_cnt in self.get_pe_set(): + if pe_cnt - DB_MAX_PE_CNT > 0: + par_task_cnt -= (pe_cnt-DB_MAX_PE_CNT) + if par_task_cnt < 0: + return False + + return True + + # can't have more memory hanging from bus than it's pe's + def pe_mem_check(self): + for idx in range(0, len(self.get_pe_set())): + if self.get_mem_set()[idx] > self.get_pe_set()[idx]: + return False + + return True + + # can't have more pe's than the number of tasks + def pe_cnt_check(self, gen_config): + MAX_TASK_CNT = gen_config["DB_MAX_TASK_CNT"] + if self.get_pe_cnt() > MAX_TASK_CNT: + return False + return True + + def set_pe_task_set(self, pe_task_set): + self.pe_task_set = pe_task_set + + def set_mem_task_set(self, mem_task_set): + self.mem_task_set = mem_task_set + + def get_mem_task_set(self): + return self.mem_task_set + + # get tasks of a bus + def get_task_per_bus(self): + bus_tasks = [] + flattened_indecies = self.flatten_indecies(self.pe_set) + for bus_idx in range(0, len(self.get_pe_set())): + list_unflattened = self.get_pe_task_set()[flattened_indecies[bus_idx]:flattened_indecies[bus_idx+1]] + list_flattened = list(itertools.chain(*list_unflattened)) + bus_tasks.append(list_flattened) + + return bus_tasks + + def get_pe_task_set(self): + return self.pe_task_set + + def get_task_s_pe(self, task_name): + for idx, el in enumerate(self.get_pe_task_set()): + if task_name in el: + return idx + + print("task not found") + return -1 + + def get_task_s_mem(self, task_name): + for idx, el in enumerate(self.get_mem_task_set()): + if task_name in el: + return idx + + print("task not found") + return -1 + + def set_PE_list(self, PE_list): + self.PE_list = PE_list + + def flatten_indecies(self, list): + result = [0] + for el in list: + result.append(result[-1]+el) + return result + + def set_BUS_PE_list(self, PE_list): + flattened_indecies = self.flatten_indecies(self.pe_set) + for idx, BUS in enumerate(self.BUS_list): + self.BUS_PE_list[idx] = PE_list[flattened_indecies[idx]: flattened_indecies[idx + 1]] + + def set_BUS_MEM_list(self, MEM_list): + flattened_indecies = self.flatten_indecies(self.mem_set) + for idx, BUS in enumerate(self.BUS_list): + self.BUS_MEM_list[idx] = MEM_list[flattened_indecies[idx]: flattened_indecies[idx + 1]] + + def get_BUS_list(self): + return self.BUS_list + + def get_BUS_PE_list(self): + return self.BUS_PE_list + + def get_BUS_MEM_list(self): + return self.BUS_MEM_list + + # idx index of the bus we need to know the neighbors for + def get_bus_s_pe_neighbours(self, idx): + return self.BUS_PE_list[idx] + + def get_bus_s_mem_neighbours(self, idx): + return self.BUS_MEM_list[idx] + + def set_MEM_list(self, MEM_list): + self.MEM_list = MEM_list + + def set_BUS_list(self, PE_list): + self.BUS_list = PE_list + + # ----------------- + # system counters + # ----------------- + # at the moment only considering bus and PE (assuming that mem would scale with bus. not the the size but bandwidth) + def simple_customization_variation_cnt(self, gen_config): + DB_MAX_PE_CNT = gen_config["DB_MAX_PE_CNT"] + DB_MAX_MEM_CNT = gen_config["DB_MAX_MEM_CNT"] + DB_MAX_BUS_CNT = gen_config["DB_MAX_BUS_CNT"] + + return pow(DB_MAX_PE_CNT, self.get_pe_cnt())*pow(DB_MAX_BUS_CNT, self.get_bus_cnt())*pow(DB_MAX_MEM_CNT, self.get_mem_cnt()) + + def calc_customization_variation_cnt(self, gen_config): + return self.simple_customization_variation_cnt(gen_config) + + # at the moment, we are considering the upper bounds for the cuts, + # so in reality, there is a gonna be smaller to cuts. + def design_space_reduction(self, gen_config, task_cnt=1): + DB_MAX_PE_CNT = gen_config["DB_MAX_PE_CNT"] + DB_MAX_MEM_CNT = gen_config["DB_MAX_MEM_CNT"] + DB_MAX_BUS_CNT = gen_config["DB_MAX_BUS_CNT"] + MAX_PAR_TASK_CNT = gen_config["MAX_PAR_TASK_CNT"] + MAX_TASK_CNT = gen_config["MAX_TASK_CNT"] + + if DB_MAX_BUS_CNT - task_cnt <= 0: + return pow(DB_MAX_BUS_CNT - 1, task_cnt)/pow(DB_MAX_BUS_CNT, task_cnt) + else: + return (DB_MAX_BUS_CNT-task_cnt)/DB_MAX_BUS_CNT # the task resides in one of the PEs, + # for that one decrement. You need to look at it + # per mapping scenarios. + # note that if task_cnt > 1, + # calculation is a bit harder as it depends on whether + # tasks are binded to the same PE or not. We + # make this assumption since this will results in more conservative (less) + # DS reduction + + + # map the tasks to the pes + # each pe needs to have at least one task + # we get rid of per bus scheduling combinations since PEs + # are not distinguished yet (note that we can't do this across buses + # as the topology (placement) already makes PE's within each bus + # different relative to others + def simple_mapping_variation_cnt(self, gen_config): + MAX_TASK_CNT = gen_config["DB_MAX_TASK_CNT"] + + bins = self.get_pe_cnt() + balls = MAX_TASK_CNT + comb_each_pe_at_least_one_task = calc_stirling(balls, bins)*numbers.factorial(bins) + for pe_cnt in self.get_pe_set(): # since PEs are still the same, get rid of the combinations per bus + comb_each_pe_at_least_one_task /= numbers.factorial(pe_cnt) + + return comb_each_pe_at_least_one_task + + def calc_mapping_variation_cnt(self, gen_config): + return self.simple_mapping_variation_cnt(gen_config) + + def system_get_mapping_variation_cnt(self): + return self.mapping_variation_cnt + +#------------------------- +# system counters +#------------------------- +def system_variation_count_2(gen_config): + full_potential_tasks_list = gen_config["full_potential_tasks_list"] + DB_MAX_PE_CNT = gen_config["DB_MAX_PE_CNT"] + DB_PE_list = gen_config["DB_PE_list"] + DB_MEM_list = gen_config["DB_MEM_list"] + DB_BUS_list = gen_config["DB_BUS_list"] + DB_MAX_MEM_CNT = gen_config["DB_MAX_MEM_CNT"] + DB_MAX_BUS_CNT = gen_config["DB_MAX_BUS_CNT"] + MAX_PAR_TASK_CNT = gen_config["MAX_PAR_TASK_CNT"] + MAX_TASK_CNT = gen_config["MAX_TASK_CNT"] + + # mapping assisted topology + task_cnt = MAX_TASK_CNT + task_list = full_potential_tasks_list[0:MAX_PAR_TASK_CNT-1] + + system_variations = 0 + system_list = [] + + #------------------ + # spawn different system topologies + #------------------ + all_lists = [] + for bus_cnt in range(1, max((DB_MAX_PE_CNT*DB_MAX_BUS_CNT), MAX_PAR_TASK_CNT) + 1): + # generate all the permutations of the pe values + all_lists = [list(range(1, max(DB_MAX_PE_CNT, MAX_PAR_TASK_CNT) + 1))] * bus_cnt + all_lists_permuted = list(itertools.product(*all_lists)) + + # now generated a system with each permutation + for pe_set in all_lists_permuted: + system_ = SYSTEM_() + system_.set_bus_cnt(bus_cnt) + system_.set_pe_set(list(pe_set)) + system_list.append(system_) + + # generate software + for system in system_list: + system.set_task_cnt(MAX_TASK_CNT) + system.set_par_task_cnt(MAX_PAR_TASK_CNT) + + + #------------------ + # filter out infeasible/unreasonable systems + #------------------ + filtered_system_list = [] + for system in system_list: + if not system.pe_cnt_check(): + continue + if not system.parallelism_check(): + continue + + filtered_system_list.append(system) + + + #------------------ + # generate different mapping associated with the system topologies + #------------------ + total_system_variation = 0 + for system in filtered_system_list: + mapping_variations = system.calc_mapping_variation_cnt(gen_config) + + #total_systems = sum([system_.system_get_mapping_variation_cnt() for system_ in system_list]) + #print("number of system variations :" + str(total_system_variation)) + return total_system_variation + #print("exhaustive simulation time (hours):" + str(float(total_system_variation)/(20*3600))) + + +def get_DS_cnt(DS_mode, gen_config, output_mode): + if DS_mode == "exhaustive_naive": + DS_size = system_variation_count_1(gen_config) + elif DS_mode == "exhaustive_reduction_DB_semantics": + DS_size = system_variation_count_2() + + if DS_output_mode == "DS_time": + return DS_size*sim_time_per_design + else: + return DS_size + + +#------------------------- +# system generators +#-------------------------- +# binning balls in to bins where there is +# at least one ball in every bin +def binning(bin_cnt, balls): + # use product + for indices in itertools.product(range(0, bin_cnt), repeat=len(balls)): + result = [[]for _ in range (0, bin_cnt)] + for ball_index, bin_index in enumerate(indices): + result[bin_index].append(balls[ball_index]) + + # discard if any of the bins are empty + res_valid = True + for el in result: + if len(el) == 0: + res_valid = False + if res_valid: + yield result + # yield indices + else: + continue + +# this is for parrallelizaton. it attempts to spreads the workload across the processes equally. +# this can not be perfectly equal as we parallelize based on mapped tasks (and not customized tasks). So, +# there is gonna be imbalances, but this is still better than no parallelization +def shard_work_equally(mapping_customization_variation_cnt_list, process_cnt): + process_id_work_bound_dict = {} + total_mapping = sum([el[0] for el in mapping_customization_variation_cnt_list]) + ideal_work_per_process = sum([el[0]*el[1] for el in mapping_customization_variation_cnt_list])/process_cnt + map_idx_list = [0] + total_system_accumulated = 0 + mapped_system_accumulated = 0 + for map, cust in mapping_customization_variation_cnt_list: + for i in range(1, map+1): + if total_system_accumulated + cust > ideal_work_per_process: + map_idx_list.append(mapped_system_accumulated + max((i-1), 0)) + total_system_accumulated = 0 + else: + total_system_accumulated += cust + + mapped_system_accumulated +=map + + if len(map_idx_list) < process_cnt + 1: + map_idx_list.append(total_mapping) + + for process_id in range(0, len(map_idx_list)-1): + process_id_work_bound_dict[process_id] = (map_idx_list[process_id], map_idx_list[process_id +1]) + + return process_id_work_bound_dict + + +# second map tasks to the mems +def generate_mem_mapping(mapped_systems, process_id_work_dictionary, process_id): + mapped_systems_completed = [] + if not (process_id in process_id_work_dictionary.keys()): # if we couldn't shard properly + exit(0) + system_lower_bound = process_id_work_dictionary[process_id][0] + system_upper_bound = process_id_work_dictionary[process_id][1] + for system in mapped_systems[system_lower_bound: system_upper_bound]: + exessive_memory = False # flaggign scenarios where not enough tasks to mapp to the memoroies + isolated_siink = False + all_task_mappings = [] # all the mappings to the memories + for tasks, mem_cnt in zip(system.get_task_per_bus(), system.get_mem_set()): + + mappings = list(binning(mem_cnt, tasks)) # generating half designs + + # this covers scenarios where there are too many memories, so + # we can distribute tasks to them + if len(mappings) == 0: + exessive_memory = True + break + else: + task_names_ = mappings[0][0] + if any([True for name in task_names_ if "siink" in name]) and len(mappings[0][0]) == 1: # siink can't be occuipying any memory in isolation as it doesn't use any memory + isolated_siink = True + break + else: + all_task_mappings.append(mappings) + + #task_mapping_filtered.append(remove_permutation_per_bus_2(all_task_mappings, system.get_mem_set())) + + if exessive_memory or isolated_siink: # too many memories. One memory would be task less + continue + + all_permutations_tuples = list(itertools.product(*all_task_mappings)) + all_permutations_listified = []#all_permutations_tuples + + + for design in all_permutations_tuples: + new_design = list(design) + #if len(design) == 1: # tuple of size 1: + # new_design.append([[]]) + all_permutations_listified.append(new_design) + + all_permutations_listified_flattened = [] + for design in all_permutations_listified: + #all_tasks = [] # for debugging + design_fixed = [] + for bus_s_mems in design: + for mem in bus_s_mems: + design_fixed.append(mem) + #all_tasks.extend(mem) + all_permutations_listified_flattened.append(design_fixed) + + """ + for task in full_potential_tasks_list: + if not task in all_tasks: + print("what") + """ + + task_mapping_filtered = remove_permutation_per_bus_2(all_permutations_listified_flattened, system.get_mem_set()) + for task_mapping in task_mapping_filtered: + system_ = deepcopy(system) + system_.set_mem_task_set(task_mapping) + mapped_systems_completed.append(system_) + + mapped_systems_completed_filtered = [] # filter scenarios that siink is by itself on memory + for system in mapped_systems_completed: + isolated_siink = False + for tasks_on_mem in system.get_mem_task_set(): + if any([True for task in tasks_on_mem if ("siink" in task and len(tasks_on_mem) == 1)]): + isolated_siink = True + break + if isolated_siink: + continue + mapped_systems_completed_filtered.append(system) + return mapped_systems_completed_filtered + # """ # comment in this if you don't care about task to memory mapping + + +# generate all customization scenarios +def generate_customization(mapped_systems, gen_config): + full_potential_tasks_list = gen_config["full_potential_tasks_list"] + DB_MAX_PE_CNT = gen_config["DB_MAX_PE_CNT"] + DB_PE_list = gen_config["DB_PE_list"] + DB_MEM_list = gen_config["DB_MEM_list"] + DB_BUS_list = gen_config["DB_BUS_list"] + DB_MAX_MEM_CNT = gen_config["DB_MAX_MEM_CNT"] + DB_MAX_BUS_CNT = gen_config["DB_MAX_BUS_CNT"] + + customized_systems = [] + + for idx, system in enumerate(mapped_systems): + pe_cnt = system.get_pe_cnt() + mem_cnt = system.get_mem_cnt() + bus_cnt = system.get_bus_cnt() + PE_scenarios = list(itertools.product(DB_PE_list, repeat=pe_cnt)) + MEM_scenarios = list(itertools.product(DB_MEM_list, repeat=mem_cnt)) + BUS_scenarios = list(itertools.product(DB_BUS_list, repeat=bus_cnt)) + customization_scenarios = list(itertools.product(PE_scenarios, MEM_scenarios, BUS_scenarios)) + for customized_scn in customization_scenarios: + system_ = deepcopy(system) + + system_.set_BUS_list(list(customized_scn[2])) + system_.set_BUS_PE_list(list(customized_scn[0])) + system_.set_BUS_MEM_list(list(customized_scn[1])) + + #system_.set_PE_list(list(customized_scn[0])) + #system_.set_MEM_list(list(customized_scn[1])) + #system_.set_BUS_list(list(customized_scn[2])) + + # quick sanity check + for task_name in full_potential_tasks_list: + if system.get_task_s_mem(task_name) == -1: + print("something went wrong") + exit(0) + customized_systems.append(system_) + return customized_systems + + +def generate_topologies(gen_config): + DB_MAX_PE_CNT = gen_config["DB_MAX_PE_CNT"] + DB_MAX_MEM_CNT = gen_config["DB_MAX_MEM_CNT"] + DB_MAX_BUS_CNT = gen_config["DB_MAX_BUS_CNT"] + MAX_PAR_TASK_CNT = gen_config["DB_MAX_PAR_TASK_CNT"] + MAX_TASK_CNT = gen_config["DB_MAX_TASK_CNT"] + + DB_MIN_PE_CNT = gen_config["DB_MIN_PE_CNT"] + DB_MIN_MEM_CNT = gen_config["DB_MIN_MEM_CNT"] + DB_MIN_BUS_CNT = gen_config["DB_MIN_BUS_CNT"] + #DB_MAX_SYSTEM_to_investigate = gen_config["DB_MAX_SYSTEM_to_investigate"] + + #DB_MIN_PE_CNT = 1 + #DB_MIN_MEM_CNT = 1 + #DB_MIN_BUS_CNT = 1 + + + + system_list = [] + all_lists = [] + for bus_cnt in range(DB_MIN_BUS_CNT, min(DB_MAX_BUS_CNT, MAX_PAR_TASK_CNT) + 1): + # generate all the permutations of the pe values + all_lists = [list(range(1, min(DB_MAX_PE_CNT, MAX_PAR_TASK_CNT) + 1))] * bus_cnt + pe_dist_perm = list(itertools.product(*all_lists)) + + # some filtering, otherwise never finish + pe_dist_perm_filtered = [] + for pe_dist in pe_dist_perm: + if sum(pe_dist) <= MAX_TASK_CNT and sum(pe_dist) >= DB_MIN_PE_CNT and sum(pe_dist)<= DB_MAX_PE_CNT: + pe_dist_perm_filtered.append(pe_dist) + + pe_mem_dist_perm = list(itertools.product(pe_dist_perm_filtered, repeat=2)) + + # now generated a system with each permutation + for pe_set, mem_set in pe_mem_dist_perm: + system_ = SYSTEM_() + system_.set_bus_cnt(bus_cnt) + system_.set_pe_set(list(pe_set)) + system_.set_mem_set(list(mem_set)) + system_list.append(system_) + + # add some sw information + for system in system_list[:len(system_list)]: + system.set_task_cnt(MAX_TASK_CNT) + system.set_par_task_cnt(MAX_PAR_TASK_CNT) + + #------------------ + # filter out infeasible/unreasonable topologies: using parallelism/customization at the moment + #------------------ + filtered_system_list = [] + for system in system_list: + if not system.pe_cnt_check(gen_config): + continue + if not system.parallelism_check(gen_config): + continue + #if not system.pe_mem_check(): + # continue + filtered_system_list.append(system) + + return filtered_system_list[:min(len(filtered_system_list), gen_config["DB_MAX_SYSTEM_to_investigate"])] + +# count all the mapping and customizations +def count_mapping_customization(system_topologies, gen_config): + mapping_customization_variation_cnt_list = [] # keep track of how many mapping variations per topology + total_system_variation_cnt = 0 + for system in system_topologies: + # get system counts per topology (for sanity checks) + mapping_variation_cnt = system.calc_mapping_variation_cnt(gen_config) + customization_variation_cnt = system.calc_customization_variation_cnt(gen_config) + # customization_variation_cnt = 1 + total_system_variation_cnt += mapping_variation_cnt*customization_variation_cnt + mapping_customization_variation_cnt_list.append((mapping_variation_cnt, customization_variation_cnt)) + #print(sum([el[0] for el in mapping_variation_cnt_list])) + return mapping_customization_variation_cnt_list, total_system_variation_cnt + + +# generate all the mappings for the topologies specified in filtered_system_list +def generate_pe_mapping(filtered_system_list, gen_config): + full_potential_tasks_list = gen_config["full_potential_tasks_list"] + DB_MAX_PE_CNT = gen_config["DB_MAX_PE_CNT"] + DB_MAX_MEM_CNT = gen_config["DB_MAX_MEM_CNT"] + DB_MAX_BUS_CNT = gen_config["DB_MAX_BUS_CNT"] + MAX_PAR_TASK_CNT = gen_config["DB_MAX_PAR_TASK_CNT"] + MAX_TASK_CNT = gen_config["DB_MAX_TASK_CNT"] + DB_MAX_PE_CNT_range = gen_config["DB_MAX_PE_CNT_range"] + MAX_TASK_range = gen_config["DB_MAX_TASK_CNT_range"] + MAX_TASK_CNT_range = gen_config["DB_MAX_TASK_CNT_range"] + MAX_PAR_TASK_CNT_range = gen_config["DB_MAX_PAR_TASK_CNT_range"] + DS_output_mode = gen_config["DS_output_mode"] + DB_MAX_SYSTEM_to_investigate = gen_config["DB_MAX_SYSTEM_to_investigate"] + + + mapping_customization_variation_cnt_list = [] + + # ------------------ + # helpers + # ------------------ + # find a the pe idx for a task + def get_tasks_pe_helper(mapping, task): + for idx, tasks in enumerate(mapping): + if task in tasks: + return idx + print("task is not found") + return -1 + + # filter mappings scenarios where source/sink don't map to the same pe as the second/second to last tasks respectively + # Since source and sink are dummies, it doesn't make a differrence if they are mapped to other pes + def filter_sink_source(all_taks_mappings, gen_config): + full_potential_tasks_list = gen_config["full_potential_tasks_list"] + MAX_TASK_CNT = gen_config["DB_MAX_TASK_CNT"] + + result = [] + for mapping in all_task_mappings: + source_idx = get_tasks_pe_helper(mapping, full_potential_tasks_list[0]) + task_1_idx = get_tasks_pe_helper(mapping, full_potential_tasks_list[1]) + if source_idx == -1 or task_1_idx == -1: + print("something went wrong") + exit(0) + if not (source_idx == task_1_idx): + continue + + sink_idx = get_tasks_pe_helper(mapping, full_potential_tasks_list[-1]) + last_task_idx = get_tasks_pe_helper(mapping, full_potential_tasks_list[-2]) + if sink_idx == -1 or last_task_idx == -1: + print("something went wrong") + exit(0) + if not (sink_idx == last_task_idx): + continue + + result.append(mapping) + + if len(result) == 0: # there are cases that because of the set up, the criteria can not be met + return all_taks_mappings + return result + + + mapped_systems = [] + start = time.time() + for idx, system in enumerate(filtered_system_list): + if len(mapped_systems) > DB_MAX_SYSTEM_to_investigate: + break + + all_task_mappings = list(binning(system.get_pe_cnt(), full_potential_tasks_list[0:MAX_TASK_CNT])) + + # Filtering + # flter 1: filter mappings scenarios where source/sink don't map to the same pe as the second/second to last tasks respectively + # Since source and sink are dummies, it doesn't make a differrence if they are mapped to other pes + #task_mapping_filtered_1 = all_task_mappings # uncomment if you don't care about source/sink + task_mapping_filtered_1 = filter_sink_source(all_task_mappings, gen_config) + # filter_2: filter the scenarios where tasks are mapped similarly to the same bus + #task_mapping_filtered_2 = remove_permutation_per_bus_2(task_mapping_filtered_1, system.get_pe_set()) + task_mapping_filtered_2 = task_mapping_filtered_1 + + use_balanced_mapping = True + # just to filter out certain scenarios for debugging. not necessary + if (use_balanced_mapping): + task_mapping_filtered_3 = [] + # find the most balanced design + task_mapping_mapping_std = {} + for idx, task_mapping in enumerate(task_mapping_filtered_2): + task_mapping_mapping_std[idx] = np.std([len(el) for el in task_mapping]) + sorted_x = collections.OrderedDict(sorted(task_mapping_mapping_std.items(), key=operator.itemgetter(1))) + + for idx in range(0, min(1, len(task_mapping_filtered_2))): + blah = task_mapping_filtered_2[list(sorted_x.keys())[idx]] + task_mapping_filtered_3.append(blah) + task_mapping_filtered_2 = task_mapping_filtered_3 + + # populate + for task_mapping in task_mapping_filtered_2: + system_ = deepcopy(system) + system_.set_pe_task_set(task_mapping) + mapped_systems.append(system_) + mapping_customization_variation_cnt_list.append((1, + pow(system_.get_pe_cnt(), DB_MAX_PE_CNT)* + pow(system_.get_mem_cnt(), DB_MAX_MEM_CNT)* + pow(system_.get_bus_cnt(), DB_MAX_BUS_CNT))) + + return mapped_systems,mapping_customization_variation_cnt_list + + +def exhaustive_system_generation(system_workers, gen_config): + mapping_process_cnt = system_workers[0] + mapping_process_id = system_workers[1] + FARSI_gen_process_id = system_workers[1] + #customization_process_cnt = system_workers[2] + #customization_process_id = system_workers[3] + + # mapping assisted topology + #task_cnt = MAX_TASK_CNT + system_variations = 0 + + # generate systems with various topologies + system_topologies = generate_topologies(gen_config) + + # count all the valid mappings and customization (together) + mapping_customization_variation_cnt_list, total_system_variation_cnt = count_mapping_customization(system_topologies, gen_config) + + # generate different mapping + start = time.time() + mapped_pe_systems, mapping_customization_variation_cnt_list = generate_pe_mapping(system_topologies, gen_config) # map tasks to PEs + print("pe mapping time" + str(time.time() - start)) + end = time.time() + #mapped_systems_completed = mapped_pe_systems + + # parallelize + process_id_work_dictionary = shard_work_equally(mapping_customization_variation_cnt_list, mapping_process_cnt) + + # map tasks to MEMs. comment out if you don't care about this + start = time.time() + mapped_systems_completed = generate_mem_mapping(mapped_pe_systems, process_id_work_dictionary, mapping_process_id) + print("mem mapping time" + str(time.time() - start)) + + # generate different customizations + start = time.time() + customized_systems = generate_customization(mapped_systems_completed, gen_config) + print("customization time" + str(time.time() - start)) + + print("total systems to explore for process id" + str(mapping_process_id) + "_" + str(FARSI_gen_process_id)+ "_" + str(len(customized_systems))) + return customized_systems[: min(len(customized_systems), gen_config["DB_MAX_SYSTEM_to_investigate"])] + + +#for DS_type in ["naive", "DB_semantics"]: +#--------------------------- +# sweepers +#--------------------------- +def sweep_DS_info(gen_config): + DB_MAX_PE_CNT = gen_config["DB_MAX_PE_CNT"] + DB_MAX_MEM_CNT = gen_config["DB_MAX_MEM_CNT"] + DB_MAX_BUS_CNT = gen_config["DB_MAX_BUS_CNT"] + MAX_PAR_TASK_CNT = gen_config["MAX_PAR_TASK_CNT"] + MAX_TASK_CNT = gen_config["MAX_TASK_CNT"] + DB_MAX_PE_CNT_range = gen_config["DB_MAX_PE_CNT_range"] + MAX_TASK_range = gen_config["MAX_TASK_CNT_range"] + MAX_TASK_CNT_range = gen_config["MAX_TASK_CNT_range"] + MAX_PAR_TASK_CNT_range = gen_config["MAX_PAR_TASK_CNT_range"] + DS_output_mode= gen_config["DS_output_mode"] + + DS_type = gen_config["DS_type"] + print("DS_type:" + DS_type) + for DB_MAX_PE_CNT in DB_MAX_PE_CNT_range: + + # printing stuff + print("-----------------------------------------") + print("PB DB count:" + str(DB_MAX_PE_CNT)) + print("-----------------------------------------") + print(" ,", end=" ") + for MAX_PAR_TASK_CNT in MAX_PAR_TASK_CNT_range: + print(str(MAX_PAR_TASK_CNT) +",", end =" ") + print("\n") + + DB_MAX_BUS_CNT = DB_MAX_MEM_CNT = DB_MAX_PE_CNT + for MAX_TASK_CNT in MAX_TASK_CNT_range: + print(str(MAX_TASK_CNT) +",", end=" ") + for MAX_PAR_TASK_CNT in MAX_PAR_TASK_CNT_range: + print(str(get_DS_cnt(DS_type, gen_config, DS_output_mode)) +",", end =" ") + print("\n") + +def lists_of_lists_equal(lol1, lol2): + if not (len(lol1) == len(lol2)): + return False + for lol1_el in lol1: + if not(lol1_el in lol2): + return False + return True + +def listify_tuples(list_): + result = [] + for el in list_: + if isinstance(el, tuple): + result.extend(list(el)) + else: + result.extend(el) + return result + +# mapping equal if the tasks under the PEs mapped to the same bus +# are equal +def mapping_equal(system_1, system_2, IP_set_per_bus): + if not len(system_1) == len(system_2): + return False + + IP_set_per_bus_acc = [0] + for idx, IP_set_per_bus_el in enumerate(IP_set_per_bus): + IP_set_per_bus_acc.append(IP_set_per_bus_acc[idx] + IP_set_per_bus_el) + + for idx in range(0, len(IP_set_per_bus_acc) - 1): + idx_low = IP_set_per_bus_acc[idx] + idx_up = IP_set_per_bus_acc[idx + 1] + #system_1_listified = listify_tuples(system_1) + #system_2_listified = listify_tuples(system_2) + #if not lists_of_lists_equal(system_1_listified[idx_low:idx_up], system_2_listified[idx_low:idx_up]): + if not lists_of_lists_equal(system_1[idx_low:idx_up], system_2[idx_low:idx_up]): + return False + + return True + +# since permutations of the blocks per buses do not generate new toopologies (since we haven't +# assigned a IP to them), we need to remove them +def remove_permutation_per_bus_2(PE_with_task_mapping_list, IP_set_per_bus): + + duplicated_systems_idx = [] + # iterate through and check the equality of each design + for idx_x in range(0, len(PE_with_task_mapping_list)): + if idx_x in PE_with_task_mapping_list: + continue + for idx_y in range(idx_x+1, len(PE_with_task_mapping_list)): + if mapping_equal(PE_with_task_mapping_list[idx_x], PE_with_task_mapping_list[idx_y], IP_set_per_bus): + duplicated_systems_idx.append(idx_y) + + non_duplicates = [PE_with_task_mapping_list[idx] for idx in range(0, len(PE_with_task_mapping_list)) + if not(idx in duplicated_systems_idx)] + + + return non_duplicates + +# ------------------ +# some unit test. keep for unit testing +# ------------------ +""" +system_1 = [[1,2,3], [4],[5]] +system_2 = [[1,3,2], [4]] +print(mapping_equal(system_1, system_2, [2,1]) == False) + +system_1 = [[1,2,3], [4,5], [5]] +system_2 = [[4, 5], [1,2,3], [5]] +print(mapping_equal(system_1, system_2, [2,1]) == True) + +system_1 = [[1,2,3], [4,5], [5], [6]] +system_2 = [[4, 5], [1,2,3], [6], [5]] +print(mapping_equal(system_1, system_2, [2,2]) == True) + +system_1 = [[1,2,3], [4,5], [5], [6]] +system_2 = [[4, 5], [1,2,3], [6], [5]] +print(mapping_equal(system_1, system_2, [3,1]) == False) + + +system_1 = [[1,2,3], [4,5], [5], [6]] +system_2 = [[1, 2,3], [5], [6], [4,5]] +print(mapping_equal(system_1, system_2, [1,3]) == True) + + +system_1 = [[1,2,3], [4,5], [5], [6]] +system_2 = [[2,3], [5], [6], [4,5]] +print(mapping_equal(system_1, system_2, [1,3]) == False) + + +system_1 = [[1,2,3], [4,5], [5], [6]] +system_2 = [[2,3, 1], [5], [6], [4,5]] +print(mapping_equal(system_1, system_2, [1,3]) == False) + +system_1 = [[1,2,3], [4,5]] +system_2 = [[4,5], [1,2,3]] +print(mapping_equal(system_1, system_2, [1,1]) == False) + +system_1 = [[1,2,3], [4,5]] +system_2 = [[4,5], [1,2,3]] +print(mapping_equal(system_1, system_2, [2,0]) == True) +print("ok") +""" +""" +# binning tasks to PEs +total_PEs = 3 +PE_with_task_mapping = list(binning(total_PEs, full_potential_tasks_list[:4])) +IP_set_per_bus = [1,1,1] +remove_permutation_per_bus_2(PE_with_task_mapping, IP_set_per_bus) +#a = bin.combinations() +print(results) +""" + +#exhaustive_system_generation() +""" +def gen_test(): + for i in range(0,10): + yield i + + return None + +gen = gen_test() +for _ in gen: + print(_) +""" +#for MAX_TASK_CNT in range(5, 8): +# for MAX_PAR_TASK_CNT in range(1,3): +# system_variation_count_2() + +# all the combinations of balls and bins +# # https://www.careerbless.com/aptitude/qa/permutations_combinations_imp7.php + +import matplotlib.pyplot as plt +plt.style.use('seaborn-whitegrid') +import numpy as np + + +def plot_design_space_size(): + pe_range = [10, 12, 14, 16, 18, 20] + pe_count_design_space_size = {} + knob_count = 4 + for pe_cnt in pe_range: + pe_count_design_space_size[pe_cnt] = count_system_variation(pe_cnt, knob_count, "customization") + + + #colors = {'Sim Time: Sub Sec':'lime', 'Sim Time: Sec':'darkgreen', 'Sim Time: Minute':'darkgreen', 'Sim Time: Hour':'olivedrab', 'Sim Time: Day':'darkgreen'} + colors = {'Sim Time: Sub Sec':'lime', 'Sim Time: Sec':'darkgreen', 'Sim Time: Minute':'darkgreen', 'Sim Time: Hour':'white', 'Sim Time: Day':'white'} + simulation_time = {"Sim Time: Sub Sec": .1*1/60, "Sim Time: Sec": 1/60, "Sim Time: Minute":2, 'Sim Time: hour': 60, 'Sim Time: Day': 60*24} + #simulation_time = {'mili-sec':.001*1/60, 'hour': 60, 'day': 60*24} + selected_simulation_time = {'Sim Time: Hour': 60, 'Sim Time: Day': 60*24} + + + # budget + + font_size =25 + + # cardinality + fig, ax1 = plt.subplots() + #ax1 = ax3.twinx() + ax1.set_xlabel('Number of Processing Elements', fontsize=font_size) + ax1.set_ylabel('Design Space Size', fontsize =font_size) + ax1.plot(list(pe_count_design_space_size.keys()), list(pe_count_design_space_size.values()), label="Cardinality", color='red') + #plt.legend() #loc="upper left") + + + # exploration time + ax2 = ax1.twinx() + cntr = 0 + for k ,v in selected_simulation_time.items(): + ax2.plot(list(pe_count_design_space_size.keys()), + [(el * v)/ (365 * 24 * 60) for el in list(pe_count_design_space_size.values())], label=k, color=colors[k], linestyle=":") + cntr +=1 + + + """ + k = 'Sim Time: Sub Sec' + v = simulation_time[k] + percentage = .0001 + ax2.plot(list(pe_count_design_space_size.keys()), + [(el * v*percentage) / (30 * 24 * 60) for el in list(pe_count_design_space_size.values())], label=k + "+ selected points", color="gold", + linestyle=":") + """ + + #ax2.plot(list(pe_count_design_space_size.keys()), + # [(el * v) / (30 * 24 * 60) for el in list(pe_count_design_space_size.values())], label=k + " + selected points", color="yellow", linestyle=":") + ax2.set_ylabel('Exploration Time (Year)', fontsize=font_size) + + ax3 = ax1.twinx() + #ax3.hlines(y=100, color='blue', linestyle='-', xmin=min(pe_range), xmax=20, label="Exploration Time Budget") + ax3.hlines(y=100, color='white', linestyle='-', xmin=min(pe_range), xmax=20, label="Exploration Time Budget") + + #ax1.hlines(y=.00000005, color='r', linestyle='-', xmin=8, xmax=20) + + + + # ticks and such + ax1.tick_params(axis='y', labelsize=font_size) + ax1.tick_params(axis='x', labelsize=font_size) + + #ax2.tick_params(axis='y', which="both", bottom=False, top=False, right=False, left=False, labelbottom=False) + #ax3.tick_params(axis='y', which="both", bottom=False, top=False, right=False, left=False, labelbottom=False) + #ax3.tick_params(None) + ax1.set_yscale('log') + ax2.set_yscale('log') + ax3.set_yscale('log') + + ax1.set_xticklabels([10, 12, 15, 17, 20]) + ax2.set_yticklabels([]) + ax3.set_yticklabels([]) + + ax2.xaxis.set_ticks_position('none') + ax2.yaxis.set_ticks_position('none') + ax3.xaxis.set_ticks_position('none') + ax3.yaxis.set_ticks_position('none') + + ax1.set_ylim((1, ax1.get_ybound()[1])) + ax2.set_ylim((1, ax1.get_ybound()[1])) + ax3.set_ylim((1, ax1.get_ybound()[1])) + ax1.set_xlim((10, 20)) + ax2.set_xlim((10, 20)) + ax3.set_xlim((10, 20)) + + + + + + # show + #plt.legend() #loc="upper left") + fig.tight_layout() + plt.show() + fig.savefig("exploration_time.png") + + print("ok") +# print("number of system variations: " + str(float(system_variations))) +# print("exhaustive simulation time (hours): " + str(float(system_variations)/(20*3600))) + +def count_system_variation(task_cnt, knob_count, mode="all"): + MAX_TASK_CNT = task_cnt + + # assuming that we have 5 different tasks and hence (can have up to 5 different blocks). + # we'd like to know how many different migration/allocation combinations are out there. + # Assumptions: + # PE's are identical. + # Buses are identical + # memory is ignored for now + + MAX_PE_CNT = MAX_TASK_CNT + task_cnt = MAX_TASK_CNT + system_variations = 0 + num_of_knobs = knob_count + + topology_dict = [1, + 1, + 4, + 38, + 728, + 26704, + 1866256, + 251548592, + 66296291072, + 34496488594816, + 35641657548953344, + 73354596206766622208, + 301272202649664088951808, + 2471648811030443735290891264, + 40527680937730480234609755344896, + 1328578958335783201008338986845427712, + 87089689052447182841791388989051400978432, + 11416413520434522308788674285713247919244640256, + 2992938411601818037370034280152893935458466172698624, + 1569215570739406346256547210377768575765884983264804405248, + 1645471602537064877722485517800176164374001516327306287561310208] + + for pe_cnt in range(1, MAX_PE_CNT): + + topology = topology_dict[pe_cnt-1] + # then calculate mapping (migrate) + mapping = calc_stirling(task_cnt, pe_cnt)*factorial(pe_cnt) + # then calculate customization (swap) + swap = math.pow(num_of_knobs, (pe_cnt)) + + if mode == "all": + system_variations += topology*mapping*swap + if mode == "customization": + system_variations += swap + if mode == "mapping": + system_variations += mapping + if mode == "topology": + system_variations += topology + + + + + + return system_variations + #print("{:e}".format(system_variations)) + +pe_cnt = 20 +knob_count = 4 +ds_size = {} +ds_size_digits = {} +for design_stage in ["topology", "mapping", "customization", "all"]: + ds_size[design_stage] = count_system_variation(pe_cnt, knob_count, design_stage) + ds_size_digits[design_stage] = math.log10(count_system_variation(pe_cnt, knob_count, design_stage)) + + + +#plot_design_space_size() + diff --git a/Project_FARSI/DSE_utils/hill_climbing.py b/Project_FARSI/DSE_utils/hill_climbing.py new file mode 100644 index 00000000..7215bbb5 --- /dev/null +++ b/Project_FARSI/DSE_utils/hill_climbing.py @@ -0,0 +1,3870 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. +from copy import * +from decimal import Decimal +import zipfile +import csv +import _pickle as cPickle +#import ujson +from design_utils.components.hardware import * +from design_utils.components.workload import * +from design_utils.components.mapping import * +from design_utils.components.scheduling import * +from SIM_utils.SIM import * +from design_utils.design import * +from design_utils.des_handler import * +from design_utils.components.krnel import * +from typing import Dict, Tuple, List +from settings import config +from visualization_utils import vis_hardware, vis_stats, plot +from visualization_utils import vis_sim +#from data_collection.FB_private.verification_utils.common import * +import dill +import pickle +import importlib +import gc +import difflib +#from pygmo import * +#from pygmo.util import * +import psutil + + +class Counters(): + def __init__(self): + self.krnel_rnk_to_consider = 0 + self.krnel_stagnation_ctr = 0 + self.fitted_budget_ctr = 0 + self.des_stag_ctr = 0 + self.krnels_not_to_consider = [] + self.population_generation_cnt = 0 + self.found_any_improvement = False + self.total_iteration_ctr = 0 + + def reset(self): + self.krnel_rnk_to_consider = 0 + self.krnel_stagnation_ctr = 0 + self.fitted_budget_ctr = 0 + self.des_stag_ctr = 0 + self.krnels_not_to_consider = [] + self.population_generation_cnt = 0 + self.found_any_improvement = False + + + def update(self, krnel_rnk_to_consider, krnel_stagnation_ctr, fitted_budget_ctr, des_stag_ctr, krnels_not_to_consider, population_generation_cnt, found_any_improvement, total_iteration_ctr): + self.krnel_rnk_to_consider = krnel_rnk_to_consider + self.krnel_stagnation_ctr = krnel_stagnation_ctr + self.fitted_budget_ctr = fitted_budget_ctr + self.des_stag_ctr = des_stag_ctr + self.krnels_not_to_consider = krnels_not_to_consider[:] + self.population_generation_cnt = population_generation_cnt + self.found_any_improvement = found_any_improvement + self.total_iteration_ctr = total_iteration_ctr + #def update_improvement(self, improvement): + # self.found_any_improvement = self.found_any_improvement or improvement + + + # ------------------------------ +# This class is responsible for design space exploration using our proprietary hill-climbing algorithm. +# Our Algorithm currently uses swap (improving the current design) and duplicate (relaxing the contention on the +# current bottleneck) as two main exploration move. +# ------------------------------ +class HillClimbing: + def __init__(self, database, result_dir): + + # parameters (to configure) + self.counters = Counters() + self.found_any_improvement = False + self.result_dir = result_dir + self.fitted_budget_ctr = 0 # counting the number of times that we were able to find a design to fit the budget. Used to terminate the search + self.name_ctr = 0 + self.DES_STAG_THRESHOLD = config.DES_STAG_THRESHOLD # Acceptable iterations count without improvement before termination. + self.TOTAL_RUN_THRESHOLD = config.TOTAL_RUN_THRESHOLD # Total number of iterations to terminate with. + self.neigh_gen_mode = config.neigh_gen_mode # Neighbouring design pts generation mode ("all" "random_one"). + self.num_neighs_to_try = config.num_neighs_to_try # How many neighs to try around a current design point. + self.neigh_sel_mode = config.neigh_sel_mode # Neighbouring design selection mode (best, sometimes, best ...) + self.dp_rank_obj = config.dp_rank_obj # Design point ranking object function(best, sometimes, best ...) + self.num_clusters = config.num_clusters # How many clusters to create everytime we split. + self.budget_coeff = config.max_budget_coeff + self.move_profile = [] + # variables (to initialize) + self.area_explored = [] # List containing area associated with the all explored designs areas. + self.latency_explored = [] # List containing latency associated with all explored designs latency. + self.power_explored = [] # List containing Power associated with all explored designs latency. + self.design_itr = [] # Design iteration counter. Simply indexing (numbering) the designs explored. + self.space_distance = 2 + self.database = database # hw/sw database to use for exploration. + self.dh = DesignHandler(self.database) # design handler for design modification. + self.so_far_best_sim_dp = None # best design found so far through out all iterations. + # For iteratively improvements. + self.cur_best_ex_dp, self.cur_best_sim_dp = None, None # current iteration's best design. + self.last_des_trail = None # last design (trail) + self.last_move = None # last move applied + self.init_ex_dp = None # Initial exploration design point. (Staring design point for the whole algorithm) + self.coeff_slice_size = int(self.TOTAL_RUN_THRESHOLD/config.max_budget_coeff) + #self.hot_krnl_pos = 0 # position of the kernel among the kernel list. Used to found and improve the + # corresponding occupying block. + + self.min_cost_to_consider = .000001 + # Counters: to determine which control path the exploration should take (e.g., terminate, pick another block instead + # of hotblock, ...). + self.des_stag_ctr = 0 # Iteration count since seen last design improvement. + self.population_generation_cnt = 0 # Total iteration count (for termination purposes). + + self.vis_move_trail_ctr = 0 + # Sanity checks (preventing bad configuration setup) + if self.neigh_gen_mode not in ["all", "some"]: raise ValueError() + # TODO: sel_cri needs to be fixed to include combinations of the objective functions + if self.dp_rank_obj not in ["all", "latency", "throughput", "power", "design_cost"]: raise ValueError() + if self.neigh_sel_mode not in ["best", "best_sometime"]: raise ValueError() + self.des_trail_list = [] + self.krnel_rnk_to_consider = 0 # this rank determines which kernel (among the sorted kernels to consider). + # we use this counter to avoid getting stuck + self.krnel_stagnation_ctr = 0 # if the same kernel is selected across iterations and no improvement observed, + # count up + + self.recently_seen_design_ctr = 0 + self.recently_cached_designs = {} + self.cleanup_ctr = 0 # use this to invoke the cleaner once we pass a threshold + + self.SA_current_breadth = -1 # which breadth is current move on + self.SA_current_mini_breadth = 0 # which breadth is current move on + self.SA_current_depth = -1 # which depth is current move on + self.check_point_folder = config.check_point_folder + + self.seen_SOC_design_codes = [] # config code of all the designs seen so far (this is mainly for debugging, concretely + # simulation validation + + self.cached_SOC_sim = {} # cache of designs simulated already. index is a unique code base on allocation and mapping + + self.move_s_krnel_selection = config.move_s_krnel_selection + self.krnels_not_to_consider = [] + self.all_itr_ex_sim_dp_dict: Dict[ExDesignPoint: SimDesignPoint] = {} # all the designs look at + self.reason_to_terminate = "" + self.log_data_list = [] + self.population_observed_ctr = 0 # + self.neighbour_selection_time = 0 + self.total_iteration_cnt = 0 + self.total_iteration_ctr = 0 + self.moos_tree = moosTreeModel(config.budgetted_metrics) # only used for moos heuristic + self.ctr_l = 0 + + def get_total_iteration_cnt(self): + return self.total_iteration_cnt + + def set_check_point_folder(self, check_point_folder): + self.check_point_folder = check_point_folder + + # retrieving the pickled check pointed file + def get_pickeld_file(self, file_addr): + if not os.path.exists(file_addr): + file_name = os.path.basename(file_addr) + file_name_modified = file_name.split(".")[0]+".zip" + dir_name = os.path.dirname(file_addr) + zip_file_addr = os.path.join(dir_name, file_name_modified) + if os.path.exists(zip_file_addr): + with zipfile.ZipFile(zip_file_addr) as thezip: + with thezip.open(file_name, mode='r') as f: + obj = pickle.load(f) + else: + print(file_addr +" does not exist for unpickling") + exit(0) + else: + with open(file_addr, 'rb') as f: # will close() when we leave this block + obj = pickle.load(f) + return obj + + def populate_counters(self, counters): + self.counters = counters + self.krnel_rnk_to_consider = counters.krnel_rnk_to_consider + self.krnel_stagnation_ctr= counters.krnel_stagnation_ctr + self.fitted_budget_ctr = counters.fitted_budget_ctr + self.des_stag_ctr = counters.des_stag_ctr + self.krnels_not_to_consider = counters.krnels_not_to_consider[:] + self.population_generation_cnt = counters.population_generation_cnt + self.found_any_improvement = counters.found_any_improvement + self.total_iteration_ctr = counters.total_iteration_ctr + + # ------------------------------ + # Functionality + # generate initial design point to start the exploration from. + # If mode is from_scratch, the default behavior is to pick the cheapest design. + # If mode is check_pointed, we start from a previously check pointed design. + # If mode is hardcode, we pick a design that is hardcoded. + # Variables + # init_des_point: initial design point + # mode: starting point mode (from scratch or from check point) + # ------------------------------ + def gen_init_ex_dp(self, mode="generated_from_scratch", init_des=""): + if mode == "generated_from_scratch": # start from the simplest design possible + self.init_ex_dp = self.dh.gen_init_des() + elif mode == "generated_from_check_point": + pickled_file_addr = self.check_point_folder + "/" + "ex_dp_pickled.txt" + database_file_addr = self.check_point_folder + "/" + "database_pickled.txt" + counters_file_addr = self.check_point_folder + "/" + "counters_pickled.txt" + #sim_pickled_file_addr = self.check_point_folder + "/" + "sim_dp_pickled.txt" + if "db" in config.check_point_list: + self.database = self.get_pickeld_file(database_file_addr) + if "counters" in config.check_point_list: + self.counters = self.get_pickeld_file(counters_file_addr) + self.init_ex_dp = self.get_pickeld_file(pickled_file_addr) + self.populate_counters(self.counters) + elif mode == "FARSI_des_passed_in": + self.init_ex_dp = init_des + elif mode == "hardcoded": + self.init_ex_dp = self.dh.gen_specific_hardcoded_ex_dp(self.dh.database) + elif mode == "parse": + self.init_ex_dp = self.dh.gen_specific_parsed_ex_dp(self.dh.database) + elif mode == "hop_mode": + self.init_ex_dp = self.dh.gen_specific_design_with_hops_and_stars(self.dh.database) + elif mode == "star_mode": + self.init_ex_dp = self.dh.gen_specific_design_with_a_star_noc(self.dh.database) + else: raise Exception("mode:" + mode + " is not supported") + + + # ------------------------------ + # Functionality: + # Generate one neighbouring design based on the moves available. + # To do this, we first specify a move and then apply it. + # A move is specified by a metric, direction kernel, block, and transformation. + # look for move definition in the move class + # Variables + # des_tup: design tuple. Contains a design tuple (ex_dp, sim_dp). ex_dp: design to find neighbours for. + # sim_dp: simulated ex_dp. + # ------------------------------ + def gen_one_neigh(self, des_tup): + ex_dp, sim_dp = des_tup + + # Copy to avoid modifying the current designs. + #new_ex_dp_pre_mod = copy.deepcopy(ex_dp) # getting a copy before modifying + #new_sim_dp_pre_mod = copy.deepcopy(sim_dp) # getting a copy before modifying + #new_ex_dp = copy.deepcopy(ex_dp) + t1 = time.time() + gc.disable() + new_ex_dp = cPickle.loads(cPickle.dumps(ex_dp, -1)) + gc.enable() + t2 = time.time() + #new_sim_dp = copy.deepcopy(sim_dp) + new_des_tup = (new_ex_dp, sim_dp) + + # ------------------------ + # select (generate) a move + # ------------------------ + # It's important that we do analysis of move selection on the copy (and not the original) because + # 1. we'd like to keep original for further modifications + # 2. for block identification/comparison of the move and the copied design + safety_chk_passed = False + # iterate and continuously generate moves, until one passes some sanity check + while not safety_chk_passed: + move_to_try, total_transformation_cnt = self.sel_moves(new_des_tup, "dist_rank") + safety_chk_passed = move_to_try.safety_check(new_ex_dp) + move_to_try.populate_system_improvement_log() + + move_to_try.set_logs(t2-t1, "pickling_time") + + # ------------------------ + # apply the move + # ------------------------ + # while conduction various validity/sanity checks + try: + self.dh.unload_read_mem(new_des_tup[0]) # unload read memories + move_to_try.validity_check() # call after unload rad mems, because we need to check the scenarios where + # task is unloaded from the mem, but was decided to be migrated/swapped + new_ex_dp_res, succeeded = self.dh.apply_move(new_des_tup, move_to_try) + move_to_try.set_before_after_designs(new_des_tup[0], new_ex_dp_res) + new_ex_dp_res.sanity_check() # sanity check + move_to_try.sanity_check() + self.dh.load_tasks_to_read_mem_and_ic(new_ex_dp_res) # loading the tasks on to memory and ic + new_ex_dp_res.hardware_graph.pipe_design() + new_ex_dp_res.sanity_check() + except Exception as e: + # if the error is already something that we are familiar with + # react appropriately, otherwise, simply raise it. + if e.__class__.__name__ in errors_names: + print("Error: " + e.__class__.__name__) + # TODOs + # for now, just return the previous design, but this needs to be fixed immediately + new_ex_dp_res = ex_dp + #raise e + elif e.__class__.__name__ in exception_names: + print("Exception: " + e.__class__.__name__) + new_ex_dp_res = ex_dp + move_to_try.set_validity(False) + else: + raise e + + + return new_ex_dp_res, move_to_try, total_transformation_cnt + + # ------------------------------ + # Functionality: + # Select a item given a probability distribution. + # The input provides a list of values and their probabilities/fitness,... + # and this function randomly picks a value based on the fitness/probability, ... + # Used for random but prioritized selections (of for example blocks, or kernels) + # input: item_prob_dict {} (item, probability) + # ------------------------------ + def pick_from_prob_dict(self, item_prob_dict): + # now encode this priorities into a encoded histogram (for example with performance + # encoded as 1 and power as 2 and ...) with frequencies + if config.DEBUG_FIX: random.seed(0) + else: time.sleep(.00001), random.seed(datetime.now().microsecond) + + item_encoding = np.arange(0, len(item_prob_dict.keys())) # encoding the metrics from 1 to ... clusters + rand_var_dis = list(item_prob_dict.values()) # distribution + + encoded_metric = np.random.choice(item_encoding, p=rand_var_dis) # cluster (metric) selected + selected_item = list(item_prob_dict.keys())[encoded_metric] + return selected_item + + # ------------------------------ + # Functionality: + # return all the hardware blocks that have same hardware characteristics (e.g, same type and same work-rate and mappability) + # ------------------------------ + def find_matching_blocks(self, blocks): + matched_idx = [] + matching_blocks = [] + for idx, _ in enumerate(blocks): + if idx in matched_idx: + continue + matched_idx.append(idx) + for idx_2 in range(idx+1, len(blocks)): + if blocks[idx].get_generic_instance_name() == blocks[idx_2].get_generic_instance_name(): # for PEs + matching_blocks.append((blocks[idx], blocks[idx_2])) + matched_idx.append(idx_2) + elif blocks[idx].type in ["mem", "ic"] and blocks[idx_2].type in ["mem", "ic"]: # for mem and ic + if blocks[idx].subtype == blocks[idx_2].subtype: + matching_blocks.append((blocks[idx], blocks[idx_2])) + matched_idx.append(idx_2) + return matching_blocks + + def get_task_parallelism_type(self, sim_dp, task, parallel_task): + workload_tasks = sim_dp.database.db_input.workload_tasks + task_s_workload = sim_dp.database.db_input.task_workload[task] + parallel_task_s_workload = sim_dp.database.db_input.task_workload[parallel_task] + if task_s_workload == parallel_task_s_workload: + return "task_level_parallelism" + else: + return "workload_level_parallelism" + + # check if there is another task (on the block that can run in parallel with the task of interest + def check_if_task_can_run_with_any_other_task_in_parallel(self, sim_dp, task, block): + parallelism_type = [] + if task.get_name() in ["souurce", "siink", "dummy_last"]: + return False, parallelism_type + if block.type == "pe": + task_dir = "loop_back" + else: + task_dir = "write" + tasks_of_block = [task_ for task_ in block.get_tasks_of_block_by_dir(task_dir) if + (not ("souurce" in task_.name) or not ("siink" in task_.name))] + if config.parallelism_analysis == "static": + for task_ in tasks_of_block: + if sim_dp.get_dp_rep().get_hardware_graph().get_task_graph().tasks_can_run_in_parallel(task_, task): + return True, _ + elif config.parallelism_analysis == "dynamic": + parallel_tasks_names_ = sim_dp.get_dp_rep().get_tasks_parallel_task_dynamically(task) + tasks_using_diff_pipe_cluster = sim_dp.get_dp_rep().get_tasks_using_the_different_pipe_cluster(task, block) + parallel_tasks_names= list(set(parallel_tasks_names_) - set(tasks_using_diff_pipe_cluster)) + for task_ in tasks_of_block: + if task_.get_name() in parallel_tasks_names: + parallelism_type.append(self.get_task_parallelism_type(sim_dp, task.get_name(), task_.get_name())) + if len(parallelism_type) > 0: + return True, list(set(parallelism_type)) + return False, parallelism_type + + + # ------------------------------ + # Functionality: + # check if there are any tasks across two blocks that can be run in parallel + # this is used for cleaning up, if there is not opportunities for parallelization + # Variables: + # sim_dp: design + # ------------------------------ + def check_if_any_tasks_on_two_blocks_parallel(self, sim_dp, block_1, block_2): + tasks_of_block_1 = [task for task in block_1.get_tasks_of_block_by_dir("write") if not task.is_task_dummy()] + tasks_of_block_2 = [task for task in block_2.get_tasks_of_block_by_dir("write") if not task.is_task_dummy()] + + for idx_1, _ in enumerate(tasks_of_block_1): + tsk_1 = tasks_of_block_1[idx_1] + parallel_tasks_names_ = sim_dp.get_dp_rep().get_tasks_parallel_task_dynamically(tsk_1) + tasks_using_diff_pipe_cluster = sim_dp.get_dp_rep().get_tasks_using_the_different_pipe_cluster(tsk_1, block_1) + parallel_tasks_names = list(set(parallel_tasks_names_) - set(tasks_using_diff_pipe_cluster)) + + + for idx_2, _ in enumerate(tasks_of_block_2): + tsk_2 = tasks_of_block_2[idx_2] + if tsk_1.get_name() == tsk_2.get_name(): + continue + if config.parallelism_analysis == "static": + if sim_dp.get_dp_rep().get_hardware_graph().get_task_graph().tasks_can_run_in_parallel(tsk_1, tsk_2): + return True + elif config.parallelism_analysis == "dynamic": + if tsk_2.get_name() in parallel_tasks_names: + return True + return False + + # ------------------------------ + # Functionality: + # Return all the blocks that are unnecessarily parallelized, i.e., there are no + # tasks across them that can run in parallel. + # this is used for cleaning up, if there is not opportunities for parallelization + # Variables: + # sim_dp: design + # matching_blocks_list: list of hardware blocks with equivalent characteristics (e.g., two A53, or two + # identical acclerators) + # ------------------------------ + def find_blocks_with_all_serial_tasks(self, sim_dp, matching_blocks_list): + matching_blocks_list_filtered = [] + for matching_blocks in matching_blocks_list: + if self.check_if_any_tasks_on_two_blocks_parallel(sim_dp, matching_blocks[0], matching_blocks[1]): + continue + matching_blocks_list_filtered.append((matching_blocks[0], matching_blocks[1])) + + return matching_blocks_list_filtered + + # ------------------------------ + # Functionality: + # search through all the blocks and return a pair of blocks that cleanup can apply to + # Variables: + # sim_dp: design + # ------------------------------ + def pick_block_pair_to_clean_up(self, sim_dp, block_pairs): + if len(block_pairs) == 0: + return block_pairs + + cleanup_ease_list = [] + block_pairs_sorted = [] # sorting the pairs elements (within each pair) based on the number of tasks on each + for blck_1, blck_2 in block_pairs: + if blck_2.type == "ic": # for now ignore ics + continue + elif blck_2.type == "mem": + if self.database.check_superiority(blck_1, blck_2): + block_pairs_sorted.append((blck_1, blck_2)) + else: + block_pairs_sorted.append((blck_2, blck_1)) + else: + if len(blck_1.get_tasks_of_block()) < len(blck_2.get_tasks_of_block()): + block_pairs_sorted.append((blck_1, blck_2)) + else: + block_pairs_sorted.append((blck_2, blck_1)) + + distance = len(sim_dp.get_dp_rep().get_hardware_graph().get_path_between_two_vertecies(blck_1, blck_2)) + num_tasks_to_move = min(len(blck_1.get_tasks_of_block()), len(blck_2.get_tasks_of_block())) + + cleanup_ease_list.append(distance + num_tasks_to_move) + + # when we need to clean up the ics, ignore for now + if len(cleanup_ease_list) == 0: + return [] + + picked_easiest = False + min_ease = 100000000 + for idx, ease in enumerate(cleanup_ease_list): + if ease < min_ease: + picked_easiest = True + easiest_pair = block_pairs_sorted[idx] + min_ease = ease + return easiest_pair + + # ------------------------------ + # Functionality: + # used to determine if two different task can use the same accelerators. + # ------------------------------ + def are_same_ip_tasks(self, task_1, task_2): + return (task_1.name, task_2.name) in self.database.db_input.misc_data["same_ip_tasks_list"] or (task_2.name, task_1.name) in self.database.db_input.misc_data["same_ip_tasks_list"] + + # ------------------------------ + # Functionality: + # find all the tasks that can run on the same ip (accelerator) + # Variables: + # des_tup: (design, simulated design) + # ------------------------------ + def find_task_with_similar_mappable_ips(self, des_tup): + ex_dp, sim_dp = des_tup + #krnls = sim_dp.get_dp_stats().get_kernels() + blcks = ex_dp.get_blocks() + pe_blocks = [blck for blck in blcks if blck.type=="pe"] + tasks_sub_ip_type = [] # (task, sub_ip) + matches = [] + for blck in pe_blocks: + tasks_sub_ip_type.extend(zip(blck.get_tasks_of_block(), [blck.subtype]*len(blck.get_tasks_of_block()))) + + for task, sub_ip_type in tasks_sub_ip_type: + check_for_similarity = False + if sub_ip_type == "ip": + check_for_similarity = True + if not check_for_similarity: + continue + + for task_2, sub_ip_type_2 in tasks_sub_ip_type: + if task_2.name == task.name : + continue + if self.are_same_ip_tasks(task, task_2): + for blck in pe_blocks: + if task_2 in blck.get_tasks_of_block(): + block_to_migrate_from = blck + if task in blck.get_tasks_of_block(): + block_to_migrate_to = blck + + if not (block_to_migrate_to == block_to_migrate_from): + matches.append((task_2, block_to_migrate_from, task, block_to_migrate_to)) + + if len(matches) == 0: + return None, None, None, None + else: + return random.choice(matches) + + # ------------------------------ + # Functionality: + # pick a block pair to apply cleaning to + # Variables: + # des_tup: (design, simulated design) + # ------------------------------ + def gen_block_match_cleanup_move(self, des_tup): + ex_dp, sim_dp = des_tup + krnls = sim_dp.get_dp_stats().get_kernels() + blcks = ex_dp.get_blocks() + + # move tasks to already generated IPs + # clean up the matching blocks + matching_blocks_list = self.find_matching_blocks(blcks) + matching_blocks_lists_filtered = self.find_blocks_with_all_serial_tasks(sim_dp, matching_blocks_list) + return self.pick_block_pair_to_clean_up(sim_dp, matching_blocks_lists_filtered) + + # ------------------------------ + # Functionality: + # is current iteration a clean up iteration (ie., should be used for clean up) + # ------------------------------ + def is_cleanup_iter(self): + result = (self.cleanup_ctr % (config.cleaning_threshold)) >= (config.cleaning_threshold- config.cleaning_consecutive_iterations) + return result + + def get_block_attr(self, selected_metric): + if selected_metric == "latency": + selected_metric_to_sort = 'peak_work_rate' + elif selected_metric == "power": + #selected_metric_to_sort = 'work_over_energy' + selected_metric_to_sort = 'one_over_power' + elif selected_metric == "area": + selected_metric_to_sort = 'one_over_area' + else: + print("selected_selected_metric: " + selected_metric + " is not defined") + return selected_metric_to_sort + + def select_block_to_migrate_to(self, ex_dp, sim_dp, hot_blck_synced, selected_metric, sorted_metric_dir, selected_krnl): + # get initial information + locality_type = [] + parallelism_type =[] + task = ex_dp.get_hardware_graph().get_task_graph().get_task_by_name(selected_krnl.get_task_name()) + selected_metric = list(sorted_metric_dir.keys())[-1] + selected_dir = sorted_metric_dir[selected_metric] + # find blocks equal or immeidately better + equal_imm_blocks_present_for_migration = self.dh.get_equal_immediate_blocks_present(ex_dp, hot_blck_synced, + selected_metric, selected_dir, [task]) + + + # does parallelism exist in the current occupying block + current_block_parallelism_exist, parallelism_type = self.check_if_task_can_run_with_any_other_task_in_parallel(sim_dp, + task, + hot_blck_synced) + inequality_dir = selected_dir*-1 + results_block = [] # results + + task_s_blocks = ex_dp.get_hardware_graph().get_blocks_of_task(task) + if len(task_s_blocks) == 0: + print("a task must have at lease three blocks") + exit(0) + + + remove_list = [] # list of blocks to remove from equal_imm_blocks_present_for_migration + # improve locality by only allowing migration to the PE/MEM close by + if hot_blck_synced.type == "mem": + # only keep memories that are connected to the IC neighbour of the task's pe + # This is to make sure that we keep data local (to the router), instead of migrating to somewhere far + task_s_pe = [blk for blk in task_s_blocks if blk.type == "pe"][0] # get task's pe + tasks_s_ic = [el for el in task_s_pe.get_neighs() if el.type == "ic"][0] # get pe's ic + potential_mems = [el for el in tasks_s_ic.get_neighs() if el.type == "mem"] # get ic's memories + for el in equal_imm_blocks_present_for_migration: + if el not in potential_mems: + remove_list.append(el) + locality_type = ["spatial_locality"] + for el in remove_list: + equal_imm_blocks_present_for_migration.remove(el) + elif hot_blck_synced.type == "pe": + # only keep memories that are connected to the IC neighbour of the task's pe + # This is to make sure that we keep data local (to the router), instead of migrating to somewhere far + task_s_mems = [blk for blk in task_s_blocks if blk.type == "mem"] # get task's pe + potential_pes = [] + for task_s_mem in task_s_mems: + tasks_s_ic = [el for el in task_s_mem.get_neighs() if el.type == "ic"][0] # get pe's ic + potential_pes.extend([el for el in tasks_s_ic.get_neighs() if el.type == "pe"]) # get ic's memories + for el in equal_imm_blocks_present_for_migration: + if el not in potential_pes: + remove_list.append(el) + locality_type = ["spatial_locality"] + for el in remove_list: + equal_imm_blocks_present_for_migration.remove(el) + + # iterate through the blocks and find the best one + for block_to_migrate_to in equal_imm_blocks_present_for_migration: + # skip yourself + if block_to_migrate_to == hot_blck_synced: + continue + + block_metric_attr = self.get_block_attr(selected_metric) # metric to pay attention to + # iterate and found blocks that are at least as good as the current block + if getattr(block_to_migrate_to, block_metric_attr) == getattr(hot_blck_synced, block_metric_attr): + # blocks have similar attr value + if (selected_metric == "power" and selected_dir == -1) or \ + (selected_metric == "latency" and selected_dir == 1) or (selected_metric == "area"): + # if we want to slow down (reduce latency, improve power), look for parallel task on the other block + block_to_mig_to_parallelism_exist, parallelism_type = self.check_if_task_can_run_with_any_other_task_in_parallel(sim_dp, + task, + block_to_migrate_to) + if (selected_metric == "area" and selected_dir == -1): + # no parallelism possibly allows for theo the other memory to shrink + if not block_to_mig_to_parallelism_exist: + results_block.append(block_to_migrate_to) + parallelism_type = ["serialism"] + else: + if block_to_mig_to_parallelism_exist: + results_block.append(block_to_migrate_to) + parallelism_type = ["serialism"] + else: + # if we want to accelerate (improve latency, get more power), look for parallel task on the same block + if current_block_parallelism_exist: + results_block.append(block_to_migrate_to) + elif inequality_dir*getattr(block_to_migrate_to, block_metric_attr) > inequality_dir*getattr(hot_blck_synced, block_metric_attr): + results_block.append(block_to_migrate_to) + break + + # if no block found, just load the results_block with current block + if len(results_block) == 0: + results_block = [hot_blck_synced] + found_block_to_mig_to = False + else: + found_block_to_mig_to = True + + # pick at random to try random scenarios. At the moment, only equal and immeidately better blocks are considered + random.seed(datetime.now().microsecond) + result_block = random.choice(results_block) + + selection_mode = "batch" + if found_block_to_mig_to: + if getattr(result_block, block_metric_attr) == getattr(hot_blck_synced, block_metric_attr): + selection_mode = "batch" + else: + selection_mode = "single" + + + return result_block, found_block_to_mig_to, selection_mode, parallelism_type, locality_type + + + def is_system_ic(self, ex_dp, sim_dp, blck): + if not sim_dp.dp_stats.fits_budget(1): + return False + elif sim_dp.dp_stats.fits_budget(1) and not self.dram_feasibility_check_pass(ex_dp): + return False + else: + for block in ex_dp.get_hardware_graph().get_blocks(): + neighs = block.get_neighs() + if any(el for el in neighs if el.subtype == "dram"): + if block == blck: + return True + return False + + def bus_has_pe_mem_topology_for_split(self, ex_dp, sim_dp, ref_task, block): + if not block.type == "ic" or ref_task.is_task_dummy(): + return False + found_pe_block = False + found_mem_block = False + + migrant_tasks = self.dh.find_parallel_tasks_of_task_in_block(ex_dp, sim_dp, ref_task, block)[0] + migrant_tasks_names = [el.get_name() for el in migrant_tasks] + mem_neighs = [el for el in block.get_neighs() if el.type == "mem"] + pe_neighs = [el for el in block.get_neighs() if el.type == "pe"] + + for neigh in pe_neighs: + neigh_tasks = [el.get_name() for el in neigh.get_tasks_of_block_by_dir("loop_back")] + # if no overlap skip + if len(list(set(migrant_tasks_names) - set(neigh_tasks) )) == len(migrant_tasks_names): + continue + else: + found_pe_block = True + break + + for neigh in mem_neighs: + neigh_tasks = [el.get_name() for el in neigh.get_tasks_of_block_by_dir("write")] + # if no overlap skip + if len(list(set(migrant_tasks_names) - set(neigh_tasks) )) == len(migrant_tasks_names): + continue + else: + found_mem_block = True + break + + + if found_pe_block and found_mem_block : + return True + else: + return False + + def get_feasible_transformations(self, ex_dp, sim_dp, hot_blck_synced, selected_metric, selected_krnl, sorted_metric_dir): + + # if this knob is set, we randomly pick a transformation + # THis is to illustrate the architectural awareness of FARSI a + if config.transformation_selection_mode == "random": + all_transformations = config.all_available_transformations + return all_transformations + + # pick a transformation smartly + imm_block = self.dh.get_immediate_block_multi_metric(hot_blck_synced, selected_metric, sorted_metric_dir, hot_blck_synced.get_tasks_of_block()) + task = ex_dp.get_hardware_graph().get_task_graph().get_task_by_name(selected_krnl.get_task_name()) + feasible_transformations = set(config.metric_trans_dict[selected_metric]) + + # find the block that is at least as good as the block (for migration) + # if can't find any, we return the same block + selected_metric = list(sorted_metric_dir.keys())[-1] + selected_dir = sorted_metric_dir[selected_metric] + + equal_imm_block_present_for_migration, found_blck_to_mig_to, selection_mode, parallelism_type, locality_type = self.select_block_to_migrate_to(ex_dp, sim_dp, hot_blck_synced, + selected_metric, sorted_metric_dir, selected_krnl) + + hot_block_type = hot_blck_synced.type + hot_block_subtype = hot_blck_synced.subtype + + parallelism_exist, parallelism_type = self.check_if_task_can_run_with_any_other_task_in_parallel(sim_dp, task, hot_blck_synced) + other_block_parallelism_exist = False + all_transformations = config.metric_trans_dict[selected_metric] + can_improve_locality = self.can_improve_locality(ex_dp, hot_blck_synced, task) + can_improve_routing = self.can_improve_routing(ex_dp, sim_dp, hot_blck_synced, task) + + bus_has_pe_mem_topology_for_split = self.bus_has_pe_mem_topology_for_split(ex_dp, sim_dp, task,hot_blck_synced) + # ------------------------ + # based on parallelism, generate feasible transformations + # ------------------------ + if parallelism_exist: + if selected_metric == "latency": + if selected_dir == -1: + if hot_block_type == "pe": + feasible_transformations = ["migrate", "split"] # only for PE since we wont to be low cost, for IC/MEM cost does not increase if you customize + else: + if hot_block_type == "ic": + mem_neighs = [el for el in hot_blck_synced.get_neighs() if el.type == "mem"] + pe_neighs = [el for el in hot_blck_synced.get_neighs() if el.type == "pe"] + if len(mem_neighs) <= 1 or len(pe_neighs) <= 1 or not bus_has_pe_mem_topology_for_split: + feasible_transformations = ["swap", "split_swap"] # ", "swap", "split_swap"] + else: + feasible_transformations = ["migrate", "split"] # ", "swap", "split_swap"] + else: + feasible_transformations = ["migrate", "split"] #", "swap", "split_swap"] + else: + # we can do better by comparing the advantage disadvantage of migrating + # (Advantage: slowing down by serialization, and disadvantage: accelerating by parallelization) + feasible_transformations = ["swap"] + if selected_metric == "power": + if selected_dir == -1: + # we can do better by comparing the advantage disadvantage of migrating + # (Advantage: slowing down by serialization, and disadvantage: accelerating by parallelization) + feasible_transformations = ["swap", "split_swap"] + else: + feasible_transformations = all_transformations + if selected_metric == "area": + if selected_dir == -1: + if hot_block_subtype == "pe": + feasible_transformations = ["migrate", "swap"] + else: + feasible_transformations = ["migrate", "swap", "split_swap"] + else: + feasible_transformations = all_transformations + elif not parallelism_exist: + if selected_metric == "latency": + if selected_dir == -1: + feasible_transformations = ["swap", "split_swap"] + else: + feasible_transformations = ["swap", "migrate"] + if selected_metric == "power": + if selected_dir == -1: + feasible_transformations = ["migrate", "swap", "split_swap"] + if selected_metric == "area": + if selected_dir == -1: + feasible_transformations = ["migrate", "swap","split_swap"] + else: + feasible_transformations = ["migrate", "swap", "split"] + + # ------------------------ + # based on locality, generate feasible transformations + # ------------------------ + if can_improve_locality and ('transfer' in config.all_available_transformations): + # locality not gonna improve area with the current set up + if not selected_metric == "area" and selected_dir == -1: + feasible_transformations.append("transfer") + + #------------------------ + # there is a on opportunity for routing + #------------------------ + if can_improve_routing and ('routing' in config.all_available_transformations): + transformation_list = list(feasible_transformations) + transformation_list.append('routing') + feasible_transformations = set(transformation_list) + + + #------------------------ + # post processing of the destination blocks to eliminate transformations + #------------------------ + # filter migrate + if not found_blck_to_mig_to: + # if can't find a block that is at least as good as the current block, can't migrate + feasible_transformations = set(list(set(feasible_transformations) - set(['migrate']))) + + # filter split + number_of_task_on_block = 0 + if hot_blck_synced.type == "pe": + number_of_task_on_block = len(hot_blck_synced.get_tasks_of_block()) + else: + number_of_task_on_block = len(hot_blck_synced.get_tasks_of_block_by_dir("write")) + if number_of_task_on_block == 1: # can't split an accelerator + feasible_transformations = set(list(set(feasible_transformations) - set(['split', 'split_swap'] ))) + + # filter swap + block_metric_attr = self.get_block_attr(selected_metric) # metric to pay attention to + if getattr(imm_block, block_metric_attr) == getattr(hot_blck_synced, block_metric_attr): + #if imm_block.get_generic_instance_name() == hot_blck_synced.get_generic_instance_name(): + # if can't swap improve, get rid of swap + feasible_transformations = set(list(set(feasible_transformations) - set(['swap']))) + + # for IC's we don't use migrate + if hot_blck_synced.type in ["ic"]: + # we don't cover migrate for ICs at the moment + # TODO: add this feature later + feasible_transformations = set(list(set(feasible_transformations) - set(['migrate', 'split_swap']))) + + # if no valid transformation left, issue the identity transformation (where nothing changes and a simple copying is done) + if len(list(set(feasible_transformations))) == 0: + feasible_transformations = ["identity"] + + + return feasible_transformations + + def set_design_space_size(self, ex_dp, sim_dp): + # if this knob is set, we randomly pick a transformation + # THis is to illustrate the architectural awareness of FARSI a + + buses = [el for el in ex_dp.get_blocks() if el.type == "ic"] + mems = [el for el in ex_dp.get_blocks() if el.type == "mem"] + srams = [el for el in ex_dp.get_blocks() if el.type == "sram"] + drams = [el for el in ex_dp.get_blocks() if el.type == "dram"] + pes = [el for el in ex_dp.get_blocks() if el.type == "pe"] + ips = [el for el in ex_dp.get_blocks() if el.subtype == "ip"] + gpps = [el for el in ex_dp.get_blocks() if el.subtype == "gpp"] + all_blocks = ex_dp.get_blocks() + + # per block + # for PEs + for pe in gpps: + number_of_task_on_block = len(pe.get_tasks_of_block()) + #sim_dp.neighbouring_design_space_size["hardening"] += number_of_task_on_block + 1# +1 for swap, the rest is for split_swap + sim_dp.neighbouring_design_space_size += number_of_task_on_block + 1 # +1 for swap, the rest is for split_swap + for pe in ips: + #sim_dp.neighbouring_design_space_size["softening"] += 1 + sim_dp.neighbouring_design_space_size += 1 + + # for all + for blck in all_blocks: + for mode in ["frequency_modulation", "bus_width_modulation", "loop_iteration_modulation"]: + if not blck.type =="pe": + if mode == "loop_iteration_modulation": + continue + #sim_dp.neighbouring_design_space_size[mode] += 2 # going up or down + sim_dp.neighbouring_design_space_size += 2 # going up or down + + for blck in all_blocks: + for mode in ["allocation"]: + if blck.type == "ic": + continue + number_of_task_on_block = len(blck.get_tasks_of_block()) + #sim_dp.neighbouring_design_space_size[mode] += number_of_task_on_block + 1 # +1 is for split, the rest os for split_swap + sim_dp.neighbouring_design_space_size += number_of_task_on_block + 1 # +1 is for split, the rest os for split_swap + + for blck in all_blocks: + equal_imm_blocks_present_for_migration = self.dh.get_equal_immediate_blocks_present(ex_dp, + blck, + "latency", + -1, + blck.get_tasks_of_block()) + + equal_imm_blocks_present_for_migration.extend(self.dh.get_equal_immediate_blocks_present(ex_dp, + blck, + "latency", + +1, + blck.get_tasks_of_block())) + + """ + imm_blocks_present_for_migration.extend([self.dh.get_immediate_block( + blck, + "latency", + -1, + blck.get_tasks_of_block())]) + """ + #other_blocks_to_map_to_lengths = len(equal_imm_blocks_present_for_migration) - len(imm_blocks_present_for_migration) # subtract to avoid double counting + other_blocks_to_map_to_lengths = 0 + for el in equal_imm_blocks_present_for_migration: + if el == blck: + continue + elif not el.type == blck.type: + continue + else: + other_blocks_to_map_to_lengths +=1 + + #other_blocks_to_map_to_lengths = len(equal_imm_blocks_present_for_migration) + #sim_dp.neighbouring_design_space_size[blck.type+"_"+"mapping"] += len(blck.get_tasks_of_block())*other_blocks_to_map_to_lengths + sim_dp.neighbouring_design_space_size += len(blck.get_tasks_of_block())*other_blocks_to_map_to_lengths + + + def get_transformation_design_space_size(self, move_to_apply, ex_dp, sim_dp, block_of_interest, selected_metric, sorted_metric_dir): + # if this knob is set, we randomly pick a transformation + # THis is to illustrate the architectural awareness of FARSI a + imm_block = self.dh.get_immediate_block_multi_metric(block_of_interest, selected_metric, sorted_metric_dir, block_of_interest.get_tasks_of_block()) + + task = (block_of_interest.get_tasks_of_block())[0] # any task for do + feasible_transformations = set(config.metric_trans_dict[selected_metric]) + + # find the block that is at least as good as the block (for migration) + # if can't find any, we return the same block + selected_metric = list(sorted_metric_dir.keys())[-1] + selected_dir = sorted_metric_dir[selected_metric] + + equal_imm_blocks_present_for_migration = self.dh.get_equal_immediate_blocks_present(ex_dp, block_of_interest, + selected_metric, selected_dir, [task]) + if len(equal_imm_blocks_present_for_migration) == 1 and equal_imm_blocks_present_for_migration[0] == block_of_interest: + equal_imm_blocks_present_for_migration = [] + + buses = [el for el in ex_dp.get_blocks() if el.type == "ic"] + mems = [el for el in ex_dp.get_blocks() if el.type == "mem"] + srams = [el for el in ex_dp.get_blocks() if el.type == "sram"] + drams = [el for el in ex_dp.get_blocks() if el.type == "dram"] + pes = [el for el in ex_dp.get_blocks() if el.type == "pe"] + ips = [el for el in ex_dp.get_blocks() if el.subtype == "ip"] + gpps = [el for el in ex_dp.get_blocks() if el.subtype == "gpp"] + + # per block + # for PEs + if block_of_interest.subtype == "gpp": + number_of_task_on_block = len(block_of_interest.get_tasks_of_block()) + move_to_apply.design_space_size["hardening"] += number_of_task_on_block + 1# +1 for swap, the rest is for split_swap + move_to_apply.design_space_size["pe_allocation"] += (number_of_task_on_block + 1) # +1 is for split, the rest os for split_swap + elif block_of_interest.subtype == "ip": + move_to_apply.design_space_size["softening"] += 1 + + # for all + for mode in ["frequency_modulation", "bus_width_modulation", "loop_iteration_modulation", "allocation"]: + if not block_of_interest.type =="pe": + if mode == "loop_iteration_modulation": + continue + value = self.dh.get_all_compatible_blocks_of_certain_char(ex_dp, block_of_interest, + selected_metric, selected_dir, [task], mode) + if mode in ["bus_width_modulation","loop_iteration_modulation"]: + move_to_apply.design_space_size[mode] += len(value) + else: + move_to_apply.design_space_size[block_of_interest.type + "_"+ mode] += len(value) + + + for block_type in ["pe", "mem", "ic"]: + if block_type == block_of_interest.type: + move_to_apply.design_space_size[block_type +"_"+"mapping"] += (len(equal_imm_blocks_present_for_migration) - 1) + else: + move_to_apply.design_space_size[block_type +"_"+"mapping"] += 0 + + can_improve_routing = self.can_improve_routing(ex_dp, sim_dp, block_of_interest, task) + if can_improve_routing: + move_to_apply.design_space_size["routing"] += (len(buses) - 1) + move_to_apply.design_space_size["transfer"] += (len(buses)-1) + move_to_apply.design_space_size["identity"] += 1 + + # pick which transformation to apply + # Variables: + # hot_blck_synced: the block bottleneck + # selected_metric: metric to focus on + # selected_krnl: the kernel to focus on + # ------------------------------ + def select_transformation(self, ex_dp, sim_dp, hot_blck_synced, selected_metric, selected_krnl, sorted_metric_dir): + feasible_transformations = self.get_feasible_transformations(ex_dp, sim_dp, hot_blck_synced, selected_metric, + selected_krnl, sorted_metric_dir) + if config.print_info_regularly: + print(list(feasible_transformations)) + random.seed(datetime.now().microsecond) + # pick randomly at the moment. + # TODO: possibly can do better + transformation = random.choice(list(feasible_transformations)) + + #if len(hot_blck_synced.get_tasks_of_block_by_dir("write")) > 1: + # transformation = "split_swap" + #else: + # transformation = "swap" + if transformation == "migrate": + batch_mode = "single" + transformation_sub_name = "irrelevant" + elif transformation == "split": + # see if any task can run in parallel + batch_mode = "batch" + transformation_sub_name = "irrelevant" + elif transformation == "split_swap": + batch_mode = "single" + transformation_sub_name = "irrelevant" + elif transformation == "transfer": + batch_mode = "irrelevant" + transformation_sub_name = "locality_improvement" + elif transformation == "routing": + batch_mode = "irrelevant" + transformation_sub_name = "routing_improvement" + else: + transformation_sub_name = "irrelevant" + batch_mode = "irrelevant" + + + return transformation, transformation_sub_name, batch_mode, len(list(feasible_transformations)) + + # calculate the cost impact of a kernel improvement + def get_swap_improvement_cost(self, sim_dp, kernels, selected_metric, dir): + def get_subtype_for_cost(block): + if block.type == "pe" and block.subtype == "ip": + return "ip" + if block.type == "pe" and block.subtype == "gpp": + if "A53" in block.instance_name or "ARM" in block.instance_name: + return "arm" + if "G3" in block.instance_name: + return "dsp" + else: + return block.type + + # Figure out whether there is a mapping that improves kernels performance + def no_swap_improvement_possible(sim_dp, selected_metric, metric_dir, krnl): + hot_block = sim_dp.get_dp_stats().get_hot_block_of_krnel(krnl.get_task_name(), selected_metric) + imm_block = self.dh.get_immediate_block_multi_metric(hot_block, metric_dir, [krnl.get_task()]) + blah = hot_block.get_generic_instance_name() + blah2 = imm_block.get_generic_instance_name() + return hot_block.get_generic_instance_name() == imm_block.get_generic_instance_name() + + + # find the cost of improvement by comparing the current and accelerated design (for the kernel) + kernel_improvement_cost = {} + kernel_name_improvement_cost = {} + for krnel in kernels: + hot_block = sim_dp.get_dp_stats().get_hot_block_of_krnel(krnel.get_task_name(), selected_metric) + hot_block_subtype = get_subtype_for_cost(hot_block) + current_cost = self.database.db_input.porting_effort[hot_block_subtype] + #if hot_block_subtype == "ip": + # print("what") + imm_block = self.dh.get_immediate_block_multi_metric(hot_block,selected_metric, metric_dir,[krnel.get_task()]) + imm_block_subtype = get_subtype_for_cost(imm_block) + imm_block_cost = self.database.db_input.porting_effort[imm_block_subtype] + improvement_cost = (imm_block_cost - current_cost) + kernel_improvement_cost[krnel] = improvement_cost + + # calcualte inverse so lower means worse + max_val = max(kernel_improvement_cost.values()) # multiply by + kernel_improvement_cost_inverse = {} + for k, v in kernel_improvement_cost.items(): + kernel_improvement_cost_inverse[k] = max_val - kernel_improvement_cost[k] + + # get sum and normalize + sum_ = sum(list(kernel_improvement_cost_inverse.values())) + for k, v in kernel_improvement_cost_inverse.items(): + # normalize + if not (sum_ == 0): + kernel_improvement_cost_inverse[k] = kernel_improvement_cost_inverse[k]/sum_ + kernel_improvement_cost_inverse[k] = max(kernel_improvement_cost_inverse[k], .0000001) + if no_swap_improvement_possible(sim_dp, selected_metric, dir, k): + kernel_improvement_cost_inverse[k] = .0000001 + kernel_name_improvement_cost[k.get_task_name()] = kernel_improvement_cost_inverse[k] + + return kernel_improvement_cost_inverse + + def get_identity_cost(self): + return self.database.db_input.porting_effort["ip"] + + + # calculate the cost impact of a kernel improvement + def get_swap_cost(self, sim_dp, krnl, selected_metric, sorted_metric_dir): + def get_subtype_for_cost(block): + if block.type == "pe" and block.subtype == "ip": + return "ip" + if block.type == "pe" and block.subtype == "gpp": + if "A53" in block.instance_name or "ARM" in block.instance_name: + return "arm" + if "G3" in block.instance_name: + return "dsp" + else: + return block.type + + hot_block = sim_dp.get_dp_stats().get_hot_block_of_krnel(krnl.get_task_name(), selected_metric) + hot_block_subtype = get_subtype_for_cost(hot_block) + current_cost = self.database.db_input.porting_effort[hot_block_subtype] + imm_block = self.dh.get_immediate_block_multi_metric(hot_block,selected_metric, sorted_metric_dir,[krnl.get_task()]) + imm_block_subtype = get_subtype_for_cost(imm_block) + imm_block_cost = self.database.db_input.porting_effort[imm_block_subtype] + improvement_cost = (imm_block_cost - current_cost) + return improvement_cost + + def get_migrate_cost(self): + return 0 + + def get_transfer_cost(self): + return 0 + + def get_routing_cost(self): + return 0 + + def get_split_cost(self): + return 1 + + def get_migration_split_cost(self, transformation): + if transformation == "migrate": + return self.get_migrate_cost() + elif transformation == "split": + return self.get_split_cost() + else: + print("this transformation" + transformation + " is not supported for cost calculation") + exit(0) + + # how much does it cost to improve the kernel for different transformations + def get_krnl_improvement_cost(self, ex_dp, sim_dp, krnls, selected_metric, move_sorted_metric_dir): + # whether you can apply the transformation for the krnel's block + def get_transformation_cost(sim_dp, selected_metric, move_sorted_metric_dir, krnl, transformation): + if transformation == "swap": + cost = self.get_swap_cost(sim_dp, krnl, selected_metric, move_sorted_metric_dir) + elif transformation in ["split", "migrate"]: + cost = self.get_migration_split_cost(transformation) + elif transformation in ["split_swap"]: + cost = self.get_migration_split_cost("split") + cost += self.get_swap_cost(sim_dp, krnl, selected_metric, move_sorted_metric_dir) + elif transformation in ["identity"]: + cost = self.get_identity_cost() + elif transformation in ["transfer"]: + cost = self.get_transfer_cost() + elif transformation in ["routing"]: + cost = self.get_routing_cost() + if cost == 0: + cost = self.min_cost_to_consider + return cost + + krnl_improvement_cost = {} + + # iterate through the kernels, find their feasible transformations and + # find cost + for krnl in krnls: + hot_block = sim_dp.get_dp_stats().get_hot_block_of_krnel(krnl.get_task_name(), selected_metric) + imm_block = self.dh.get_immediate_block_multi_metric(hot_block, selected_metric, move_sorted_metric_dir, [krnl.get_task()]) + hot_blck_synced = self.dh.find_cores_hot_kernel_blck_bottlneck(ex_dp, hot_block) + feasible_trans = self.get_feasible_transformations(ex_dp, sim_dp, hot_blck_synced, selected_metric, + krnl,move_sorted_metric_dir) + for trans in feasible_trans: + cost = get_transformation_cost(sim_dp, selected_metric, move_sorted_metric_dir, krnl, trans) + krnl_improvement_cost[(krnl, trans)] = cost + return krnl_improvement_cost + + # select a metric to improve on + def select_metric(self, sim_dp): + # prioritize metrics based on their distance contribution to goal + metric_prob_dict = {} # (metric:priority value) each value is in [0 ,1] interval + for metric in config.budgetted_metrics: + metric_prob_dict[metric] = sim_dp.dp_stats.dist_to_goal_per_metric(metric, config.metric_sel_dis_mode)/\ + sim_dp.dp_stats.dist_to_goal(["power", "area", "latency"], + config.metric_sel_dis_mode) + + # sort the metric based on distance (and whether the sort is probabilistic or exact). + # probabilistic sorting, first sort exactly, then use the exact value as a probability of selection + metric_prob_dict_sorted = {k: v for k, v in sorted(metric_prob_dict.items(), key=lambda item: item[1])} + if config.move_metric_ranking_mode== "exact": + selected_metric = list(metric_prob_dict_sorted.keys())[len(metric_prob_dict_sorted.keys()) -1] + else: + selected_metric = self.pick_from_prob_dict(metric_prob_dict_sorted) + + sorted_low_to_high_metric_dir = {} + for metric, prob in metric_prob_dict_sorted.items(): + move_dir = 1 # try to increase the metric value + if not sim_dp.dp_stats.fits_budget_for_metric_for_SOC(metric, 1): + move_dir = -1 # try to reduce the metric value + sorted_low_to_high_metric_dir[metric] = move_dir + + # Delete later. for now for validation + #selected_metric = "latency" + #sorted_low_to_high_metric_dir= {'area':1, 'power':-1, 'latency':-1} + #metric_prob_dict_sorted = {'area':.1, 'power':.1, 'latency':.8} + + return selected_metric, metric_prob_dict_sorted, sorted_low_to_high_metric_dir + + # select direction for the move + def select_dir(self, sim_dp, metric): + move_dir = 1 # try to increase the metric value + if not sim_dp.dp_stats.fits_budget_for_metric_for_SOC(metric, 1): + move_dir = -1 # try to reduce the metric value + return move_dir + + def filter_in_kernels_meeting_budget(self, selected_metric, sim_dp): + krnls = sim_dp.get_dp_stats().get_kernels() + + # filter the kernels whose workload already met the budget + workload_tasks = sim_dp.database.db_input.workload_tasks + task_workload = sim_dp.database.db_input.task_workload + workloads_to_consider = [] + for workload in workload_tasks.keys(): + if sim_dp.dp_stats.workload_fits_budget(workload, 1): + continue + workloads_to_consider.append(workload) + + krnls_to_consider = [] + for krnl in krnls: + if task_workload[krnl.get_task_name()] in workloads_to_consider and not krnl.get_task().is_task_dummy(): + krnls_to_consider.append(krnl) + + return krnls_to_consider + + # get each kernels_contribution to the metric of interest + def get_kernels_s_contribution(self, selected_metric, sim_dp): + krnl_prob_dict = {} # (kernel, metric_value) + + + #krnls = sim_dp.get_dp_stats().get_kernels() + # filter it kernels whose workload meet the budget + krnls = self.filter_in_kernels_meeting_budget(selected_metric, sim_dp) + if krnls == []: # the design meets the budget, hence all kernels can be improved for cost improvement + krnls = sim_dp.get_dp_stats().get_kernels() + + metric_total = sum([krnl.stats.get_metric(selected_metric) for krnl in krnls]) + # sort kernels based on their contribution to the metric of interest + for krnl in krnls: + krnl_prob_dict[krnl] = krnl.stats.get_metric(selected_metric)/metric_total + + if not "bottleneck" in self.move_s_krnel_selection: + for krnl in krnls: + krnl_prob_dict[krnl] = 1 + return krnl_prob_dict + + # get each_kernels_improvement_ease (ease = 1/cost) + def get_kernels_s_improvement_ease(self, ex_dp, sim_dp, selected_metric, move_sorted_metric_dir): + krnls = sim_dp.get_dp_stats().get_kernels() + krnl_improvement_ease = {} + if not "improvement_ease" in self.move_s_krnel_selection: + for krnl in krnls: + krnl_improvement_ease[krnl] = 1 + else: + krnl_trans_improvement_cost = self.get_krnl_improvement_cost(ex_dp, sim_dp, krnls, selected_metric, move_sorted_metric_dir) + # normalize + # normalized and reverse (we need to reverse, so higher cost is worse, i.e., smaller) + krnl_trans_improvement_ease = {} + for krnl_trans, cost in krnl_trans_improvement_cost.items(): + krnl_trans_improvement_ease[krnl_trans] = 1 / (cost) + max_ease = max(krnl_trans_improvement_ease.values()) + for krnl_trans, ease in krnl_trans_improvement_ease.items(): + krnl_trans_improvement_ease[krnl_trans] = ease / max_ease + + for krnl in krnls: + krnl_improvement_ease[krnl] = 0 + + for krnl_trans, ease in krnl_trans_improvement_ease.items(): + krnl, trans = krnl_trans + krnl_improvement_ease[krnl] = max(ease, krnl_improvement_ease[krnl]) + + + return krnl_improvement_ease + + # select the kernel for the move + def select_kernel(self, ex_dp, sim_dp, selected_metric, move_sorted_metric_dir): + + # get each kernel's contributions + krnl_contribution_dict = self.get_kernels_s_contribution(selected_metric, sim_dp) + # get each kernel's improvement cost + krnl_improvement_ease = self.get_kernels_s_improvement_ease(ex_dp, sim_dp, selected_metric, move_sorted_metric_dir) + + + + + # combine the selections methods + # multiply the probabilities for a more complex metric + krnl_prob_dict = {} + for krnl in krnl_contribution_dict.keys(): + krnl_prob_dict[krnl] = krnl_contribution_dict[krnl] * krnl_improvement_ease[krnl] + + # give zero probablity to the krnls that you filtered out + for krnl in sim_dp.get_dp_stats().get_kernels(): + if krnl not in krnl_prob_dict.keys(): + krnl_prob_dict[krnl] = 0 + # sort + #krnl_prob_dict_sorted = {k: v for k, v in sorted(krnl_prob_dict.items(), key=lambda item: item[1])} + krnl_prob_dict_sorted = sorted(krnl_prob_dict.items(), key=lambda item: item[1], reverse=True) + + # get the worse kernel + if config.move_krnel_ranking_mode == "exact": # for area to allow us pick scenarios that are not necessarily the worst + #selected_krnl = list(krnl_prob_dict_sorted.keys())[ + # len(krnl_prob_dict_sorted.keys()) - 1 - self.krnel_rnk_to_consider] + for krnl, prob in krnl_prob_dict_sorted: + if krnl.get_task_name() in self.krnels_not_to_consider: + continue + selected_krnl = krnl + break + else: + selected_krnl = self.pick_from_prob_dict(krnl_prob_dict_sorted) + + if config.transformation_selection_mode == "random": + krnls = sim_dp.get_dp_stats().get_kernels() + random.seed(datetime.now().microsecond) + selected_krnl = random.choice(krnls) + + return selected_krnl, krnl_prob_dict, krnl_prob_dict_sorted + + # select blocks for the move + def select_block(self, sim_dp, ex_dp, selected_krnl, selected_metric): + # get the hot block for the kernel. Hot means the most contributing block for the kernel/metric of interest + hot_blck = sim_dp.get_dp_stats().get_hot_block_of_krnel(selected_krnl.get_task_name(), selected_metric) + + # randomly pick one + if config.transformation_selection_mode =="random": + random.seed(datetime.now().microsecond) + hot_blck = any_block = random.choice(ex_dp.get_hardware_graph().get_blocks()) # this is just dummmy to prevent breaking the plotting + + # hot_blck_synced is the same block but ensured that the block instance + # is chosen from ex instead of sim, so it can be modified + hot_blck_synced = self.dh.find_cores_hot_kernel_blck_bottlneck(ex_dp, hot_blck) + block_prob_dict = sim_dp.get_dp_stats().get_hot_block_of_krnel_sorted(selected_krnl.get_task_name(), selected_metric) + return hot_blck_synced, block_prob_dict + + def select_block_without_sync(self, sim_dp, selected_krnl, selected_metric): + # get the hot block for the kernel. Hot means the most contributing block for the kernel/metric of interest + hot_blck = sim_dp.get_dp_stats().get_hot_block_of_krnel(selected_krnl.get_task_name(), selected_metric) + # hot_blck_synced is the same block but ensured that the block instance + # is chosen from ex instead of sim, so it can be modified + block_prob_dict = sim_dp.get_dp_stats().get_hot_block_of_krnel_sorted(selected_krnl.get_task_name(), selected_metric) + return hot_blck, block_prob_dict + + def change_read_task_to_write_if_necessary(self, ex_dp, sim_dp, move_to_apply, selected_krnl): + tasks_synced = [task__ for task__ in move_to_apply.get_block_ref().get_tasks_of_block() if + task__.name == selected_krnl.get_task_name()] + if len(tasks_synced) == 0: # this condition happens when we have a read task and we have unloaded reads + krnl_s_tsk = ex_dp.get_hardware_graph().get_task_graph().get_task_by_name( + move_to_apply.get_kernel_ref().get_task_name()) + parents_s_task = [el.get_name() for el in + ex_dp.get_hardware_graph().get_task_graph().get_task_s_parents(krnl_s_tsk)] + tasks_on_block = [el.get_name() for el in move_to_apply.get_block_ref().get_tasks_of_block()] + for parent_task in parents_s_task: + if parent_task in tasks_on_block: + parents_task_obj = ex_dp.get_hardware_graph().get_task_graph().get_task_by_name(parent_task) + krnl = sim_dp.get_kernel_by_task_name(parents_task_obj) + move_to_apply.set_krnel_ref(krnl) + return + + def find_block_with_sharing_tasks(self, ex_dp, selected_block, selected_krnl): + succeeded = False + # not covering ic at the moment + if selected_block.type =="ic": + return succeeded, "_" + + # get task of block + cur_block_tasks = [el.get_name() for el in selected_block.get_tasks_of_block()] + krnl_task = selected_krnl.get_task().get_name() + # get other blocks in the system + all_blocks = ex_dp.get_hardware_graph().get_blocks() + all_blocks_minus_src_block = list(set(all_blocks) - set([selected_block])) + assert(len(all_blocks) == len(all_blocks_minus_src_block) +1), "all_blocks must have one more block in it" + + if selected_block.type == "pe": + blocks_with_sharing_task_type = "mem" + elif selected_block.type == "mem": + blocks_with_sharing_task_type = "pe" + + blocks_to_look_at = [blck for blck in all_blocks_minus_src_block if blck.type == blocks_with_sharing_task_type] + + # iterate through ic neighs (of oppotie type, i.e., for mem, look for pe, and for pe look for mem), + # and look for shared tasks. If there is no shared task, we should move the block somewhere where there is + # sort the neighbours based on the number of sharings. + blocks_sorted_based_on_sharing = sorted(blocks_to_look_at, key=lambda blck: len(list(set(cur_block_tasks) - set([el.get_name() for el in blck.get_tasks_of_block()])))) + for block_with_sharing in blocks_sorted_based_on_sharing: + block_tasks = [el.get_name() for el in block_with_sharing.get_tasks_of_block()] + if krnl_task in block_tasks: + return True, block_with_sharing + else: + return False, "_" + + + + def find_improve_routing(self, ex_dp, sim_dp, selected_block, selected_krnl_task): + if not selected_block.type == "ic": + return None, None + result = True + task_name = selected_krnl_task.get_task_name() + task_s_blocks = ex_dp.get_hardware_graph().get_blocks_of_task_by_name(task_name) + pe = [blk for blk in task_s_blocks if blk.type == "pe"][0] + mems =[blk for blk in task_s_blocks if blk.type == "mem"] + + ic_entry, ic_exit = None, None + for mem in mems: + path= sim_dp.dp.get_hardware_graph().get_path_between_two_vertecies(pe, mem) + if len(path)> 4: # more than two ICs + ic_entry = path[1] + ic_exit = path[-2] + break + + success = not(ic_entry==None) + return success, ic_exit + + def can_improve_routing(self, ex_dp, sim_dp, selected_block, selected_krnl_task): + if not selected_block.type == "ic": + return False + result = True + task_name = selected_krnl_task.get_name() + task_s_blocks = ex_dp.get_hardware_graph().get_blocks_of_task_by_name(task_name) + pe =[blk for blk in task_s_blocks if blk.type == "pe"][0] + mems =[blk for blk in task_s_blocks if blk.type == "mem"] + + for mem in mems: + path= ex_dp.get_hardware_graph().get_path_between_two_vertecies(pe, mem) + if len(path) > 4: # more than two ICs + return True + return False + + def can_improve_locality(self, ex_dp, selected_block, selected_krnl_task): + result = True + + # not covering ic at the moment + if selected_block.type =="ic": + return False + + # get task of block + cur_block_tasks = list(set([el.get_name() for el in selected_block.get_tasks_of_block()])) + # get neighbouring ic + ic = [neigh for neigh in selected_block.get_neighs() if neigh.type == "ic"][0] + if selected_block.type == "pe": + blocks_with_sharing_task_type = "mem" + elif selected_block.type == "mem": + blocks_with_sharing_task_type = "pe" + + ic_neighs = [neigh for neigh in ic.get_neighs() if neigh.type == blocks_with_sharing_task_type] + + # iterate through ic neighs (of oppotie type, i.e., for mem, look for pe, and for pe look for mem), + # and look for shared tasks. If there is no shared task, we should move the block somewhere where there is + for block in ic_neighs: + block_tasks = list(set([el.get_name() for el in block.get_tasks_of_block()])) + shared_tasks_exist = len(list(set(cur_block_tasks) - set(block_tasks))) < len(list(set(cur_block_tasks))) + if shared_tasks_exist: + result = False + break + return result + + def find_absorbing_block_tuple(self, ex_dp, sim_dp, task_name, block): + task_pe = None + for blk in ex_dp.get_hardware_graph().get_blocks(): + # only cover absorbing for PEs + if not blk.type == "pe" : + continue + blk_tasks = [el.get_name() for el in el.get_tasks_of_block()] + if task_name in blk_tasks: + task_pe = blk + break + + ic_neigh = [neigh for neigh in task_pe.get_neighs() if neigh.type == "ic"][0] + ic_neigh_neigh = [neigh for neigh in ic_neigh.get_neighs() if neigh.type == "ic"][0] + + # we don't mess with system ic + if self.is_system_ic(ex_dp, sim_dp, ic_neigh): + absorbee, absorber = None, None + else: + # if the ic didn't have any memory attached to it, return true + mem_neighs = [neigh for neigh in ic_neigh.get_neighs() if neigh.type == "mem"] + if len(mem_neighs) == 0: + absorbee, absorber = ic_neigh, ic_neigh_neigh + else: + absorbee, absorber = None, None + + return absorbee, absorber + + def can_absorb_block(self, ex_dp, sim_dp, task_name): + task_pe = None + for blk in ex_dp.get_hardware_graph().get_blocks(): + # only cover absorbing for PEs + if not blk.type == "pe" : + continue + blk_tasks = [el.get_name() for el in blk.get_tasks_of_block()] + if task_name in blk_tasks: + task_pe = blk + break + + ic_neigh = [neigh for neigh in task_pe.get_neighs() if neigh.type == "ic"][0] + # we don't mess with system ic + if self.is_system_ic(ex_dp, sim_dp, ic_neigh): + result = False + else: + # if the ic didn't have any memory attached to it, return true + mem_neighs = [neigh for neigh in ic_neigh.get_neighs() if neigh.type == "mem"] + if len(mem_neighs) == 0: + result = True + else: + result = False + + return result + + # ------------------------------ + # Functionality: + # generate a move to apply. A move consists of a metric, direction, kernel, block and transformation. + # At the moment, we target the metric that is most further from the budget. Kernel and block are chosen + # based on how much they contribute to the distance. + # + # Variables: + # des_tup: (design, simulated design) + # ------------------------------ + def sel_moves_based_on_dis(self, des_tup): + ex_dp, sim_dp = des_tup + if config.DEBUG_FIX: random.seed(0) + else: time.sleep(.00001), random.seed(datetime.now().microsecond) + + # select move components + t_0 = time.time() + selected_metric, metric_prob_dir_dict, sorted_metric_dir = self.select_metric(sim_dp) + t_1 = time.time() + move_dir = self.select_dir(sim_dp, selected_metric) + t_2 = time.time() + selected_krnl, krnl_prob_dict, krnl_prob_dir_dict_sorted = self.select_kernel(ex_dp, sim_dp, selected_metric, sorted_metric_dir) + t_3 = time.time() + selected_block, block_prob_dict = self.select_block(sim_dp, ex_dp, selected_krnl, selected_metric) + t_4 = time.time() + transformation_name,transformation_sub_name, transformation_batch_mode, total_transformation_cnt = self.select_transformation(ex_dp, sim_dp, selected_block, selected_metric, selected_krnl, sorted_metric_dir) + t_5 = time.time() + + self.set_design_space_size(des_tup[0], des_tup[1]) + + + """ + if sim_dp.dp_stats.fits_budget(1) and self.dram_feasibility_check_pass(ex_dp) and self.can_improve_locality(selected_block, selected_krnl): + transformation_sub_name = "transfer_no_prune" + transformation_name = "improve_locality" + transformation_batch_mode = "single" + selected_metric = "cost" + """ + # prepare for move + # if bus, (forgot which exception), if IP, avoid split . + """ + if sim_dp.dp_stats.fits_budget(1) and not self.dram_feasibility_check_pass(ex_dp): + transformation_name = "dram_fix" + transformation_sub_name = "dram_fix_no_prune" + transformation_batch_mode = "single" + selected_metric = "cost" + """ + if self.is_cleanup_iter(): + transformation_name = "cleanup" + transformation_sub_name = "non" + transformation_batch_mode = "single" + selected_metric = "cost" + #config.VIS_GR_PER_GEN = True + self.cleanup_ctr += 1 + #config.VIS_GR_PER_GEN = False + + # log the data for future profiling/data collection/debugging + move_to_apply = move(transformation_name, transformation_sub_name, transformation_batch_mode, move_dir, selected_metric, selected_block, selected_krnl, krnl_prob_dir_dict_sorted) + move_to_apply.set_sorted_metric_dir(sorted_metric_dir) + move_to_apply.set_logs(sim_dp.database.db_input.task_workload[selected_krnl.get_task_name()],"workload") + move_to_apply.set_logs(sim_dp.dp_stats.get_system_complex_metric("cost"), "cost") + move_to_apply.set_logs(krnl_prob_dict, "kernels") + move_to_apply.set_logs(metric_prob_dir_dict, "metrics") + move_to_apply.set_logs(block_prob_dict, "blocks") + move_to_apply.set_logs(self.krnel_rnk_to_consider, "kernel_rnk_to_consider") + move_to_apply.set_logs(sim_dp.dp_stats.dist_to_goal(["power", "area", "latency", "cost"], + config.metric_sel_dis_mode),"ref_des_dist_to_goal_all") + move_to_apply.set_logs(sim_dp.dp_stats.dist_to_goal(["power", "area", "latency"], + config.metric_sel_dis_mode),"ref_des_dist_to_goal_non_cost") + + for blck_of_interest in ex_dp.get_blocks(): + self.get_transformation_design_space_size(move_to_apply, ex_dp, sim_dp, blck_of_interest, selected_metric, sorted_metric_dir) + + + move_to_apply.set_logs(t_1 - t_0, "metric_selection_time") + move_to_apply.set_logs(t_2 - t_1, "dir_selection_time") + move_to_apply.set_logs(t_3 - t_2, "kernel_selection_time") + move_to_apply.set_logs(t_4 - t_3, "block_selection_time") + move_to_apply.set_logs(t_5 - t_4, "transformation_selection_time") + + blck_ref = move_to_apply.get_block_ref() + gc.disable() + blck_ref_cp = cPickle.loads(cPickle.dumps(blck_ref, -1)) + gc.enable() + # ------------------------ + # prepare for the move + # ------------------------ + if move_to_apply.get_transformation_name() == "identity": + pass + #move_to_apply.set_validity(False, "NoValidTransformationException") + if move_to_apply.get_transformation_name() == "swap": + self.dh.unload_read_mem(ex_dp) # unload memories + if not blck_ref.type == "ic": + self.dh.unload_buses(ex_dp) # unload buses + else: + self.dh.unload_read_buses(ex_dp) # unload buses + # get immediate superior/inferior block (based on the desired direction) + imm_block = self.dh.get_immediate_block_multi_metric(blck_ref, + move_to_apply.get_metric(), move_to_apply.get_sorted_metric_dir(), + blck_ref.get_tasks_of_block()) # immediate block either superior or + move_to_apply.set_dest_block(imm_block) + move_to_apply.set_customization_type(blck_ref, imm_block) + + move_to_apply.set_tasks(blck_ref.get_tasks_of_block()) + elif move_to_apply.get_transformation_name() in ["split_swap"]: + self.dh.unload_buses(ex_dp) # unload buses + # get immediate superior/inferior block (based on the desired direction) + succeeded,migrant = blck_ref.get_tasks_by_name(move_to_apply.get_kernel_ref().get_task_name()) + if not succeeded: + move_to_apply.set_validity(False, "NoMigrantException") + else: + imm_block = self.dh.get_immediate_block_multi_metric(blck_ref, + move_to_apply.get_metric(), move_to_apply.get_sorted_metric_dir(), + [migrant]) # immediate block either superior or + move_to_apply.set_dest_block(imm_block) + + self.dh.unload_read_mem(ex_dp) # unload memories + self.change_read_task_to_write_if_necessary(ex_dp, sim_dp, move_to_apply, selected_krnl) + migrant_tasks = self.dh.migrant_selection(ex_dp, sim_dp, blck_ref, blck_ref_cp, move_to_apply.get_kernel_ref(), + move_to_apply.get_transformation_batch()) + #migrant_tasks = list(set(move_to_apply.get_block_ref().get_tasks()) - set(migrant_tasks_)) # reverse the order to allow for swap to happen on the ref_block + move_to_apply.set_tasks(migrant_tasks) + move_to_apply.set_customization_type(blck_ref, imm_block) + elif move_to_apply.get_transformation_name() in ["split"]: + # select tasks to migrate + #self.change_read_task_to_write_if_necessary(ex_dp, sim_dp, move_to_apply, selected_krnl) + migrant_tasks = self.dh.migrant_selection(ex_dp, sim_dp, blck_ref, blck_ref_cp, move_to_apply.get_kernel_ref(), + move_to_apply.get_transformation_batch()) + + + # determine the parallelism type + parallelism_type_ = [] #with repetition + parallelism_type = [] + migrant_tasks_names = [el.get_name() for el in migrant_tasks] + for task_ in migrant_tasks_names: + parallelism_type_.append(self.get_task_parallelism_type(sim_dp, task_, selected_krnl.get_task_name())) + parallelism_type = list(set(parallelism_type_)) + + move_to_apply.set_parallelism_type(parallelism_type) + move_to_apply.set_tasks(migrant_tasks) + if len(migrant_tasks) == 0: + move_to_apply.set_validity(False, "NoParallelTaskException") + if blck_ref.subtype == "ip": # makes no sense to split the IPs, + # it can actually cause problems where + # we end up duplicating the hardware + move_to_apply.set_validity(False, "IPSplitException") + elif move_to_apply.get_transformation_name() == "migrate": + if not selected_block.type == "ic": # ic migration is not supported + # check and see if tasks exist (if not, it was a read) + imm_block_present, found_blck_to_mig_to, mig_selection_mode, parallelism_type, locality_type = self.select_block_to_migrate_to(ex_dp, + sim_dp, + blck_ref_cp, + move_to_apply.get_metric(), + move_to_apply.get_sorted_metric_dir(), + move_to_apply.get_kernel_ref()) + + + + move_to_apply.set_parallelism_type(parallelism_type) + move_to_apply.set_locality_type(locality_type) + self.dh.unload_buses(ex_dp) # unload buses + self.dh.unload_read_mem(ex_dp) # unload memories + if not imm_block_present.subtype == "ip": + self.change_read_task_to_write_if_necessary(ex_dp, sim_dp, move_to_apply, selected_krnl) + if not found_blck_to_mig_to: + move_to_apply.set_validity(False, "NoMigrantException") + imm_block_present = blck_ref + elif move_to_apply.get_kernel_ref().get_task_name() in ["souurce", "siink", "dummy_last"]: + move_to_apply.set_validity(False, "NoMigrantException") + imm_block_present = blck_ref + else: + migrant_tasks = self.dh.migrant_selection(ex_dp, sim_dp, blck_ref, blck_ref_cp, move_to_apply.get_kernel_ref(), + mig_selection_mode) + + move_to_apply.set_tasks(migrant_tasks) + move_to_apply.set_dest_block(imm_block_present) + else: + move_to_apply.set_validity(False, "ICMigrationException") + elif move_to_apply.get_transformation_name() == "dram_fix": + any_block = ex_dp.get_hardware_graph().get_blocks()[0] # this is just dummmy to prevent breaking the plotting + any_task = any_block.get_tasks_of_block()[0] + move_to_apply.set_tasks([any_task]) # this is just dummmy to prevent breaking the plotting + move_to_apply.set_dest_block(any_block) + pass + elif move_to_apply.get_transformation_name() == "transfer": + if move_to_apply.get_transformation_sub_name() == "locality_improvement": + succeeded, dest_block = self.find_block_with_sharing_tasks(ex_dp, selected_block, selected_krnl) + if succeeded: + move_to_apply.set_dest_block(dest_block) + move_to_apply.set_tasks([move_to_apply.get_kernel_ref().get_task()]) + else: + move_to_apply.set_validity(False, "TransferException") + else: + move_to_apply.set_validity(False, "TransferException") + pass + elif move_to_apply.get_transformation_name() == "routing": + if move_to_apply.get_transformation_sub_name() == "routing_improvement": + succeeded, dest_block = self.find_improve_routing(ex_dp, sim_dp, selected_block, selected_krnl) + if succeeded: + move_to_apply.set_dest_block(dest_block) + move_to_apply.set_tasks([move_to_apply.get_kernel_ref().get_task()]) + else: + move_to_apply.set_validity(False, "RoutingException") + else: + move_to_apply.set_validity(False, "RoutingException") + pass + elif move_to_apply.get_transformation_name() == "cleanup": + if self.can_absorb_block(ex_dp, sim_dp, move_to_apply.get_kernel_ref().get_task_name()): + move_to_apply.set_transformation_sub_name("absorb") + absorbee, absorber = self.find_absorbing_block_tuple(ex_dp, sim_dp, move_to_apply.get_kernel_ref().get_task_name()) + if absorber == None or absorbee == "None": + move_to_apply.set_validity(False, "NoAbsorbee(er)Exception") + else: + move_to_apply.set_ref_block(absorbee) + move_to_apply.set_dest_block(absorber) + else: + move_to_apply.set_validity(False, "CostPairingException") + self.dh.unload_buses(ex_dp) # unload buses + self.dh.unload_read_mem(ex_dp) # unload memories + task_1, block_task_1, task_2, block_task_2 = self.find_task_with_similar_mappable_ips(des_tup) + # we also randomize + if not (task_1 is None) and (random.choice(np.arange(0,1,.1))>.5): + move_to_apply.set_ref_block(block_task_1) + migrant_tasks = [task_1] + imm_block_present = block_task_2 + move_to_apply.set_tasks(migrant_tasks) + move_to_apply.set_dest_block(imm_block_present) + else: + pair = self.gen_block_match_cleanup_move(des_tup) + if len(pair) == 0: + move_to_apply.set_validity(False, "CostPairingException") + else: + ref_block = pair[0] + if not ref_block.type == "ic": # ic migration is not supported + move_to_apply.set_ref_block(ref_block) + migrant_tasks = ref_block.get_tasks_of_block() + imm_block_present = pair[1] + move_to_apply.set_tasks(migrant_tasks) + move_to_apply.set_dest_block(imm_block_present) + + + move_to_apply.set_breadth_depth(self.SA_current_breadth, self.SA_current_depth, self.SA_current_mini_breadth) # set depth and breadth (for debugging/ plotting) + return move_to_apply, total_transformation_cnt + + # ------------------------------ + # Functionality: + # How to choose the move. + # Variables: + # des_tup: (design, simulated design) + # ------------------------------ + def sel_moves(self, des_tup, mode="dist_rank"): # TODO: add mode + if mode == "dist_rank": # rank and choose probabilistically based on distance + return self.sel_moves_based_on_dis(des_tup) + else: + print("mode" + mode + " is not supported") + exit(0) + + # ------------------------------ + # Functionality: + # Calculate possible neighbours, though not randomly. + # des_tup: design tuple. Contains a design tuple (ex_dp, sim_dp). ex_dp: design to find neighbours for. + # sim_dp: simulated ex_dp. + # ------------------------------ + def gen_some_neighs_orchestrated(self, des_tup): + all_possible_moves = config.navigation_moves + ctr = 0 + kernel_pos_to_hndl = self.hot_krnl_pos # for now, but change it + + # generate neighbours until you hit the threshold + while(ctr < self.num_neighs_to_try): + ex_dp, sim_dp = des_tup + # Copy to avoid modifying the current designs. + new_ex_dp_1 = copy.deepcopy(ex_dp) + new_sim_dp_1 = copy.deepcopy(sim_dp) + new_ex_dp = copy.deepcopy(new_ex_dp_1) + new_sim_dp = copy.deepcopy(new_sim_dp_1) + + # apply the move + yield self.dh.apply_move(new_ex_dp, new_sim_dp, all_possible_moves[ctr%len(all_possible_moves)], kernel_pos_to_hndl) + ctr += 1 + return 0 + + def simulated_annealing_energy(self, sim_dp_stats): + return () + + def find_best_design(self,sim_stat_ex_dp_dict, sim_dp_stat_ann_delta_energy_dict, sim_dp_stat_ann_delta_energy_dict_all_metrics, best_sim_dp_stat_so_far, best_ex_dp_so_far): + if config.heuristic_type == "SA" or config.heuristic_type == "moos": + return self.find_best_design_SA(sim_stat_ex_dp_dict, sim_dp_stat_ann_delta_energy_dict, sim_dp_stat_ann_delta_energy_dict_all_metrics, best_sim_dp_stat_so_far, best_ex_dp_so_far) + else: + return self.find_best_design_others(sim_stat_ex_dp_dict, sim_dp_stat_ann_delta_energy_dict, sim_dp_stat_ann_delta_energy_dict_all_metrics, best_sim_dp_stat_so_far, best_ex_dp_so_far) + + + # find the best design from a list + def find_best_design_SA(self, sim_stat_ex_dp_dict, sim_dp_stat_ann_delta_energy_dict, sim_dp_stat_ann_delta_energy_dict_all_metrics, best_sim_dp_stat_so_far, best_ex_dp_so_far): + # for all metrics, we only return true if there is an improvement, + # it does not make sense to look at block equality (as energy won't be zero in cases that there is a difficult trade off) + if config.sel_next_dp == "all_metrics": + sorted_sim_dp_stat_ann_delta_energy_dict_all_metrics = sorted(sim_dp_stat_ann_delta_energy_dict_all_metrics.items(), + key=lambda x: x[1]) + + if sorted_sim_dp_stat_ann_delta_energy_dict_all_metrics[0][1] < -.0001: # a very small number + # if a better design (than the best exist), return + return sorted_sim_dp_stat_ann_delta_energy_dict_all_metrics[0], True + else: + return sorted_sim_dp_stat_ann_delta_energy_dict_all_metrics[0], False + + + # two blocks are equal if they have the same generic instance name + # and have been scheduled the same tasks + def blocks_are_equal(block_1, block_2): + if not selected_block.get_generic_instance_name() == best_sim_selected_block.get_generic_instance_name(): + return False + elif selected_block.get_generic_instance_name() == best_sim_selected_block.get_generic_instance_name(): + # make sure tasks are not the same. + # this is to avoid scenarios where a block is improved (but it's generic name) equal to the + # next block bottleneck. Here we make sure tasks are different + block_1_tasks = [tsk.name for tsk in block_1.get_tasks_of_block()] + block_2_tasks = [tsk.name for tsk in block_1.get_tasks_of_block()] + task_diff = list(set(block_1_tasks) - set(block_2_tasks)) + return len(task_diff) == 0 + + # get the best_sim info + sim_dp = best_sim_dp_stat_so_far.dp + ex_dp = best_ex_dp_so_far + + best_sim_selected_metric, metric_prob_dict, best_sorted_metric_dir = self.select_metric(sim_dp) + best_sim_move_dir = self.select_dir(sim_dp, best_sim_selected_metric) + best_sim_selected_krnl, krnl_prob_dict = self.select_kernel(ex_dp, sim_dp, best_sim_selected_metric, best_sorted_metric_dir) + best_sim_selected_block, block_prob_dict = self.select_block_without_sync(sim_dp, best_sim_selected_krnl, best_sim_selected_metric) + + # sort the design base on distance + sorted_sim_dp_stat_ann_delta_energy_dict = sorted(sim_dp_stat_ann_delta_energy_dict.items(), key=lambda x: x[1]) + + #best_neighbour_stat, best_neighbour_delta_energy = sorted_sim_dp_stat_ann_delta_energy_dict[0] # here we can be smarter + if sorted_sim_dp_stat_ann_delta_energy_dict[0][1] < 0: + # if a better design (than the best exist), return + return sorted_sim_dp_stat_ann_delta_energy_dict[0], True + elif sorted_sim_dp_stat_ann_delta_energy_dict[0][1] == 0: + # if no better design + if len(sorted_sim_dp_stat_ann_delta_energy_dict[0]) == 1: + # if no better design (only one design means that our original design is the one) + return sorted_sim_dp_stat_ann_delta_energy_dict[0], False + else: + # filter out the designs which hasn't seen a distance improvement + sim_dp_to_select_from = [] + for sim_dp_stat, energy in sorted_sim_dp_stat_ann_delta_energy_dict: + #sim_dp_to_select_from.append((sim_dp_stat, energy)) + + designs_to_consider = [] + sim_dp = sim_dp_stat.dp + ex_dp = sim_stat_ex_dp_dict[sim_dp_stat] + selected_metric, metric_prob_dict, sorted_metric_dir = self.select_metric(sim_dp) + move_dir = self.select_dir(sim_dp, selected_metric) + selected_krnl, krnl_prob_dict = self.select_kernel(ex_dp, sim_dp, selected_metric, sorted_metric_dir) + selected_block, block_prob_dict = self.select_block_without_sync(sim_dp, selected_krnl, + selected_metric) + if energy > 0: + designs_to_consider.append((sim_dp_stat, energy)) + return sim_dp_to_select_from[0], False + elif not selected_krnl.get_task_name() == best_sim_selected_krnl.get_task_name(): + designs_to_consider.append((sim_dp_stat, energy)) + return designs_to_consider[0], True + elif not blocks_are_equal(selected_block, best_sim_selected_block): + #elif not selected_block.get_generic_instance_name() == best_sim_selected_block.get_generic_instance_name(): + designs_to_consider.append((sim_dp_stat, energy)) + return designs_to_consider[0], True + + return sim_dp_to_select_from[0], False + + + # find the best design from a list + def find_best_design_others(self, sim_stat_ex_dp_dict, sim_dp_stat_ann_delta_energy_dict, sim_dp_stat_ann_delta_energy_dict_all_metrics, best_sim_dp_stat_so_far, best_ex_dp_so_far): + # for all metrics, we only return true if there is an improvement, + # it does not make sense to look at block equality (as energy won't be zero in cases that there is a difficult trade off) + if config.sel_next_dp == "all_metrics": + sorted_sim_dp_stat_ann_delta_energy_dict_all_metrics = sorted(sim_dp_stat_ann_delta_energy_dict_all_metrics.items(), + key=lambda x: x[1]) + + if sorted_sim_dp_stat_ann_delta_energy_dict_all_metrics[0][1] < -.0001: # a very small number + # if a better design (than the best exist), return + return sorted_sim_dp_stat_ann_delta_energy_dict_all_metrics[0], True + else: + return sorted_sim_dp_stat_ann_delta_energy_dict_all_metrics[0], False + + + # two blocks are equal if they have the same generic instance name + # and have been scheduled the same tasks + def blocks_are_equal(block_1, block_2): + if not selected_block.get_generic_instance_name() == best_sim_selected_block.get_generic_instance_name(): + return False + elif selected_block.get_generic_instance_name() == best_sim_selected_block.get_generic_instance_name(): + # make sure tasks are not the same. + # this is to avoid scenarios where a block is improved (but it's generic name) equal to the + # next block bottleneck. Here we make sure tasks are different + block_1_tasks = [tsk.name for tsk in block_1.get_tasks_of_block()] + block_2_tasks = [tsk.name for tsk in block_1.get_tasks_of_block()] + task_diff = list(set(block_1_tasks) - set(block_2_tasks)) + return len(task_diff) == 0 + + # get the best_sim info + sim_dp = best_sim_dp_stat_so_far.dp + ex_dp = best_ex_dp_so_far + + best_sim_selected_metric, metric_prob_dict, best_sorted_metric_dir = self.select_metric(sim_dp) + best_sim_move_dir = self.select_dir(sim_dp, best_sim_selected_metric) + best_sim_selected_krnl, krnl_prob_dict = self.select_kernel(ex_dp, sim_dp, best_sim_selected_metric, best_sorted_metric_dir) + best_sim_selected_block, block_prob_dict = self.select_block_without_sync(sim_dp, best_sim_selected_krnl, best_sim_selected_metric) + + # sort the design base on distance + sorted_sim_dp_stat_ann_delta_energy_dict = sorted(sim_dp_stat_ann_delta_energy_dict.items(), key=lambda x: x[1]) + + #best_neighbour_stat, best_neighbour_delta_energy = sorted_sim_dp_stat_ann_delta_energy_dict[0] # here we can be smarter + if sorted_sim_dp_stat_ann_delta_energy_dict[0][1] < 0: + # if a better design (than the best exist), return + return sorted_sim_dp_stat_ann_delta_energy_dict[0], True + elif sorted_sim_dp_stat_ann_delta_energy_dict[0][1] == 0: + # if no better design + if len(sorted_sim_dp_stat_ann_delta_energy_dict[0]) == 1: + # if no better design (only one design means that our original design is the one) + return sorted_sim_dp_stat_ann_delta_energy_dict[0], False + else: + # filter out the designs which hasn't seen a distance improvement + sim_dp_to_select_from = [] + for sim_dp_stat, energy in sorted_sim_dp_stat_ann_delta_energy_dict: + if energy == 0: + sim_dp_to_select_from.append((sim_dp_stat, energy)) + + + designs_to_consider = [] + for sim_dp_stat, energy in sim_dp_to_select_from: + sim_dp = sim_dp_stat.dp + ex_dp = sim_stat_ex_dp_dict[sim_dp_stat] + selected_metric, metric_prob_dict, sorted_metric_dir = self.select_metric(sim_dp) + move_dir = self.select_dir(sim_dp, selected_metric) + selected_krnl, krnl_prob_dict = self.select_kernel(ex_dp, sim_dp, selected_metric, sorted_metric_dir) + selected_block, block_prob_dict = self.select_block_without_sync(sim_dp, selected_krnl, + selected_metric) + if not selected_krnl.get_task_name() == best_sim_selected_krnl.get_task_name(): + designs_to_consider.append((sim_dp_stat, energy)) + elif not blocks_are_equal(selected_block, best_sim_selected_block): + #elif not selected_block.get_generic_instance_name() == best_sim_selected_block.get_generic_instance_name(): + designs_to_consider.append((sim_dp_stat, energy)) + + if len(designs_to_consider) == 0: + return sim_dp_to_select_from[0], False + else: + return designs_to_consider[0], True # can be smarter here + + + # use simulated annealing to pick the next design(s). + # Use this link to understand simulated annealing (SA) http://www.cs.cmu.edu/afs/cs.cmu.edu/project/learn-43/lib/photoz/.g/we /glossary/anneal.html + # cur_temp: current temperature for simulated annealing + def moos_greedy_design_selection(self, sim_stat_ex_dp_dict, sim_dp_stat_list, best_ex_dp_so_far, best_sim_dp_so_far_stats, cur_temp): + def get_kernel_not_to_consider(krnels_not_to_consider, cur_best_move_applied, random_move_applied): + if cur_best_move_applied == None: # only none for the first iteration + move_applied = random_move_applied + else: + move_applied = cur_best_move_applied + if move_applied == None: + return None + + krnl_prob_dict_sorted = move_applied.krnel_prob_dict_sorted + for krnl, prob in krnl_prob_dict_sorted: + if krnl.get_task_name() in krnels_not_to_consider: + continue + return krnl.get_task_name() + + + + # get the kernel of interest using this for now to collect cached designs + best_sim_selected_metric, metric_prob_dict,best_sorted_metric_dir = self.select_metric(best_sim_dp_so_far_stats.dp) + best_sim_move_dir = self.select_dir(best_sim_dp_so_far_stats.dp, best_sim_selected_metric) + best_sim_selected_krnl, _, _= self.select_kernel(best_ex_dp_so_far, best_sim_dp_so_far_stats.dp, best_sim_selected_metric, best_sorted_metric_dir) + if best_sim_selected_krnl.get_task_name() not in self.recently_cached_designs.keys(): + self.recently_cached_designs[best_sim_selected_krnl.get_task_name()] = [] + + + # get the worse case cost for normalizing the cost when calculating the distance + best_cost = min([sim_dp.get_system_complex_metric("cost") for sim_dp in (sim_dp_stat_list + [best_sim_dp_so_far_stats])]) + self.database.set_ideal_metric_value("cost", "glass", best_cost) + + # find if any of the new designs meet the budget + new_designs_meeting_budget = [] # designs that are meeting the budget + for sim_dp_stat in sim_dp_stat_list: + if sim_dp_stat.fits_budget(1): + new_designs_meeting_budget.append(sim_dp_stat) + + new_designs_meeting_budget_with_dram = [] # designs that are meeting the budget + for sim_dp_stat in sim_dp_stat_list: + if sim_dp_stat.fits_budget(1): + ex_dp = sim_stat_ex_dp_dict[sim_dp_stat] + if ex_dp.has_system_bus(): + new_designs_meeting_budget_with_dram.append(sim_dp_stat) + dram_fixed = False + if len(new_designs_meeting_budget_with_dram) > 0 and not self.dram_feasibility_check_pass(best_ex_dp_so_far): + dram_fixed = True + + + # find each design's simulated annealing Energy difference with the best design's energy + # if any of the designs meet the budget or it's a cleanup iteration, include cost in distance calculation. + # note that when we compare, we need to use the same dist_to_goal calculation, hence + # ann_energy_best_dp_so_far needs to use the same calculation + metric_to_target , metric_prob_dict, sorted_metric_dir = self.select_metric(best_sim_dp_so_far_stats.dp) + include_cost_in_distance = best_sim_dp_so_far_stats.fits_budget(1) or (len(new_designs_meeting_budget) > 0) or self.is_cleanup_iter() or (len(new_designs_meeting_budget_with_dram)>0) + if include_cost_in_distance: + ann_energy_best_dp_so_far = best_sim_dp_so_far_stats.dist_to_goal(["cost", "latency", "power", "area"], + "eliminate") + ann_energy_best_dp_so_far_all_metrics = best_sim_dp_so_far_stats.dist_to_goal(["cost", "latency", "power", "area"], + "eliminate") + else: + ann_energy_best_dp_so_far = best_sim_dp_so_far_stats.dist_to_goal([metric_to_target], "dampen") + ann_energy_best_dp_so_far_all_metrics = best_sim_dp_so_far_stats.dist_to_goal(["power", "area", "latency"], + "dampen") + sim_dp_stat_ann_delta_energy_dict = {} + sim_dp_stat_ann_delta_energy_dict_all_metrics = {} + # deleteee the following debugging lines + if config.print_info_regularly: + print("--------%%%%%%%%%%%---------------") + print("--------%%%%%%%%%%%---------------") + print("first the best design from the previous iteration") + print(" des" + " latency:" + str(best_sim_dp_so_far_stats.get_system_complex_metric("latency"))) + print(" des" + " power:" + str( + best_sim_dp_so_far_stats.get_system_complex_metric("power"))) + print("energy :" + str(ann_energy_best_dp_so_far)) + + + + sim_dp_to_look_at = [] # which designs to look at. + # only look at the designs that meet the budget (if any), basically prioritize these designs first + if len(new_designs_meeting_budget_with_dram) > 0: + sim_dp_to_look_at = new_designs_meeting_budget_with_dram + elif len(new_designs_meeting_budget) > 0: + sim_dp_to_look_at = new_designs_meeting_budget + else: + sim_dp_to_look_at = sim_dp_stat_list + + for sim_dp_stat in sim_dp_to_look_at: + if include_cost_in_distance: + sim_dp_stat_ann_delta_energy_dict[sim_dp_stat] = sim_dp_stat.dist_to_goal( + ["cost", "latency", "power", "area"], "eliminate") - ann_energy_best_dp_so_far + sim_dp_stat_ann_delta_energy_dict_all_metrics[sim_dp_stat] = sim_dp_stat.dist_to_goal( + ["cost", "latency", "power", "area"], "eliminate") - ann_energy_best_dp_so_far_all_metrics + else: + new_design_energy = sim_dp_stat.dist_to_goal([metric_to_target], "dampen") + sim_dp_stat_ann_delta_energy_dict[sim_dp_stat] = new_design_energy - ann_energy_best_dp_so_far + new_design_energy_all_metrics = sim_dp_stat.dist_to_goal(["power", "latency", "area"], "dampen") + sim_dp_stat_ann_delta_energy_dict_all_metrics[sim_dp_stat] = new_design_energy_all_metrics - ann_energy_best_dp_so_far_all_metrics + + # changing the seed for random selection + if config.DEBUG_FIX: random.seed(0) + else: time.sleep(.00001), random.seed(datetime.now().microsecond) + + result, design_improved = self.find_best_design(sim_stat_ex_dp_dict, sim_dp_stat_ann_delta_energy_dict, + sim_dp_stat_ann_delta_energy_dict_all_metrics, best_sim_dp_so_far_stats, best_ex_dp_so_far) + best_neighbour_stat, best_neighbour_delta_energy = result + + if config.print_info_regularly: + print("all the designs tried") + for el, energy in sim_dp_stat_ann_delta_energy_dict_all_metrics.items(): + print("----------------") + sim_dp_ = el.dp + if not sim_dp_.move_applied == None: + sim_dp_.move_applied.print_info() + print("energy" + str(energy)) + print("design's latency: " + str(el.get_system_complex_metric("latency"))) + print("design's power: " + str(el.get_system_complex_metric("power"))) + print("design's area: " + str(el.get_system_complex_metric("area"))) + print("design's sub area: " + str(el.get_system_complex_area_stacked_dram())) + + # if any negative (desired move) value is detected or there is a design in the new batch + # that meet the budget, but the previous best design didn't, we have at least one improved solution + found_an_improved_solution = (len(new_designs_meeting_budget)>0 and not(best_sim_dp_so_far_stats).fits_budget(1)) or design_improved or dram_fixed + + + # for debugging. delete later + if (len(new_designs_meeting_budget)>0 and not(best_sim_dp_so_far_stats).fits_budget(1)): + print("what") + + if not found_an_improved_solution: + # avoid not improving + self.krnel_stagnation_ctr +=1 + self.des_stag_ctr += 1 + if self.krnel_stagnation_ctr > config.max_krnel_stagnation_ctr: + self.krnel_rnk_to_consider = min(self.krnel_rnk_to_consider + 1, len(best_sim_dp_so_far_stats.get_kernels()) -1) + krnel_not_to_consider = get_kernel_not_to_consider(self.krnels_not_to_consider, best_sim_dp_so_far_stats.dp.move_applied, sim_dp_to_look_at[-1].dp.move_applied) + if not krnel_not_to_consider == None: + self.krnels_not_to_consider.append(krnel_not_to_consider) + #self.krnel_stagnation_ctr = 0 + #self.recently_seen_design_ctr = 0 + elif best_neighbour_stat.dp.dp_rep.get_hardware_graph().get_SOC_design_code() in self.recently_cached_designs[best_sim_selected_krnl.get_task_name()] and False: + # avoid circular exploration + self.recently_seen_design_ctr += 1 + self.des_stag_ctr += 1 + if self.recently_seen_design_ctr > config.max_recently_seen_design_ctr: + self.krnel_rnk_to_consider = min(self.krnel_rnk_to_consider + 1, + len(best_sim_dp_so_far_stats.get_kernels()) - 1) + self.krnel_stagnation_ctr = 0 + #self.recently_seen_design_ctr = 0 + else: + self.krnel_stagnation_ctr = max(0, self.krnel_stagnation_ctr -1) + if self.krnel_stagnation_ctr == 0: + if not len(self.krnels_not_to_consider) == 0: + self.krnels_not_to_consider = self.krnels_not_to_consider[:-1] + self.krnel_rnk_to_consider = max(0, self.krnel_rnk_to_consider - 1) + self.cleanup_ctr +=1 + self.des_stag_ctr = 0 + self.recently_seen_design_ctr = 0 + + # initialize selected_sim_dp + selected_sim_dp = best_sim_dp_so_far_stats.dp + if found_an_improved_solution: + selected_sim_dp = best_neighbour_stat.dp + else: + try: + #if math.e**(best_neighbour_delta_energy/max(cur_temp, .001)) < random.choice(range(0, 1)): + # selected_sim_dp = best_neighbour_stat.dp + if random.choice(range(0, 1)) < math.e**(best_neighbour_delta_energy/max(cur_temp, .001)): + selected_sim_dp = best_neighbour_stat.dp + except: + selected_sim_dp = best_neighbour_stat.dp + + # cache the best design + if len(self.recently_cached_designs[best_sim_selected_krnl.get_task_name()]) < config.recently_cached_designs_queue_size: + self.recently_cached_designs[best_sim_selected_krnl.get_task_name()].append(selected_sim_dp.dp_rep.get_hardware_graph().get_SOC_design_code()) + else: + self.recently_cached_designs[best_sim_selected_krnl.get_task_name()][self.population_generation_cnt%config.recently_cached_designs_queue_size] = selected_sim_dp.dp_rep.get_hardware_graph().get_SOC_design_code() + + return selected_sim_dp, found_an_improved_solution + + # use simulated annealing to pick the next design(s). + # Use this link to understand simulated annealing (SA) http://www.cs.cmu.edu/afs/cs.cmu.edu/project/learn-43/lib/photoz/.g/we /glossary/anneal.html + # cur_temp: current temperature for simulated annealing + def SA_design_selection(self, sim_stat_ex_dp_dict, sim_dp_stat_list, best_ex_dp_so_far, best_sim_dp_so_far_stats, cur_temp): + + def get_kernel_not_to_consider(krnels_not_to_consider, cur_best_move_applied, random_move_applied): + if cur_best_move_applied == None: # only none for the first iteration + move_applied = random_move_applied + else: + move_applied = cur_best_move_applied + if move_applied == None: + return None + + krnl_prob_dict_sorted = move_applied.krnel_prob_dict_sorted + for krnl, prob in krnl_prob_dict_sorted: + if krnl.get_task_name() in krnels_not_to_consider: + continue + return krnl.get_task_name() + + + # get the kernel of interest using this for now to collect cached designs + best_sim_selected_metric, metric_prob_dict,best_sorted_metric_dir = self.select_metric(best_sim_dp_so_far_stats.dp) + best_sim_move_dir = self.select_dir(best_sim_dp_so_far_stats.dp, best_sim_selected_metric) + best_sim_selected_krnl, _, _= self.select_kernel(best_ex_dp_so_far, best_sim_dp_so_far_stats.dp, best_sim_selected_metric, best_sorted_metric_dir) + if best_sim_selected_krnl.get_task_name() not in self.recently_cached_designs.keys(): + self.recently_cached_designs[best_sim_selected_krnl.get_task_name()] = [] + + + # get the worse case cost for normalizing the cost when calculating the distance + best_cost = min([sim_dp.get_system_complex_metric("cost") for sim_dp in (sim_dp_stat_list + [best_sim_dp_so_far_stats])]) + self.database.set_ideal_metric_value("cost", "glass", best_cost) + + # find if any of the new designs meet the budget + new_designs_meeting_budget = [] # designs that are meeting the budget + for sim_dp_stat in sim_dp_stat_list: + if sim_dp_stat.fits_budget(1): + new_designs_meeting_budget.append(sim_dp_stat) + + new_designs_meeting_budget_with_dram = [] # designs that are meeting the budget + for sim_dp_stat in sim_dp_stat_list: + if sim_dp_stat.fits_budget(1): + ex_dp = sim_stat_ex_dp_dict[sim_dp_stat] + if ex_dp.has_system_bus(): + new_designs_meeting_budget_with_dram.append(sim_dp_stat) + dram_fixed = False + if len(new_designs_meeting_budget_with_dram) > 0 and not self.dram_feasibility_check_pass(best_ex_dp_so_far): + dram_fixed = True + + + # find each design's simulated annealing Energy difference with the best design's energy + # if any of the designs meet the budget or it's a cleanup iteration, include cost in distance calculation. + # note that when we compare, we need to use the same dist_to_goal calculation, hence + # ann_energy_best_dp_so_far needs to use the same calculation + metric_to_target , metric_prob_dict, sorted_metric_dir = self.select_metric(best_sim_dp_so_far_stats.dp) + include_cost_in_distance = best_sim_dp_so_far_stats.fits_budget(1) or (len(new_designs_meeting_budget) > 0) or self.is_cleanup_iter() or (len(new_designs_meeting_budget_with_dram)>0) + if include_cost_in_distance: + ann_energy_best_dp_so_far = best_sim_dp_so_far_stats.dist_to_goal(["cost", "latency", "power", "area"], + "eliminate") + ann_energy_best_dp_so_far_all_metrics = best_sim_dp_so_far_stats.dist_to_goal(["cost", "latency", "power", "area"], + "eliminate") + else: + ann_energy_best_dp_so_far = best_sim_dp_so_far_stats.dist_to_goal([metric_to_target], "dampen") + ann_energy_best_dp_so_far_all_metrics = best_sim_dp_so_far_stats.dist_to_goal(["power", "area", "latency"], + "dampen") + sim_dp_stat_ann_delta_energy_dict = {} + sim_dp_stat_ann_delta_energy_dict_all_metrics = {} + # deleteee the following debugging lines + print("--------%%%%%%%%%%%---------------") + print("--------%%%%%%%%%%%---------------") + print("first the best design from the previous iteration") + print(" des" + " latency:" + str(best_sim_dp_so_far_stats.get_system_complex_metric("latency"))) + print(" des" + " power:" + str( + best_sim_dp_so_far_stats.get_system_complex_metric("power"))) + print("energy :" + str(ann_energy_best_dp_so_far)) + + + + sim_dp_to_look_at = [] # which designs to look at. + # only look at the designs that meet the budget (if any), basically prioritize these designs first + if len(new_designs_meeting_budget_with_dram) > 0: + sim_dp_to_look_at = new_designs_meeting_budget_with_dram + elif len(new_designs_meeting_budget) > 0: + sim_dp_to_look_at = new_designs_meeting_budget + else: + sim_dp_to_look_at = sim_dp_stat_list + + for sim_dp_stat in sim_dp_to_look_at: + if include_cost_in_distance: + sim_dp_stat_ann_delta_energy_dict[sim_dp_stat] = sim_dp_stat.dist_to_goal( + ["cost", "latency", "power", "area"], "eliminate") - ann_energy_best_dp_so_far + sim_dp_stat_ann_delta_energy_dict_all_metrics[sim_dp_stat] = sim_dp_stat.dist_to_goal( + ["cost", "latency", "power", "area"], "eliminate") - ann_energy_best_dp_so_far_all_metrics + else: + new_design_energy = sim_dp_stat.dist_to_goal([metric_to_target], "dampen") + sim_dp_stat_ann_delta_energy_dict[sim_dp_stat] = new_design_energy - ann_energy_best_dp_so_far + new_design_energy_all_metrics = sim_dp_stat.dist_to_goal(["power", "latency", "area"], "dampen") + sim_dp_stat_ann_delta_energy_dict_all_metrics[sim_dp_stat] = new_design_energy_all_metrics - ann_energy_best_dp_so_far_all_metrics + + # changing the seed for random selection + if config.DEBUG_FIX: random.seed(0) + else: time.sleep(.00001), random.seed(datetime.now().microsecond) + + result, design_improved = self.find_best_design(sim_stat_ex_dp_dict, sim_dp_stat_ann_delta_energy_dict, + sim_dp_stat_ann_delta_energy_dict_all_metrics, best_sim_dp_so_far_stats, best_ex_dp_so_far) + best_neighbour_stat, best_neighbour_delta_energy = result + + print("all the designs tried") + for el, energy in sim_dp_stat_ann_delta_energy_dict_all_metrics.items(): + print("----------------") + sim_dp_ = el.dp + if not sim_dp_.move_applied == None: + sim_dp_.move_applied.print_info() + print("energy" + str(energy)) + print("design's latency: " + str(el.get_system_complex_metric("latency"))) + print("design's power: " + str(el.get_system_complex_metric("power"))) + print("design's area: " + str(el.get_system_complex_metric("area"))) + print("design's sub area: " + str(el.get_system_complex_area_stacked_dram())) + + # if any negative (desired move) value is detected or there is a design in the new batch + # that meet the budget, but the previous best design didn't, we have at least one improved solution + found_an_improved_solution = (len(new_designs_meeting_budget)>0 and not(best_sim_dp_so_far_stats).fits_budget(1)) or design_improved or dram_fixed + + + # for debugging. delete later + if (len(new_designs_meeting_budget)>0 and not(best_sim_dp_so_far_stats).fits_budget(1)): + print("what") + + if not found_an_improved_solution: + # avoid not improving + self.krnel_stagnation_ctr +=1 + self.des_stag_ctr += 1 + if self.krnel_stagnation_ctr > config.max_krnel_stagnation_ctr: + self.krnel_rnk_to_consider = min(self.krnel_rnk_to_consider + 1, len(best_sim_dp_so_far_stats.get_kernels()) -1) + krnel_not_to_consider = get_kernel_not_to_consider(self.krnels_not_to_consider, best_sim_dp_so_far_stats.dp.move_applied, sim_dp_to_look_at[-1].dp.move_applied) + if not krnel_not_to_consider == None: + self.krnels_not_to_consider.append(krnel_not_to_consider) + #self.krnel_stagnation_ctr = 0 + #self.recently_seen_design_ctr = 0 + elif best_neighbour_stat.dp.dp_rep.get_hardware_graph().get_SOC_design_code() in self.recently_cached_designs[best_sim_selected_krnl.get_task_name()] and False: + # avoid circular exploration + self.recently_seen_design_ctr += 1 + self.des_stag_ctr += 1 + if self.recently_seen_design_ctr > config.max_recently_seen_design_ctr: + self.krnel_rnk_to_consider = min(self.krnel_rnk_to_consider + 1, + len(best_sim_dp_so_far_stats.get_kernels()) - 1) + self.krnel_stagnation_ctr = 0 + #self.recently_seen_design_ctr = 0 + else: + self.krnel_stagnation_ctr = max(0, self.krnel_stagnation_ctr -1) + if self.krnel_stagnation_ctr == 0: + if not len(self.krnels_not_to_consider) == 0: + self.krnels_not_to_consider = self.krnels_not_to_consider[:-1] + self.krnel_rnk_to_consider = max(0, self.krnel_rnk_to_consider - 1) + self.cleanup_ctr +=1 + self.des_stag_ctr = 0 + self.recently_seen_design_ctr = 0 + + # initialize selected_sim_dp + selected_sim_dp = best_sim_dp_so_far_stats.dp + if found_an_improved_solution: + selected_sim_dp = best_neighbour_stat.dp + else: + try: + #if math.e**(best_neighbour_delta_energy/max(cur_temp, .001)) < random.choice(range(0, 1)): + # selected_sim_dp = best_neighbour_stat.dp + if random.choice(range(0, 1)) < math.e**(best_neighbour_delta_energy/max(cur_temp, .001)): + selected_sim_dp = best_neighbour_stat.dp + + except: + selected_sim_dp = best_neighbour_stat.dp + + # cache the best design + if len(self.recently_cached_designs[best_sim_selected_krnl.get_task_name()]) < config.recently_cached_designs_queue_size: + self.recently_cached_designs[best_sim_selected_krnl.get_task_name()].append(selected_sim_dp.dp_rep.get_hardware_graph().get_SOC_design_code()) + else: + self.recently_cached_designs[best_sim_selected_krnl.get_task_name()][self.population_generation_cnt%config.recently_cached_designs_queue_size] = selected_sim_dp.dp_rep.get_hardware_graph().get_SOC_design_code() + + return selected_sim_dp, found_an_improved_solution + + def find_design_scalarized_value_from_moos_perspective(self, sim, lambdas): + sim_metric_values = {} + value = [] + for metric_name in config.budgetted_metrics: + for type, id in sim.get_designs_SOCs(): + if metric_name == "latency": + metric_val = sum(list(sim.dp_stats.get_SOC_metric_value(metric_name, type, id).values())) + else: + metric_val = sim.dp_stats.get_SOC_metric_value(metric_name, type, id) + + value.append(Decimal(metric_val)*lambdas[metric_name]) + + + #return max(value) + return sum(value) + + # ------------------------------ + # Functionality: + # select the next best design (from the sorted dp) + # Variables + # ex_sim_dp_dict: example_simulate_design_point_list. List of designs to pick from. + # ------------------------------ + def sel_start_dp_moos(self, ex_sim_dp_dict, best_sim_dp_so_far, best_ex_dp_so_far, lambda_list): + # convert to stats + sim_dp_list = list(ex_sim_dp_dict.values()) + sim_dp_stat_list = [sim_dp.dp_stats for sim_dp in sim_dp_list] + sim_stat_ex_dp_dict = {} + for k, v in ex_sim_dp_dict.items(): + sim_stat_ex_dp_dict[v.dp_stats] = k + + + # find the ones that fit the expanded budget (note that budget radius shrinks) + sim_scalarized_value = {} + for ex, sim in ex_sim_dp_dict.items(): + value = self.find_design_scalarized_value_from_moos_perspective(sim, lambda_list) + sim_scalarized_value[sim] = value + + min_scalarized_value = float('inf') + min_sim = "" + for sim,value in sim_scalarized_value.items(): + if value < min_scalarized_value: + min_sim = sim + min_ex = sim_stat_ex_dp_dict[sim.dp_stats] + min_scalarized_value = value + + if min_sim == "": + print("some thing went wrong. should have at least one minimum") + return min_ex, min_sim + + + # extract the design + for key, val in ex_sim_dp_dict.items(): + key.sanity_check() + if val == selected_sim_dp: + selected_ex_dp = key + break + + # generate verification data + if found_improved_solution and config.RUN_VERIFICATION_PER_IMPROVMENT: + self.gen_verification_data(selected_sim_dp, selected_ex_dp) + return selected_ex_dp, selected_sim_dp + + + + # ------------------------------ + # Functionality: + # select the next best design (from the sorted dp) + # Variables + # ex_sim_dp_dict: example_simulate_design_point_list. List of designs to pick from. + # ------------------------------ + def sel_next_dp(self, ex_sim_dp_dict, best_sim_dp_so_far, best_ex_dp_so_far, cur_temp): + # convert to stats + sim_dp_list = list(ex_sim_dp_dict.values()) + sim_dp_stat_list = [sim_dp.dp_stats for sim_dp in sim_dp_list] + sim_stat_ex_dp_dict = {} + for k, v in ex_sim_dp_dict.items(): + sim_stat_ex_dp_dict[v.dp_stats] = k + + + # find the ones that fit the expanded budget (note that budget radius shrinks) + selected_sim_dp, found_improved_solution = self.SA_design_selection(sim_stat_ex_dp_dict, sim_dp_stat_list, + best_ex_dp_so_far, best_sim_dp_so_far.dp_stats, + cur_temp) + + self.found_any_improvement = self.found_any_improvement or found_improved_solution + + if not found_improved_solution: + selected_sim_dp = self.so_far_best_sim_dp + selected_ex_dp = self.so_far_best_ex_dp + else: + # extract the design + for key, val in ex_sim_dp_dict.items(): + key.sanity_check() + if val == selected_sim_dp: + selected_ex_dp = key + break + + # generate verification data + if found_improved_solution and config.RUN_VERIFICATION_PER_IMPROVMENT: + self.gen_verification_data(selected_sim_dp, selected_ex_dp) + return selected_ex_dp, selected_sim_dp + + # ------------------------------ + # Functionality: + # simulate one design. + # Variables + # ex_dp: example design point. Design point to simulate. + # database: hardware/software data base to simulated based off of. + # ------------------------------ + def sim_one_design(self, ex_dp, database): + if config.simulation_method == "power_knobs": + sim_dp = self.dh.convert_to_sim_des_point(ex_dp) + power_knob_sim_dp = self.dh.convert_to_sim_des_point(ex_dp) + OSA = OSASimulator(sim_dp, database, power_knob_sim_dp) + else: + sim_dp = self.dh.convert_to_sim_des_point(ex_dp) + # Simulator initialization + OSA = OSASimulator(sim_dp, database) # change + + # Does the actual simulation + t = time.time() + OSA.simulate() + sim_time = time.time() - t + + # profile info + sim_dp.set_population_generation_cnt(self.population_generation_cnt) + sim_dp.set_population_observed_number(self.population_observed_ctr) + sim_dp.set_depth_number(self.SA_current_depth) + sim_dp.set_simulation_time(sim_time) + + #print("sim time" + str(sim_time)) + #exit(0) + return sim_dp + + # ------------------------------ + # Functionality: + # Sampling from the task distribution. This is used for jitter incorporation. + # Variables: + # ex_dp: example design point. + # ------------------------------ + def generate_sample(self, ex_dp, hw_sampling): + #new_ex_dp = copy.deepcopy(ex_dp) + gc.disable() + new_ex_dp = cPickle.loads(cPickle.dumps(ex_dp, -1)) + gc.enable() + new_ex_dp.sample_hardware_graph(hw_sampling) + return new_ex_dp + + + + + def transform_to_most_inferior_design(self, ex_dp:ExDesignPoint): + new_ex_dp = cPickle.loads(cPickle.dumps(ex_dp, -1)) + move_to_try = move("swap", "swap", "irrelevant", "-1", "latency", "", "", "") + all_blocks = new_ex_dp.get_blocks() + for block in all_blocks: + self.dh.unload_read_mem(new_ex_dp) # unload memories + if not block.type == "ic": + self.dh.unload_buses(new_ex_dp) # unload buses + else: + self.dh.unload_read_buses(new_ex_dp) # unload buses + + move_to_try.set_ref_block(block) + # get immediate superior/inferior block (based on the desired direction) + most_inferior_block = self.dh.get_most_inferior_block(block, block.get_tasks_of_block()) + move_to_try.set_dest_block(most_inferior_block) + move_to_try.set_customization_type(block, most_inferior_block) + move_to_try.set_tasks(block.get_tasks_of_block()) + self.dh.unload_read_mem(new_ex_dp) # unload read memories + move_to_try.validity_check() # call after unload rad mems, because we need to check the scenarios where + # task is unloaded from the mem, but was decided to be migrated/swapped + new_ex_dp_res, succeeded = self.dh.apply_move([new_ex_dp,""], move_to_try) + #self.dh.load_tasks_to_read_mem_and_ic(new_ex_dp_res) # loading the tasks on to memory and ic + new_ex_dp_res.hardware_graph.pipe_design() + new_ex_dp_res.sanity_check() + new_ex_dp = new_ex_dp_res + #cPickle.loads(cPickle.dumps(new_ex_dp_res, -1)) + #self.dh.load_tasks_to_read_mem_and_ic(new_ex_dp) # loading the tasks on to memory and ic + + + self.dh.load_tasks_to_read_mem_and_ic(new_ex_dp) # loading the tasks on to memory and ic + return new_ex_dp + + def transform_to_most_inferior_design_before_loop_unrolling(self, ex_dp: ExDesignPoint): + new_ex_dp = cPickle.loads(cPickle.dumps(ex_dp, -1)) + move_to_try = move("swap", "swap", "irrelevant", "-1", "latency", "", "", "") + all_blocks = new_ex_dp.get_blocks() + for block in all_blocks: + self.dh.unload_read_mem(new_ex_dp) # unload memories + if not block.type == "ic": + self.dh.unload_buses(new_ex_dp) # unload buses + else: + self.dh.unload_read_buses(new_ex_dp) # unload buses + + move_to_try.set_ref_block(block) + # get immediate superior/inferior block (based on the desired direction) + most_inferior_block = self.dh.get_most_inferior_block_before_unrolling(block, block.get_tasks_of_block()) + #most_inferior_block = self.dh.get_most_inferior_block(block, block.get_tasks_of_block()) + move_to_try.set_dest_block(most_inferior_block) + move_to_try.set_customization_type(block, most_inferior_block) + move_to_try.set_tasks(block.get_tasks_of_block()) + self.dh.unload_read_mem(new_ex_dp) # unload read memories + move_to_try.validity_check() # call after unload rad mems, because we need to check the scenarios where + # task is unloaded from the mem, but was decided to be migrated/swapped + new_ex_dp_res, succeeded = self.dh.apply_move([new_ex_dp, ""], move_to_try) + # self.dh.load_tasks_to_read_mem_and_ic(new_ex_dp_res) # loading the tasks on to memory and ic + new_ex_dp_res.hardware_graph.pipe_design() + new_ex_dp_res.sanity_check() + new_ex_dp = new_ex_dp_res + # cPickle.loads(cPickle.dumps(new_ex_dp_res, -1)) + # self.dh.load_tasks_to_read_mem_and_ic(new_ex_dp) # loading the tasks on to memory and ic + + self.dh.load_tasks_to_read_mem_and_ic(new_ex_dp) # loading the tasks on to memory and ic + return new_ex_dp + + def single_out_workload(self,ex_dp, database, workload, workload_tasks): + new_ex_dp = cPickle.loads(cPickle.dumps(ex_dp, -1)) + database_ = cPickle.loads(cPickle.dumps(database, -1)) + for block in new_ex_dp.get_blocks(): + for dir in ["loop_back","write","read"]: + tasks= block.get_tasks_of_block_by_dir(dir) + for task in tasks: + if task.get_name() not in workload_tasks: + block.unload((task,dir)) + + tasks = new_ex_dp.get_tasks() + for task in tasks: + children = task.get_children()[:] + parents = task.get_parents()[:] + for child in children: + if child.get_name() not in workload_tasks: + task.remove_child(child) + for parent in parents: + if parent.get_name() not in workload_tasks: + task.remove_parent(parent) + + database_.set_workloads_last_task({workload: database_.db_input.workloads_last_task[workload]}) + return new_ex_dp, database_ + + # ------------------------------ + # Functionality: + # Evaluate the design. 1. simulate 2. collect (profile) data. + # Variables: + # ex_dp: example design point. + # database: database containing hardware/software modeled characteristics. + # ------------------------------ + def eval_design(self, ex_dp:ExDesignPoint, database): + #start = time.time() + # according to config singular runs + if config.eval_mode == "singular": + print("this mode is deprecated. just use statistical. singular is simply a special case") + exit(0) + return self.sim_one_design(ex_dp, database) # evaluation the design directly + elif config.eval_mode == "statistical": + # generate a population (geneate_sample), evaluate them and reduce to some statistical indicator + ex_dp_pop_sample = [self.generate_sample(ex_dp, database.hw_sampling) for i in range(0, self.database.hw_sampling["population_size"])] # population sample + ex_dp.get_tasks()[0].task_id_for_debugging_static += 1 + sim_dp_pop_sample = list(map(lambda ex_dp_: self.sim_one_design(ex_dp_, database), ex_dp_pop_sample)) # evaluate the population sample + + # collect profiling information + sim_dp_statistical = SimDesignPointContainer(sim_dp_pop_sample, database, config.statistical_reduction_mode) + #print("time is:" + str(time.time() -start)) + return sim_dp_statistical + else: + print("mode" + config.eval_mode + " is not defined for eval design") + + # ------------------------------ + # Functionality: + # generate verification (platform architect digestible) designs. + # Variables: + # sim_dp: simulated design point. + # ------------------------------ + def gen_verification_data(self, sim_dp_, ex_dp_): + #from data_collection.FB_private.verification_utils.PA_generation.PA_generators import * + import_ver = importlib.import_module("data_collection.FB_private.verification_utils.PA_generation.PA_generators") + # iterate till you can make a directory + while True: + date_time = datetime.now().strftime('%m-%d_%H-%M_%S') + #result_folder = os.path.join(self.result_dir, "data_per_design", + # date_time+"_"+str(self.name_ctr)) + # one for PA data collection + result_folder = os.path.join(self.result_dir+"/../../", "data_per_design", + date_time+"_"+str(self.name_ctr)) + + if not os.path.isdir(result_folder): + os.makedirs(result_folder) + collection_ctr = self.name_ctr # used to realize which results to compare + break + + ex_with_PA = [] # + pa_ver_obj = import_ver.PAVerGen() # initialize a PA generator + # make all the combinations + knobs_list, knob_order = pa_ver_obj.gen_all_PA_knob_combos(import_ver.PA_knobs_to_explore) # generate knob combinations + # for different PA designs. Since PA has extra knobs, we'd like sweep this knobs for verification purposes. + knob_ctr = 0 + # Iterate though the knob combos and generate a (PA digestible) design accordingly + for knobs in knobs_list: + result_folder_for_knob = os.path.join(result_folder, "PA_knob_ctr_"+str(knob_ctr)) + for sim_dp in sim_dp_.get_design_point_list(): + sim_dp.reset_PA_knobs() + sim_dp.update_ex_id(date_time+"_"+str(collection_ctr)+"_" + str(self.name_ctr)) + sim_dp.update_FARSI_ex_id(date_time+"_"+str(collection_ctr)) + sim_dp.update_PA_knob_ctr_id(str(knob_ctr)) + sim_dp.update_PA_knobs(knobs, knob_order, import_ver.auto_tuning_knobs) + PA_result_folder = os.path.join(result_folder_for_knob, str(sim_dp.id)) + os.makedirs(PA_result_folder) + + # dump data for bus-memory data with connection (used for bw calculation) + sim_dp.dump_mem_bus_connection_bw(PA_result_folder) # write the results into a file + # initialize and do some clean up + #vis_hardware.vis_hardware(sim_dp, config.hw_graphing_mode, PA_result_folder) + sim_dp.dp_stats.dump_stats(PA_result_folder) + if config.VIS_SIM_PROG: vis_sim.plot_sim_data(sim_dp.dp_stats, sim_dp, PA_result_folder) + block_names = [block.instance_name for block in sim_dp.hardware_graph.blocks] + vis_hardware.vis_hardware(sim_dp, config.hw_graphing_mode, PA_result_folder) + pa_obj = import_ver.PAGen(self.database, self.database.db_input.proj_name, sim_dp, PA_result_folder, config.sw_model) + + pa_obj.gen_all() # generate the PA digestible design + sim_dp.dump_props(PA_result_folder) # write the results into a file + # pickle the results for (out of run) verifications. + ex_dp_pickled_file = open(os.path.join(PA_result_folder, "ex_dp_pickled.txt"), "wb") + dill.dump(ex_dp_, ex_dp_pickled_file) + ex_dp_pickled_file.close() + + database_pickled_file = open(os.path.join(PA_result_folder, "database_pickled.txt"), "wb") + dill.dump(self.database, database_pickled_file) + database_pickled_file.close() + + sim_dp_pickled_file = open(os.path.join(PA_result_folder, "sim_dp_pickled.txt"), "wb") + dill.dump(sim_dp, sim_dp_pickled_file) + sim_dp_pickled_file.close() + self.name_ctr += 1 + knob_ctr += 1 + + + + + # ------------------------------ + # Functionality: + # generate one neighbour and evaluate it. + # Variables: + # des_tup: starting point design point tuple (design point, simulated design point) + # ------------------------------ + #@profile + def gen_neigh_and_eval(self, des_tup): + # "delete this later" + print("------ depth ------") + # generate on neighbour + move_strt_time = time.time() + ex_dp, move_to_try,total_trans_cnt = self.gen_one_neigh(des_tup) + move_end_time = time.time() + move_to_try.set_generation_time(move_end_time- move_strt_time) + + # generate a code for the design (that specifies the topology, mapping and scheduling). + # look into cache and see if this design has been seen before. If so, just use the + # cached value, other wise just use the sim from cache + design_unique_code = ex_dp.get_hardware_graph().get_SOC_design_code() # cache index + if move_to_try.get_transformation_name() == "identity" or not move_to_try.is_valid(): + # if nothing has changed, just copy the sim from before + sim_dp = des_tup[1] + elif design_unique_code not in self.cached_SOC_sim.keys(): + self.population_observed_ctr += 1 + sim_dp = self.eval_design(ex_dp, self.database) # evaluate the designs + #if config.cache_seen_designs: # this seems to be slower than just simulation, because of deepcopy + # self.cached_SOC_sim[design_unique_code] = (ex_dp, sim_dp) + else: + ex_dp = self.cached_SOC_sim[design_unique_code][0] + sim_dp = self.cached_SOC_sim[design_unique_code][1] + + # collect the moves for debugging/visualization + if config.DEBUG_MOVE: + if (self.population_generation_cnt % config.vis_reg_ctr_threshold) == 0 and self.SA_current_mini_breadth == 0: + self.move_profile.append(move_to_try) # for debugging + self.last_move = move_to_try + sim_dp.set_move_applied(move_to_try) + + # visualization and verification + if config.VIS_GR_PER_GEN: + vis_hardware.vis_hardware(sim_dp.get_dp_rep()) + if config.RUN_VERIFICATION_PER_GEN or \ + (config.RUN_VERIFICATION_PER_NEW_CONFIG and + not(sim_dp.dp.get_hardware_graph().get_SOC_design_code() in self.seen_SOC_design_codes)): + self.gen_verification_data(sim_dp, ex_dp) + self.seen_SOC_design_codes.append(sim_dp.dp.get_hardware_graph().get_SOC_design_code()) + + + if not sim_dp.move_applied == None and config.print_info_regularly: + sim_dp.move_applied.print_info() + print("design's latency: " + str(sim_dp.dp_stats.get_system_complex_metric("latency"))) + print("design's power: " + str(sim_dp.dp_stats.get_system_complex_metric("power"))) + print("design's area: " + str(sim_dp.dp_stats.get_system_complex_metric("area"))) + print("design's sub area: " + str(sim_dp.dp_stats.get_system_complex_area_stacked_dram())) + + + return (ex_dp, sim_dp), total_trans_cnt + + + def protected_gen_neigh_and_eval(self, des_tup): + ctr = 0 + while True and ctr <100: + ctr +=1 + try: + des_tup_new, possible_des_cnt = self.gen_neigh_and_eval(des_tup) + return des_tup_new, possible_des_cnt + break + except SystemExit: + print("caught an exit") + continue + except Exception as e: + print("caught an exception") + print("return too many exception or exits") + exit(0) + + # ------------------------------ + # Functionality: + # generate neighbours and evaluate them. + # neighbours are generated based on the depth and breath count determined in the config file. + # Depth means vertical, i.e., daisy chaining of the moves). Breadth means horizontal exploration. + # Variables: + # des_tup: starting point design point tuple (design point, simulated design point) + # breadth: the breadth according to which to generate designs (used for breadth wise search) + # depth: the depth according to which to generate designs (used for look ahead) + # ------------------------------ + def gen_some_neighs_and_eval(self, des_tup, breath_length, depth_length, des_tup_list): + # base case + if depth_length == 0: + return [des_tup] + #des_tup_list = [] + # iterate on breath + for i in range(0, breath_length): + self.SA_current_mini_breadth = 0 + if not(breath_length == 1): + self.SA_current_breadth += 1 + self.SA_current_depth = -1 + print("--------breadth--------") + # iterate on depth (generate one neighbour and evaluate it) + self.SA_current_depth += 1 + #des_tup_new, possible_des_cnt = self.gen_neigh_and_eval(des_tup) + des_tup_new, possible_des_cnt = self.protected_gen_neigh_and_eval(des_tup) + + #self.total_iteration_ctr += 1 + + # collect the generate design in a list and run sanity check on it + des_tup_list.append(des_tup_new) + + # do more coverage if needed + """ + for i in range(0, max(possible_des_cnt,1)-1): + self.SA_current_mini_breadth += 1 + des_tup_new_breadth, _ = self.gen_neigh_and_eval(des_tup) + des_tup_list.append(des_tup_new_breadth) + """ + # just a quick optimization, since there is not need + # to go deeper if we encounter identity. + # This is because we will keep repeating the identity at the point + if des_tup_new[1].move_applied.get_transformation_name() == "identity": + break + + self.gen_some_neighs_and_eval(des_tup_new, 1, depth_length-1, des_tup_list) + #des_tup_list.extend(self.gen_some_neighs_and_eval(des_tup_new, 1, depth_length-1)) + + # visualization and sanity checks + if config.VIS_MOVE_TRAIL: + if (self.population_generation_cnt % config.vis_reg_ctr_threshold) == 0: + best_design_sim_cpy = copy.deepcopy(self.so_far_best_sim_dp) + self.des_trail_list.append((best_design_sim_cpy, des_tup_list[-1][1])) + self.last_des_trail = (best_design_sim_cpy, des_tup_list[-1][1]) + #self.des_trail_list.append((cPickle.loads(cPickle.dumps(self.so_far_best_sim_dp, -1)),cPickle.loads(cPickle.dumps(des_tup_list[-1][1], -1)))) + #self.last_des_trail = (cPickle.loads(cPickle.dumps(self.so_far_best_sim_dp, -1)),cPickle.loads(cPickle.dumps(des_tup_list[-1][1], -1))) + + + #self.vis_move_ctr += 1 + if config.DEBUG_SANITY: des_tup[0].sanity_check() + #return des_tup_list + + # simple simulated annealing + def simple_SA(self): + # define the result dictionary + this_itr_ex_sim_dp_dict:Dict[ExDesignPoint: SimDesignPoint] = {} + this_itr_ex_sim_dp_dict[self.so_far_best_ex_dp] = self.so_far_best_sim_dp # init the res dict + + # navigate the space using depth and breath parameters + strt = time.time() + print("------------------------ itr:" + str(self.population_generation_cnt) + " ---------------------") + self.SA_current_breadth = -1 + self.SA_current_depth = -1 + + # generate some neighbouring design points and evaluate them + des_tup_list =[] + #config.SA_depth = 3*len(self.so_far_best_ex_dp.get_hardware_graph().get_blocks_by_type("mem"))+ len(self.so_far_best_ex_dp.get_hardware_graph().get_blocks_by_type("ic")) + self.gen_some_neighs_and_eval((self.so_far_best_ex_dp, self.so_far_best_sim_dp), config.SA_breadth, config.SA_depth, des_tup_list) + exploration_and_simulation_approximate_time_per_iteration = (time.time() - strt)/max(len(des_tup_list), 1) + #print("sim time + neighbour generation per design point " + str((time.time() - strt)/max(len(des_tup_list), 1))) + + # convert (outputed) list to dictionary of (ex:sim) specified above. + # Also, run sanity check on the design, making sure everything is alright + for ex_dp, sim_dp in des_tup_list: + sim_dp.add_exploration_and_simulation_approximate_time(exploration_and_simulation_approximate_time_per_iteration) + this_itr_ex_sim_dp_dict[ex_dp] = sim_dp + if config.DEBUG_SANITY: + ex_dp.sanity_check() + return this_itr_ex_sim_dp_dict + + + def convert_tuple_list_to_parsable_csv(self, list_): + result = "" + for k, v in list_: + result +=str(k) + "=" + str(v) + "___" + return result + + def convert_dictionary_to_parsable_csv_with_underline(self, dict_): + result = "" + for k, v in dict_.items(): + phase_value_dict = list(v.values())[0] + value = list(phase_value_dict.values())[0] + result +=str(k) + "=" + str(value) + "___" + return result + + def convert_dictionary_to_parsable_csv_with_semi_column(self, dict_): + result = "" + for k, v in dict_.items(): + result +=str(k) + "=" + str(v) + ";" + return result + + # ------------------------------ + # Functionality: + # Explore the initial design. Basically just simulated the initial design + # Variables + # it uses the config parameters that are used to instantiate the object. + # ------------------------------ + def explore_one_design(self): + self.so_far_best_ex_dp = self.init_ex_dp + self.init_sim_dp = self.so_far_best_sim_dp = self.eval_design(self.so_far_best_ex_dp, self.database) + this_itr_ex_sim_dp_dict = {} + this_itr_ex_sim_dp_dict[self.so_far_best_ex_dp] = self.so_far_best_sim_dp + #self.init_sim_dp = self.eval_design(self.so_far_best_ex_dp, self.database) + + # collect statistics about the design + self.log_data(this_itr_ex_sim_dp_dict) + self.collect_stats(this_itr_ex_sim_dp_dict) + + # visualize/checkpoint/PA generation + vis_hardware.vis_hardware(self.so_far_best_sim_dp.get_dp_rep()) + if config.RUN_VERIFICATION_PER_GEN or config.RUN_VERIFICATION_PER_IMPROVMENT or config.RUN_VERIFICATION_PER_NEW_CONFIG: + self.gen_verification_data(self.so_far_best_sim_dp, self.so_far_best_ex_dp) + + def get_log_data(self): + return self.log_data_list + + def write_data_log(self, log_data, reason_to_terminate, case_study, result_dir_specific, unique_number, file_name): + output_file_all = os.path.join(result_dir_specific, file_name + "_all_reults.csv") + csv_columns = list(log_data[0].keys()) + # minimal output + with open(output_file_all, 'w') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=csv_columns) + writer.writeheader() + for data in log_data: + writer.writerow(data) + + # ------------------------------ + # Functionality: + # log the data for plotting and such + # ------------------------------ + def log_data(self, this_itr_ex_sim_dp_dict): + ctr = len(self.log_data_list) + for sim_dp in this_itr_ex_sim_dp_dict.values(): + sim_dp.add_exploration_and_simulation_approximate_time(self.neighbour_selection_time/len(list(this_itr_ex_sim_dp_dict.keys()))) + ma = sim_dp.get_move_applied() # move applied + if not ma == None: + sorted_metrics = self.convert_tuple_list_to_parsable_csv( + [(el, val) for el, val in ma.sorted_metrics.items()]) + metric = ma.get_metric() + transformation_name = ma.get_transformation_name() + task_name = ma.get_kernel_ref().get_task_name() + block_type = ma.get_block_ref().type + dir = ma.get_dir() + generation_time = ma.get_generation_time() + sorted_blocks = self.convert_tuple_list_to_parsable_csv( + [(el.get_generic_instance_name(), val) for el, val in ma.sorted_blocks]) + sorted_kernels = self.convert_tuple_list_to_parsable_csv( + [(el.get_task_name(), val) for el, val in ma.sorted_kernels.items()]) + blk_instance_name = ma.get_block_ref().get_generic_instance_name() + blk_type = ma.get_block_ref().type + comm_comp = (ma.get_system_improvement_log())["comm_comp"] + high_level_optimization = (ma.get_system_improvement_log())["high_level_optimization"] + exact_optimization = (ma.get_system_improvement_log())["exact_optimization"] + architectural_principle = (ma.get_system_improvement_log())["architectural_principle"] + block_selection_time = ma.get_logs("block_selection_time") + kernel_selection_time = ma.get_logs("kernel_selection_time") + transformation_selection_time = ma.get_logs("transformation_selection_time") + pickling_time = ma.get_logs("pickling_time") + metric_selection_time = ma.get_logs("metric_selection_time") + dir_selection_time = ma.get_logs("dir_selection_time") + move_validity = ma.is_valid() + ref_des_dist_to_goal_all = ma.get_logs("ref_des_dist_to_goal_all") + ref_des_dist_to_goal_non_cost = ma.get_logs("ref_des_dist_to_goal_non_cost") + #neighbouring_design_space_size = self.convert_dictionary_to_parsable_csv_with_semi_column(ma.get_design_space_size()) + neighbouring_design_space_size = sim_dp.get_neighbouring_design_space_size() + workload = ma.get_logs("workload") + else: # happens at the very fist iteration + pickling_time = 0 + sorted_metrics = "" + metric = "" + transformation_name = "" + task_name = "" + block_type = "" + dir = "" + generation_time = '' + sorted_blocks = '' + sorted_kernels = {} + blk_instance_name = '' + blk_type = '' + comm_comp = "" + high_level_optimization = "" + exact_optimization = "" + architectural_principle = "" + block_selection_time = "" + kernel_selection_time = "" + metric_selection_time = "" + dir_selection_time = "" + transformation_selection_time = "" + move_validity = "" + ref_des_dist_to_goal_all = "" + ref_des_dist_to_goal_non_cost = "" + neighbouring_design_space_size = "" + workload = "" + + + sub_block_area_break_down = self.convert_dictionary_to_parsable_csv_with_underline(sim_dp.dp_stats.SOC_area_subtype_dict) + block_area_break_down = self.convert_dictionary_to_parsable_csv_with_underline(sim_dp.dp_stats.SOC_area_dict) + routing_complexity = sim_dp.dp_rep.get_hardware_graph().get_routing_complexity() + area_non_dram = sim_dp.dp_stats.get_system_complex_area_stacked_dram()["non_dram"] + area_dram = sim_dp.dp_stats.get_system_complex_area_stacked_dram()["dram"] + simple_topology = sim_dp.dp_rep.get_hardware_graph().get_simplified_topology_code() + channel_cnt = sim_dp.dp_rep.get_hardware_graph().get_number_of_channels() + blk_cnt = sum([int(el) for el in simple_topology.split("_")]) + bus_cnt = [int(el) for el in simple_topology.split("_")][0] + mem_cnt = [int(el) for el in simple_topology.split("_")][1] + pe_cnt = [int(el) for el in simple_topology.split("_")][2] + task_cnt = len(list(sim_dp.dp_rep.krnl_phase_present.keys())) + #itr_depth_multiplied = sim_dp.dp_rep.get_iteration_number() * config.SA_depth + sim_dp.dp_rep.get_depth_number() + + self.total_iteration_cnt = ctr + data = { + "data_number": ctr, + "iteration cnt" : self.total_iteration_cnt, + "exploration_plus_simulation_time" : sim_dp.get_exploration_and_simulation_approximate_time(), + #"phase_calculation_time": sim_dp.get_phase_calculation_time(), + # "task_update_time": sim_dp.get_task_update_time(), + #"phase_scheduling_time": sim_dp.get_phase_scheduling_time(), + "observed population number" : sim_dp.dp_rep.get_population_observed_number(), + "SA_total_depth": str(config.SA_depth), + "transformation_selection_mode": str(config.transformation_selection_mode), + "workload": workload, + "heuristic_type": config.heuristic_type, + "population generation cnt": sim_dp.dp_rep.get_population_generation_cnt(), + "simulation time" : sim_dp.dp_rep.get_simulation_time(), + "transformation generation time" : generation_time, + "metric selection time" :metric_selection_time, + "dir selection time" :dir_selection_time, + "kernel selection time" :kernel_selection_time, + "block selection time" : block_selection_time, + "transformation selection time" : transformation_selection_time, + "design duplication time": pickling_time, + "neighbour selection time": self.neighbour_selection_time, + "dist_to_goal_all" : sim_dp.dp_stats.dist_to_goal(metrics_to_look_into=["area", "latency", "power", "cost"], + mode="eliminate"), + "dist_to_goal_non_cost" : sim_dp.dp_stats.dist_to_goal(metrics_to_look_into=["area", "latency", "power"], + mode="eliminate"), + "ref_des_dist_to_goal_all" : ref_des_dist_to_goal_all, + "ref_des_dist_to_goal_non_cost" : ref_des_dist_to_goal_non_cost, + "best_des_so_far_dist_to_goal_non_cost": self.so_far_best_sim_dp.dp_stats.dist_to_goal(metrics_to_look_into=["area", "latency", "power"], + mode="eliminate"), + "best_des_so_far_dist_to_goal_all": self.so_far_best_sim_dp.dp_stats.dist_to_goal(metrics_to_look_into=["area", "latency", "power"], + mode="eliminate"), + "best_des_so_far_area_non_dram": self.so_far_best_sim_dp.dp_stats.get_system_complex_area_stacked_dram()["non_dram"], + "best_des_so_far_area_dram": self.so_far_best_sim_dp.dp_stats.get_system_complex_area_stacked_dram()["dram"], + #"area_breakdown_subtype":self.convert_dictionary_to_parsable_csv_with_semi_column(sim_dp.dp_stats.SOC_area_subtype_dict.keys()), + #"best_des_so_far_area_breakdown_subtype":self.so_far_best_sim_dp.dp_stats.convert_dictionary_to_parsable_csv_with_semi_column(sim_dp.dp_stats.SOC_area_subtype_dict.keys()), + "system block count" : blk_cnt, + "system PE count" : pe_cnt, + "system bus count" : bus_cnt, + "system memory count" : mem_cnt, + "routing complexity" : routing_complexity, + "workload_set" : '_'.join(sim_dp.database.db_input.workload_tasks.keys()), + "block_impact_sorted" : sorted_blocks, + "kernel_impact_sorted" : sorted_kernels, + "metric_impact_sorted" : sorted_metrics, + "transformation_metric" : metric, + "move validity" : move_validity, + "move name" : transformation_name, + "transformation_kernel" : task_name, + "transformation_block_name" : blk_instance_name, + "transformation_block_type" : blk_type, + "transformation_dir" : dir, + "comm_comp" : comm_comp, + "high level optimization name" : high_level_optimization, + "exact optimization name": exact_optimization, + "architectural principle" : architectural_principle, + "neighbouring design space size" : neighbouring_design_space_size, + "block_area_break_down":block_area_break_down, + "sub_block_area_break_down":sub_block_area_break_down, + "task_cnt": task_cnt, + "channel_cnt":channel_cnt, + "area_dram":area_dram, + "area_non_dram":area_non_dram + } + + for metric in config.all_metrics: + # convert dictionary to a parsable data + data_ = sim_dp.dp_stats.get_system_complex_metric(metric) + if isinstance(data_, dict): + data__ =self.convert_dictionary_to_parsable_csv_with_semi_column(data_) + else: + data__ = data_ + data[metric] = data__ + + if metric in sim_dp.database.db_input.get_budget_dict("glass").keys(): + # convert dictionary to a parsable rsult + data_ = sim_dp.database.db_input.get_budget_dict("glass")[metric] + if isinstance(data_, dict): + data__ = self.convert_dictionary_to_parsable_csv_with_semi_column(data_) + else: + data__ = data_ + data[metric +"_budget"] = data__ + + for metric in config.all_metrics: + # convert dictionary to a parsable data + data_ = self.so_far_best_sim_dp.dp_stats.get_system_complex_metric(metric) + if isinstance(data_, dict): + data__ = self.convert_dictionary_to_parsable_csv_with_semi_column(data_) + else: + data__ = data_ + data["best_des_so_far_"+metric] = data__ + + ctr +=1 + self.log_data_list.append(data) + + + def sel_next_dp_for_moos(self, ex_sim_dp_dict, best_sim_dp_so_far, best_ex_dp_so_far, cur_temp): + # convert to stats + sim_dp_list = list(ex_sim_dp_dict.values()) + sim_dp_stat_list = [sim_dp.dp_stats for sim_dp in sim_dp_list] + sim_stat_ex_dp_dict = {} + for k, v in ex_sim_dp_dict.items(): + sim_stat_ex_dp_dict[v.dp_stats] = k + + + # find the ones that fit the expanded budget (note that budget radius shrinks) + selected_sim_dp, found_improved_solution = self.moos_greedy_design_selection(sim_stat_ex_dp_dict, sim_dp_stat_list, + best_ex_dp_so_far, best_sim_dp_so_far.dp_stats, + cur_temp) + + if not found_improved_solution: + selected_sim_dp = self.so_far_best_sim_dp + selected_ex_dp = self.so_far_best_ex_dp + else: + # extract the design + for key, val in ex_sim_dp_dict.items(): + key.sanity_check() + if val == selected_sim_dp: + selected_ex_dp = key + break + + # generate verification data + if found_improved_solution and config.RUN_VERIFICATION_PER_IMPROVMENT: + self.gen_verification_data(selected_sim_dp, selected_ex_dp) + return selected_ex_dp, selected_sim_dp, found_improved_solution + + + def greedy_for_moos(self, starting_ex_sim, moos_greedy_mode = "ctr"): + this_itr_ex_sim_dp_dict_all = {} + greedy_ctr_run = 0 + design_collected_ctr = 0 + + if moos_greedy_mode == 'ctr': + while greedy_ctr_run < config.MOOS_GREEDY_CTR_RUN and design_collected_ctr < config.DESIGN_COLLECTED_PER_GREEDY: + this_itr_ex_sim_dp_dict = self.simple_SA() # run simple simulated annealing + self.cur_best_ex_dp, self.cur_best_sim_dp, found_improvement = self.sel_next_dp_for_moos(this_itr_ex_sim_dp_dict, + self.so_far_best_sim_dp, self.so_far_best_ex_dp, 1) + + for ex, sim in this_itr_ex_sim_dp_dict.items(): + this_itr_ex_sim_dp_dict_all[ex] = sim + + self.so_far_best_sim_dp = self.cur_best_sim_dp + self.so_far_best_ex_dp = self.cur_best_ex_dp + self.found_any_improvement = self.found_any_improvement or found_improvement + design_collected_ctr = len(this_itr_ex_sim_dp_dict_all) + greedy_ctr_run +=1 + elif moos_greedy_mode == 'neighbour': + found_improvement = True + while found_improvement: + this_itr_ex_sim_dp_dict = self.simple_SA() # run simple simulated annealing + self.cur_best_ex_dp, self.cur_best_sim_dp, found_improvement = self.sel_next_dp_for_moos( + this_itr_ex_sim_dp_dict, + self.so_far_best_sim_dp, self.so_far_best_ex_dp, 1) + + for ex, sim in this_itr_ex_sim_dp_dict.items(): + this_itr_ex_sim_dp_dict_all[ex] = sim + + self.so_far_best_sim_dp = self.cur_best_sim_dp + self.so_far_best_ex_dp = self.cur_best_ex_dp + self.found_any_improvement = self.found_any_improvement or found_improvement + design_collected_ctr = len(this_itr_ex_sim_dp_dict_all) + greedy_ctr_run += 1 + elif moos_greedy_mode == "phv": + phv_improvement = True + hyper_volume_ref = [300, 2, 2] + local_pareto = {} + phv_so_far = 0 + while phv_improvement and greedy_ctr_run < config.MOOS_GREEDY_CTR_RUN: + # run hill climbing + this_itr_ex_sim_dp_dict = self.simple_SA() # run simple simulated annealing + # get best neighbour + self.cur_best_ex_dp, self.cur_best_sim_dp, found_improvement = self.sel_next_dp_for_moos( + this_itr_ex_sim_dp_dict, + self.so_far_best_sim_dp, self.so_far_best_ex_dp, 1) + + # find best neighbour + for ex, sim in this_itr_ex_sim_dp_dict.items(): + if ex == self.cur_best_ex_dp: + best_neighbour_ex = ex + best_neighbour_sim = sim + break + + for ex, sim in this_itr_ex_sim_dp_dict.items(): + this_itr_ex_sim_dp_dict_all[ex] = sim + + # update the pareto with new best neighbour + new_pareto = {} + for ex, sim in local_pareto.items(): + new_pareto[ex] = sim + new_pareto[best_neighbour_ex] = best_neighbour_sim + pareto_designs = self.get_pareto_designs(new_pareto) + pareto_with_best_neighbour = self.evaluate_pareto(new_pareto, hyper_volume_ref) + phv_improvement = pareto_with_best_neighbour > phv_so_far + + # if phv improved, add the neighbour to the local pareto + + if phv_improvement: + local_pareto = {} + for ex, sim in new_pareto.items(): + local_pareto[ex] = sim + phv_so_far = pareto_with_best_neighbour + + greedy_ctr_run +=1 + + #result = {self.cur_best_ex_dp: self.cur_best_sim_dp} + result = this_itr_ex_sim_dp_dict_all + return result + + + def get_pareto_designs(self, ex_sim_designs): + pareto_designs = {} + point_list = [] + # iterate through the designs and generate points ([latency, power, area] tuple or points) + for ex, sim in ex_sim_designs.items(): + point = [] + for metric_name in config.budgetted_metrics: + for type, id in sim.dp_stats.get_designs_SOCs(): + if metric_name == "latency": + metric_val = sum(list(sim.dp_stats.get_SOC_metric_value(metric_name, type, id).values())) + #metric_val = format(metric_val, ".10f") + else: + metric_val = sim.dp_stats.get_SOC_metric_value(metric_name, type, id) + #metric_val = format(metric_val, ".10f") + + + point.append(metric_val) + point_list.append(point) + + # find the pareto points + pareto_points = self.find_pareto_points(point_list) + remove_list = [] + # extract the designs according to the pareto points + for ex, sim in ex_sim_designs.items(): + point = [] + for metric_name in config.budgetted_metrics: + for type, id in sim.dp_stats.get_designs_SOCs(): + if metric_name == "latency": + metric_val = sum(list(sim.dp_stats.get_SOC_metric_value(metric_name, type, id).values())) + else: + metric_val = sim.dp_stats.get_SOC_metric_value(metric_name, type, id) + + #metric_val = format(metric_val, ".10f") + point.append(metric_val) + if point in pareto_points: + pareto_points.remove(point) # no double counting + pareto_designs[ex] = sim + else: + remove_list.append(ex) + + for el in remove_list: + del ex_sim_designs[el] + + if pareto_designs == {}: + print("hmm there shoujld be a point in the pareto design") + return pareto_designs + + def find_pareto_points(self, points): + def is_pareto_efficient_dumb(costs): + is_efficient = np.ones(costs.shape[0], dtype=bool) + for i, c in enumerate(costs): + is_efficient[i] = np.all(np.any(costs[:i] > c, axis=1)) and np.all(np.any(costs[i + 1:] > c, axis=1)) + return is_efficient + + # removing the duplicates (other wise we get wrong results for pareto front) + points.sort() + points = list(k for k, _ in itertools.groupby(points)) + efficients = is_pareto_efficient_dumb(np.array(points)) + pareto_points_array = [points[idx] for idx, el in enumerate(efficients) if el] + + return pareto_points_array + + """ + pareto_points = [] + for el in pareto_points_array: + list_ = [] + for el_ in el: + list.append(el) + pareto_points.append(list_) + + return pareto_points + """ + + def evaluate_pareto(self, pareto_ex_sim, ref): + point_list = [] + for ex, sim in pareto_ex_sim.items(): + point = [] + for metric_name in config.budgetted_metrics: + for type, id in sim.get_designs_SOCs(): + if metric_name == "latency": + metric_val = sum(list(sim.dp_stats.get_SOC_metric_value(metric_name, type, id).values())) + else: + metric_val = sim.dp_stats.get_SOC_metric_value(metric_name, type, id) + + #metric_val = format(metric_val, ".10f") + point.append(metric_val) + point_list.append(point) + + + hv = hypervolume(point_list) + hv_value = hv.compute(ref) + + return hv_value + + # ------------------------------ + # Functionality: + # Explore the design space. + # Variables + # it uses the config parameters that are used to instantiate the object. + # ------------------------------ + def explore_ds_with_moos(self): + #gc.DEBUG_SAVEALL = True + self.so_far_best_ex_dp = self.init_ex_dp + self.so_far_best_sim_dp = self.eval_design(self.so_far_best_ex_dp, self.database) + self.init_sim_dp = self.eval_design(self.so_far_best_ex_dp, self.database) + + # visualize/checkpoint/PA generation + vis_hardware.vis_hardware(self.so_far_best_sim_dp.get_dp_rep()) + if config.RUN_VERIFICATION_PER_GEN or config.RUN_VERIFICATION_PER_IMPROVMENT or config.RUN_VERIFICATION_PER_NEW_CONFIG: + self.gen_verification_data(self.so_far_best_sim_dp, self.so_far_best_ex_dp) + + + #num_of_workloads = len(self.so_far_best_sim_dp.dp_stats.database.get_workloads_last_task().values()) + # initializing the tree + self.pareto_global = {} + self.pareto_global [self.so_far_best_ex_dp] = self.so_far_best_sim_dp + hyper_volume_ref = [300,2,2] + + pareto_global_child_evaluation = self.evaluate_pareto(self.pareto_global, hyper_volume_ref) + root_node = self.moos_tree.get_root() + root_node.update_evaluation(pareto_global_child_evaluation) + best_leaf = self.moos_tree.get_root() + + des_per_iteration = [0] + start = True + cur_temp = config.annealing_max_temp + + pareto_global_init = {} + pareto_global_child = {} + should_terminate = False + reason_to_termiante = "" + ctr = 0 + while not should_terminate: + # get pareto designs + self.pareto_global = self.get_pareto_designs(self.pareto_global) + # expand the tree + best_leaf = self.moos_tree.find_node_to_expand() + best_leaf_evaluation = best_leaf.get_evaluation() + expanded = self.moos_tree.expand(best_leaf) + ctr +=1 + # if expansion failed, terminate + # usually happens when the intervals are too small + if not expanded: + should_terminate, reason_to_terminate = True, "no_interval_for_moos" + ctr +=1 + # populate the pareto (local pareto) + pareto_global_init.clear() + for ex, sim in self.pareto_global.items(): + pareto_global_init[ex] = sim + + # iterate through the children and run greedy heuristic + for name, child in best_leaf.get_children().items(): + if should_terminate: + continue + + if name == "center": + pareto_global_child_evaluation = best_leaf_evaluation + child.update_evaluation(pareto_global_child_evaluation) + continue + + # populate the pareto front with pareto global + pareto_global_child.clear() + for ex, sim in pareto_global_init.items(): + pareto_global_child[ex] = sim + lambdas = child.get_lambdas() + self.cur_best_ex_dp, self.cur_best_sim_dp = self.sel_start_dp_moos(pareto_global_child, + self.so_far_best_sim_dp, self.so_far_best_ex_dp, lambdas) + + # use the follow as a the starting point for greedy heuristic + self.so_far_best_ex_dp = self.cur_best_ex_dp + self.so_far_best_sim_dp = self.cur_best_sim_dp + #this_itr_ex_sim_dp_dict = {self.so_far_best_ex_dp: self.so_far_best_sim_dp} + this_itr_ex_sim_dp_dict = self.greedy_for_moos(self.so_far_best_ex_dp, config.moos_greedy_mode) # run simple simulated annealing + + + self.total_iteration_ctr += len(list(this_itr_ex_sim_dp_dict.keys())) + + """ + # collect profiling information about moves and designs generated + if config.VIS_MOVE_TRAIL and (self.population_generation_cnt% config.vis_reg_ctr_threshold) == 0 and len(self.des_trail_list) > 0: + plot.des_trail_plot(self.des_trail_list, self.move_profile, des_per_iteration) + plot.move_profile_plot(self.move_profile) + """ + + # get new pareto design and merge, and evaluate (update the tree) + self.log_data(this_itr_ex_sim_dp_dict) + self.collect_stats(this_itr_ex_sim_dp_dict) + pareto_designs = self.get_pareto_designs(this_itr_ex_sim_dp_dict) + pareto_global_child.update(pareto_designs) + pareto_global_child = self.get_pareto_designs(pareto_global_child) + pareto_global_child_evaluation = self.evaluate_pareto(pareto_global_child, hyper_volume_ref) + print("pareto evluation" + str(pareto_global_child_evaluation)) + child.update_evaluation(pareto_global_child_evaluation) + + # update pareto global + for ex, sim in pareto_global_child.items(): + self.pareto_global[ex] = sim + #self.pareto_global.update(pareto_global_child) + + """ + if config.VIS_GR_PER_ITR and (self.population_generation_cnt% config.vis_reg_ctr_threshold) == 0: + vis_hardware.vis_hardware(self.cur_best_sim_dp.get_dp_rep()) + """ + # collect statistics about the design + gc.collect() + + # update and check for termination + print("memory usage ===================================== " +str(psutil.virtual_memory())) + # check terminattion status + should_terminate, reason_to_terminate = self.update_ctrs() + mem = psutil.virtual_memory() + mem_used = int(mem.percent) + if mem_used > config.out_of_memory_percentage: + should_terminate, reason_to_terminate = True, "out_of_memory" + + if should_terminate: + # get the best design + self.cur_best_ex_dp, self.cur_best_sim_dp = self.sel_next_dp(self.pareto_global, + self.so_far_best_sim_dp, + self.so_far_best_ex_dp, config.annealing_max_temp) + self.so_far_best_ex_dp = self.cur_best_ex_dp + self.so_far_best_sim_dp = self.cur_best_sim_dp + + print("reason to terminate is:" + reason_to_terminate) + vis_hardware.vis_hardware(self.cur_best_sim_dp.get_dp_rep()) + if not (self.last_des_trail == None): + if self.last_des_trail == None: + self.last_des_trail = ( + copy.deepcopy(self.so_far_best_sim_dp), copy.deepcopy(self.so_far_best_sim_dp)) + # self.last_des_trail = (cPickle.loads(cPickle.dumps(self.so_far_best_sim_dp, -1)),cPickle.loads(cPickle.dumps(self.so_far_best_sim_dp))) + else: + self.des_trail_list.append(self.last_des_trail) + if not (self.last_move == None): + self.move_profile.append(self.last_move) + + if config.VIS_MOVE_TRAIL: + plot.des_trail_plot(self.des_trail_list, self.move_profile, des_per_iteration) + plot.move_profile_plot(self.move_profile) + self.reason_to_terminate = reason_to_terminate + + return + + print(" >>>>> des" + " latency:" + str(self.so_far_best_sim_dp.dp_stats.get_system_complex_metric("latency"))) + + """ + stat_result = self.so_far_best_sim_dp.dp_stats + if stat_result.fits_budget(1): + should_terminate = True + reason_to_terminate = "met the budget" + elif self.population_generation_cnt > self.TOTAL_RUN_THRESHOLD: + should_terminate = True + reason_to_terminate = "exploration (total itr_ctr) iteration threshold reached" + + """ + + + def explore_simple_greedy_one_sample(self, orig_ex_dp, mode="random"): + orig_sim_dp = self.eval_design(self.init_ex_dp, self.database) + des_tup = [orig_ex_dp, orig_sim_dp] + des_tup_new, possible_des_cnt = self.gen_neigh_and_eval(des_tup) + + # ------------------------------ + # Functionality: + # Explore the design space. + # Variables + # it uses the config parameters that are used to instantiate the object. + # ------------------------------ + def explore_ds(self): + self.so_far_best_ex_dp = self.init_ex_dp + self.so_far_best_sim_dp = self.eval_design(self.so_far_best_ex_dp, self.database) + self.init_sim_dp = self.eval_design(self.so_far_best_ex_dp, self.database) + + # visualize/checkpoint/PA generation + vis_hardware.vis_hardware(self.so_far_best_sim_dp.get_dp_rep()) + if config.RUN_VERIFICATION_PER_GEN or config.RUN_VERIFICATION_PER_IMPROVMENT or config.RUN_VERIFICATION_PER_NEW_CONFIG: + self.gen_verification_data(self.so_far_best_sim_dp, self.so_far_best_ex_dp) + + des_per_iteration = [0] + start = True + cur_temp = config.annealing_max_temp + + while True: + this_itr_ex_sim_dp_dict = self.simple_SA() # run simple simulated annealing + self.total_iteration_ctr += len(list(this_itr_ex_sim_dp_dict.keys())) + + # collect profiling information about moves and designs generated + if config.VIS_MOVE_TRAIL and (self.population_generation_cnt% config.vis_reg_ctr_threshold) == 0 and len(self.des_trail_list) > 0: + plot.des_trail_plot(self.des_trail_list, self.move_profile, des_per_iteration) + plot.move_profile_plot(self.move_profile) + + # select the next best design + t1 = time.time() + self.cur_best_ex_dp, self.cur_best_sim_dp = self.sel_next_dp(this_itr_ex_sim_dp_dict, + self.so_far_best_sim_dp, self.so_far_best_ex_dp, cur_temp) + t2 = time.time() + self.neighbour_selection_time = t2-t1 + self.log_data(this_itr_ex_sim_dp_dict) + print("-------:):):):):)----------") + print("Best design's latency: " + str(self.cur_best_sim_dp.dp_stats.get_system_complex_metric("latency"))) + print("Best design's power: " + str(self.cur_best_sim_dp.dp_stats.get_system_complex_metric("power"))) + print("Best design's sub area: " + str(self.cur_best_sim_dp.dp_stats.get_system_complex_area_stacked_dram())) + + if not self.cur_best_sim_dp.move_applied == None: + self.cur_best_sim_dp.move_applied.print_info() + + if config.VIS_GR_PER_ITR and (self.population_generation_cnt% config.vis_reg_ctr_threshold) == 0: + vis_hardware.vis_hardware(self.cur_best_sim_dp.get_dp_rep()) + + # collect statistics about the design + self.collect_stats(this_itr_ex_sim_dp_dict) + + # determine if the design has met the budget, if so, terminate + mem = psutil.virtual_memory() + mem_used = int(mem.percent) + print("memory usage ===================================== " + str(psutil.virtual_memory())) + if mem_used > config.out_of_memory_percentage: + should_terminate, reason_to_terminate = True, "out_of_memory" + else: + should_terminate, reason_to_terminate = self.update_ctrs() + + if should_terminate: + print("reason to terminate is:" + reason_to_terminate) + vis_hardware.vis_hardware(self.cur_best_sim_dp.get_dp_rep()) + if not (self.last_des_trail == None): + if self.last_des_trail == None: + self.last_des_trail = (copy.deepcopy(self.so_far_best_sim_dp), copy.deepcopy(self.so_far_best_sim_dp)) + #self.last_des_trail = (cPickle.loads(cPickle.dumps(self.so_far_best_sim_dp, -1)),cPickle.loads(cPickle.dumps(self.so_far_best_sim_dp))) + else: + self.des_trail_list.append(self.last_des_trail) + if not (self.last_move == None): + self.move_profile.append(self.last_move) + + if config.VIS_MOVE_TRAIL: + plot.des_trail_plot(self.des_trail_list, self.move_profile, des_per_iteration) + plot.move_profile_plot(self.move_profile) + self.reason_to_terminate = reason_to_terminate + return + cur_temp -= config.annealing_temp_dec + self.vis_move_trail_ctr += 1 + + # ------------------------------ + # Functionality: + # generating plots for data analysis + # ----------------------------- + def plot_data(self): + iterations = [iter*config.num_neighs_to_try for iter in self.design_itr] + if config.DATA_DELIVEYRY == "obfuscate": + plot.scatter_plot(iterations, [area/self.area_explored[0] for area in self.area_explored], ("iteration", "area"), self.database) + plot.scatter_plot(iterations, [power/self.power_explored[0] for power in self.power_explored], ("iteration", "power"), self.database) + latency_explored_normalized = [el/self.latency_explored[0] for el in self.latency_explored] + plot.scatter_plot(iterations, latency_explored_normalized, ("iteration", "latency"), self.database) + else: + plot.scatter_plot(iterations, [1000000*area/self.area_explored[0] for area in self.area_explored], ("iteration", "area"), self.database) + plot.scatter_plot(iterations, [1000*power/self.power_explored[0] for power in self.power_explored], ("iteration", "power"), self.database) + plot.scatter_plot(iterations, self.latency_explored/self.latency_explored[0], ("iteration", "latency"), self.database) + + # ------------------------------ + # Functionality: + # report the data collected in a humanly readable way. + # Variables: + # explorations_start_time: to exploration start time used to determine the end-to-end exploration time. + # ----------------------------- + def report(self, exploration_start_time): + exploration_end_time = time.time() + total_sim_time = exploration_end_time - exploration_start_time + print("*********************************") + print("------- Best Designs Metrics ----") + print("*********************************") + print("Best design's latency: " + str(self.so_far_best_sim_dp.dp_stats.get_system_complex_metric("latency")) + \ + ", ---- time budget:" + str(config.budgets_dict["glass"]["latency"])) + print("Best design's thermal power: " + str(self.so_far_best_sim_dp.dp_stats.get_system_complex_metric("power"))+ + ", ---- thermal power budget:" + str(config.budgets_dict["glass"]["power"])) + print("Best design's area: " + str(self.so_far_best_sim_dp.dp_stats.get_system_complex_metric("area")) + \ + ", ---- area budget:" + str(config.budgets_dict["glass"]["area"])) + print("Best design's energy: " + str(self.so_far_best_sim_dp.dp_stats.get_system_complex_metric("energy"))) + print("*********************************") + print("------- DSE performance --------") + print("*********************************") + print("Initial design's latency: " + str(self.init_sim_dp.dp_stats.get_system_complex_metric("latency"))) + print("Speed up: " + str(self.init_sim_dp.dp_stats.get_system_complex_metric("latency")/self.so_far_best_sim_dp.dp_stats.get_system_complex_metric("latency"))) + print("Number of design points examined:" + str(self.population_generation_cnt*config.num_neighs_to_try)) + print("Time spent per design point:" + str(total_sim_time/(self.population_generation_cnt*config.num_neighs_to_try))) + print("The design meet the latency requirement: " + str(self.so_far_best_sim_dp.dp_stats.get_system_complex_metric("latency") < config.objective_budget)) + vis_hardware.vis_hardware(self.so_far_best_ex_dp) + if config.VIS_FINAL_RES: + vis_hardware.vis_hardware(self.so_far_best_ex_dp, config.hw_graphing_mode) + + # write the output + home_dir = os.getcwd() + FARSI_result_dir = config.FARSI_result_dir + FARSI_result_directory_path = os.path.join(home_dir, 'data_collection/data/', FARSI_result_dir) + output_file_verbose = os.path.join(FARSI_result_directory_path, config.FARSI_outputfile_prefix_verbose +".txt") + output_file_minimal = os.path.join(FARSI_result_directory_path, config.FARSI_outputfile_prefix_minimal +".csv") + + # minimal output + output_fh_minimal = open(output_file_minimal, "w") + for metric in config.all_metrics: + output_fh_minimal.write(metric+ ",") + output_fh_minimal.write("\n") + for metric in config.all_metrics: + output_fh_minimal.write(str(self.so_far_best_sim_dp.dp_stats.get_system_complex_metric(metric))+ ",") + output_fh_minimal.close() + + # verbose + output_fh_verbose = open(output_file_verbose, "w") + output_fh_verbose.write("iter_cnt" + ": ") + for el in range(0, len(self.power_explored)): + output_fh_verbose.write(str(el) +",") + + output_fh_verbose.write("\npower" + ": ") + for el in self.power_explored: + output_fh_verbose.write(str(el) +",") + + output_fh_verbose.write("\nlatency" + ": ") + for el in self.latency_explored: + output_fh_verbose.write(str(el) +",") + output_fh_verbose.write("\narea" + ": ") + for el in self.area_explored: + output_fh_verbose.write(str(el) +",") + + output_fh_verbose.close() + + # ------------------------------ + # Functionality: + # collect the profiling information for all the design generated by the explorer. For data analysis. + # Variables: + # ex_sim_dp_dict: example_design_simulated_design_dictionary. A dictionary containing the + # (example_design, simulated_design) tuple. + # ----------------------------- + def collect_stats(self, ex_sim_dp_dict): + for sim_dp in ex_sim_dp_dict.values(): + self.area_explored.append(sim_dp.dp_stats.get_system_complex_metric("area")) + self.power_explored.append(sim_dp.dp_stats.get_system_complex_metric("power")) + self.latency_explored.append(sim_dp.dp_stats.get_system_complex_metric("latency")) + self.design_itr.append(self.population_generation_cnt) + + # ------------------------------ + # Functionality: + # calculate the budget coefficients. This is used for simulated annealing purposes. + # Concretely, first we use relax budgets to allow wider exploration, and then + # incrementally tighten the budget to direct the explorer more toward the goal. + # ------------------------------ + def calc_budget_coeff(self): + self.budget_coeff = int(((self.TOTAL_RUN_THRESHOLD - self.population_generation_cnt)/self.coeff_slice_size) + 1) + + def reset_ctrs(self): + should_terminate = False + reason_to_terminate = "" + self.fitted_budget_ctr = 0 + self.krnels_not_to_consider = [] + self.des_stag_ctr = 0 + self.krnel_stagnation_ctr = 0 + self.krnel_rnk_to_consider = 0 + self.cleanup_ctr = 0 + self.des_stag_ctr = 0 + self.recently_seen_design_ctr = 0 + self.counters.reset() + self.moos_tree = moosTreeModel(config.budgetted_metrics) # only used for moos heuristic + self.found_any_improvement = False + + + + # ------------------------------ + # Functionality: + # Update the counters to determine the exploration (navigation heuristic) control path to follow. + # ------------------------------ + def update_ctrs(self): + should_terminate = False + reason_to_terminate = "" + + self.so_far_best_ex_dp = self.cur_best_ex_dp + self.so_far_best_sim_dp = self.cur_best_sim_dp + + self.population_generation_cnt += 1 + stat_result = self.so_far_best_sim_dp.dp_stats + + tasks_not_meeting_budget = [el.get_task_name() for el in self.filter_in_kernels_meeting_budget("", self.so_far_best_sim_dp)] + tsks_left_to_optimize = list(set(tasks_not_meeting_budget) - set(self.krnels_not_to_consider)) + + if stat_result.fits_budget(1) : + config.VIS_GR_PER_GEN = True # visualize the graph per design point generation + config.VIS_SIM_PER_GEN = True # if true, we visualize the simulation progression + self.fitted_budget_ctr +=1 + if (self.fitted_budget_ctr > config.fitted_budget_ctr_threshold): + reason_to_terminate = "met the budget" + should_terminate = True + elif self.des_stag_ctr > self.DES_STAG_THRESHOLD: + reason_to_terminate = "des_stag_ctr exceeded" + should_terminate = True + elif len(self.krnels_not_to_consider) >= (len(self.so_far_best_sim_dp.get_kernels()) - len(self.so_far_best_sim_dp.get_dummy_tasks())): + if stat_result.fits_budget(1): + reason_to_terminate = "met the budget" + else: + reason_to_terminate = "all kernels already targeted without improvement" + should_terminate = True + elif len(tsks_left_to_optimize) == 0: + if stat_result.fits_budget(1): + reason_to_terminate = "met the budget" + else: + reason_to_terminate = "all kernels already targeted without improvement" + should_terminate = True + elif self.total_iteration_ctr > self.TOTAL_RUN_THRESHOLD: + if stat_result.fits_budget(1): + reason_to_terminate = "met the budget" + else: + reason_to_terminate = "exploration (total itr_ctr) iteration threshold reached" + should_terminate = True + + self.counters.update(self.krnel_rnk_to_consider, self.krnel_stagnation_ctr, self.fitted_budget_ctr, self.des_stag_ctr, + self.krnels_not_to_consider, self.population_generation_cnt, self.found_any_improvement, self.total_iteration_ctr) + + print(">>>>> total iteration count is: " + str(self.total_iteration_ctr)) + return should_terminate, reason_to_terminate + + +class moosTreeNode: + def __init__(self, k_intervals): + self.k_ins = k_intervals + self.children = {} + self.evaluation = "None" + + def update_evaluation(self, evaluation): + self.evaluation = evaluation + + def get_evaluation(self): + if self.evaluation == "None": + print("must populate the evluation first") + exit(0) + return self.evaluation + + def get_k_ins(self): + return self.k_ins + + def get_interval(self, metric): + return self.get_k_ins()[metric] + + + def get_interval_length(self, metric): + assert(metric in self.k_ins.keys()) + interval = self.k_ins[metric] + length = interval[1] - interval[0] + if (length <=0): + print("length" + str(length)) + print("interval is:") + print(str(self.k_ins[metric])) + assert(length > 0) + return length + + def longest_dimension_name(self): + key = list(self.get_k_ins().keys())[0] + max_dimension = self.get_interval_length(key) + max_key = key + #print("max_dimentions" + str(max_dimension)) + + for k,v in self.get_k_ins().items(): + if self.get_interval_length(k) > max_dimension: + max_dimension = self.get_interval_length(k) + max_key = k + + return max_key + + def update_interval(self, key, interval): + self.k_ins[key] = interval + + def add_children(self, left_children_, center_children_, right_children_): + self.children["left"]= left_children_ + self.children["center"] = center_children_ + self.children["right"] = right_children_ + + def get_children_with_position(self, position): + return self.children[position] + + def get_children(self): + return self.children + + def get_lambdas(self): + lambdas = {} + for metric_name, val in self.k_ins.items(): + lambdas[metric_name] = (val[1] - val[0])/2 + return lambdas + + +class moosTreeModel: + def __init__(self, metric_names): + max = Decimal(1000000000) + min = Decimal(0) + node_val = [] + k_ins = {} + for el in metric_names: + k_ins[el] = [min, max] + self.root = moosTreeNode(k_ins) + + def get_root(self): + return self.root + + def get_leaves_with_depth(self): + def get_leaves_helper(node, depth): + result = [] + if node.get_children() == {}: + result = [(node, depth)] + else: + for position, child in node.get_children().items(): + child_result = get_leaves_helper(child, depth+1) + result.extend(child_result) + return result + + leaves_depth = get_leaves_helper(self.root, 0) + return leaves_depth + + def expand(self, node): + # initializing the longest + longest_dimension_key = node.longest_dimension_name() + longest_dimension_interval = node.get_interval(longest_dimension_key) + longest_intr_min = min(longest_dimension_interval) + longest_intr_max = max(longest_dimension_interval) + longest_dimension_incr = Decimal(longest_intr_max - longest_intr_min)/3 + + child_center = moosTreeNode(copy.deepcopy(node.get_k_ins())) + child_left = moosTreeNode(copy.deepcopy(node.get_k_ins())) + child_right = moosTreeNode(copy.deepcopy(node.get_k_ins())) + + if (longest_intr_min - longest_intr_min + 1*longest_dimension_incr == 0): + print("this shouldn't happen. intervals should be the same") + print("intrval" + str(longest_dimension_interval)) + print("incr" + str(longest_dimension_incr)) + print("min" +str(longest_intr_min)) + print("max" + str(longest_intr_max)) + print("first upper" + str(longest_intr_min + 1*longest_dimension_incr)) + print("second upper"+ str(longest_intr_min + 2 * longest_dimension_incr)) + return False + + if (longest_intr_min + 1 * longest_dimension_incr - longest_intr_min + 2 * longest_dimension_incr == 0): + print("this shouldn't happen. intervals should be the same") + print(longest_dimension_interval) + print(longest_dimension_incr) + print(longest_intr_min) + print(longest_intr_min + 1*longest_dimension_incr) + print(longest_intr_min + 2 * longest_dimension_incr) + print(longest_intr_max) + return False + + if (longest_intr_min + 2 * longest_dimension_incr - longest_intr_max == 0): + print("this shouldn't happen. intervals should be the same") + print(longest_dimension_interval) + print(longest_dimension_incr) + print(longest_intr_min) + print(longest_intr_min + 1*longest_dimension_incr) + print(longest_intr_min + 2 * longest_dimension_incr) + print(longest_intr_max) + return False + + + child_left.update_interval(longest_dimension_key, [longest_intr_min, longest_intr_min + 1*longest_dimension_incr]) + child_center.update_interval(longest_dimension_key, [longest_intr_min + 1*longest_dimension_incr, longest_intr_min + 2*longest_dimension_incr]) + child_right.update_interval(longest_dimension_key, [longest_intr_min + 2*longest_dimension_incr, longest_intr_max]) + node.add_children(child_left, child_center, child_right) + return True + + def find_node_to_expand(self): + node_star_list = [] # selected node + leaves_with_depth = self.get_leaves_with_depth() + + # find max depth + max_depth = max([depth for leaf,depth in leaves_with_depth]) + + # split leaves to max and non max depth + leaves_with_max_depth = [leaf for leaf,depth in leaves_with_depth if depth == max_depth] + leaves_with_non_max_depth = [leaf for leaf,depth in leaves_with_depth if not depth == max_depth] + + # select node star + for max_leaf in leaves_with_max_depth: + leaf_is_better_list = [] + for non_max_leaf in leaves_with_non_max_depth: + leaf_is_better_list.append(max_leaf.get_evaluation() >= non_max_leaf.get_evaluation()) + if all(leaf_is_better_list): + node_star_list.append(max_leaf) + + + best_node = node_star_list[0] + for node in node_star_list: + if node.get_evaluation() > best_node.get_evaluation(): + best_node = node + """ + # for debugging. delete later + if (len(node_star_list) == 0): + for max_leaf in leaves_with_max_depth: + leaf_is_better_list = [] + for non_max_leaf in leaves_with_non_max_depth: + leaf_is_better_list.append(max_leaf.get_evaluation() >= non_max_leaf.get_evaluation()) + if all(leaf_is_better_list): + node_star_list.append(max_leaf) + + if not (len(node_star_list) == 1): + print("something went wrong") + """ + return best_node + + + + + + + + + + + diff --git a/Project_FARSI/DSE_utils/simple_minimal_change_ga.py b/Project_FARSI/DSE_utils/simple_minimal_change_ga.py new file mode 100644 index 00000000..307fcae9 --- /dev/null +++ b/Project_FARSI/DSE_utils/simple_minimal_change_ga.py @@ -0,0 +1,209 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +# This file is part of DEAP. +# +# DEAP is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of +# the License, or (at your option) any later version. +# +# DEAP is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with DEAP. If not, see . + + +# example which maximizes the sum of a list of integers +# each of which can be 0 or 1 +import time +import random + +from deap import base +from deap import creator +from deap import tools +from datetime import datetime + +creator.create("FitnessMax", base.Fitness, weights=(1.0,)) +creator.create("Individual", list, fitness=creator.FitnessMax) + +toolbox = base.Toolbox() + +# Attribute generator +# define 'attr_bool' to be an attribute ('gene') +# which corresponds to integers sampled uniformly +# from the range [0,1] (i.e. 0 or 1 with equal +# probability) +bus_DB = range(0, 20) +memory_DB = range(20, 40) +ip_DB = range(40, 600) +block_to_choose_from = ["ip", "memory", "bus"] + +def individual_att_generator(): + global bus_DB + global memory_DB + global ip_DB + time.sleep(.00001) + random.seed(datetime.now().microsecond) + # for mating or mutation + #block_choice = random.choice(block_to_choose_from) + #alternate_block_val = random.choice(eval(block_choice+"_DB")) + + memory = random.choice(memory_DB) + ip = random.choice(ip_DB) + bus = random.choice(bus_DB) + return [ip, memory, bus] + +toolbox.register("attr_bool", individual_att_generator) + +# Structure initializers +# define 'individual' to be an individual +# consisting of 100 'attr_bool' elements ('genes') + +toolbox.register("individual", tools.initRepeat, creator.Individual, + toolbox.attr_bool, 1) + +# define the population to be a list of individuals +# should always be a list +toolbox.register("population", tools.initRepeat, list, toolbox.individual) + +# the goal ('fitness') function to be maximized +def eval_individual_dp(individual): + # call OSSIM individual is a design point + return [max(individual[0])] +#---------- +# Operator registration +#---------- +# register the goal / fitness function +toolbox.register("evaluate", eval_individual_dp) + +def mate_dp(individual1, individual2): + # register the crossover operator + # pick a random index and swap + time.sleep(.00001) + random.seed(datetime.now().microsecond + 10) + rand_idx = random.choice(list(range(0, len(individual1[0])))) + individual1_el = individual1[0][rand_idx] + individual1[0][rand_idx] = individual2[0][rand_idx] + individual2[0][rand_idx] = individual1_el + +toolbox.register("mate", mate_dp) + + +def mutate_dp(individual1, indpb): + # register the crossover operator + # pick a random index and swap + time.sleep(.00001) + random.seed(datetime.now().microsecond + 50) + rand_idx = random.choice(list(range(0, len(individual1[0])))) + individual1[0][rand_idx] = random.choice(list(range(40, 600))) + +# register a mutation operator with a probability to +# flip each attribute/gene of 0.05 +toolbox.register("mutate", mutate_dp, indpb=0.05) + +# operator for selecting individuals for breeding the next +# generation: each individual of the current generation +# is replaced by the 'fittest' (best) of three individuals +# drawn randomly from the current generation. +toolbox.register("select", tools.selTournament, tournsize=3) + +#---------- + +def main(): + time.sleep(.00001) + random.seed(datetime.now().microsecond + 70) + + # CXPB is the probability with which two individuals + # are crossed + # + # MUTPB is the probability for mutating an individual + CXPB, MUTPB = 0.5, 0.02 + max_num_gen = 10000 + max_pop_count = 20 + + # create an initial population of 300 individuals (where + # each individual is a list of integers) + pop = toolbox.population(n=max_pop_count) + + + print("Start of evolution") + + # Evaluate the entire population + fitnesses = list(map(toolbox.evaluate, pop)) + for ind, fit in zip(pop, fitnesses): + ind.fitness.values = fit + + print(" Evaluated %i individuals" % len(pop)) + + # Extracting all the fitnesses of + fits = [ind.fitness.values[0] for ind in pop] + + # Variable keeping track of the number of generations + num_gen = 0 + + # Begin the evolution + while num_gen < max_num_gen: + # A new generation + num_gen = num_gen + 1 + print("-- Generation %i --" % num_gen) + + # Select the next generation individuals + offspring = toolbox.select(pop, len(pop)) + # Clone the selected individuals + offspring = list(map(toolbox.clone, offspring)) + + # Apply crossover and mutation on the offspring + for child1, child2 in zip(offspring[::2], offspring[1::2]): + + # cross two individuals with probability CXPB + if random.random() < CXPB: + toolbox.mate(child1, child2) + + # fitness values of the children + # must be recalculated later + del child1.fitness.values + del child2.fitness.values + + for mutant in offspring: + + # mutate an individual with probability MUTPB + if random.random() < MUTPB: + toolbox.mutate(mutant) + del mutant.fitness.values + + # Evaluate the individuals with an invalid fitness + invalid_ind = [ind for ind in offspring if not ind.fitness.valid] + fitnesses = map(toolbox.evaluate, invalid_ind) + for ind, fit in zip(invalid_ind, fitnesses): + ind.fitness.values = fit + + print(" Evaluated %i individuals" % len(invalid_ind)) + + # The population is entirely replaced by the offspring + pop[:] = offspring + + # Gather all the fitnesses in one list and print the stats + fits = [ind.fitness.values[0] for ind in pop] + + length = len(pop) + mean = sum(fits) / length + sum2 = sum(x*x for x in fits) + std = abs(sum2 / length - mean**2)**0.5 + + print(" Min %s" % min(fits)) + print(" Max %s" % max(fits)) + print(" Avg %s" % mean) + print(" Std %s" % std) + + print("-- End of (successful) evolution --") + + best_ind = tools.selBest(pop, 1)[0] + print("Best individual is %s, %s" % (best_ind, best_ind.fitness.values)) + +if __name__ == "__main__": + main() diff --git a/Project_FARSI/EXTERNAL_CONTINUOUS_INTEGRATION_RUNS.md b/Project_FARSI/EXTERNAL_CONTINUOUS_INTEGRATION_RUNS.md new file mode 100644 index 00000000..e69de29b diff --git a/Project_FARSI/GHANGELOG.md b/Project_FARSI/GHANGELOG.md new file mode 100644 index 00000000..658c84cf --- /dev/null +++ b/Project_FARSI/GHANGELOG.md @@ -0,0 +1 @@ +change log here diff --git a/Project_FARSI/GITHUB_ISSUE_TEMPLATE_EXISTS.md b/Project_FARSI/GITHUB_ISSUE_TEMPLATE_EXISTS.md new file mode 100644 index 00000000..e69de29b diff --git a/Project_FARSI/GROUP_NOTIFICATIONS.md b/Project_FARSI/GROUP_NOTIFICATIONS.md new file mode 100644 index 00000000..e69de29b diff --git a/Project_FARSI/LICENSE b/Project_FARSI/LICENSE new file mode 100644 index 00000000..b3f94ae5 --- /dev/null +++ b/Project_FARSI/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) Facebook, Inc. and its affiliates. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Project_FARSI/PAGES_ENABLED.md b/Project_FARSI/PAGES_ENABLED.md new file mode 100644 index 00000000..e69de29b diff --git a/Project_FARSI/README.md b/Project_FARSI/README.md new file mode 100644 index 00000000..1b7e720c --- /dev/null +++ b/Project_FARSI/README.md @@ -0,0 +1,133 @@ +# Project_FARSI +FARSI is a agile pre-rtl design space exploration framework. It allows SOC designers to find optimal +designs given a set of constraints (performance/power/area and development cost). + + +## How Does it work +To solve the aforementioned problem, FARSI solves 3 problems simultaneously (figure bellow) using +3 main components: +* (1) A simulator to capture the behavior of the SOC. +* (2) An exploration heuristic to navigate the design space in an optimal fashion. +* (3) A database populated by the designer with workloads (e.g., hand tracking, ...), and the +possible hardware options (e.g., general purpose processors, accelerators, memory, etc). + +FARSI continuously samples the database to generate a sample design, simulates its fitness and uses its navigation heuristic to get closer to the optimal design. + +![alt text](figures/FARSI_methodology.png "FARSI components") +![alt text](figures/FARSI_output.png "FARSI Output") + +## Why FARSI +FARSI is developed to overcome the existing DSE problems such as scalability and performance. +To further clarify this, the figure below puts FARSI on the map compared to the other DSEs. +![alt text](figures/DSE_on_the_map.png "components") + + +## Building/Installing FARSI +FARSI is a python based source code. Hence, relevant python libraries need to be installed. + + +## FARSI Input +Software/hardware database shown above is used as an input to FARSI's framework. Here we briefly explain their functionality and encoding. + +**Software Database:** This includes labeled task dependency graphs (TDG). A task is the smallest optimization unit and is typically selected from the computationally intensive functions since they significantly impact the system behavior. TDG contains the dependency information between tasks, the number of instructions processed within a task, and the data movement between them. + +**Hardware Database**: This involves power, performance, and area estimation of each task for different hardware mappings (e.g., to general-purpose processors or specialized accelerators). + +### FARSI Input Encoding: +Although the semantics discussed above can be encoded and inputted in various formats, currently, our front-end parsers take them in the form of spreadsheets. Here we detail these sheets. Please note that examples of these sheets are provided in the specs/database_data/parsing folder. + +Each workload has its set of spreadsheet whose name starts with the $workload name$_database, e.g., audio_decoder_database. + +**Software Database Spreadsheets** + +*Task Data Movement:* contains information about the data movement between tasks and their execution dependency. This sheet is an adjacency matrix format, where the first row and the first column list the workload's tasks. The cell at the coordinate between two tasks shows the data movement among them. Note that data flows from the task shown in the row to the task shown in the column. Also, note that this format implies the execution dependency between tasks if said cells are non-empty. + +*Task instruction count:* contains information about each task's computation, specifically quantifying its non-memory instruction count. + +*Task Itr Count:* each task's loop iteration count. + +**Hardware Database Spreadsheets** + +*Task PE Performance:* Performance (in the number of cycles) associated with mapping of tasks to different processing elements (PEs). + +*Task PE Energy:* Energy associated with the accelerator running a task. + +*Task Area Performance:* Area associated with accelerators. + +*misc_database - Budget:* budget (power, performance, area) associated with various workloads. + +*misc_database - Block Characteristics:* contains information about the potential IPs used in the system. These are non-task specific IPs (as opposed to accelerators that are task-specific and whose information is provided in the TASK PE (AREA/Energy/Performance) spreadsheets.). + +*misc_database - Last Tasks.csv:* name of the last task within each workload. + +**Mapping Database Spreadsheets** + +*Hardware Graph:* contains information about how hardware components are connected. It's an adjacency matrix with the first row and the first column specifying the hardware block names. a **1** in the cell at the coordinate between two blocks indicates a connection between said blocks. + +*Task To Hardware Mapping:* contains information about which hardware blocks various tasks are mapped to. The first row specifies the hardware block names, and the first column specifies the software task names. If a task is mapped onto a hardware block, it is listed under that block. We follow two conventions within this spread sheet. + 1) Under the NoC and Memory blocks, direction of the accesses (read or write) needs to be specified, and this is denoted by an arrow **->**. + For example, a cell that contains **Task1 -> Task2** under a memory **M0** cell indicates that **Task1 data is written into M0 and furthermore, this data will be read by Task2 (as Task1's child)**. Please note that only the write direction is specified, and the read direction is implied from the writes, as was shown in the previous example. + 2) If multiple tasks are mapped to the same block, we separate them with a semicolon. + + +## Running FARSI + +### Stand Alone Simulation ### +The following commands allows the user to run the simulation in standalone mode. + +Switch into the bellow directory. +```shell +cd data_collection/collection_utils/sim_run/ +``` +Set the workload name in the simple_sim_run.py (by default we choose a simple workload) and run the simulation. + +```shell +python simple_sim_run.py +``` + +The output data will be provided under the data_collection/data/simple_sim_run/$date_time$ folder + + + +### Simulation + Exploration Heuristic ### +The following commands allow the user to run both the simulation and exploration simulatenously. + +Switch into bellow directory. +```shell +cd data_collection/collection_utils/what_ifs/ +``` +Set the workload name properly in FARSI_what_ifs_with_params.py (Select among, audio_decoder, hpvm_cava, and edge_detection) and run FARSI. + +```shell +python FARSI_what_ifs_with_params.py # run FARSI +``` + +PS: To modify the settings, modify the settings/config.py file. This file contains many knobs that will determine the exploration heuristic and simulation +features. Please refer to the in file documentations for more details + +output will be provided under data_collection/data/simple_run/$date_time$ + +## Main Contributors +Behzad Boroujerdian\ +Ying Jing + + +## How to Cite +@article{10.1145/3544016, +author = {Boroujerdian, Behzad and Jing, Ying and Tripathy, Devashree and Kumar, Amit and Subramanian, Lavanya and Yen, Luke and Lee, Vincent and Venkatesan, Vivek and Jindal, Amit and Shearer, Robert and Reddi, Vijay Janapa}, +title = {FARSI: An Early-Stage Design Space Exploration Framework to Tame the Domain-Specific System-on-Chip Complexity}, +year = {2022}, +publisher = {Association for Computing Machinery}, +url = {https://doi.org/10.1145/3544016}, +doi = {10.1145/3544016}, +journal = {ACM Trans. Embed. Comput. Syst.}, +month = {may} +} + + +## License +Copyright (c) Facebook, Inc. and its affiliates. +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + + diff --git a/Project_FARSI/SHIP_IT_IS_ENABLED.md b/Project_FARSI/SHIP_IT_IS_ENABLED.md new file mode 100644 index 00000000..e69de29b diff --git a/Project_FARSI/SHIP_IT_IS_SET_UP.md b/Project_FARSI/SHIP_IT_IS_SET_UP.md new file mode 100644 index 00000000..e69de29b diff --git a/Project_FARSI/SIM_utils/SIM.py b/Project_FARSI/SIM_utils/SIM.py new file mode 100644 index 00000000..285d5673 --- /dev/null +++ b/Project_FARSI/SIM_utils/SIM.py @@ -0,0 +1,82 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +from SIM_utils.components.perf_sim import * +from SIM_utils.components.pow_sim import * +#from OSSIM_utils.components.pow_knob_sim import * +from design_utils.design import * +from settings import config + +# This module is our top level simulator containing all simulators (perf, and pow simulator) +class OSASimulator: + def __init__(self, dp, database, pk_dp=""): + self.time_elapsed = 0 # time elapsed from the beginning of the simulation + self.dp = dp # design point to simulate + self.perf_sim = PerformanceSimulator(self.dp) # performance simulator instance + self.pow_sim = PowerSimulator(self.dp) # power simulator instance + + self.database = database + if config.simulation_method == "power_knobs": + self.pk_dp = pk_dp + #self.knob_change_sim = PowerKnobSimulator(self.dp, self.pk_dp, self.database) + self.completion_time = -1 # time passed for the simulation to complete + self.program_status = "idle" + self.cur_tick_time = self.next_tick_time = 0 # current tick time + + # ------------------------------ + # Functionality: + # whether the simulation should terminate + # ------------------------------ + def terminate(self, program_status): + if config.termination_mode == "workload_completion": + return program_status == "done" + elif config.termination_mode == "time_budget_reahced": + return self.time_elapsed >= config.time_budge + else: + return False + + # ------------------------------ + # Functionality: + # ticking the simulation. Note that the tick time varies depending on what is (dynamically) happening in the + # system + # ------------------------------ + def tick(self): + self.cur_tick_time = self.next_tick_time + + # ------------------------------ + # Functionality + # progress the simulation for clock_time forward + # ------------------------------ + def step(self, clock_time): + self.next_tick_time, self.program_status = self.perf_sim.simulate(clock_time) + + # ------------------------------ + # Functionality: + # simulation + # ------------------------------ + def simulate(self): + blah = time.time() + while not self.terminate(self.program_status): + self.tick() + self.step(self.cur_tick_time) + + if config.use_cacti: + self.dp.correct_power_area_with_cacti(self.database) + + # collect all the stats upon completion of simulation + self.dp.collect_dp_stats(self.database) + + if config.simulation_method == "power_knobs": + self.knob_change_sim.launch() + + self.completion_time = self.next_tick_time + self.dp.set_serial_design_time(self.perf_sim.serial_latency) + self.dp.set_par_speedup(self.perf_sim.serial_latency/self.completion_time) + self.dp.set_simulation_time_analytical_portion(self.perf_sim.task_update_time + self.perf_sim.phase_interval_calc_time) + self.dp.set_simulation_time_phase_driven_portion(self.perf_sim.phase_scheduling_time) + self.dp.set_simulation_time_phase_calculation_portion(self.perf_sim.phase_interval_calc_time) + self.dp.set_simulation_time_task_update_portion(self.perf_sim.task_update_time) + self.dp.set_simulation_time_phase_scheduling_portion(self.perf_sim.phase_scheduling_time) + + return self.dp \ No newline at end of file diff --git a/Project_FARSI/SIM_utils/__init__.py b/Project_FARSI/SIM_utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Project_FARSI/SIM_utils/components/__init__.py b/Project_FARSI/SIM_utils/components/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Project_FARSI/SIM_utils/components/perf_sim.py b/Project_FARSI/SIM_utils/components/perf_sim.py new file mode 100644 index 00000000..983c7ead --- /dev/null +++ b/Project_FARSI/SIM_utils/components/perf_sim.py @@ -0,0 +1,492 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +from design_utils.design import * +from functools import reduce + + +# This class is the performance simulator of FARSI +class PerformanceSimulator: + def __init__(self, sim_design): + self.design = sim_design # design to simulate + self.scheduled_kernels = [] # kernels already scheduled + self.driver_waiting_queue = [] # kernels whose trigger condition is met but can not run for various reasons + self.completed_kernels_for_memory_sizing = [] # kernels already completed + # List of all the kernels that are not scheduled yet (to be launched) + self.yet_to_schedule_kernels = self.design.get_kernels()[:] # kernels to be scheduled + self.all_kernels = self.yet_to_schedule_kernels[:] + self.task_token_queue = [] + self.old_clock_time = self.clock_time = 0 + self.program_status = "idle" # specifying the status of the program at the current tick + self.phase_num = -1 + self.krnl_latency_if_run_in_isolation = {} + self.serial_latency = 0 + self.workload_time_if_each_kernel_run_serially() + self.phase_interval_calc_time = 0 + self.phase_scheduling_time = 0 + self.task_update_time = 0 + + + def workload_time_if_each_kernel_run_serially(self): + self.serial_latency = 0 + for krnl in self.yet_to_schedule_kernels: + self.krnl_latency_if_run_in_isolation[krnl] = krnl.get_latency_if_krnel_run_in_isolation() + + for krnl, latency in self.krnl_latency_if_run_in_isolation.items(): + self.serial_latency += latency + + + def reset_perf_sim(self): + self.scheduled_kernels = [] + self.completed_kernels_for_memory_sizing = [] + # List of all the kernels that are not scheduled yet (to be launched) + self.yet_to_schedule_kernels = self.design.get_kernels()[:] + self.old_clock_time = self.clock_time = 0 + self.program_status = "idle" # specifying the status of the program at the current tick + self.phase_num = -1 + + + # ------------------------------ + # Functionality: + # tick the simulator clock_time forward + # ------------------------------ + def tick(self, clock_time): + self.clock_time = clock_time + + # ------------------------------ + # Functionality: + # find the next kernel to be scheduled time + # ------------------------------ + def next_kernel_to_be_scheduled_time(self): + timely_sorted_kernels = sorted(self.yet_to_schedule_kernels, key=lambda kernel: kernel.get_schedule().starting_time) + return timely_sorted_kernels[0].get_schedule().starting_time + + # ------------------------------ + # Functionality: + # convert the task to kernel + # ------------------------------ + def get_kernel_from_task(self, task): + for kernel in self.design.get_kernels()[:]: + if kernel.get_task() == task: + return kernel + raise Exception("kernel associated with task with name" + task.name + " is not found") + + # ------------------------------ + # Functionality: + # find the completion time of kernel that will be done the fastest + # ------------------------------ + def next_kernel_to_be_completed_time(self): + comp_time_list = [] # contains completion time of the running kernels + for kernel in self.scheduled_kernels: + comp_time_list.append(kernel.calc_kernel_completion_time()) + return min(comp_time_list) + self.clock_time + #else: + # return self.clock_time + """ + # ------------------------------ + # Functionality: + # all the dependencies of a kernel are done or no? + # ------------------------------ + def kernel_parents_all_done(self, kernel): + kernel_s_task = kernel.get_task() + parents_s_task = self.design.get_hardware_graph().get_task_graph().get_task_s_parents(kernel_s_task) + completed_tasks = [kernel.get_task() for kernel in self.completed_kernels_for_memory_sizing] + for task in parents_s_task: + if task not in completed_tasks: + return False + return True + """ + + def kernel_s_parents_done(self, krnl): + kernel_s_task = krnl.get_task() + parents_s_task = self.design.get_hardware_graph().get_task_graph().get_task_s_parents(kernel_s_task) + for parent in parents_s_task: + if not (parent, kernel_s_task) in self.task_token_queue: + return False + return True + + + # launch: Every iteration, we launch the kernel, i.e, + # we set the operating state appropriately, and size the hardware accordingly + def kernel_ready_to_be_launched(self, krnl): + if self.kernel_s_parents_done(krnl) and krnl not in self.scheduled_kernels and not self.krnl_done_iterating(krnl): + return True + return False + + def kernel_ready_to_fire(self, krnl): + if krnl.get_type() == "throughput_based" and krnl.throughput_time_trigger_achieved(self.clock_time): + return True + else: + return False + + def remove_parents_from_token_queue(self, krnl): + kernel_s_task = krnl.get_task() + parents_s_task = self.design.get_hardware_graph().get_task_graph().get_task_s_parents(kernel_s_task) + for parent in parents_s_task: + self.task_token_queue.remove((parent, kernel_s_task)) + + def krnl_done_iterating(self, krnl): + if krnl.iteration_ctr == -1 or krnl.iteration_ctr > 0: + return False + elif krnl.iteration_ctr == 0: + return True + + + # if multiple kernels running on the same PE's driver, + # only one can access it at a time. Thus, this function only keeps one on the PE. + def serialize_DMA(self): + # append waiting kernels and to be sch kernels to the already scheduled kernels + scheduled_kernels_tmp = [] + for el in self.scheduled_kernels: + scheduled_kernels_tmp.append(el) + for kernel in self.driver_waiting_queue: + scheduled_kernels_tmp.append(kernel) + + PE_blocks_used = [] + scheduled_kernels = [] + driver_waiting_queue = [] + for el in scheduled_kernels_tmp: + # only for read/write we serialize + if el.get_operating_state() in ["read", "write", "none"]: + pe = [blk for blk in el.get_blocks() if blk.type == "pe"][0] + if pe in PE_blocks_used: + driver_waiting_queue.append(el) + else: + scheduled_kernels.append(el) + PE_blocks_used.append(pe) + else: + scheduled_kernels.append(el) + return scheduled_kernels, driver_waiting_queue + + # ------------------------------ + # Functionality: + # Finds the kernels that are free to be scheduled (their parents are completed) + # ------------------------------ + def schedule_kernels_token_based(self): + for krnl in self.all_kernels: + if self.kernel_ready_to_be_launched(krnl): + # launch: Every iteration, we launch the kernel, i.e, + # we set the operating state appropriately, and size the hardware accordingly + self.kernel_s_parents_done(krnl) + self.remove_parents_from_token_queue(krnl) + self.scheduled_kernels.append(krnl) + if krnl in self.yet_to_schedule_kernels: + self.yet_to_schedule_kernels.remove(krnl) + + # initialize #insts, tick, and kernel progress status + krnl.launch(self.clock_time) + # update PE's that host the kernel + krnl.update_mem_size(1) + krnl.update_pe_size() + krnl.update_ic_size() + elif krnl.status == "in_progress" and not krnl.get_task().is_task_dummy() and self.kernel_ready_to_fire(krnl): + if krnl in self.scheduled_kernels: + print("a throughput based kernel was scheduled before it met its desired throughput. " + "This can cause issues in the models. Fix Later") + self.scheduled_kernels.append(krnl) + + # filter out kernels based on DMA serialization, i.e., only keep one kernel using the PE's driver. + if config.DMA_mode == "serialized_read_write": + self.scheduled_kernels, self.driver_waiting_queue = self.serialize_DMA() + + # ------------------------------ + # Functionality: + # Finds the kernels that are free to be scheduled (their parents are completed) + # ------------------------------ + def schedule_kernels(self): + if config.scheduling_policy == "FRFS": + kernels_to_schedule = [kernel_ for kernel_ in self.yet_to_schedule_kernels + if self.kernel_parents_all_done(kernel_)] + elif config.scheduling_policy == "time_based": + kernels_to_schedule = [kernel_ for kernel_ in self.yet_to_schedule_kernels + if self.clock_time >= kernel_.get_schedule().starting_time] + else: + raise Exception("scheduling policy not supported") + + for kernel in kernels_to_schedule: + self.scheduled_kernels.append(kernel) + self.yet_to_schedule_kernels.remove(kernel) + # initialize #insts, tick, and kernel progress status + kernel.launch(self.clock_time) + # update memory size -> allocate memory regions on different mem blocks + kernel.update_mem_size(1) + # update pe allocation -> allocate a part of pe quantum for current task + # (Hadi Note: allocation looks arbitrary and without any meaning though - just to know that something + # is allocated or it is floating) + kernel.update_pe_size() + # empty function! + kernel.update_ic_size() + + + def update_parallel_tasks(self): + # keep track of krnls that are present per phase + for krnl in self.scheduled_kernels: + if krnl.get_task_name() in ["souurce", "siink", "dummy_last"]: + continue + if krnl not in self.design.krnl_phase_present.keys(): + self.design.krnl_phase_present[krnl] = [] + self.design.krnl_phase_present_operating_state[krnl] = [] + self.design.krnl_phase_present[krnl].append(self.phase_num) + self.design.krnl_phase_present_operating_state[krnl].append((self.phase_num, krnl.operating_state)) + self.design.phase_krnl_present[self.phase_num] = self.scheduled_kernels[:] + + """ + scheduled_kernels = self.scheduled_kernels[:] + for idx, krnl in enumerate(scheduled_kernels): + if krnl.get_task_name() in ["souurce", "siink"]: + continue + if krnl not in self.design.parallel_kernels.keys(): + self.design.parallel_kernels[krnl] = [] + for idx, krnl_2 in enumerate(scheduled_kernels): + if krnl_2 == krnl: + continue + elif not krnl_2 in self.design.parallel_kernels[krnl]: + self.design.parallel_kernels[krnl].append(krnl_2) + + pass + """ + # ------------------------------ + # Functionality: + # update the status of each kernel, this means update + # how much work is left for each kernel (that is already schedulued) + # ------------------------------ + def update_scheduled_kernel_list(self): + scheduled_kernels = self.scheduled_kernels[:] + for kernel in scheduled_kernels: + if kernel.status == "completed": + self.scheduled_kernels.remove(kernel) + self.completed_kernels_for_memory_sizing.append(kernel) + kernel.set_stats() + for child_task in kernel.get_task().get_children(): + self.task_token_queue.append((kernel.get_task(), child_task)) + # iterate though parents and check if for each parent, all the children are completed. + # if so, retract the memory + all_parent_kernels = [self.get_kernel_from_task(parent_task) for parent_task in + kernel.get_task().get_parents()] + for parent_kernel in all_parent_kernels: + all_children_kernels = [self.get_kernel_from_task(child_task) for child_task in + parent_kernel.get_task().get_children()] + + if all([child_kernel in self.completed_kernels_for_memory_sizing for child_kernel in all_children_kernels]): + parent_kernel.update_mem_size(-1) + for child_kernel in all_children_kernels: + self.completed_kernels_for_memory_sizing.remove(child_kernel) + + elif kernel.type == "throughput_based" and kernel.throughput_work_achieved(): + #del kernel.data_work_left_to_meet_throughput[kernel.operating_state][0] + del kernel.firing_work_to_meet_throughput[kernel.operating_state][0] + self.scheduled_kernels.remove(kernel) + + + # ------------------------------ + # Functionality: + # iterate through all kernels and step them + # ------------------------------ + def step_kernels(self): + # by stepping the kernels, we calculate how much work each kernel has done and how much of their + # work is left for them + _ = [kernel_.step(self.time_step_size, self.phase_num) for kernel_ in self.scheduled_kernels] + + # update kernel's status, sets the progress + _ = [kernel_.update_status(self.time_step_size, self.clock_time) for kernel_ in self.scheduled_kernels] + + # ------------------------------ + # Functionality: + # update the status of the program, i.e., whether it's done or still in progress + # ------------------------------ + def update_program_status(self): + if len(self.scheduled_kernels) == 0 and len(self.yet_to_schedule_kernels) == 0: + self.program_status = "done" + elif len(self.scheduled_kernels) == 0: + self.program_status = "idle" # nothing scheduled yet + elif len(self.yet_to_schedule_kernels) == 0: + self.program_status = "all_kernels_scheduled" + else: + self.program_status = "in_progress" + + def next_throughput_trigger_time(self): + throughput_achieved_time_list = [] + for krnl in self.all_kernels: + if krnl.get_type() == "throughput_based" and krnl.status == "in_progress" \ + and not krnl.get_task().is_task_dummy() and not krnl.operating_state == "execute": + throughput_achieved_time_list.extend(krnl.firing_time_to_meet_throughput[krnl.operating_state]) + + throughput_achieved_time_list_filtered = [el for el in throughput_achieved_time_list if el> self.clock_time] + #time_sorted = sorted(throughput_achieved_time_list_filtered) + return throughput_achieved_time_list_filtered + + def any_throughput_based_kernel(self): + for krnl in self.all_kernels: + if krnl.get_type() == "throughput_based": + return True + return False + + # ------------------------------ + # Functionality: + # find the next tick time + # ------------------------------ + def calc_new_tick_position(self): + if config.scheduling_policy == "FRFS": + new_clock_list = [] + if len(self.scheduled_kernels) > 0: + new_clock_list.append(self.next_kernel_to_be_completed_time()) + + if self.any_throughput_based_kernel(): + trigger_time = self.next_throughput_trigger_time() + if len(trigger_time) > 0: + new_clock_list.append(min(trigger_time)) + + if len(new_clock_list) == 0: + return self.clock_time + else: + return min(new_clock_list) + #new_clock = max(new_clock, min(throughput_achieved_time)) + """ + elif self.program_status == "in_progress": + if self.program_status == "in_progress": + if config.scheudling_policy == "time_based": + new_clock = min(self.next_kernel_to_be_scheduled_time(), self.next_kernel_to_be_completed_time()) + elif self.program_status == "all_kernels_scheduled": + new_clock = self.next_kernel_to_be_completed_time() + elif self.program_status == "idle": + new_clock = self.next_kernel_to_be_scheduled_time() + if self.program_status == "done": + new_clock = self.clock_time + else: + raise Exception("scheduling policy:" + config.scheduling_policy + " is not supported") + """ + + #return new_clock + + # ------------------------------ + # Functionality: + # determine the various KPIs for each kernel. + # work-rate is how quickly each kernel can be done, which depends on it's bottleneck + # ------------------------------ + def update_kernels_kpi_for_next_tick(self, design): + # update each kernels's work-rate (bandwidth) + _ = [kernel.update_block_att_work_rate(self.scheduled_kernels) for kernel in self.scheduled_kernels] + # update each pipe cluster's paths (inpipe-outpipe) work-rate + + _ = [kernel.update_pipe_clusters_pathlet_work_rate() for kernel in self.scheduled_kernels] + # update each pipe cluster's paths (inpipe-outpipe) latency. Note that latency update must run after path's work-rate + + # update as it depends on it + # for fast simulation, ignore this + #_ = [kernel.update_path_latency() for kernel in self.scheduled_kernels] + #_ = [kernel.update_path_structural_latency(design) for kernel in self.scheduled_kernels] + + + # ------------------------------ + # Functionality: + # how much work does each block do for each phase + # ------------------------------ + def calc_design_work(self): + for SOC_type, SOC_id in self.design.get_designs_SOCs(): + blocks_seen = [] + for kernel in self.scheduled_kernels: + if kernel.SOC_type == SOC_type and kernel.SOC_id == SOC_id: + for block, work in kernel.block_phase_work_dict.items(): + if block not in blocks_seen : + blocks_seen.append(block) + #if block in self.block_phase_work_dict.keys(): + if self.phase_num in self.design.block_phase_work_dict[block].keys(): + self.design.block_phase_work_dict[block][self.phase_num] += work[self.phase_num] + else: + self.design.block_phase_work_dict[block][self.phase_num] = work[self.phase_num] + all_blocks = self.design.get_blocks() + for block in all_blocks: + if block in blocks_seen: + continue + self.design.block_phase_work_dict[block][self.phase_num] = 0 + + # ------------------------------ + # Functionality: + # calculate the utilization of each block in the design + # ------------------------------ + def calc_design_utilization(self): + for SOC_type, SOC_id in self.design.get_designs_SOCs(): + for block,phase_work in self.design.block_phase_work_dict.items(): + if self.design.phase_latency_dict[self.phase_num] == 0: + work_rate = 0 + else: + work_rate = (self.design.block_phase_work_dict[block][self.phase_num])/self.design.phase_latency_dict[self.phase_num] + self.design.block_phase_utilization_dict[block][self.phase_num] = work_rate/block.peak_work_rate + + # ------------------------------ + # Functionality: + # Aggregates the energy consumed for current phase over all the blocks + # ------------------------------ + def calc_design_energy(self): + for SOC_type, SOC_id in self.design.get_designs_SOCs(): + self.design.SOC_phase_energy_dict[(SOC_type, SOC_id)][self.phase_num] = \ + sum([kernel.stats.phase_energy_dict[self.phase_num] for kernel in self.scheduled_kernels + if kernel.SOC_type == SOC_type and kernel.SOC_id == SOC_id]) + if config.simulation_method == "power_knobs": + # Add up the leakage energy to the total energy consumption + # Please note the phase_leakage_energy_dict only counts for PE and IC energy (no mem included) + # since memory cannot be cut-off; otherwise will lose its contents + self.design.SOC_phase_energy_dict[(SOC_type, SOC_id)][self.phase_num] += \ + sum([kernel.stats.phase_leakage_energy_dict[self.phase_num] for kernel in self.scheduled_kernels + if kernel.SOC_type == SOC_type and kernel.SOC_id == SOC_id]) + + # Add the leakage power for memories + self.design.SOC_phase_energy_dict[(SOC_type, SOC_id)][self.phase_num] += \ + sum([block.get_leakage_power() * self.time_step_size for block in self.design.get_blocks() + if block.get_block_type_name() == "mem"]) + + # ------------------------------ + # Functionality: + # step the simulator forward, by moving all the kernels forward in time + # ------------------------------ + def step(self): + # the time step of the previous phase + self.time_step_size = self.clock_time - self.old_clock_time + # add the time step (time spent in the phase) to the design phase duration dictionary + self.design.phase_latency_dict[self.phase_num] = self.time_step_size + + # advance kernels + before_time = time.time() + self.step_kernels() + self.task_update_time += (time.time() - before_time) + + before_time = time.time() + # Aggregates the energy consumed for current phase over all the blocks + self.calc_design_energy() # needs be done after kernels have stepped, to aggregate their energy and divide + self.calc_design_work() # calculate how much work does each block do for this phase + self.calc_design_utilization() + self.phase_interval_calc_time += (time.time() - before_time) + + before_time = time.time() + self.update_scheduled_kernel_list() # if a kernel is done, schedule it out + self.phase_scheduling_time += (time.time() - before_time) + + #self.schedule_kernels() # schedule ready to be scheduled kernels + + self.schedule_kernels_token_based() + self.old_clock_time = self.clock_time # update clock + + # check if execution is completed or not! + self.update_program_status() + before_time = time.time() + self.update_kernels_kpi_for_next_tick(self.design) # update each kernels' work rate + self.phase_interval_calc_time += (time.time() - before_time) + + self.phase_num += 1 + self.update_parallel_tasks() + + # return the new tick position + before_time = time.time() + new_tick_position = self.calc_new_tick_position() + self.phase_interval_calc_time += (time.time() - before_time) + + return new_tick_position, self.program_status + + # ------------------------------ + # Functionality: + # call the simulator + # ------------------------------ + def simulate(self, clock_time): + self.tick(clock_time) + return self.step() \ No newline at end of file diff --git a/Project_FARSI/SIM_utils/components/pow_sim.py b/Project_FARSI/SIM_utils/components/pow_sim.py new file mode 100644 index 00000000..7aa9525b --- /dev/null +++ b/Project_FARSI/SIM_utils/components/pow_sim.py @@ -0,0 +1,21 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +class PowerSimulator(): + def __init__(self, design): + return None + + def power_model(self): + print("comming soon") + return 0 + + def step(self): + return 0 + + def tick(self, clock_time): + return 0 + + def simulate(self, clock_time): + self.tick(clock_time) + self.step() diff --git a/Project_FARSI/__init__.py b/Project_FARSI/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Project_FARSI/cacti_for_FARSI/2DDRAM_Samsung2GbDDR2.cfg b/Project_FARSI/cacti_for_FARSI/2DDRAM_Samsung2GbDDR2.cfg new file mode 100644 index 00000000..d035eae3 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/2DDRAM_Samsung2GbDDR2.cfg @@ -0,0 +1,194 @@ +# Cache size +//-size (bytes) 528 +//-size (bytes) 4096 +//-size (bytes) 262144 +//-size (bytes) 1048576 +//-size (bytes) 2097152 +//-size (bytes) 4194304 +//-size (bytes) 8388608 +//-size (bytes) 16777216 +//-size (bytes) 33554432 +//-size (bytes) 134217728 +//-size (bytes) 268435456 +//-size (bytes) 536870912 +//-size (bytes) 67108864 +//-size (bytes) 536870912 +//-size (bytes) 1073741824 +# For 3D DRAM memory please use Gb as units +-size (Gb) 2 + +# Line size +//-block size (bytes) 8 +-block size (bytes) 128 + +# To model Fully Associative cache, set associativity to zero +//-associativity 0 +//-associativity 2 +//-associativity 4 +-associativity 1 +//-associativity 16 + +-read-write port 1 +-exclusive read port 0 +-exclusive write port 0 +-single ended read ports 0 + +# Multiple banks connected using a bus +-UCA bank count 16 +//-technology (u) 0.032 +//-technology (u) 0.040 +//-technology (u) 0.065 +//-technology (u) 0.078 +-technology (u) 0.080 + +# following three parameters are meaningful only for main memories + +//-page size (bits) 8192 +-burst length 4 +-internal prefetch width 1 + +# following parameter can have one of five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +//-Data array cell type - "itrs-hp" +//-Data array cell type - "itrs-lstp" +//-Data array cell type - "itrs-lop" +-Data array cell type - "comm-dram" + +# following parameter can have one of three values -- (itrs-hp, itrs-lstp, itrs-lop) +//-Data array peripheral type - "itrs-hp" +-Data array peripheral type - "itrs-lstp" +//-Data array peripheral type - "itrs-lop" + +# following parameter can have one of five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +-Tag array cell type - "itrs-hp" +//-Tag array cell type - "itrs-lstp" + +# following parameter can have one of three values -- (itrs-hp, itrs-lstp, itrs-lop) +-Tag array peripheral type - "itrs-hp" +//-Tag array peripheral type - "itrs-lstp" + +# Bus width include data bits and address bits required by the decoder +//-output/input bus width 16 +//-output/input bus width 64 +-output/input bus width 64 + +// 300-400 in steps of 10 +-operating temperature (K) 350 + +# Type of memory - cache (with a tag array) or ram (scratch ram similar to a register file) +# or main memory (no tag array and every access will happen at a page granularity Ref: CACTI 5.3 report) +//-cache type "cache" +//-cache type "ram" +//-cache type "main memory" +-cache type "3D memory or 2D main memory" + +# Parameters for 3D DRAM +//-page size (bits) 16384 +-page size (bits) 8192 +//-page size (bits) 4096 +-burst depth 4 +-IO width 4 +-system frequency (MHz) 266 + +-stacked die count 1 +-partitioning granularity 0 // 0: coarse-grained rank-level; 1: fine-grained rank-level +//-TSV projection 1 // 0: ITRS aggressive; 1: industrial conservative + +## End of parameters for 3D DRAM + +# to model special structure like branch target buffers, directory, etc. +# change the tag size parameter +# if you want cacti to calculate the tagbits, set the tag size to "default" +-tag size (b) "default" +//-tag size (b) 45 + +# fast - data and tag access happen in parallel +# sequential - data array is accessed after accessing the tag array +# normal - data array lookup and tag access happen in parallel +# final data block is broadcasted in data array h-tree +# after getting the signal from the tag array +-access mode (normal, sequential, fast) - "fast" +//-access mode (normal, sequential, fast) - "normal" +//-access mode (normal, sequential, fast) - "sequential" + + +# DESIGN OBJECTIVE for UCA (or banks in NUCA) +-design objective (weight delay, dynamic power, leakage power, cycle time, area) 0:0:0:0:100 + +# Percentage deviation from the minimum value +# Ex: A deviation value of 10:1000:1000:1000:1000 will try to find an organization +# that compromises at most 10% delay. +# NOTE: Try reasonable values for % deviation. Inconsistent deviation +# percentage values will not produce any valid organizations. For example, +# 0:0:100:100:100 will try to identify an organization that has both +# least delay and dynamic power. Since such an organization is not possible, CACTI will +# throw an error. Refer CACTI-6 Technical report for more details +-deviate (delay, dynamic power, leakage power, cycle time, area) 50:100000:100000:100000:1000000 + +# Objective for NUCA +-NUCAdesign objective (weight delay, dynamic power, leakage power, cycle time, area) 0:0:0:0:100 +-NUCAdeviate (delay, dynamic power, leakage power, cycle time, area) 10:10000:10000:10000:10000 + +# Set optimize tag to ED or ED^2 to obtain a cache configuration optimized for +# energy-delay or energy-delay sq. product +# Note: Optimize tag will disable weight or deviate values mentioned above +# Set it to NONE to let weight and deviate values determine the +# appropriate cache configuration +//-Optimize ED or ED^2 (ED, ED^2, NONE): "ED" +//-Optimize ED or ED^2 (ED, ED^2, NONE): "ED^2" +-Optimize ED or ED^2 (ED, ED^2, NONE): "NONE" + +-Cache model (NUCA, UCA) - "UCA" +//-Cache model (NUCA, UCA) - "NUCA" + +# In order for CACTI to find the optimal NUCA bank value the following +# variable should be assigned 0. +-NUCA bank count 0 + +# NOTE: for nuca network frequency is set to a default value of +# 5GHz in time.c. CACTI automatically +# calculates the maximum possible frequency and downgrades this value if necessary + +# By default CACTI considers both full-swing and low-swing +# wires to find an optimal configuration. However, it is possible to +# restrict the search space by changing the signaling from "default" to +# "fullswing" or "lowswing" type. +-Wire signaling (fullswing, lowswing, default) - "Global_5" +//-Wire signaling (fullswing, lowswing, default) - "default" +//-Wire signaling (fullswing, lowswing, default) - "lowswing" + +//-Wire inside mat - "global" +-Wire inside mat - "semi-global" +-Wire outside mat - "global" +//-Wire outside mat - "semi-global" + +-Interconnect projection - "conservative" +//-Interconnect projection - "aggressive" + +# Contention in network (which is a function of core count and cache level) is one of +# the critical factor used for deciding the optimal bank count value +# core count can be 4, 8, or 16 +//-Core count 4 +-Core count 8 +//-Core count 16 +-Cache level (L2/L3) - "L3" + +-Add ECC - "false" + +//-Print level (DETAILED, CONCISE) - "CONCISE" +-Print level (DETAILED, CONCISE) - "DETAILED" + +# for debugging +//-Print input parameters - "true" +-Print input parameters - "false" +# force CACTI to model the cache with the +# following Ndbl, Ndwl, Nspd, Ndsam, +# and Ndcm values +-Force cache config - "true" +//-Force cache config - "false" +-Ndwl 128 +-Ndbl 32 +-Nspd 1 +-Ndcm 1 +-Ndsam1 1 +-Ndsam2 1 + diff --git a/Project_FARSI/cacti_for_FARSI/2DDRAM_micron1Gb.cfg b/Project_FARSI/cacti_for_FARSI/2DDRAM_micron1Gb.cfg new file mode 100644 index 00000000..4b94de4f --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/2DDRAM_micron1Gb.cfg @@ -0,0 +1,194 @@ +# Cache size +//-size (bytes) 528 +//-size (bytes) 4096 +//-size (bytes) 262144 +//-size (bytes) 1048576 +//-size (bytes) 2097152 +//-size (bytes) 4194304 +//-size (bytes) 8388608 +//-size (bytes) 16777216 +//-size (bytes) 33554432 +//-size (bytes) 134217728 +//-size (bytes) 268435456 +//-size (bytes) 536870912 +//-size (bytes) 67108864 +//-size (bytes) 536870912 +//-size (bytes) 1073741824 +# For 3D DRAM memory please use Gb as units +-size (Gb) 1 + +# Line size +//-block size (bytes) 8 +-block size (bytes) 128 + +# To model Fully Associative cache, set associativity to zero +//-associativity 0 +//-associativity 2 +//-associativity 4 +-associativity 1 +//-associativity 16 + +-read-write port 1 +-exclusive read port 0 +-exclusive write port 0 +-single ended read ports 0 + +# Multiple banks connected using a bus +-UCA bank count 8 +//-technology (u) 0.032 +//-technology (u) 0.040 +//-technology (u) 0.065 +-technology (u) 0.078 +//-technology (u) 0.080 + +# following three parameters are meaningful only for main memories + +//-page size (bits) 8192 +-burst length 4 +-internal prefetch width 1 + +# following parameter can have one of five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +//-Data array cell type - "itrs-hp" +//-Data array cell type - "itrs-lstp" +//-Data array cell type - "itrs-lop" +-Data array cell type - "comm-dram" + +# following parameter can have one of three values -- (itrs-hp, itrs-lstp, itrs-lop) +//-Data array peripheral type - "itrs-hp" +-Data array peripheral type - "itrs-lstp" +//-Data array peripheral type - "itrs-lop" + +# following parameter can have one of five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +-Tag array cell type - "itrs-hp" +//-Tag array cell type - "itrs-lstp" + +# following parameter can have one of three values -- (itrs-hp, itrs-lstp, itrs-lop) +-Tag array peripheral type - "itrs-hp" +//-Tag array peripheral type - "itrs-lstp" + +# Bus width include data bits and address bits required by the decoder +//-output/input bus width 16 +//-output/input bus width 64 +-output/input bus width 64 + +// 300-400 in steps of 10 +-operating temperature (K) 350 + +# Type of memory - cache (with a tag array) or ram (scratch ram similar to a register file) +# or main memory (no tag array and every access will happen at a page granularity Ref: CACTI 5.3 report) +//-cache type "cache" +//-cache type "ram" +//-cache type "main memory" +-cache type "3D memory or 2D main memory" + +## Parameters for 3D DRAM +-page size (bits) 16384 +//-page size (bits) 8192 +-burst depth 8 +-IO width 4 +-system frequency (MHz) 533 + +-stacked die count 1 +-partitioning granularity 0 // 0: coarse-grained rank-level; 1: fine-grained rank-level +//-TSV projection 1 // 0: ITRS aggressive; 1: industrial conservative + +## End of parameters for 3D DRAM + + +# to model special structure like branch target buffers, directory, etc. +# change the tag size parameter +# if you want cacti to calculate the tagbits, set the tag size to "default" +-tag size (b) "default" +//-tag size (b) 45 + +# fast - data and tag access happen in parallel +# sequential - data array is accessed after accessing the tag array +# normal - data array lookup and tag access happen in parallel +# final data block is broadcasted in data array h-tree +# after getting the signal from the tag array +-access mode (normal, sequential, fast) - "fast" +//-access mode (normal, sequential, fast) - "normal" +//-access mode (normal, sequential, fast) - "sequential" + + +# DESIGN OBJECTIVE for UCA (or banks in NUCA) +-design objective (weight delay, dynamic power, leakage power, cycle time, area) 0:0:0:0:10 + +# Percentage deviation from the minimum value +# Ex: A deviation value of 10:1000:1000:1000:1000 will try to find an organization +# that compromises at most 10% delay. +# NOTE: Try reasonable values for % deviation. Inconsistent deviation +# percentage values will not produce any valid organizations. For example, +# 0:0:100:100:100 will try to identify an organization that has both +# least delay and dynamic power. Since such an organization is not possible, CACTI will +# throw an error. Refer CACTI-6 Technical report for more details +-deviate (delay, dynamic power, leakage power, cycle time, area) 50:100000:100000:100000:1000000 + +# Objective for NUCA +-NUCAdesign objective (weight delay, dynamic power, leakage power, cycle time, area) 100:100:0:0:100 +-NUCAdeviate (delay, dynamic power, leakage power, cycle time, area) 10:10000:10000:10000:10000 + +# Set optimize tag to ED or ED^2 to obtain a cache configuration optimized for +# energy-delay or energy-delay sq. product +# Note: Optimize tag will disable weight or deviate values mentioned above +# Set it to NONE to let weight and deviate values determine the +# appropriate cache configuration +//-Optimize ED or ED^2 (ED, ED^2, NONE): "ED" +//-Optimize ED or ED^2 (ED, ED^2, NONE): "ED^2" +-Optimize ED or ED^2 (ED, ED^2, NONE): "NONE" + +-Cache model (NUCA, UCA) - "UCA" +//-Cache model (NUCA, UCA) - "NUCA" + +# In order for CACTI to find the optimal NUCA bank value the following +# variable should be assigned 0. +-NUCA bank count 0 + +# NOTE: for nuca network frequency is set to a default value of +# 5GHz in time.c. CACTI automatically +# calculates the maximum possible frequency and downgrades this value if necessary + +# By default CACTI considers both full-swing and low-swing +# wires to find an optimal configuration. However, it is possible to +# restrict the search space by changing the signaling from "default" to +# "fullswing" or "lowswing" type. +-Wire signaling (fullswing, lowswing, default) - "Global_30" +//-Wire signaling (fullswing, lowswing, default) - "default" +//-Wire signaling (fullswing, lowswing, default) - "lowswing" + +//-Wire inside mat - "global" +-Wire inside mat - "semi-global" +-Wire outside mat - "global" +//-Wire outside mat - "semi-global" + +-Interconnect projection - "conservative" +//-Interconnect projection - "aggressive" + +# Contention in network (which is a function of core count and cache level) is one of +# the critical factor used for deciding the optimal bank count value +# core count can be 4, 8, or 16 +//-Core count 4 +-Core count 8 +//-Core count 16 +-Cache level (L2/L3) - "L3" + +-Add ECC - "true" + +//-Print level (DETAILED, CONCISE) - "CONCISE" +-Print level (DETAILED, CONCISE) - "DETAILED" + +# for debugging +//-Print input parameters - "true" +-Print input parameters - "false" +# force CACTI to model the cache with the +# following Ndbl, Ndwl, Nspd, Ndsam, +# and Ndcm values +-Force cache config - "true" +//-Force cache config - "false" +-Ndwl 16 +-Ndbl 16 +-Nspd 1 +-Ndcm 1 +-Ndsam1 1 +-Ndsam2 1 + diff --git a/Project_FARSI/cacti_for_FARSI/3DDRAM_Samsung3D8Gb_extened.cfg b/Project_FARSI/cacti_for_FARSI/3DDRAM_Samsung3D8Gb_extened.cfg new file mode 100644 index 00000000..197bc21f --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/3DDRAM_Samsung3D8Gb_extened.cfg @@ -0,0 +1,197 @@ +# Cache size +//-size (bytes) 528 +//-size (bytes) 4096 +//-size (bytes) 262144 +//-size (bytes) 1048576 +//-size (bytes) 2097152 +//-size (bytes) 4194304 +//-size (bytes) 8388608 +//-size (bytes) 16777216 +//-size (bytes) 33554432 +//-size (bytes) 134217728 +//-size (bytes) 268435456 +//-size (bytes) 536870912 +//-size (bytes) 67108864 +//-size (bytes) 536870912 +//-size (bytes) 1073741824 +# For 3D DRAM memory please use Gb as units +-size (Gb) 8 + +# Line size +//-block size (bytes) 8 +-block size (bytes) 128 + +# To model Fully Associative cache, set associativity to zero +//-associativity 0 +//-associativity 2 +//-associativity 4 +-associativity 1 +//-associativity 16 + +-read-write port 1 +-exclusive read port 0 +-exclusive write port 0 +-single ended read ports 0 + +# Multiple banks connected using a bus +-UCA bank count 8 +//-technology (u) 0.032 +//-technology (u) 0.040 +//-technology (u) 0.065 +//-technology (u) 0.078 +//-technology (u) 0.080 +//-technology (u) 0.090 +-technology (u) 0.050 + +# following three parameters are meaningful only for main memories + +//-page size (bits) 8192 +-burst length 4 +-internal prefetch width 1 + +# following parameter can have one of five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +//-Data array cell type - "itrs-hp" +//-Data array cell type - "itrs-lstp" +//-Data array cell type - "itrs-lop" +-Data array cell type - "comm-dram" + +# following parameter can have one of three values -- (itrs-hp, itrs-lstp, itrs-lop) +//-Data array peripheral type - "itrs-hp" +-Data array peripheral type - "itrs-lstp" +//-Data array peripheral type - "itrs-lop" + +# following parameter can have one of five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +-Tag array cell type - "itrs-hp" +//-Tag array cell type - "itrs-lstp" + +# following parameter can have one of three values -- (itrs-hp, itrs-lstp, itrs-lop) +-Tag array peripheral type - "itrs-hp" +//-Tag array peripheral type - "itrs-lstp" + +# Bus width include data bits and address bits required by the decoder +//-output/input bus width 16 +//-output/input bus width 64 +-output/input bus width 64 + +// 300-400 in steps of 10 +-operating temperature (K) 350 + +# Type of memory - cache (with a tag array) or ram (scratch ram similar to a register file) +# or main memory (no tag array and every access will happen at a page granularity Ref: CACTI 5.3 report) +//-cache type "cache" +//-cache type "ram" +//-cache type "main memory" # old main memory model, in fact, it is eDRAM model. +-cache type "3D memory or 2D main memory" # once this parameter is used, the new parameter section below of will override the same parameter above + +# +//-page size (bits) 16384 +-page size (bits) 8192 +//-page size (bits) 4096 +-burst depth 8 // for 3D DRAM, IO per bank equals the product of burst depth and IO width +-IO width 4 +-system frequency (MHz) 677 + +-stacked die count 4 +-partitioning granularity 0 // 0: coarse-grained rank-level; 1: fine-grained rank-level +-TSV projection 1 // 0: ITRS aggressive; 1: industrial conservative + +## End of parameters for 3D DRAM + +# to model special structure like branch target buffers, directory, etc. +# change the tag size parameter +# if you want cacti to calculate the tagbits, set the tag size to "default" +-tag size (b) "default" +//-tag size (b) 45 + +# fast - data and tag access happen in parallel +# sequential - data array is accessed after accessing the tag array +# normal - data array lookup and tag access happen in parallel +# final data block is broadcasted in data array h-tree +# after getting the signal from the tag array +-access mode (normal, sequential, fast) - "fast" +//-access mode (normal, sequential, fast) - "normal" +//-access mode (normal, sequential, fast) - "sequential" + + +# DESIGN OBJECTIVE for UCA (or banks in NUCA) +-design objective (weight delay, dynamic power, leakage power, cycle time, area) 0:0:0:0:100 + +# Percentage deviation from the minimum value +# Ex: A deviation value of 10:1000:1000:1000:1000 will try to find an organization +# that compromises at most 10% delay. +# NOTE: Try reasonable values for % deviation. Inconsistent deviation +# percentage values will not produce any valid organizations. For example, +# 0:0:100:100:100 will try to identify an organization that has both +# least delay and dynamic power. Since such an organization is not possible, CACTI will +# throw an error. Refer CACTI-6 Technical report for more details +-deviate (delay, dynamic power, leakage power, cycle time, area) 50:100000:100000:100000:1000000 + +# Objective for NUCA +-NUCAdesign objective (weight delay, dynamic power, leakage power, cycle time, area) 0:0:0:0:100 +-NUCAdeviate (delay, dynamic power, leakage power, cycle time, area) 10:10000:10000:10000:10000 + +# Set optimize tag to ED or ED^2 to obtain a cache configuration optimized for +# energy-delay or energy-delay sq. product +# Note: Optimize tag will disable weight or deviate values mentioned above +# Set it to NONE to let weight and deviate values determine the +# appropriate cache configuration +//-Optimize ED or ED^2 (ED, ED^2, NONE): "ED" +//-Optimize ED or ED^2 (ED, ED^2, NONE): "ED^2" +-Optimize ED or ED^2 (ED, ED^2, NONE): "NONE" + +-Cache model (NUCA, UCA) - "UCA" +//-Cache model (NUCA, UCA) - "NUCA" + +# In order for CACTI to find the optimal NUCA bank value the following +# variable should be assigned 0. +-NUCA bank count 0 + +# NOTE: for nuca network frequency is set to a default value of +# 5GHz in time.c. CACTI automatically +# calculates the maximum possible frequency and downgrades this value if necessary + +# By default CACTI considers both full-swing and low-swing +# wires to find an optimal configuration. However, it is possible to +# restrict the search space by changing the signaling from "default" to +# "fullswing" or "lowswing" type. +-Wire signaling (fullswing, lowswing, default) - "Global_30" +//-Wire signaling (fullswing, lowswing, default) - "default" +//-Wire signaling (fullswing, lowswing, default) - "lowswing" + +//-Wire inside mat - "global" +-Wire inside mat - "semi-global" +-Wire outside mat - "global" +//-Wire outside mat - "semi-global" + +-Interconnect projection - "conservative" +//-Interconnect projection - "aggressive" + +# Contention in network (which is a function of core count and cache level) is one of +# the critical factor used for deciding the optimal bank count value +# core count can be 4, 8, or 16 +//-Core count 4 +-Core count 8 +//-Core count 16 +-Cache level (L2/L3) - "L3" + +-Add ECC - "true" + +//-Print level (DETAILED, CONCISE) - "CONCISE" +-Print level (DETAILED, CONCISE) - "DETAILED" + + +# for debugging +//-Print input parameters - "true" +-Print input parameters - "false" +# force CACTI to model the cache with the +# following Ndbl, Ndwl, Nspd, Ndsam, +# and Ndcm values +-Force cache config - "true" +//-Force cache config - "false" +-Ndwl 16 +-Ndbl 32 +-Nspd 1 +-Ndcm 1 +-Ndsam1 1 +-Ndsam2 1 + diff --git a/Project_FARSI/cacti_for_FARSI/README b/Project_FARSI/cacti_for_FARSI/README new file mode 100644 index 00000000..0dc88f52 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/README @@ -0,0 +1,122 @@ +----------------------------------------------------------- + + + ____ __ ________ __ + /\ _`\ /\ \__ __ /\_____ \ /'__`\ + \ \ \/\_\ __ ___\ \ ,_\/\_\ \/___//'/'/\ \/\ \ + \ \ \/_/_ /'__`\ /'___\ \ \/\/\ \ /' /' \ \ \ \ \ + \ \ \L\ \/\ \L\.\_/\ \__/\ \ \_\ \ \ /' /'__ \ \ \_\ \ + \ \____/\ \__/.\_\ \____\\ \__\\ \_\ /\_/ /\_\ \ \____/ + \/___/ \/__/\/_/\/____/ \/__/ \/_/ \// \/_/ \/___/ + + +A Tool to Model Caches/Memories, 3D stacking, and off-chip IO +----------------------------------------------------------- + +CACTI is an analytical tool that takes a set of cache/memory para- +meters as input and calculates its access time, power, cycle +time, and area. +CACTI was originally developed by Dr. Jouppi and Dr. Wilton +in 1993 and since then it has undergone six major +revisions. + +List of features (version 1-7): +=============================== +The following is the list of features supported by the tool. + +* Power, delay, area, and cycle time model for + direct mapped caches + set-associative caches + fully associative caches + Embedded DRAM memories + Commodity DRAM memories + +* Support for modeling multi-ported uniform cache access (UCA) + and multi-banked, multi-ported non-uniform cache access (NUCA). + +* Leakage power calculation that also considers the operating + temperature of the cache. + +* Router power model. + +* Interconnect model with different delay, power, and area + properties including low-swing wire model. + +* An interface to perform trade-off analysis involving power, delay, + area, and bandwidth. + +* All process specific values used by the tool are obtained + from ITRS and currently, the tool supports 90nm, 65nm, 45nm, + and 32nm technology nodes. + +* Chip IO model to calculate latency and energy for DDR bus. Users can model + different loads (fan-outs) and evaluate the impact on frequency and energy. + This model can be used to study LR-DIMMs, R-DIMMs, etc. + +Version 7.0 is derived from 6.5 and merged with CACTI 3D. +It has many new additions apart from code refinements and +bug fixes: new IO model, 3D memory model, and power gating models. +Ref: CACTI-IO: CACTI With OFF-chip Power-Area-Timing Models + MemCAD: An Interconnect Exploratory Tool for Innovative Memories Beyond DDR4 + CACTI-3DD: Architecture-level modeling for 3D die-stacked DRAM main memory + +-------------------------------------------------------------------------- +Version 6.5 has a new c++ code base and includes numerous bug fixes. +CACTI 5.3 and 6.0 activate an entire row of mats to read/write a single +block of data. This technique improves reliability at the cost of +power. CACTI 6.5 activates minimum number of mats just enough to retrieve +a block to minimize power. + +How to use the tool? +==================== +Prior versions of CACTI take input parameters such as cache +size and technology node as a set of command line arguments. +To avoid a long list of command line arguments, +CACTI 6.5 & & let users specify their cache model in a more +detailed manner by using a config file (cache.cfg). + +-> define the cache model using cache.cfg +-> run the "cacti" binary <./cacti -infile cache.cfg> + +CACTI also provides a command line interface similar to earlier versions. The command line interface can be used as + +./cacti cache_size line_size associativity rw_ports excl_read_ports excl_write_ports + single_ended_read_ports search_ports banks tech_node output_width specific_tag tag_width + access_mode cache main_mem obj_func_delay obj_func_dynamic_power obj_func_leakage_power + obj_func_cycle_time obj_func_area dev_func_delay dev_func_dynamic_power dev_func_leakage_power + dev_func_area dev_func_cycle_time ed_ed2_none temp wt data_arr_ram_cell_tech_flavor_in + data_arr_peri_global_tech_flavor_in tag_arr_ram_cell_tech_flavor_in tag_arr_peri_global_tech_flavor_in + interconnect_projection_type_in wire_inside_mat_type_in wire_outside_mat_type_in + REPEATERS_IN_HTREE_SEGMENTS_in VERTICAL_HTREE_WIRES_OVER_THE_ARRAY_in + BROADCAST_ADDR_DATAIN_OVER_VERTICAL_HTREES_in PAGE_SIZE_BITS_in BURST_LENGTH_in + INTERNAL_PREFETCH_WIDTH_in force_wiretype wiretype force_config ndwl ndbl nspd ndcm + ndsam1 ndsam2 ecc + +For complete documentation of the tool, please refer +to the following publications and reports. + +CACTI-5.3 & 6 reports - Details on Meory/cache organizations and tradeoffs. + +Latency/Energy tradeoffs for large caches and NUCA design: + "Optimizing NUCA Organizations and Wiring Alternatives for Large Caches With CACTI 6.0", that appears in MICRO 2007. + +Memory IO design: CACTI-IO: CACTI With OFF-chip Power-Area-Timing Models, + MemCAD: An Interconnect Exploratory Tool for Innovative Memories Beyond DDR4 + CACTI-IO Technical Report - http://www.hpl.hp.com/techreports/2013/HPL-2013-79.pdf + +3D model: + CACTI-3DD: Architecture-level modeling for 3D die-stacked DRAM main memory + +We are still improving the tool and refining the code. If you +have any comments, questions, or suggestions please write to +us. + +Naveen Muralimanohar +naveen.muralimanohar@hpe.com + +Ali Shafiee +shafiee@cs.utah.edu + +Vaishnav Srinivas +vaishnav.srinivas@gmail.com + diff --git a/Project_FARSI/cacti_for_FARSI/TSV.cc b/Project_FARSI/cacti_for_FARSI/TSV.cc new file mode 100644 index 00000000..2821d4bd --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/TSV.cc @@ -0,0 +1,242 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + +#include "TSV.h" + +TSV::TSV(enum TSV_type tsv_type, + /*TechnologyParameter::*/DeviceType *dt)://TSV driver's device type set to be peri_global + deviceType(dt), tsv_type(tsv_type) +{ + num_gates = 1; + num_gates_min = 1;//Is there a minimum number of stages? + min_w_pmos = deviceType -> n_to_p_eff_curr_drv_ratio * g_tp.min_w_nmos_; + + switch (tsv_type) + { + case Fine: + cap = g_tp.tsv_parasitic_capacitance_fine; + res = g_tp.tsv_parasitic_resistance_fine; + min_area = g_tp.tsv_minimum_area_fine; + break; + case Coarse: + cap = g_tp.tsv_parasitic_capacitance_coarse; + res = g_tp.tsv_parasitic_resistance_coarse; + min_area = g_tp.tsv_minimum_area_coarse; + break; + default: + break; + } + + for (int i = 0; i < MAX_NUMBER_GATES_STAGE; i++) + { + w_TSV_n[i] = 0; + w_TSV_p[i] = 0; + } + + double first_buf_stg_coef = 5; // To tune the total buffer delay. + w_TSV_n[0] = g_tp.min_w_nmos_*first_buf_stg_coef; + w_TSV_p[0] = min_w_pmos *first_buf_stg_coef; + + is_dram = 0; + is_wl_tr = 0; + + //What does the function assert() mean? Should I put the function here? + compute_buffer_stage(); + compute_area(); + compute_delay(); +} + +TSV::~TSV() +{ +} + +void TSV::compute_buffer_stage() +{ + double p_to_n_sz_ratio = deviceType->n_to_p_eff_curr_drv_ratio; + + //BEOL parasitics in Katti's E modeling and charac. of TSV. Needs further detailed values. + //double res_beol = 0.1;//inaccurate + //double cap_beol = 1e-15; + + //C_load_TSV = cap_beol + cap + cap_beol + gate_C(g_tp.min_w_nmos_ + min_w_pmos, 0); + C_load_TSV = cap + gate_C(g_tp.min_w_nmos_ + min_w_pmos, 0); //+ 57.5e-15; + if(g_ip->print_detail_debug) + { + cout << " The input cap of 1st buffer: " << gate_C(w_TSV_n[0] + w_TSV_p[0], 0) * 1e15 << " fF"; + } + double F = C_load_TSV / gate_C(w_TSV_n[0] + w_TSV_p[0], 0); + if(g_ip->print_detail_debug) + { + cout<<"\nF is "<Vdd; + double cumulative_area = 0; + double cumulative_curr = 0; // cumulative leakage current + double cumulative_curr_Ig = 0; // cumulative leakage current + Buffer_area.h = g_tp.cell_h_def;//cell_h_def is the assigned height for memory cell (5um), is it correct to use it here? + + //logic_effort() didn't give the size of w_n[0] and w_p[0], which is min size inverter + //w_TSV_n[0] = g_tp.min_w_nmos_; + //w_TSV_p[0] = min_w_pmos; + + int i; + for (i = 0; i < num_gates; i++) + { + cumulative_area += compute_gate_area(INV, 1, w_TSV_p[i], w_TSV_n[i], Buffer_area.h); + if(g_ip->print_detail_debug) + { + cout << "\n\tArea up to the " << i+1 << " stages is: " << cumulative_area << " um2"; + } + cumulative_curr += cmos_Isub_leakage(w_TSV_n[i], w_TSV_p[i], 1, inv, is_dram); + cumulative_curr_Ig += cmos_Ig_leakage(w_TSV_n[i], w_TSV_p[i], 1, inv, is_dram);// The operator += is mistakenly put as = in decoder.cc + } + power.readOp.leakage = cumulative_curr * Vdd; + power.readOp.gate_leakage = cumulative_curr_Ig * Vdd; + + Buffer_area.set_area(cumulative_area); + Buffer_area.w = (cumulative_area / Buffer_area.h); + + TSV_metal_area.set_area(min_area * 3.1416/16); + + if( Buffer_area.get_area() < min_area - TSV_metal_area.get_area() ) + area.set_area(min_area); + else + area.set_area(Buffer_area.get_area() + TSV_metal_area.get_area()); + +} + +void TSV::compute_delay() +{ + //Buffer chain delay and Dynamic Power + double rd, tf, this_delay, c_load, c_intrinsic, inrisetime = 0/*The initial time*/; + //is_dram, is_wl_tr are declared to be false in the constructor + rd = tr_R_on(w_TSV_n[0], NCH, 1, is_dram, false, is_wl_tr); + c_load = gate_C(w_TSV_n[1] + w_TSV_p[1], 0.0, is_dram, false, is_wl_tr); + c_intrinsic = drain_C_(w_TSV_p[0], PCH, 1, 1, area.h, is_dram, false, is_wl_tr) + + drain_C_(w_TSV_n[0], NCH, 1, 1, area.h, is_dram, false, is_wl_tr); + tf = rd * (c_intrinsic + c_load); + //Refer to horowitz function definition + this_delay = horowitz(inrisetime, tf, 0.5, 0.5, RISE); + delay += this_delay; + inrisetime = this_delay / (1.0 - 0.5); + + double Vdd = deviceType -> Vdd; + power.readOp.dynamic += (c_load + c_intrinsic) * Vdd * Vdd; + + int i; + for (i = 1; i < num_gates - 1; ++i) + { + rd = tr_R_on(w_TSV_n[i], NCH, 1, is_dram, false, is_wl_tr); + c_load = gate_C(w_TSV_p[i+1] + w_TSV_n[i+1], 0.0, is_dram, false, is_wl_tr); + c_intrinsic = drain_C_(w_TSV_p[i], PCH, 1, 1, area.h, is_dram, false, is_wl_tr) + + drain_C_(w_TSV_n[i], NCH, 1, 1, area.h, is_dram, false, is_wl_tr); + tf = rd * (c_intrinsic + c_load); + this_delay = horowitz(inrisetime, tf, 0.5, 0.5, RISE); + delay += this_delay; + inrisetime = this_delay / (1.0 - 0.5); + power.readOp.dynamic += (c_load + c_intrinsic) * Vdd * Vdd; + } + + // add delay of final inverter that drives the TSV + i = num_gates - 1; + c_load = C_load_TSV; + rd = tr_R_on(w_TSV_n[i], NCH, 1, is_dram, false, is_wl_tr); + c_intrinsic = drain_C_(w_TSV_p[i], PCH, 1, 1, area.h, is_dram, false, is_wl_tr) + + drain_C_(w_TSV_n[i], NCH, 1, 1, area.h, is_dram, false, is_wl_tr); + //The delay method for the last stage of buffer chain in Decoder.cc + + //double res_beol = 0.1;//inaccurate + //double R_TSV_out = res_beol + res + res_beol; + double R_TSV_out = res; + tf = rd * (c_intrinsic + c_load) + R_TSV_out * c_load / 2; + this_delay = horowitz(inrisetime, tf, 0.5, 0.5, RISE); + delay += this_delay; + + power.readOp.dynamic += (c_load + c_intrinsic) * Vdd * Vdd; //Dynamic power done + + //Is the delay actually delay/(1.0-0.5)?? + //ret_val = this_delay / (1.0 - 0.5); + //return ret_val;//Originally for decoder.cc to get outrise time + + + /* This part is to obtain delay in the TSV path, refer to Katti's paper. + * It can be used alternatively as the step to get the final-stage delay + double C_ext = c_intrinsic; + R_dr = rd; + double C_int = gate_C(g_tp.min_w_nmos_ + min_w_pmos, 0.0, is_dram, false, is_wl_tr); + delay_TSV_path = 0.693 * (R_dr * C_ext + (R_dr + res_beol) * cap_beol + (R_dr + res_beol + 0.5 * res) * cap + + (R_dr + res_beol + res + res_beol) * (cap_beol + C_int); + delay += delay_TSV_path; + */ +} + +void TSV::print_TSV() +{ + + cout << "\nTSV Properties:\n\n"; + cout << " Delay Optimal - "<< + " \n\tTSV Cap: " << cap * 1e15 << " fF" << + " \n\tTSV Res: " << res * 1e3 << " mOhm"<< + " \n\tNumber of Buffer Chain stages - " << num_gates << + " \n\tDelay - " << delay * 1e9 << " (ns) " + " \n\tPowerD - " << power.readOp.dynamic * 1e9<< " (nJ)" + " \n\tPowerL - " << power.readOp.leakage * 1e3<< " (mW)" + " \n\tPowerLgate - " << power.readOp.gate_leakage * 1e3<< " (mW)\n" << + " \n\tBuffer Area: " << Buffer_area.get_area() << " um2" << + " \n\tBuffer Height: " << Buffer_area.h << " um" << + " \n\tBuffer Width: " << Buffer_area.w << " um" << + " \n\tTSV metal area: " << TSV_metal_area.get_area() << " um2" << + " \n\tTSV minimum occupied area: " < +#include +#include + + +class TSV : public Component +{ + public: + TSV(enum TSV_type tsv_type, + /*TechnologyParameter::*/DeviceType * dt = &(g_tp.peri_global));//Should change peri_global to TSV in technology.cc + //TSV():len(20),rad(2.5),pitch(50){} + ~TSV(); + + double res;//TSV resistance + double cap;//TSV capacitance + double C_load_TSV;//The intrinsic load plus the load TSV is driving, needs changes? + double min_area; + + //int num_IO;//number of I/O + int num_gates; + int num_gates_min;//Necessary? + double w_TSV_n[MAX_NUMBER_GATES_STAGE]; + double w_TSV_p[MAX_NUMBER_GATES_STAGE]; + + //double delay_TSV_path;//Delay of TSV path including the parasitics + + double is_dram;//two external arguments, defaulted to be false in constructor + double is_wl_tr; + + void compute_buffer_stage(); + void compute_area(); + void compute_delay(); + void print_TSV(); + + Area TSV_metal_area; + Area Buffer_area; + + /*//Herigated from Component + double delay; + Area area; + powerDef power, rt_power; + double delay; + double cycle_time; + + int logical_effort();*/ + + private: + double min_w_pmos; + /*TechnologyParameter::*/DeviceType * deviceType; + unsigned int tsv_type; + +}; + + +#endif /* TSV_H_ */ diff --git a/Project_FARSI/cacti_for_FARSI/Ucache.cc b/Project_FARSI/cacti_for_FARSI/Ucache.cc new file mode 100644 index 00000000..7df02079 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/Ucache.cc @@ -0,0 +1,1073 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + +#include +#include + + +#include "area.h" +#include "bank.h" +#include "basic_circuit.h" +#include "component.h" +#include "const.h" +#include "decoder.h" +#include "parameter.h" +#include "Ucache.h" +#include "subarray.h" +#include "uca.h" + +#include +#include +#include +#include + +using namespace std; + +const uint32_t nthreads = NTHREADS; + + +void min_values_t::update_min_values(const min_values_t * val) +{ + min_delay = (min_delay > val->min_delay) ? val->min_delay : min_delay; + min_dyn = (min_dyn > val->min_dyn) ? val->min_dyn : min_dyn; + min_leakage = (min_leakage > val->min_leakage) ? val->min_leakage : min_leakage; + min_area = (min_area > val->min_area) ? val->min_area : min_area; + min_cyc = (min_cyc > val->min_cyc) ? val->min_cyc : min_cyc; +} + + + +void min_values_t::update_min_values(const uca_org_t & res) +{ + min_delay = (min_delay > res.access_time) ? res.access_time : min_delay; + min_dyn = (min_dyn > res.power.readOp.dynamic) ? res.power.readOp.dynamic : min_dyn; + min_leakage = (min_leakage > res.power.readOp.leakage) ? res.power.readOp.leakage : min_leakage; + min_area = (min_area > res.area) ? res.area : min_area; + min_cyc = (min_cyc > res.cycle_time) ? res.cycle_time : min_cyc; +} + +void min_values_t::update_min_values(const nuca_org_t * res) +{ + min_delay = (min_delay > res->nuca_pda.delay) ? res->nuca_pda.delay : min_delay; + min_dyn = (min_dyn > res->nuca_pda.power.readOp.dynamic) ? res->nuca_pda.power.readOp.dynamic : min_dyn; + min_leakage = (min_leakage > res->nuca_pda.power.readOp.leakage) ? res->nuca_pda.power.readOp.leakage : min_leakage; + min_area = (min_area > res->nuca_pda.area.get_area()) ? res->nuca_pda.area.get_area() : min_area; + min_cyc = (min_cyc > res->nuca_pda.cycle_time) ? res->nuca_pda.cycle_time : min_cyc; +} + +void min_values_t::update_min_values(const mem_array * res) +{ + min_delay = (min_delay > res->access_time) ? res->access_time : min_delay; + min_dyn = (min_dyn > res->power.readOp.dynamic) ? res->power.readOp.dynamic : min_dyn; + min_leakage = (min_leakage > res->power.readOp.leakage) ? res->power.readOp.leakage : min_leakage; + min_area = (min_area > res->area) ? res->area : min_area; + min_cyc = (min_cyc > res->cycle_time) ? res->cycle_time : min_cyc; +} + + + +void * calc_time_mt_wrapper(void * void_obj) +{ + calc_time_mt_wrapper_struct * calc_obj = (calc_time_mt_wrapper_struct *) void_obj; + uint32_t tid = calc_obj->tid; + list & data_arr = calc_obj->data_arr; + list & tag_arr = calc_obj->tag_arr; + bool is_tag = calc_obj->is_tag; + bool pure_ram = calc_obj->pure_ram; + bool pure_cam = calc_obj->pure_cam; + bool is_main_mem = calc_obj->is_main_mem; + double Nspd_min = calc_obj->Nspd_min; + min_values_t * data_res = calc_obj->data_res; + min_values_t * tag_res = calc_obj->tag_res; + + data_arr.clear(); + data_arr.push_back(new mem_array); + tag_arr.clear(); + tag_arr.push_back(new mem_array); + + uint32_t Ndwl_niter = _log2(MAXDATAN) + 1; + uint32_t Ndbl_niter = _log2(MAXDATAN) + 1; + uint32_t Ndcm_niter = _log2(MAX_COL_MUX) + 1; + uint32_t niter = Ndwl_niter * Ndbl_niter * Ndcm_niter; + + + bool is_valid_partition; + int wt_min, wt_max; + + if (g_ip->force_wiretype) { + if (g_ip->wt == Full_swing) { + wt_min = Global; + wt_max = Low_swing-1; + } + else { + switch(g_ip->wt) { + case Global: + wt_min = wt_max = Global; + break; + case Global_5: + wt_min = wt_max = Global_5; + break; + case Global_10: + wt_min = wt_max = Global_10; + break; + case Global_20: + wt_min = wt_max = Global_20; + break; + case Global_30: + wt_min = wt_max = Global_30; + break; + case Low_swing: + wt_min = wt_max = Low_swing; + break; + default: + cerr << "Unknown wire type!\n"; + exit(0); + } + } + } + else { + wt_min = Global; + wt_max = Low_swing; + } + + for (double Nspd = Nspd_min; Nspd <= MAXDATASPD; Nspd *= 2) + { + for (int wr = wt_min; wr <= wt_max; wr++) + { + for (uint32_t iter = tid; iter < niter; iter += nthreads) + { + // reconstruct Ndwl, Ndbl, Ndcm + unsigned int Ndwl = 1 << (iter / (Ndbl_niter * Ndcm_niter)); + unsigned int Ndbl = 1 << ((iter / (Ndcm_niter))%Ndbl_niter); + unsigned int Ndcm = 1 << (iter % Ndcm_niter); + for(unsigned int Ndsam_lev_1 = 1; Ndsam_lev_1 <= MAX_COL_MUX; Ndsam_lev_1 *= 2) + { + for(unsigned int Ndsam_lev_2 = 1; Ndsam_lev_2 <= MAX_COL_MUX; Ndsam_lev_2 *= 2) + { + //for debuging + if (g_ip->force_cache_config && is_tag == false) + { + wr = g_ip->wt; + Ndwl = g_ip->ndwl; + Ndbl = g_ip->ndbl; + Ndcm = g_ip->ndcm; + if(g_ip->nspd != 0) { + Nspd = g_ip->nspd; + } + if(g_ip->ndsam1 != 0) { + Ndsam_lev_1 = g_ip->ndsam1; + Ndsam_lev_2 = g_ip->ndsam2; + } + } + + if (is_tag == true) + { + is_valid_partition = calculate_time(is_tag, pure_ram, pure_cam, Nspd, Ndwl, + Ndbl, Ndcm, Ndsam_lev_1, Ndsam_lev_2, + tag_arr.back(), 0, NULL, NULL,(Wire_type) wr, + is_main_mem); + } + // If it's a fully-associative cache, the data array partition parameters are identical to that of + // the tag array, so compute data array partition properties also here. + if (is_tag == false || g_ip->fully_assoc) + { + is_valid_partition = calculate_time(is_tag/*false*/, pure_ram, pure_cam, Nspd, Ndwl, + Ndbl, Ndcm, Ndsam_lev_1, Ndsam_lev_2, + data_arr.back(), 0, NULL, NULL,(Wire_type) wr, + is_main_mem); + if (g_ip->is_3d_mem) + { + Ndsam_lev_1 = MAX_COL_MUX+1; + Ndsam_lev_2 = MAX_COL_MUX+1; + } + } + + if (is_valid_partition) + { + if (is_tag == true) + { + tag_arr.back()->wt = (enum Wire_type) wr; + tag_res->update_min_values(tag_arr.back()); + tag_arr.push_back(new mem_array); + } + if (is_tag == false || g_ip->fully_assoc) + { + data_arr.back()->wt = (enum Wire_type) wr; + data_res->update_min_values(data_arr.back()); + data_arr.push_back(new mem_array); + } + } + + if (g_ip->force_cache_config && is_tag == false) + { + wr = wt_max; + iter = niter; + if(g_ip->nspd != 0) { + Nspd = MAXDATASPD; + } + if (g_ip->ndsam1 != 0) { + Ndsam_lev_1 = MAX_COL_MUX+1; + Ndsam_lev_2 = MAX_COL_MUX+1; + } + } + } + } + } + } + } + + delete data_arr.back(); + delete tag_arr.back(); + data_arr.pop_back(); + tag_arr.pop_back(); + + pthread_exit(NULL); +} + + + +bool calculate_time( + bool is_tag, + int pure_ram, + bool pure_cam, + double Nspd, + unsigned int Ndwl, + unsigned int Ndbl, + unsigned int Ndcm, + unsigned int Ndsam_lev_1, + unsigned int Ndsam_lev_2, + mem_array *ptr_array, + int flag_results_populate, + results_mem_array *ptr_results, + uca_org_t *ptr_fin_res, + Wire_type wt, // merge from cacti-7 to cacti3d + bool is_main_mem) +{ + DynamicParameter dyn_p(is_tag, pure_ram, pure_cam, Nspd, Ndwl, Ndbl, Ndcm, Ndsam_lev_1, Ndsam_lev_2, wt, is_main_mem); + + if (dyn_p.is_valid != true) + { + return false; + } + + UCA * uca = new UCA(dyn_p); + + + if (flag_results_populate) + { //For the final solution, populate the ptr_results data structure -- TODO: copy only necessary variables + } + else + { + int num_act_mats_hor_dir = uca->bank.dp.num_act_mats_hor_dir; + int num_mats = uca->bank.dp.num_mats; + bool is_fa = uca->bank.dp.fully_assoc; + bool pure_cam = uca->bank.dp.pure_cam; + ptr_array->Ndwl = Ndwl; + ptr_array->Ndbl = Ndbl; + ptr_array->Nspd = Nspd; + ptr_array->deg_bl_muxing = dyn_p.deg_bl_muxing; + ptr_array->Ndsam_lev_1 = Ndsam_lev_1; + ptr_array->Ndsam_lev_2 = Ndsam_lev_2; + ptr_array->access_time = uca->access_time; + ptr_array->cycle_time = uca->cycle_time; + ptr_array->multisubbank_interleave_cycle_time = uca->multisubbank_interleave_cycle_time; + ptr_array->area_ram_cells = uca->area_all_dataramcells; + ptr_array->area = uca->area.get_area(); + if(g_ip->is_3d_mem) + { //ptr_array->area = (uca->area_all_dataramcells)/0.5; + ptr_array->area = uca->area.get_area(); + if(g_ip->num_die_3d>1) + ptr_array->area += uca->area_TSV_tot; + } + + ptr_array->height = uca->area.h; + ptr_array->width = uca->area.w; + ptr_array->mat_height = uca->bank.mat.area.h; + ptr_array->mat_length = uca->bank.mat.area.w; + ptr_array->subarray_height = uca->bank.mat.subarray.area.h; + ptr_array->subarray_length = uca->bank.mat.subarray.area.w; + ptr_array->power = uca->power; + ptr_array->delay_senseamp_mux_decoder = + MAX(uca->delay_array_to_sa_mux_lev_1_decoder, + uca->delay_array_to_sa_mux_lev_2_decoder); + ptr_array->delay_before_subarray_output_driver = uca->delay_before_subarray_output_driver; + ptr_array->delay_from_subarray_output_driver_to_output = uca->delay_from_subarray_out_drv_to_out; + + ptr_array->delay_route_to_bank = uca->htree_in_add->delay; + ptr_array->delay_input_htree = uca->bank.htree_in_add->delay; + ptr_array->delay_row_predecode_driver_and_block = uca->bank.mat.r_predec->delay; + ptr_array->delay_row_decoder = uca->bank.mat.row_dec->delay; + ptr_array->delay_bitlines = uca->bank.mat.delay_bitline; + ptr_array->delay_matchlines = uca->bank.mat.delay_matchchline; + ptr_array->delay_sense_amp = uca->bank.mat.delay_sa; + ptr_array->delay_subarray_output_driver = uca->bank.mat.delay_subarray_out_drv_htree; + ptr_array->delay_dout_htree = uca->bank.htree_out_data->delay; + ptr_array->delay_comparator = uca->bank.mat.delay_comparator; + + if(g_ip->is_3d_mem) + { + ptr_array->delay_row_activate_net = uca->membus_RAS->delay_bus; + ptr_array->delay_row_predecode_driver_and_block = uca->membus_RAS->delay_add_predecoder; + ptr_array->delay_row_decoder = uca->membus_RAS->delay_add_decoder; + ptr_array->delay_local_wordline = uca->membus_RAS->delay_lwl_drv; + ptr_array->delay_column_access_net = uca->membus_CAS->delay_bus; + ptr_array->delay_column_predecoder = uca->membus_CAS->delay_add_predecoder; + ptr_array->delay_column_decoder = uca->membus_CAS->delay_add_decoder; + ptr_array->delay_column_selectline = 0; // Integrated into add_decoder + ptr_array->delay_datapath_net = uca->membus_data->delay_bus; + ptr_array->delay_global_data = uca->membus_data->delay_global_data; + ptr_array->delay_local_data_and_drv = uca->membus_data->delay_local_data; + ptr_array->delay_subarray_output_driver = uca->bank.mat.delay_subarray_out_drv; + ptr_array->delay_data_buffer = uca->membus_data->delay_data_buffer; + + /*ptr_array->energy_row_activate_net = uca->membus_RAS->add_bits * (uca->membus_RAS->center_stripe->power.readOp.dynamic + uca->membus_RAS->bank_bus->power.readOp.dynamic); + ptr_array->energy_row_predecode_driver_and_block = uca->membus_RAS->add_predec->power.readOp.dynamic; + ptr_array->energy_row_decoder = uca->membus_RAS->add_dec->power.readOp.dynamic; + ptr_array->energy_local_wordline = uca->membus_RAS->num_lwl_drv * uca->membus_RAS->lwl_drv->power.readOp.dynamic; + ptr_array->energy_column_access_net = uca->membus_CAS->add_bits * (uca->membus_CAS->center_stripe->power.readOp.dynamic + uca->membus_CAS->bank_bus->power.readOp.dynamic); + ptr_array->energy_column_predecoder = uca->membus_CAS->add_predec->power.readOp.dynamic; + ptr_array->energy_column_decoder = uca->membus_CAS->add_dec->power.readOp.dynamic; + ptr_array->energy_column_selectline = uca->membus_CAS->column_sel->power.readOp.dynamic; + ptr_array->energy_datapath_net = uca->membus_data->data_bits * (uca->membus_data->center_stripe->power.readOp.dynamic + uca->membus_data->bank_bus->power.readOp.dynamic); + ptr_array->energy_global_data = uca->membus_data->data_bits * (uca->membus_data->global_data->power.readOp.dynamic); + ptr_array->energy_local_data_and_drv = uca->membus_data->data_bits * (uca->membus_data->data_drv->power.readOp.dynamic); + ptr_array->energy_data_buffer = 0;*/ + + ptr_array->energy_row_activate_net = uca->membus_RAS->power_bus.readOp.dynamic; + ptr_array->energy_row_predecode_driver_and_block = uca->membus_RAS->power_add_predecoder.readOp.dynamic; + ptr_array->energy_row_decoder = uca->membus_RAS->power_add_decoders.readOp.dynamic; + ptr_array->energy_local_wordline = uca->membus_RAS->power_lwl_drv.readOp.dynamic; + ptr_array->energy_bitlines = dyn_p.Ndwl * uca->bank.mat.power_bitline.readOp.dynamic; + ptr_array->energy_sense_amp = dyn_p.Ndwl * uca->bank.mat.power_sa.readOp.dynamic; + + ptr_array->energy_column_access_net = uca->membus_CAS->power_bus.readOp.dynamic; + ptr_array->energy_column_predecoder = uca->membus_CAS->power_add_predecoder.readOp.dynamic; + ptr_array->energy_column_decoder = uca->membus_CAS->power_add_decoders.readOp.dynamic; + ptr_array->energy_column_selectline = uca->membus_CAS->power_col_sel.readOp.dynamic; + + ptr_array->energy_datapath_net = uca->membus_data->power_bus.readOp.dynamic; + ptr_array->energy_global_data = uca->membus_data->power_global_data.readOp.dynamic; + ptr_array->energy_local_data_and_drv = uca->membus_data->power_local_data.readOp.dynamic; + ptr_array->energy_subarray_output_driver = uca->bank.mat.power_subarray_out_drv.readOp.dynamic; // + ptr_array->energy_data_buffer = 0; + + ptr_array->area_lwl_drv = uca->area_lwl_drv; + ptr_array->area_row_predec_dec = uca->area_row_predec_dec; + ptr_array->area_col_predec_dec = uca->area_col_predec_dec; + ptr_array->area_subarray = uca->area_subarray; + ptr_array->area_bus = uca->area_bus; + ptr_array->area_address_bus = uca->area_address_bus; + ptr_array->area_data_bus = uca->area_data_bus; + ptr_array->area_data_drv = uca->area_data_drv; + ptr_array->area_IOSA = uca->area_IOSA; + ptr_array->area_sense_amp = uca->area_sense_amp; + + } + + ptr_array->all_banks_height = uca->area.h; + ptr_array->all_banks_width = uca->area.w; + //ptr_array->area_efficiency = uca->area_all_dataramcells * 100 / (uca->area.get_area()); + ptr_array->area_efficiency = uca->area_all_dataramcells * 100 / ptr_array->area; + + ptr_array->power_routing_to_bank = uca->power_routing_to_bank; + ptr_array->power_addr_input_htree = uca->bank.htree_in_add->power; + ptr_array->power_data_input_htree = uca->bank.htree_in_data->power; +// cout<<"power_data_input_htree"<bank.htree_in_data->power.readOp.leakage<power_data_output_htree = uca->bank.htree_out_data->power; +// cout<<"power_data_output_htree"<bank.htree_out_data->power.readOp.leakage<power_row_predecoder_drivers = uca->bank.mat.r_predec->driver_power; + ptr_array->power_row_predecoder_drivers.readOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_row_predecoder_drivers.writeOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_row_predecoder_drivers.searchOp.dynamic *= num_act_mats_hor_dir; + + ptr_array->power_row_predecoder_blocks = uca->bank.mat.r_predec->block_power; + ptr_array->power_row_predecoder_blocks.readOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_row_predecoder_blocks.writeOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_row_predecoder_blocks.searchOp.dynamic *= num_act_mats_hor_dir; + + ptr_array->power_row_decoders = uca->bank.mat.power_row_decoders; + ptr_array->power_row_decoders.readOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_row_decoders.writeOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_row_decoders.searchOp.dynamic *= num_act_mats_hor_dir; + + ptr_array->power_bit_mux_predecoder_drivers = uca->bank.mat.b_mux_predec->driver_power; + ptr_array->power_bit_mux_predecoder_drivers.readOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_bit_mux_predecoder_drivers.writeOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_bit_mux_predecoder_drivers.searchOp.dynamic *= num_act_mats_hor_dir; + + ptr_array->power_bit_mux_predecoder_blocks = uca->bank.mat.b_mux_predec->block_power; + ptr_array->power_bit_mux_predecoder_blocks.readOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_bit_mux_predecoder_blocks.writeOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_bit_mux_predecoder_blocks.searchOp.dynamic *= num_act_mats_hor_dir; + + ptr_array->power_bit_mux_decoders = uca->bank.mat.power_bit_mux_decoders; + ptr_array->power_bit_mux_decoders.readOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_bit_mux_decoders.writeOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_bit_mux_decoders.searchOp.dynamic *= num_act_mats_hor_dir; + + ptr_array->power_senseamp_mux_lev_1_predecoder_drivers = uca->bank.mat.sa_mux_lev_1_predec->driver_power; + ptr_array->power_senseamp_mux_lev_1_predecoder_drivers .readOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_senseamp_mux_lev_1_predecoder_drivers .writeOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_senseamp_mux_lev_1_predecoder_drivers .searchOp.dynamic *= num_act_mats_hor_dir; + + ptr_array->power_senseamp_mux_lev_1_predecoder_blocks = uca->bank.mat.sa_mux_lev_1_predec->block_power; + ptr_array->power_senseamp_mux_lev_1_predecoder_blocks.readOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_senseamp_mux_lev_1_predecoder_blocks.writeOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_senseamp_mux_lev_1_predecoder_blocks.searchOp.dynamic *= num_act_mats_hor_dir; + + ptr_array->power_senseamp_mux_lev_1_decoders = uca->bank.mat.power_sa_mux_lev_1_decoders; + ptr_array->power_senseamp_mux_lev_1_decoders.readOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_senseamp_mux_lev_1_decoders.writeOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_senseamp_mux_lev_1_decoders.searchOp.dynamic *= num_act_mats_hor_dir; + + ptr_array->power_senseamp_mux_lev_2_predecoder_drivers = uca->bank.mat.sa_mux_lev_2_predec->driver_power; + ptr_array->power_senseamp_mux_lev_2_predecoder_drivers.readOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_senseamp_mux_lev_2_predecoder_drivers.writeOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_senseamp_mux_lev_2_predecoder_drivers.searchOp.dynamic *= num_act_mats_hor_dir; + + ptr_array->power_senseamp_mux_lev_2_predecoder_blocks = uca->bank.mat.sa_mux_lev_2_predec->block_power; + ptr_array->power_senseamp_mux_lev_2_predecoder_blocks.readOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_senseamp_mux_lev_2_predecoder_blocks.writeOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_senseamp_mux_lev_2_predecoder_blocks.searchOp.dynamic *= num_act_mats_hor_dir; + + ptr_array->power_senseamp_mux_lev_2_decoders = uca->bank.mat.power_sa_mux_lev_2_decoders; + ptr_array->power_senseamp_mux_lev_2_decoders .readOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_senseamp_mux_lev_2_decoders .writeOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_senseamp_mux_lev_2_decoders .searchOp.dynamic *= num_act_mats_hor_dir; + + ptr_array->power_bitlines = uca->bank.mat.power_bitline; + ptr_array->power_bitlines.readOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_bitlines.writeOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_bitlines.searchOp.dynamic *= num_act_mats_hor_dir; + + ptr_array->power_sense_amps = uca->bank.mat.power_sa; + ptr_array->power_sense_amps.readOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_sense_amps.writeOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_sense_amps.searchOp.dynamic *= num_act_mats_hor_dir; + + ptr_array->power_prechg_eq_drivers = uca->bank.mat.power_bl_precharge_eq_drv; + ptr_array->power_prechg_eq_drivers.readOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_prechg_eq_drivers.writeOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_prechg_eq_drivers.searchOp.dynamic *= num_act_mats_hor_dir; + + ptr_array->power_output_drivers_at_subarray = uca->bank.mat.power_subarray_out_drv; + ptr_array->power_output_drivers_at_subarray.readOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_output_drivers_at_subarray.writeOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_output_drivers_at_subarray.searchOp.dynamic *= num_act_mats_hor_dir; + + ptr_array->power_comparators = uca->bank.mat.power_comparator; + ptr_array->power_comparators.readOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_comparators.writeOp.dynamic *= num_act_mats_hor_dir; + ptr_array->power_comparators.searchOp.dynamic *= num_act_mats_hor_dir; + +// cout << " num of mats: " << dyn_p.num_mats << endl; + if (is_fa || pure_cam) + { + ptr_array->power_htree_in_search = uca->bank.htree_in_search->power; +// cout<<"power_htree_in_search"<bank.htree_in_search->power.readOp.leakage<power_htree_out_search = uca->bank.htree_out_search->power; +// cout<<"power_htree_out_search"<bank.htree_out_search->power.readOp.leakage<power_searchline = uca->bank.mat.power_searchline; +// cout<<"power_searchlineh"<bank.mat.power_searchline.readOp.leakage<power_searchline.searchOp.dynamic *= num_mats; + ptr_array->power_searchline_precharge = uca->bank.mat.power_searchline_precharge; + ptr_array->power_searchline_precharge.searchOp.dynamic *= num_mats; + ptr_array->power_matchlines = uca->bank.mat.power_matchline; + ptr_array->power_matchlines.searchOp.dynamic *= num_mats; + ptr_array->power_matchline_precharge = uca->bank.mat.power_matchline_precharge; + ptr_array->power_matchline_precharge.searchOp.dynamic *= num_mats; + ptr_array->power_matchline_to_wordline_drv = uca->bank.mat.power_ml_to_ram_wl_drv; +// cout<<"power_matchline.searchOp.leakage"<bank.mat.power_matchline.searchOp.leakage<activate_energy = uca->activate_energy; + ptr_array->read_energy = uca->read_energy; + ptr_array->write_energy = uca->write_energy; + ptr_array->precharge_energy = uca->precharge_energy; + ptr_array->refresh_power = uca->refresh_power; + ptr_array->leak_power_subbank_closed_page = uca->leak_power_subbank_closed_page; + ptr_array->leak_power_subbank_open_page = uca->leak_power_subbank_open_page; + ptr_array->leak_power_request_and_reply_networks = uca->leak_power_request_and_reply_networks; + + ptr_array->precharge_delay = uca->precharge_delay; + + if(g_ip->is_3d_mem) + { + //CACTI3DD + ptr_array->t_RCD = uca->t_RCD; + ptr_array->t_RAS = uca->t_RAS; + ptr_array->t_RC = uca->t_RC; + ptr_array->t_CAS = uca->t_CAS; + ptr_array->t_RP = uca->t_RP; + ptr_array->t_RRD = uca->t_RRD; + + ptr_array->activate_energy = uca->activate_energy; + ptr_array->read_energy = uca->read_energy; + ptr_array->write_energy = uca->write_energy; + ptr_array->precharge_energy = uca->precharge_energy; + + + ptr_array->activate_power = uca->activate_power; + ptr_array->read_power = uca->read_power; + ptr_array->write_power = uca->write_power; + ptr_array->peak_read_power = uca->read_energy/((g_ip->burst_depth)/(g_ip->sys_freq_MHz*1e6)/2); + + ptr_array->num_row_subarray = dyn_p.num_r_subarray; + ptr_array->num_col_subarray = dyn_p.num_c_subarray; + + + ptr_array->delay_TSV_tot = uca->delay_TSV_tot; + ptr_array->area_TSV_tot = uca->area_TSV_tot; + ptr_array->dyn_pow_TSV_tot = uca->dyn_pow_TSV_tot; + ptr_array->dyn_pow_TSV_per_access = uca->dyn_pow_TSV_per_access; + ptr_array->num_TSV_tot = uca->num_TSV_tot; + + //Covers the previous values + //ptr_array->area = g_ip->num_die_3d * (uca->area_per_bank * g_ip->nbanks); + //ptr_array->area_efficiency = g_ip->num_die_3d * uca->area_all_dataramcells * 100 / ptr_array->area; + } +// cout<<"power_matchline.searchOp.leakage"<bank.mat.<bank.mat.subarray.get_total_cell_area()<power_gating) + { + ptr_array->sram_sleep_tx_width= uca->bank.mat.sram_sleep_tx->width; + ptr_array->sram_sleep_tx_area= uca->bank.mat.array_sleep_tx_area; + ptr_array->sram_sleep_wakeup_latency= uca->bank.mat.array_wakeup_t; + ptr_array->sram_sleep_wakeup_energy= uca->bank.mat.array_wakeup_e.readOp.dynamic; + + ptr_array->wl_sleep_tx_width= uca->bank.mat.row_dec->sleeptx->width; + ptr_array->wl_sleep_tx_area= uca->bank.mat.wl_sleep_tx_area; + ptr_array->wl_sleep_wakeup_latency= uca->bank.mat.wl_wakeup_t; + ptr_array->wl_sleep_wakeup_energy= uca->bank.mat.wl_wakeup_e.readOp.dynamic; + + ptr_array->bl_floating_wakeup_latency= uca->bank.mat.blfloating_wakeup_t; + ptr_array->bl_floating_wakeup_energy= uca->bank.mat.blfloating_wakeup_e.readOp.dynamic; + + ptr_array->array_leakage= uca->bank.array_leakage; + ptr_array->wl_leakage= uca->bank.wl_leakage; + ptr_array->cl_leakage= uca->bank.cl_leakage; + } + + ptr_array->num_active_mats = uca->bank.dp.num_act_mats_hor_dir; + ptr_array->num_submarray_mats = uca->bank.mat.num_subarrays_per_mat; + // cout<<"array_leakage"<array_leakage<wl_leakage<cl_leakage<min_delay)*100/minval->min_delay) > g_ip->delay_dev) { + return false; + } + if (((u.power.readOp.dynamic - minval->min_dyn)/minval->min_dyn)*100 > + g_ip->dynamic_power_dev) { + return false; + } + if (((u.power.readOp.leakage - minval->min_leakage)/minval->min_leakage)*100 > + g_ip->leakage_power_dev) { + return false; + } + if (((u.cycle_time - minval->min_cyc)/minval->min_cyc)*100 > + g_ip->cycle_time_dev) { + return false; + } + if (((u.area - minval->min_area)/minval->min_area)*100 > + g_ip->area_dev) { + return false; + } + return true; +} + +bool check_mem_org(mem_array & u, const min_values_t *minval) +{ + if (((u.access_time - minval->min_delay)*100/minval->min_delay) > g_ip->delay_dev) { + return false; + } + if (((u.power.readOp.dynamic - minval->min_dyn)/minval->min_dyn)*100 > + g_ip->dynamic_power_dev) { + return false; + } + if (((u.power.readOp.leakage - minval->min_leakage)/minval->min_leakage)*100 > + g_ip->leakage_power_dev) { + return false; + } + if (((u.cycle_time - minval->min_cyc)/minval->min_cyc)*100 > + g_ip->cycle_time_dev) { + return false; + } + if (((u.area - minval->min_area)/minval->min_area)*100 > + g_ip->area_dev) { + return false; + } + return true; +} + + + + +void find_optimal_uca(uca_org_t *res, min_values_t * minval, list & ulist) +{ + double cost = 0; + double min_cost = BIGNUM; + float d, a, dp, lp, c; + + dp = g_ip->dynamic_power_wt; + lp = g_ip->leakage_power_wt; + a = g_ip->area_wt; + d = g_ip->delay_wt; + c = g_ip->cycle_time_wt; + + if (ulist.empty() == true) + { + cout << "ERROR: no valid cache organizations found" << endl; + exit(0); + } + + for (list::iterator niter = ulist.begin(); niter != ulist.end(); niter++) + { + if (g_ip->ed == 1) + { + cost = ((niter)->access_time/minval->min_delay) * ((niter)->power.readOp.dynamic/minval->min_dyn); + if (min_cost > cost) + { + min_cost = cost; + *res = (*(niter)); + } + } + else if (g_ip->ed == 2) + { + cost = ((niter)->access_time/minval->min_delay)* + ((niter)->access_time/minval->min_delay)* + ((niter)->power.readOp.dynamic/minval->min_dyn); + if (min_cost > cost) + { + min_cost = cost; + *res = (*(niter)); + } + } + else + { + /* + * check whether the current organization + * meets the input deviation constraints + */ + bool v = check_uca_org(*niter, minval); + //if (minval->min_leakage == 0) minval->min_leakage = 0.1; //FIXME remove this after leakage modeling + + if (v) + { + cost = (d * ((niter)->access_time/minval->min_delay) + + c * ((niter)->cycle_time/minval->min_cyc) + + dp * ((niter)->power.readOp.dynamic/minval->min_dyn) + + lp * ((niter)->power.readOp.leakage/minval->min_leakage) + + a * ((niter)->area/minval->min_area)); + //fprintf(stderr, "cost = %g\n", cost); + + if (min_cost > cost) { + min_cost = cost; + *res = (*(niter)); + niter = ulist.erase(niter); + if (niter!=ulist.begin()) + niter--; + } + } + else { + niter = ulist.erase(niter); + if (niter!=ulist.begin()) + niter--; + } + } + } + + if (min_cost == BIGNUM) + { + cout << "ERROR: no cache organizations met optimization criteria" << endl; + exit(0); + } +} + + + +void filter_tag_arr(const min_values_t * min, list & list) +{ + double cost = BIGNUM; + double cur_cost; + double wt_delay = g_ip->delay_wt, wt_dyn = g_ip->dynamic_power_wt, wt_leakage = g_ip->leakage_power_wt, wt_cyc = g_ip->cycle_time_wt, wt_area = g_ip->area_wt; + mem_array * res = NULL; + + if (list.empty() == true) + { + cout << "ERROR: no valid tag organizations found" << endl; + exit(1); + } + + + while (list.empty() != true) + { + bool v = check_mem_org(*list.back(), min); + if (v) + { + cur_cost = wt_delay * (list.back()->access_time/min->min_delay) + + wt_dyn * (list.back()->power.readOp.dynamic/min->min_dyn) + + wt_leakage * (list.back()->power.readOp.leakage/min->min_leakage) + + wt_area * (list.back()->area/min->min_area) + + wt_cyc * (list.back()->cycle_time/min->min_cyc); + } + else + { + cur_cost = BIGNUM; + } + if (cur_cost < cost) + { + if (res != NULL) + { + delete res; + } + cost = cur_cost; + res = list.back(); + } + else + { + delete list.back(); + } + list.pop_back(); + } + if(!res) + { + cout << "ERROR: no valid tag organizations found" << endl; + exit(0); + } + + list.push_back(res); +} + + + +void filter_data_arr(list & curr_list) +{ + if (curr_list.empty() == true) + { + cout << "ERROR: no valid data array organizations found" << endl; + exit(1); + } + + list::iterator iter; + + for (iter = curr_list.begin(); iter != curr_list.end(); ++iter) + { + mem_array * m = *iter; + + if (m == NULL) exit(1); + + if(((m->access_time - m->arr_min->min_delay)/m->arr_min->min_delay > 0.5) && + ((m->power.readOp.dynamic - m->arr_min->min_dyn)/m->arr_min->min_dyn > 0.5)) + { + delete m; + iter = curr_list.erase(iter); + iter --; + } + } +} + + + +/* + * Performs exhaustive search across different sub-array sizes, + * wire types and aspect ratios to find an optimal UCA organization + * 1. First different valid tag array organizations are calculated + * and stored in tag_arr array + * 2. The exhaustive search is repeated to find valid data array + * organizations and stored in data_arr array + * 3. Cache area, delay, power, and cycle time for different + * cache organizations are calculated based on the + * above results + * 4. Cache model with least cost is picked from sol_list + */ +void solve(uca_org_t *fin_res) +{ + ///bool is_dram = false; + int pure_ram = g_ip->pure_ram; + bool pure_cam = g_ip->pure_cam; + + init_tech_params(g_ip->F_sz_um, false); + g_ip->print_detail_debug = 0; // ---detail outputs for debug, initiated for 3D memory + + list tag_arr (0); + list data_arr(0); + list::iterator miter; + list sol_list(1, uca_org_t()); + + fin_res->tag_array.access_time = 0; + fin_res->tag_array.Ndwl = 0; + fin_res->tag_array.Ndbl = 0; + fin_res->tag_array.Nspd = 0; + fin_res->tag_array.deg_bl_muxing = 0; + fin_res->tag_array.Ndsam_lev_1 = 0; + fin_res->tag_array.Ndsam_lev_2 = 0; + + + // distribute calculate_time() execution to multiple threads + calc_time_mt_wrapper_struct * calc_array = new calc_time_mt_wrapper_struct[nthreads]; + pthread_t threads[nthreads]; + + for (uint32_t t = 0; t < nthreads; t++) + { + calc_array[t].tid = t; + calc_array[t].pure_ram = pure_ram; + calc_array[t].pure_cam = pure_cam; + calc_array[t].data_res = new min_values_t(); + calc_array[t].tag_res = new min_values_t(); + } + + bool is_tag; + ///uint32_t ram_cell_tech_type; + + // If it's a cache, first calculate the area, delay and power for all tag array partitions. + if (!(pure_ram||pure_cam||g_ip->fully_assoc)) + { //cache + is_tag = true; + /// ram_cell_tech_type = g_ip->tag_arr_ram_cell_tech_type; + /// is_dram = ((ram_cell_tech_type == lp_dram) || (ram_cell_tech_type == comm_dram)); + init_tech_params(g_ip->F_sz_um, is_tag); + + for (uint32_t t = 0; t < nthreads; t++) + { + calc_array[t].is_tag = is_tag; + calc_array[t].is_main_mem = false; + calc_array[t].Nspd_min = 0.125; + pthread_create(&threads[t], NULL, calc_time_mt_wrapper, (void *)(&(calc_array[t]))); + } + + for (uint32_t t = 0; t < nthreads; t++) + { + pthread_join(threads[t], NULL); + } + + for (uint32_t t = 0; t < nthreads; t++) + { + calc_array[t].data_arr.sort(mem_array::lt); + data_arr.merge(calc_array[t].data_arr, mem_array::lt); + calc_array[t].tag_arr.sort(mem_array::lt); + tag_arr.merge(calc_array[t].tag_arr, mem_array::lt); + } + } + + + // calculate the area, delay and power for all data array partitions (for cache or plain RAM). +// if (!g_ip->fully_assoc) +// {//in the new cacti, cam, fully_associative cache are processed as single array in the data portion + is_tag = false; + /// ram_cell_tech_type = g_ip->data_arr_ram_cell_tech_type; + /// is_dram = ((ram_cell_tech_type == lp_dram) || (ram_cell_tech_type == comm_dram)); + init_tech_params(g_ip->F_sz_um, is_tag); + + for (uint32_t t = 0; t < nthreads; t++) + { + calc_array[t].is_tag = is_tag; + calc_array[t].is_main_mem = g_ip->is_main_mem; + if (!(pure_cam||g_ip->fully_assoc)) + { + calc_array[t].Nspd_min = (double)(g_ip->out_w)/(double)(g_ip->block_sz*8); + } + else + { + calc_array[t].Nspd_min = 1; + } + + pthread_create(&threads[t], NULL, calc_time_mt_wrapper, (void *)(&(calc_array[t]))); + } + + for (uint32_t t = 0; t < nthreads; t++) + { + pthread_join(threads[t], NULL); + } + + data_arr.clear(); + for (uint32_t t = 0; t < nthreads; t++) + { + calc_array[t].data_arr.sort(mem_array::lt); + data_arr.merge(calc_array[t].data_arr, mem_array::lt); + } +// } + + + min_values_t * d_min = new min_values_t(); + min_values_t * t_min = new min_values_t(); + min_values_t * cache_min = new min_values_t(); + + for (uint32_t t = 0; t < nthreads; t++) + { + d_min->update_min_values(calc_array[t].data_res); + t_min->update_min_values(calc_array[t].tag_res); + } + + for (miter = data_arr.begin(); miter != data_arr.end(); miter++) + { + (*miter)->arr_min = d_min; + } + + + //cout << data_arr.size() << "\t" << tag_arr.size() <<" before\n"; + filter_data_arr(data_arr); + if(!(pure_ram||pure_cam||g_ip->fully_assoc)) + { + filter_tag_arr(t_min, tag_arr); + } + //cout << data_arr.size() << "\t" << tag_arr.size() <<" after\n"; + + + if (pure_ram||pure_cam||g_ip->fully_assoc) + { + for (miter = data_arr.begin(); miter != data_arr.end(); miter++) + { + uca_org_t & curr_org = sol_list.back(); + curr_org.tag_array2 = NULL; + curr_org.data_array2 = (*miter); + + curr_org.find_delay(); + curr_org.find_energy(); + curr_org.find_area(); + curr_org.find_cyc(); + + //update min values for the entire cache + cache_min->update_min_values(curr_org); + + sol_list.push_back(uca_org_t()); + } + } + else + { + while (tag_arr.empty() != true) + { + mem_array * arr_temp = (tag_arr.back()); + //delete tag_arr.back(); + tag_arr.pop_back(); + + for (miter = data_arr.begin(); miter != data_arr.end(); miter++) + { + uca_org_t & curr_org = sol_list.back(); + curr_org.tag_array2 = arr_temp; + curr_org.data_array2 = (*miter); + + curr_org.find_delay(); + curr_org.find_energy(); + curr_org.find_area(); + curr_org.find_cyc(); + + //update min values for the entire cache + cache_min->update_min_values(curr_org); + + sol_list.push_back(uca_org_t()); + } + } + } + + sol_list.pop_back(); + + find_optimal_uca(fin_res, cache_min, sol_list); + + sol_list.clear(); + + for (miter = data_arr.begin(); miter != data_arr.end(); ++miter) + { + if (*miter != fin_res->data_array2) + { + delete *miter; + } + } + data_arr.clear(); + + for (uint32_t t = 0; t < nthreads; t++) + { + delete calc_array[t].data_res; + delete calc_array[t].tag_res; + } + + delete [] calc_array; + delete cache_min; + delete d_min; + delete t_min; +} + +void update(uca_org_t *fin_res) +{ + if(fin_res->tag_array2) + { + init_tech_params(g_ip->F_sz_um,true); + DynamicParameter tag_arr_dyn_p(true, g_ip->pure_ram, g_ip->pure_cam, fin_res->tag_array2->Nspd, fin_res->tag_array2->Ndwl, fin_res->tag_array2->Ndbl, fin_res->tag_array2->Ndcm, fin_res->tag_array2->Ndsam_lev_1, fin_res->tag_array2->Ndsam_lev_2, fin_res->data_array2->wt, g_ip->is_main_mem); + if(tag_arr_dyn_p.is_valid) + { + UCA * tag_arr = new UCA(tag_arr_dyn_p); + fin_res->tag_array2->power = tag_arr->power; + } + else + { + cout << "ERROR: Cannot retrieve array structure for leakage feedback" << endl; + exit(1); + } + } + init_tech_params(g_ip->F_sz_um,false); + DynamicParameter data_arr_dyn_p(false, g_ip->pure_ram, g_ip->pure_cam, fin_res->data_array2->Nspd, fin_res->data_array2->Ndwl, fin_res->data_array2->Ndbl, fin_res->data_array2->Ndcm, fin_res->data_array2->Ndsam_lev_1, fin_res->data_array2->Ndsam_lev_2, fin_res->data_array2->wt, g_ip->is_main_mem); + if(data_arr_dyn_p.is_valid) + { + UCA * data_arr = new UCA(data_arr_dyn_p); + fin_res->data_array2->power = data_arr->power; + } + else + { + cout << "ERROR: Cannot retrieve array structure for leakage feedback" << endl; + exit(1); + } + + fin_res->find_energy(); +} + diff --git a/Project_FARSI/cacti_for_FARSI/Ucache.h b/Project_FARSI/cacti_for_FARSI/Ucache.h new file mode 100644 index 00000000..bfa1a308 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/Ucache.h @@ -0,0 +1,118 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + +#ifndef __UCACHE_H__ +#define __UCACHE_H__ + +#include +#include "area.h" +#include "router.h" +#include "nuca.h" + + +class min_values_t +{ + public: + double min_delay; + double min_dyn; + double min_leakage; + double min_area; + double min_cyc; + + min_values_t() : min_delay(BIGNUM), min_dyn(BIGNUM), min_leakage(BIGNUM), min_area(BIGNUM), min_cyc(BIGNUM) { } + + void update_min_values(const min_values_t * val); + void update_min_values(const uca_org_t & res); + void update_min_values(const nuca_org_t * res); + void update_min_values(const mem_array * res); +}; + + + +struct solution +{ + int tag_array_index; + int data_array_index; + list::iterator tag_array_iter; + list::iterator data_array_iter; + double access_time; + double cycle_time; + double area; + double efficiency; + powerDef total_power; +}; + + + +bool calculate_time( + bool is_tag, + int pure_ram, + bool pure_cam, + double Nspd, + unsigned int Ndwl, + unsigned int Ndbl, + unsigned int Ndcm, + unsigned int Ndsam_lev_1, + unsigned int Ndsam_lev_2, + mem_array *ptr_array, + int flag_results_populate, + results_mem_array *ptr_results, + uca_org_t *ptr_fin_res, + Wire_type wtype, // merge from cacti-7 to cacti3d + bool is_main_mem); +void update(uca_org_t *fin_res); + +void solve(uca_org_t *fin_res); +void init_tech_params(double tech, bool is_tag); + + +struct calc_time_mt_wrapper_struct +{ + uint32_t tid; + bool is_tag; + bool pure_ram; + bool pure_cam; + bool is_main_mem; + double Nspd_min; + + min_values_t * data_res; + min_values_t * tag_res; + + list data_arr; + list tag_arr; +}; + +void *calc_time_mt_wrapper(void * void_obj); + +void print_g_tp(); + +#endif diff --git a/Project_FARSI/cacti_for_FARSI/arbiter.cc b/Project_FARSI/cacti_for_FARSI/arbiter.cc new file mode 100644 index 00000000..f09dcb78 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/arbiter.cc @@ -0,0 +1,130 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + +#include "arbiter.h" + +Arbiter::Arbiter( + double n_req, + double flit_size_, + double output_len, + /*TechnologyParameter::*/DeviceType *dt + ):R(n_req), flit_size(flit_size_), + o_len (output_len), deviceType(dt) +{ + min_w_pmos = deviceType->n_to_p_eff_curr_drv_ratio*g_tp.min_w_nmos_; + Vdd = dt->Vdd; + double technology = g_ip->F_sz_um; + NTn1 = 13.5*technology/2; + PTn1 = 76*technology/2; + NTn2 = 13.5*technology/2; + PTn2 = 76*technology/2; + NTi = 12.5*technology/2; + PTi = 25*technology/2; + NTtr = 10*technology/2; /*Transmission gate's nmos tr. length*/ + PTtr = 20*technology/2; /* pmos tr. length*/ +} + +Arbiter::~Arbiter(){} + +double +Arbiter::arb_req() { + double temp = ((R-1)*(2*gate_C(NTn1, 0)+gate_C(PTn1, 0)) + 2*gate_C(NTn2, 0) + + gate_C(PTn2, 0) + gate_C(NTi, 0) + gate_C(PTi, 0) + + drain_C_(NTi, 0, 1, 1, g_tp.cell_h_def) + drain_C_(PTi, 1, 1, 1, g_tp.cell_h_def)); + return temp; +} + +double +Arbiter::arb_pri() { + double temp = 2*(2*gate_C(NTn1, 0)+gate_C(PTn1, 0)); /* switching capacitance + of flip-flop is ignored */ + return temp; +} + + +double +Arbiter::arb_grant() { + double temp = drain_C_(NTn1, 0, 1, 1, g_tp.cell_h_def)*2 + drain_C_(PTn1, 1, 1, 1, g_tp.cell_h_def) + crossbar_ctrline(); + return temp; +} + +double +Arbiter::arb_int() { + double temp = (drain_C_(NTn1, 0, 1, 1, g_tp.cell_h_def)*2 + drain_C_(PTn1, 1, 1, 1, g_tp.cell_h_def) + + 2*gate_C(NTn2, 0) + gate_C(PTn2, 0)); + return temp; +} + +void +Arbiter::compute_power() { + power.readOp.dynamic = (R*arb_req()*Vdd*Vdd/2 + R*arb_pri()*Vdd*Vdd/2 + + arb_grant()*Vdd*Vdd + arb_int()*0.5*Vdd*Vdd); + double nor1_leak = cmos_Isub_leakage(g_tp.min_w_nmos_*NTn1*2, min_w_pmos * PTn1*2, 2, nor); + double nor2_leak = cmos_Isub_leakage(g_tp.min_w_nmos_*NTn2*R, min_w_pmos * PTn2*R, 2, nor); + double not_leak = cmos_Isub_leakage(g_tp.min_w_nmos_*NTi, min_w_pmos * PTi, 1, inv); + double nor1_leak_gate = cmos_Ig_leakage(g_tp.min_w_nmos_*NTn1*2, min_w_pmos * PTn1*2, 2, nor); + double nor2_leak_gate = cmos_Ig_leakage(g_tp.min_w_nmos_*NTn2*R, min_w_pmos * PTn2*R, 2, nor); + double not_leak_gate = cmos_Ig_leakage(g_tp.min_w_nmos_*NTi, min_w_pmos * PTi, 1, inv); + power.readOp.leakage = (nor1_leak + nor2_leak + not_leak)*Vdd; //FIXME include priority table leakage + power.readOp.gate_leakage = nor1_leak_gate*Vdd + nor2_leak_gate*Vdd + not_leak_gate*Vdd; +} + +double //wire cap with triple spacing +Arbiter::Cw3(double length) { + Wire wc(g_ip->wt, length, 1, 3, 3); + double temp = (wc.wire_cap(length,true)); + return temp; +} + +double +Arbiter::crossbar_ctrline() { + double temp = (Cw3(o_len * 1e-6 /* m */) + + drain_C_(NTi, 0, 1, 1, g_tp.cell_h_def) + drain_C_(PTi, 1, 1, 1, g_tp.cell_h_def) + + gate_C(NTi, 0) + gate_C(PTi, 0)); + return temp; +} + +double +Arbiter::transmission_buf_ctrcap() { + double temp = gate_C(NTtr, 0)+gate_C(PTtr, 0); + return temp; +} + + +void Arbiter::print_arbiter() +{ + cout << "\nArbiter Stats (" << R << " input arbiter" << ")\n\n"; + cout << "Flit size : " << flit_size << " bits" << endl; + cout << "Dynamic Power : " << power.readOp.dynamic*1e9 << " (nJ)" << endl; + cout << "Leakage Power : " << power.readOp.leakage*1e3 << " (mW)" << endl; +} + + diff --git a/Project_FARSI/cacti_for_FARSI/arbiter.h b/Project_FARSI/cacti_for_FARSI/arbiter.h new file mode 100644 index 00000000..8358e957 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/arbiter.h @@ -0,0 +1,77 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + +#ifndef __ARBITER__ +#define __ARBITER__ + +#include +#include +#include "basic_circuit.h" +#include "cacti_interface.h" +#include "component.h" +#include "parameter.h" +#include "mat.h" +#include "wire.h" + +class Arbiter : public Component +{ + public: + Arbiter( + double Req, + double flit_sz, + double output_len, + /*TechnologyParameter::*/DeviceType *dt = &(g_tp.peri_global)); + ~Arbiter(); + + void print_arbiter(); + double arb_req(); + double arb_pri(); + double arb_grant(); + double arb_int(); + void compute_power(); + double Cw3(double len); + double crossbar_ctrline(); + double transmission_buf_ctrcap(); + + + + private: + double NTn1, PTn1, NTn2, PTn2, R, PTi, NTi; + double flit_size; + double NTtr, PTtr; + double o_len; + /*TechnologyParameter::*/DeviceType *deviceType; + double TriS1, TriS2; + double min_w_pmos, Vdd; + +}; + +#endif diff --git a/Project_FARSI/cacti_for_FARSI/area.cc b/Project_FARSI/cacti_for_FARSI/area.cc new file mode 100644 index 00000000..d6a37468 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/area.cc @@ -0,0 +1,46 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + + +#include "area.h" +#include "component.h" +#include "decoder.h" +#include "parameter.h" +#include "basic_circuit.h" +#include +#include +#include + +using namespace std; + + + diff --git a/Project_FARSI/cacti_for_FARSI/area.h b/Project_FARSI/cacti_for_FARSI/area.h new file mode 100644 index 00000000..a592dbcc --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/area.h @@ -0,0 +1,71 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + + +#ifndef __AREA_H__ +#define __AREA_H__ + +#include "cacti_interface.h" +#include "basic_circuit.h" + +using namespace std; + +class Area +{ + public: + double w; + double h; + + Area():w(0), h(0), area(0) { } + double get_w() const { return w; } + double get_h() const { return h; } + double get_area() const + { + if (w == 0 && h == 0) + { + return area; + } + else + { + return w*h; + } + } + void set_w(double w_) { w = w_; } + void set_h(double h_) { h = h_; } + void set_area(double a_) { area = a_; } + + private: + double area; +}; + +#endif + diff --git a/Project_FARSI/cacti_for_FARSI/bank.cc b/Project_FARSI/cacti_for_FARSI/bank.cc new file mode 100644 index 00000000..e7e5d819 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/bank.cc @@ -0,0 +1,206 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + + +#include "bank.h" +#include + + +Bank::Bank(const DynamicParameter & dyn_p): + dp(dyn_p), mat(dp), + num_addr_b_mat(dyn_p.number_addr_bits_mat), + num_mats_hor_dir(dyn_p.num_mats_h_dir), num_mats_ver_dir(dyn_p.num_mats_v_dir), + array_leakage(0), + wl_leakage(0), + cl_leakage(0) +{ +// Mat temp(dyn_p); + int RWP; + int ERP; + int EWP; + int SCHP; + + if (dp.use_inp_params) + { + RWP = dp.num_rw_ports; + ERP = dp.num_rd_ports; + EWP = dp.num_wr_ports; + SCHP = dp.num_search_ports; + } + else + { + RWP = g_ip->num_rw_ports; + ERP = g_ip->num_rd_ports; + EWP = g_ip->num_wr_ports; + SCHP = g_ip->num_search_ports; + } + + int total_addrbits = (dp.number_addr_bits_mat + dp.number_subbanks_decode)*(RWP+ERP+EWP); + int datainbits = dp.num_di_b_bank_per_port * (RWP + EWP); + int dataoutbits = dp.num_do_b_bank_per_port * (RWP + ERP); + int searchinbits; + int searchoutbits; + + if (dp.fully_assoc || dp.pure_cam) + { + datainbits = dp.num_di_b_bank_per_port * (RWP + EWP); + dataoutbits = dp.num_do_b_bank_per_port * (RWP + ERP); + searchinbits = dp.num_si_b_bank_per_port * SCHP; + searchoutbits = dp.num_so_b_bank_per_port * SCHP; + } + + if (!(dp.fully_assoc || dp.pure_cam)) + { + if (g_ip->fast_access && dp.is_tag == false) + { + dataoutbits *= g_ip->data_assoc; + } + + htree_in_add = new Htree2 (dp.wtype/*g_ip->wt*/,(double) mat.area.w, (double)mat.area.h, + total_addrbits, datainbits, 0,dataoutbits,0, num_mats_ver_dir*2, num_mats_hor_dir*2, Add_htree); + htree_in_data = new Htree2 (dp.wtype/*g_ip->wt*/,(double) mat.area.w, (double)mat.area.h, + total_addrbits, datainbits, 0,dataoutbits,0, num_mats_ver_dir*2, num_mats_hor_dir*2, Data_in_htree); + htree_out_data = new Htree2 (dp.wtype/*g_ip->wt*/,(double) mat.area.w, (double)mat.area.h, + total_addrbits, datainbits, 0,dataoutbits,0, num_mats_ver_dir*2, num_mats_hor_dir*2, Data_out_htree); + +// htree_out_data = new Htree2 (g_ip->wt,(double) 100, (double)100, +// total_addrbits, datainbits, 0,dataoutbits,0, num_mats_ver_dir*2, num_mats_hor_dir*2, Data_out_htree); + + area.w = htree_in_data->area.w; + area.h = htree_in_data->area.h; + } + else + { + htree_in_add = new Htree2 (dp.wtype/*g_ip->wt*/,(double) mat.area.w, (double)mat.area.h, + total_addrbits, datainbits, searchinbits,dataoutbits,searchoutbits, num_mats_ver_dir*2, num_mats_hor_dir*2, Add_htree); + htree_in_data = new Htree2 (dp.wtype/*g_ip->wt*/,(double) mat.area.w, (double)mat.area.h, + total_addrbits, datainbits,searchinbits, dataoutbits, searchoutbits, num_mats_ver_dir*2, num_mats_hor_dir*2, Data_in_htree); + htree_out_data = new Htree2 (dp.wtype/*g_ip->wt*/,(double) mat.area.w, (double)mat.area.h, + total_addrbits, datainbits,searchinbits, dataoutbits, searchoutbits,num_mats_ver_dir*2, num_mats_hor_dir*2, Data_out_htree); + htree_in_search = new Htree2 (dp.wtype/*g_ip->wt*/,(double) mat.area.w, (double)mat.area.h, + total_addrbits, datainbits,searchinbits, dataoutbits, searchoutbits, num_mats_ver_dir*2, num_mats_hor_dir*2, Data_in_htree,true, true); + htree_out_search = new Htree2 (dp.wtype/*g_ip->wt*/,(double) mat.area.w, (double)mat.area.h, + total_addrbits, datainbits,searchinbits, dataoutbits, searchoutbits,num_mats_ver_dir*2, num_mats_hor_dir*2, Data_out_htree,true); + + area.w = htree_in_data->area.w; + area.h = htree_in_data->area.h; + } + + num_addr_b_row_dec = _log2(mat.subarray.num_rows); + num_addr_b_routed_to_mat_for_act = num_addr_b_row_dec; + num_addr_b_routed_to_mat_for_rd_or_wr = num_addr_b_mat - num_addr_b_row_dec; +} + + + +Bank::~Bank() +{ + delete htree_in_add; + delete htree_out_data; + delete htree_in_data; + if (dp.fully_assoc || dp.pure_cam) + { + delete htree_in_search; + delete htree_out_search; + } +} + + + +double Bank::compute_delays(double inrisetime) +{ + return mat.compute_delays(inrisetime); +} + + + +void Bank::compute_power_energy() +{ + mat.compute_power_energy(); + + if (!(dp.fully_assoc || dp.pure_cam)) + { + power.readOp.dynamic += mat.power.readOp.dynamic * dp.num_act_mats_hor_dir; + power.readOp.leakage += mat.power.readOp.leakage * dp.num_mats; + power.readOp.gate_leakage += mat.power.readOp.gate_leakage * dp.num_mats; + + power.readOp.dynamic += htree_in_add->power.readOp.dynamic; + power.readOp.dynamic += htree_out_data->power.readOp.dynamic; + + array_leakage += mat.array_leakage*dp.num_mats; + wl_leakage += mat.wl_leakage*dp.num_mats; + cl_leakage += mat.cl_leakage*dp.num_mats; +// +// power.readOp.leakage += htree_in_add->power.readOp.leakage; +// power.readOp.leakage += htree_in_data->power.readOp.leakage; +// power.readOp.leakage += htree_out_data->power.readOp.leakage; +// power.readOp.gate_leakage += htree_in_add->power.readOp.gate_leakage; +// power.readOp.gate_leakage += htree_in_data->power.readOp.gate_leakage; +// power.readOp.gate_leakage += htree_out_data->power.readOp.gate_leakage; + } + else + { + + power.readOp.dynamic += mat.power.readOp.dynamic ;//for fa and cam num_act_mats_hor_dir is 1 for plain r/w + power.readOp.leakage += mat.power.readOp.leakage * dp.num_mats; + power.readOp.gate_leakage += mat.power.readOp.gate_leakage * dp.num_mats; + + power.searchOp.dynamic += mat.power.searchOp.dynamic * dp.num_mats; + power.searchOp.dynamic += mat.power_bl_precharge_eq_drv.searchOp.dynamic + + mat.power_sa.searchOp.dynamic + + mat.power_bitline.searchOp.dynamic + + mat.power_subarray_out_drv.searchOp.dynamic+ + mat.ml_to_ram_wl_drv->power.readOp.dynamic; + + power.readOp.dynamic += htree_in_add->power.readOp.dynamic; + power.readOp.dynamic += htree_out_data->power.readOp.dynamic; + + power.searchOp.dynamic += htree_in_search->power.searchOp.dynamic; + power.searchOp.dynamic += htree_out_search->power.searchOp.dynamic; + + power.readOp.leakage += htree_in_add->power.readOp.leakage; + power.readOp.leakage += htree_in_data->power.readOp.leakage; + power.readOp.leakage += htree_out_data->power.readOp.leakage; + power.readOp.leakage += htree_in_search->power.readOp.leakage; + power.readOp.leakage += htree_out_search->power.readOp.leakage; + + + power.readOp.gate_leakage += htree_in_add->power.readOp.gate_leakage; + power.readOp.gate_leakage += htree_in_data->power.readOp.gate_leakage; + power.readOp.gate_leakage += htree_out_data->power.readOp.gate_leakage; + power.readOp.gate_leakage += htree_in_search->power.readOp.gate_leakage; + power.readOp.gate_leakage += htree_out_search->power.readOp.gate_leakage; + + } + +} + diff --git a/Project_FARSI/cacti_for_FARSI/bank.h b/Project_FARSI/cacti_for_FARSI/bank.h new file mode 100644 index 00000000..e12665f7 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/bank.h @@ -0,0 +1,74 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + + +#ifndef __BANK_H__ +#define __BANK_H__ + +#include "component.h" +#include "decoder.h" +#include "mat.h" +#include "htree2.h" + + +class Bank : public Component +{ + public: + Bank(const DynamicParameter & dyn_p); + ~Bank(); + double compute_delays(double inrisetime); // return outrisetime + void compute_power_energy(); + + const DynamicParameter & dp; + Mat mat; + Htree2 *htree_in_add; + Htree2 *htree_in_data; + Htree2 *htree_out_data; + Htree2 *htree_in_search; + Htree2 *htree_out_search; + + int num_addr_b_mat; + int num_mats_hor_dir; + int num_mats_ver_dir; + + int num_addr_b_row_dec; + int num_addr_b_routed_to_mat_for_act; + int num_addr_b_routed_to_mat_for_rd_or_wr; + + double array_leakage; + double wl_leakage; + double cl_leakage; +}; + + + +#endif diff --git a/Project_FARSI/cacti_for_FARSI/basic_circuit.cc b/Project_FARSI/cacti_for_FARSI/basic_circuit.cc new file mode 100644 index 00000000..696f45c4 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/basic_circuit.cc @@ -0,0 +1,999 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + + + +#include "basic_circuit.h" +#include "parameter.h" +#include +#include +#include + +uint32_t _log2(uint64_t num) +{ + uint32_t log2 = 0; + + if (num == 0) + { + std::cerr << "log0?" << std::endl; + exit(1); + } + + while (num > 1) + { + num = (num >> 1); + log2++; + } + + return log2; +} + + +bool is_pow2(int64_t val) +{ + if (val <= 0) + { + return false; + } + else if (val == 1) + { + return true; + } + else + { + return (_log2(val) != _log2(val-1)); + } +} + + +int powers (int base, int n) +{ + int i, p; + + p = 1; + for (i = 1; i <= n; ++i) + p *= base; + return p; +} + +/*----------------------------------------------------------------------*/ + +double logtwo (double x) +{ + assert(x > 0); + return ((double) (log (x) / log (2.0))); +} + +/*----------------------------------------------------------------------*/ + + +double gate_C( + double width, + double wirelength, + bool _is_dram, + bool _is_cell, + bool _is_wl_tr, + bool _is_sleep_tx) +{ + const /*TechnologyParameter::*/DeviceType * dt; + + if (_is_dram && _is_cell) + { + dt = &g_tp.dram_acc; //DRAM cell access transistor + } + else if (_is_dram && _is_wl_tr) + { + dt = &g_tp.dram_wl; //DRAM wordline transistor + } + else if (!_is_dram && _is_cell) + { + dt = &g_tp.sram_cell; // SRAM cell access transistor + } + else if (_is_sleep_tx) + { + dt = &g_tp.sleep_tx; // Sleep transistor + } + else + { + dt = &g_tp.peri_global; + } + + return (dt->C_g_ideal + dt->C_overlap + 3*dt->C_fringe)*width + dt->l_phy*Cpolywire; +} + + +// returns gate capacitance in Farads +// actually this function is the same as gate_C() now +double gate_C_pass( + double width, // gate width in um (length is Lphy_periph_global) + double wirelength, // poly wire length going to gate in lambda + bool _is_dram, + bool _is_cell, + bool _is_wl_tr, + bool _is_sleep_tx) +{ + // v5.0 + const /*TechnologyParameter::*/DeviceType * dt; + + if ((_is_dram) && (_is_cell)) + { + dt = &g_tp.dram_acc; //DRAM cell access transistor + } + else if ((_is_dram) && (_is_wl_tr)) + { + dt = &g_tp.dram_wl; //DRAM wordline transistor + } + else if ((!_is_dram) && _is_cell) + { + dt = &g_tp.sram_cell; // SRAM cell access transistor + } + else if (_is_sleep_tx) + { + dt = &g_tp.sleep_tx; // Sleep transistor + } + else + { + dt = &g_tp.peri_global; + } + + return (dt->C_g_ideal + dt->C_overlap + 3*dt->C_fringe)*width + dt->l_phy*Cpolywire; +} + + + +double drain_C_( + double width, + int nchannel, + int stack, + int next_arg_thresh_folding_width_or_height_cell, + double fold_dimension, + bool _is_dram, + bool _is_cell, + bool _is_wl_tr, + bool _is_sleep_tx) +{ + double w_folded_tr; + const /*TechnologyParameter::*/DeviceType * dt; + + if ((_is_dram) && (_is_cell)) + { + dt = &g_tp.dram_acc; // DRAM cell access transistor + } + else if ((_is_dram) && (_is_wl_tr)) + { + dt = &g_tp.dram_wl; // DRAM wordline transistor + } + else if ((!_is_dram) && _is_cell) + { + dt = &g_tp.sram_cell; // SRAM cell access transistor + } + else if (_is_sleep_tx) + { + dt = &g_tp.sleep_tx; // Sleep transistor + } + else + { + dt = &g_tp.peri_global; + } + + double c_junc_area = dt->C_junc; + double c_junc_sidewall = dt->C_junc_sidewall; + double c_fringe = 2*dt->C_fringe; + double c_overlap = 2*dt->C_overlap; + double drain_C_metal_connecting_folded_tr = 0; + + // determine the width of the transistor after folding (if it is getting folded) + if (next_arg_thresh_folding_width_or_height_cell == 0) + { // interpret fold_dimension as the the folding width threshold + // i.e. the value of transistor width above which the transistor gets folded + w_folded_tr = fold_dimension; + } + else + { // interpret fold_dimension as the height of the cell that this transistor is part of. + double h_tr_region = fold_dimension - 2 * g_tp.HPOWERRAIL; + // TODO : w_folded_tr must come from Component::compute_gate_area() + double ratio_p_to_n = 2.0 / (2.0 + 1.0); + if (nchannel) + { + w_folded_tr = (1 - ratio_p_to_n) * (h_tr_region - g_tp.MIN_GAP_BET_P_AND_N_DIFFS); + } + else + { + w_folded_tr = ratio_p_to_n * (h_tr_region - g_tp.MIN_GAP_BET_P_AND_N_DIFFS); + } + } + int num_folded_tr = (int) (ceil(width / w_folded_tr)); + + if (num_folded_tr < 2) + { + w_folded_tr = width; + } + + double total_drain_w = (g_tp.w_poly_contact + 2 * g_tp.spacing_poly_to_contact) + // only for drain + (stack - 1) * g_tp.spacing_poly_to_poly; + double drain_h_for_sidewall = w_folded_tr; + double total_drain_height_for_cap_wrt_gate = w_folded_tr + 2 * w_folded_tr * (stack - 1); + if (num_folded_tr > 1) + { + total_drain_w += (num_folded_tr - 2) * (g_tp.w_poly_contact + 2 * g_tp.spacing_poly_to_contact) + + (num_folded_tr - 1) * ((stack - 1) * g_tp.spacing_poly_to_poly); + + if (num_folded_tr%2 == 0) + { + drain_h_for_sidewall = 0; + } + total_drain_height_for_cap_wrt_gate *= num_folded_tr; + drain_C_metal_connecting_folded_tr = g_tp.wire_local.C_per_um * total_drain_w; + } + + double drain_C_area = c_junc_area * total_drain_w * w_folded_tr; + double drain_C_sidewall = c_junc_sidewall * (drain_h_for_sidewall + 2 * total_drain_w); + double drain_C_wrt_gate = (c_fringe + c_overlap) * total_drain_height_for_cap_wrt_gate; + + return (drain_C_area + drain_C_sidewall + drain_C_wrt_gate + drain_C_metal_connecting_folded_tr); +} + + +double tr_R_on( + double width, + int nchannel, + int stack, + bool _is_dram, + bool _is_cell, + bool _is_wl_tr, + bool _is_sleep_tx) +{ + const /*TechnologyParameter::*/DeviceType * dt; + + if ((_is_dram) && (_is_cell)) + { + dt = &g_tp.dram_acc; //DRAM cell access transistor + } + else if ((_is_dram) && (_is_wl_tr)) + { + dt = &g_tp.dram_wl; //DRAM wordline transistor + } + else if ((!_is_dram) && _is_cell) + { + dt = &g_tp.sram_cell; // SRAM cell access transistor + } + else if (_is_sleep_tx) + { + dt = &g_tp.sleep_tx; // Sleep transistor + } + else + { + dt = &g_tp.peri_global; + } + + double restrans = (nchannel) ? dt->R_nch_on : dt->R_pch_on; + return (stack * restrans / width); +} + + +/* This routine operates in reverse: given a resistance, it finds + * the transistor width that would have this R. It is used in the + * data wordline to estimate the wordline driver size. */ + +// returns width in um +double R_to_w( + double res, + int nchannel, + bool _is_dram, + bool _is_cell, + bool _is_wl_tr, + bool _is_sleep_tx) +{ + const /*TechnologyParameter::*/DeviceType * dt; + + if ((_is_dram) && (_is_cell)) + { + dt = &g_tp.dram_acc; //DRAM cell access transistor + } + else if ((_is_dram) && (_is_wl_tr)) + { + dt = &g_tp.dram_wl; //DRAM wordline transistor + } + else if ((!_is_dram) && (_is_cell)) + { + dt = &g_tp.sram_cell; // SRAM cell access transistor + } + else if (_is_sleep_tx) + { + dt = &g_tp.sleep_tx; // Sleep transistor + } + else + { + dt = &g_tp.peri_global; + } + + double restrans = (nchannel) ? dt->R_nch_on : dt->R_pch_on; + return (restrans / res); +} + + +double pmos_to_nmos_sz_ratio( + bool _is_dram, + bool _is_wl_tr, + bool _is_sleep_tx) +{ + double p_to_n_sizing_ratio; + if ((_is_dram) && (_is_wl_tr)) + { //DRAM wordline transistor + p_to_n_sizing_ratio = g_tp.dram_wl.n_to_p_eff_curr_drv_ratio; + } + else if (_is_sleep_tx) + { + p_to_n_sizing_ratio = g_tp.sleep_tx.n_to_p_eff_curr_drv_ratio; // Sleep transistor + } + else + { //DRAM or SRAM all other transistors + p_to_n_sizing_ratio = g_tp.peri_global.n_to_p_eff_curr_drv_ratio; + } + return p_to_n_sizing_ratio; +} + + +// "Timing Models for MOS Circuits" by Mark Horowitz, 1984 +double horowitz( + double inputramptime, // input rise time + double tf, // time constant of gate + double vs1, // threshold voltage + double vs2, // threshold voltage + int rise) // whether input rises or fall +{ + if (inputramptime == 0 && vs1 == vs2) + { + return tf * (vs1 < 1 ? -log(vs1) : log(vs1)); + } + double a, b, td; + + a = inputramptime / tf; + if (rise == RISE) + { + b = 0.5; + td = tf * sqrt(log(vs1)*log(vs1) + 2*a*b*(1.0 - vs1)) + tf*(log(vs1) - log(vs2)); + } + else + { + b = 0.4; + td = tf * sqrt(log(1.0 - vs1)*log(1.0 - vs1) + 2*a*b*(vs1)) + tf*(log(1.0 - vs1) - log(1.0 - vs2)); + } + return (td); +} + +double cmos_Ileak( + double nWidth, + double pWidth, + bool _is_dram, + bool _is_cell, + bool _is_wl_tr, + bool _is_sleep_tx) +{ + /*TechnologyParameter::*/DeviceType * dt; + + if ((!_is_dram)&&(_is_cell)) + { //SRAM cell access transistor + dt = &(g_tp.sram_cell); + } + else if ((_is_dram)&&(_is_wl_tr)) + { //DRAM wordline transistor + dt = &(g_tp.dram_wl); + } + else if (_is_sleep_tx) + { + dt = &g_tp.sleep_tx; // Sleep transistor + } + else + { //DRAM or SRAM all other transistors + dt = &(g_tp.peri_global); + } + return nWidth*dt->I_off_n + pWidth*dt->I_off_p; +} + +int factorial(int n, int m) +{ + int fa = m, i; + for (i=m+1; i<=n; i++) + fa *=i; + return fa; +} + +int combination(int n, int m) +{ + int ret; + ret = factorial(n, m+1) / factorial(n - m); + return ret; +} + +double simplified_nmos_Isat( + double nwidth, + bool _is_dram, + bool _is_cell, + bool _is_wl_tr, + bool _is_sleep_tx) +{ + /*TechnologyParameter::*/DeviceType * dt; + + if ((!_is_dram)&&(_is_cell)) + { //SRAM cell access transistor + dt = &(g_tp.sram_cell); + } + else if ((_is_dram)&&(_is_wl_tr)) + { //DRAM wordline transistor + dt = &(g_tp.dram_wl); + } + else if (_is_sleep_tx) + { + dt = &g_tp.sleep_tx; // Sleep transistor + } + else + { //DRAM or SRAM all other transistors + dt = &(g_tp.peri_global); + } + return nwidth * dt->I_on_n; +} + +double simplified_pmos_Isat( + double pwidth, + bool _is_dram, + bool _is_cell, + bool _is_wl_tr, + bool _is_sleep_tx) +{ + /*TechnologyParameter::*/DeviceType * dt; + + if ((!_is_dram)&&(_is_cell)) + { //SRAM cell access transistor + dt = &(g_tp.sram_cell); + } + else if ((_is_dram)&&(_is_wl_tr)) + { //DRAM wordline transistor + dt = &(g_tp.dram_wl); + } + else if (_is_sleep_tx) + { + dt = &g_tp.sleep_tx; // Sleep transistor + } + else + { //DRAM or SRAM all other transistors + dt = &(g_tp.peri_global); + } + return pwidth * dt->I_on_n/dt->n_to_p_eff_curr_drv_ratio; +} + + +double simplified_nmos_leakage( + double nwidth, + bool _is_dram, + bool _is_cell, + bool _is_wl_tr, + bool _is_sleep_tx) +{ + /*TechnologyParameter::*/DeviceType * dt; + + if ((!_is_dram)&&(_is_cell)) + { //SRAM cell access transistor + dt = &(g_tp.sram_cell); + } + else if ((_is_dram)&&(_is_wl_tr)) + { //DRAM wordline transistor + dt = &(g_tp.dram_wl); + } + else if (_is_sleep_tx) + { + dt = &g_tp.sleep_tx; // Sleep transistor + } + else + { //DRAM or SRAM all other transistors + dt = &(g_tp.peri_global); + } + return nwidth * dt->I_off_n; +} + +double simplified_pmos_leakage( + double pwidth, + bool _is_dram, + bool _is_cell, + bool _is_wl_tr, + bool _is_sleep_tx) +{ + /*TechnologyParameter::*/DeviceType * dt; + + if ((!_is_dram)&&(_is_cell)) + { //SRAM cell access transistor + dt = &(g_tp.sram_cell); + } + else if ((_is_dram)&&(_is_wl_tr)) + { //DRAM wordline transistor + dt = &(g_tp.dram_wl); + } + else if (_is_sleep_tx) + { + dt = &g_tp.sleep_tx; // Sleep transistor + } + else + { //DRAM or SRAM all other transistors + dt = &(g_tp.peri_global); + } + return pwidth * dt->I_off_p; +} + +double cmos_Ig_n( + double nWidth, + bool _is_dram, + bool _is_cell, + bool _is_wl_tr, + bool _is_sleep_tx) +{ + /*TechnologyParameter::*/DeviceType * dt; + + if ((!_is_dram)&&(_is_cell)) + { //SRAM cell access transistor + dt = &(g_tp.sram_cell); + } + else if ((_is_dram)&&(_is_wl_tr)) + { //DRAM wordline transistor + dt = &(g_tp.dram_wl); + } + else if (_is_sleep_tx) + { + dt = &g_tp.sleep_tx; // Sleep transistor + } + else + { //DRAM or SRAM all other transistors + dt = &(g_tp.peri_global); + } + return nWidth*dt->I_g_on_n; +} + +double cmos_Ig_p( + double pWidth, + bool _is_dram, + bool _is_cell, + bool _is_wl_tr, + bool _is_sleep_tx) +{ + /*TechnologyParameter::*/DeviceType * dt; + + if ((!_is_dram)&&(_is_cell)) + { //SRAM cell access transistor + dt = &(g_tp.sram_cell); + } + else if ((_is_dram)&&(_is_wl_tr)) + { //DRAM wordline transistor + dt = &(g_tp.dram_wl); + } + else if (_is_sleep_tx) + { + dt = &g_tp.sleep_tx; // Sleep transistor + } + else + { //DRAM or SRAM all other transistors + dt = &(g_tp.peri_global); + } + return pWidth*dt->I_g_on_p; +} + +double cmos_Isub_leakage( + double nWidth, + double pWidth, + int fanin, + enum Gate_type g_type, + bool _is_dram, + bool _is_cell, + bool _is_wl_tr, + bool _is_sleep_tx, + enum Half_net_topology topo) +{ + assert (fanin>=1); + double nmos_leak = simplified_nmos_leakage(nWidth, _is_dram, _is_cell, _is_wl_tr, _is_sleep_tx); + double pmos_leak = simplified_pmos_leakage(pWidth, _is_dram, _is_cell, _is_wl_tr, _is_sleep_tx); + double Isub=0; + int num_states; + int num_off_tx; + + num_states = int(pow(2.0, fanin)); + + switch (g_type) + { + case nmos: + if (fanin==1) + { + Isub = nmos_leak/num_states; + } + else + { + if (topo==parallel) + { + Isub=nmos_leak*fanin/num_states; //only when all tx are off, leakage power is non-zero. The possibility of this state is 1/num_states + } + else + { + for (num_off_tx=1; num_off_tx<=fanin; num_off_tx++) //when num_off_tx ==0 there is no leakage power + { + //Isub += nmos_leak*pow(UNI_LEAK_STACK_FACTOR,(num_off_tx-1))*(factorial(fanin)/(factorial(fanin, num_off_tx)*factorial(num_off_tx))); + Isub += nmos_leak*pow(UNI_LEAK_STACK_FACTOR,(num_off_tx-1))*combination(fanin, num_off_tx); + } + Isub /=num_states; + } + + } + break; + case pmos: + if (fanin==1) + { + Isub = pmos_leak/num_states; + } + else + { + if (topo==parallel) + { + Isub=pmos_leak*fanin/num_states; //only when all tx are off, leakage power is non-zero. The possibility of this state is 1/num_states + } + else + { + for (num_off_tx=1; num_off_tx<=fanin; num_off_tx++) //when num_off_tx ==0 there is no leakage power + { + //Isub += pmos_leak*pow(UNI_LEAK_STACK_FACTOR,(num_off_tx-1))*(factorial(fanin)/(factorial(fanin, num_off_tx)*factorial(num_off_tx))); + Isub += pmos_leak*pow(UNI_LEAK_STACK_FACTOR,(num_off_tx-1))*combination(fanin, num_off_tx); + } + Isub /=num_states; + } + + } + break; + case inv: + Isub = (nmos_leak + pmos_leak)/2; + break; + case nand: + Isub += fanin*pmos_leak;//the pullup network + for (num_off_tx=1; num_off_tx<=fanin; num_off_tx++) // the pulldown network + { + //Isub += nmos_leak*pow(UNI_LEAK_STACK_FACTOR,(num_off_tx-1))*(factorial(fanin)/(factorial(fanin, num_off_tx)*factorial(num_off_tx))); + Isub += nmos_leak*pow(UNI_LEAK_STACK_FACTOR,(num_off_tx-1))*combination(fanin, num_off_tx); + } + Isub /=num_states; + break; + case nor: + for (num_off_tx=1; num_off_tx<=fanin; num_off_tx++) // the pullup network + { + //Isub += pmos_leak*pow(UNI_LEAK_STACK_FACTOR,(num_off_tx-1))*(factorial(fanin)/(factorial(fanin, num_off_tx)*factorial(num_off_tx))); + Isub += pmos_leak*pow(UNI_LEAK_STACK_FACTOR,(num_off_tx-1))*combination(fanin, num_off_tx); + } + Isub += fanin*nmos_leak;//the pulldown network + Isub /=num_states; + break; + case tri: + Isub += (nmos_leak + pmos_leak)/2;//enabled + Isub += nmos_leak*UNI_LEAK_STACK_FACTOR; //disabled upper bound of leakage power + Isub /=2; + break; + case tg: + Isub = (nmos_leak + pmos_leak)/2; + break; + default: + assert(0); + break; + } + + return Isub; +} + + +double cmos_Ig_leakage( + double nWidth, + double pWidth, + int fanin, + enum Gate_type g_type, + bool _is_dram, + bool _is_cell, + bool _is_wl_tr, + bool _is_sleep_tx, + enum Half_net_topology topo) +{ + assert (fanin>=1); + double nmos_leak = cmos_Ig_n(nWidth, _is_dram, _is_cell, _is_wl_tr, _is_sleep_tx); + double pmos_leak = cmos_Ig_p(pWidth, _is_dram, _is_cell, _is_wl_tr, _is_sleep_tx); + double Ig_on=0; + int num_states; + int num_on_tx; + + num_states = int(pow(2.0, fanin)); + + switch (g_type) + { + case nmos: + if (fanin==1) + { + Ig_on = nmos_leak/num_states; + } + else + { + if (topo==parallel) + { + for (num_on_tx=1; num_on_tx<=fanin; num_on_tx++) + { + Ig_on += nmos_leak*combination(fanin, num_on_tx)*num_on_tx; + } + } + else + { + Ig_on += nmos_leak * fanin;//pull down network when all TXs are on. + //num_on_tx is the number of on tx + for (num_on_tx=1; num_on_txprint_detail_debug) + { + cout<<"TSV ox cap: "<1 transitions) per cycle (for DDR, need to account for the higher activity in this parameter. E.g. max. activity factor for DDR is 1.0, for SDR is 0.5) + +-activity_dq 1.0 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR + +# Activity factor for Control/Address (0->1 transitions) per cycle (for DDR, need to account for the higher activity in this parameter. E.g. max. activity factor for DDR is 1.0, for SDR is 0.5) + +-activity_ca 0.5 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR, 0 to 0.25 for 2T, and 0 to 0.17 for 3T + +# Number of DQ pins + +-num_dq 72 //Number of DQ pins. Includes ECC pins. + +# Number of DQS pins. DQS is a data strobe that is sent along with a small number of data-lanes so the source synchronous timing is local to these DQ bits. Typically, 1 DQS per byte (8 DQ bits) is used. The DQS is also typucally differential, just like the CLK pin. + +-num_dqs 18 //2 x differential pairs. Include ECC pins as well. Valid range 0 to 18. For x4 memories, could have 36 DQS pins. + +# Number of CA pins + +-num_ca 25 //Valid range 0 to 35 pins. + +# Number of CLK pins. CLK is typically a differential pair. In some cases additional CLK pairs may be used to limit the loading on the CLK pin. + +-num_clk 2 //2 x differential pair. Valid values: 0/2/4. + +# Number of Physical Ranks + +-num_mem_dq 2 //Number of ranks (loads on DQ and DQS) per buffer/register. If multiple LRDIMMs or buffer chips exist, the analysis for capacity and power is reported per buffer/register. + +# Width of the Memory Data Bus + +-mem_data_width 8 //x4 or x8 or x16 or x32 memories. For WideIO upto x128. + +# RTT Termination Resistance + +-rtt_value 10000 + +# RON Termination Resistance + +-ron_value 34 + +# Time of flight for DQ + +-tflight_value + +# Parameter related to MemCAD + +# Number of BoBs: 1,2,3,4,5,6, +-num_bobs 1 + +# Memory System Capacity in GB +-capacity 80 + +# Number of Channel per BoB: 1,2. +-num_channels_per_bob 1 + +# First Metric for ordering different design points +-first metric "Cost" +#-first metric "Bandwidth" +#-first metric "Energy" + +# Second Metric for ordering different design points +#-second metric "Cost" +-second metric "Bandwidth" +#-second metric "Energy" + +# Third Metric for ordering different design points +#-third metric "Cost" +#-third metric "Bandwidth" +-third metric "Energy" + + +# Possible DIMM option to consider +#-DIMM model "JUST_UDIMM" +#-DIMM model "JUST_RDIMM" +#-DIMM model "JUST_LRDIMM" +-DIMM model "ALL" + +#if channels of each bob have the same configurations +#-mirror_in_bob "T" +-mirror_in_bob "F" + +#if we want to see all channels/bobs/memory configurations explored +#-verbose "T" +#-verbose "F" + diff --git a/Project_FARSI/cacti_for_FARSI/cacti b/Project_FARSI/cacti_for_FARSI/cacti new file mode 100755 index 00000000..334437c3 Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/cacti differ diff --git a/Project_FARSI/cacti_for_FARSI/cacti.i b/Project_FARSI/cacti_for_FARSI/cacti.i new file mode 100644 index 00000000..79641387 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/cacti.i @@ -0,0 +1,8 @@ +%module cacti +%{ +/* Includes the header in the wrapper code */ +#include "cacti_interface.h" +%} + +/* Parse the header file to generate wrappers */ +%include "cacti_interface.h" \ No newline at end of file diff --git a/Project_FARSI/cacti_for_FARSI/cacti.mk b/Project_FARSI/cacti_for_FARSI/cacti.mk new file mode 100644 index 00000000..f06f65a2 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/cacti.mk @@ -0,0 +1,53 @@ +TARGET = cacti +SHELL = /bin/sh +.PHONY: all depend clean +.SUFFIXES: .cc .o + +ifndef NTHREADS + NTHREADS = 8 +endif + + +LIBS = +INCS = -lm + +ifeq ($(TAG),dbg) + DBG = -Wall + OPT = -ggdb -g -O0 -DNTHREADS=1 +else + DBG = + OPT = -g -msse2 -mfpmath=sse -DNTHREADS=$(NTHREADS) +endif + +#CXXFLAGS = -Wall -Wno-unknown-pragmas -Winline $(DBG) $(OPT) +CXXFLAGS = -Wno-unknown-pragmas $(DBG) $(OPT) +CXX = g++ -m64 +CC = gcc -m64 + +SRCS = area.cc bank.cc mat.cc main.cc Ucache.cc io.cc technology.cc basic_circuit.cc parameter.cc \ + decoder.cc component.cc uca.cc subarray.cc wire.cc htree2.cc extio.cc extio_technology.cc \ + cacti_interface.cc router.cc nuca.cc crossbar.cc arbiter.cc powergating.cc TSV.cc memorybus.cc \ + memcad.cc memcad_parameters.cc + + +OBJS = $(patsubst %.cc,obj_$(TAG)/%.o,$(SRCS)) +PYTHONLIB_SRCS = $(patsubst main.cc, ,$(SRCS)) obj_$(TAG)/cacti_wrap.cc +PYTHONLIB_OBJS = $(patsubst %.cc,%.o,$(PYTHONLIB_SRCS)) +INCLUDES = -I /usr/include/python2.4 -I /usr/lib/python2.4/config + +all: obj_$(TAG)/$(TARGET) + cp -f obj_$(TAG)/$(TARGET) $(TARGET) + +obj_$(TAG)/$(TARGET) : $(OBJS) + $(CXX) $(OBJS) -o $@ $(INCS) $(CXXFLAGS) $(LIBS) -pthread + +#obj_$(TAG)/%.o : %.cc +# $(CXX) -c $(CXXFLAGS) $(INCS) -o $@ $< + +obj_$(TAG)/%.o : %.cc + $(CXX) $(CXXFLAGS) -c $< -o $@ + +clean: + -rm -f *.o _cacti.so cacti.py $(TARGET) + + diff --git a/Project_FARSI/cacti_for_FARSI/cacti_interface.cc b/Project_FARSI/cacti_for_FARSI/cacti_interface.cc new file mode 100644 index 00000000..763b1d6f --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/cacti_interface.cc @@ -0,0 +1,174 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + +#include +#include + + +#include "area.h" +#include "basic_circuit.h" +#include "component.h" +#include "const.h" +#include "parameter.h" +#include "cacti_interface.h" +#include "Ucache.h" + +#include +#include +#include + +using namespace std; + + +bool mem_array::lt(const mem_array * m1, const mem_array * m2) +{ + if (m1->Nspd < m2->Nspd) return true; + else if (m1->Nspd > m2->Nspd) return false; + else if (m1->Ndwl < m2->Ndwl) return true; + else if (m1->Ndwl > m2->Ndwl) return false; + else if (m1->Ndbl < m2->Ndbl) return true; + else if (m1->Ndbl > m2->Ndbl) return false; + else if (m1->deg_bl_muxing < m2->deg_bl_muxing) return true; + else if (m1->deg_bl_muxing > m2->deg_bl_muxing) return false; + else if (m1->Ndsam_lev_1 < m2->Ndsam_lev_1) return true; + else if (m1->Ndsam_lev_1 > m2->Ndsam_lev_1) return false; + else if (m1->Ndsam_lev_2 < m2->Ndsam_lev_2) return true; + else return false; +} + + + +void uca_org_t::find_delay() +{ + mem_array * data_arr = data_array2; + mem_array * tag_arr = tag_array2; + + // check whether it is a regular cache or scratch ram + if (g_ip->pure_ram|| g_ip->pure_cam || g_ip->fully_assoc) + { + access_time = data_arr->access_time; + } + // Both tag and data lookup happen in parallel + // and the entire set is sent over the data array h-tree without + // waiting for the way-select signal --TODO add the corresponding + // power overhead Nav + else if (g_ip->fast_access == true) + { + access_time = MAX(tag_arr->access_time, data_arr->access_time); + } + // Tag is accessed first. On a hit, way-select signal along with the + // address is sent to read/write the appropriate block in the data + // array + else if (g_ip->is_seq_acc == true) + { + access_time = tag_arr->access_time + data_arr->access_time; + } + // Normal access: tag array access and data array access happen in parallel. + // But, the data array will wait for the way-select and transfer only the + // appropriate block over the h-tree. + else + { + access_time = MAX(tag_arr->access_time + data_arr->delay_senseamp_mux_decoder, + data_arr->delay_before_subarray_output_driver) + + data_arr->delay_from_subarray_output_driver_to_output; + } +} + + + +void uca_org_t::find_energy() +{ + if (!(g_ip->pure_ram|| g_ip->pure_cam || g_ip->fully_assoc))//(g_ip->is_cache) + power = data_array2->power + tag_array2->power; + else + power = data_array2->power; +} + + + +void uca_org_t::find_area() +{ + if (g_ip->pure_ram|| g_ip->pure_cam || g_ip->fully_assoc)//(g_ip->is_cache == false) + { + cache_ht = data_array2->height; + cache_len = data_array2->width; + } + else + { + cache_ht = MAX(tag_array2->height, data_array2->height); + cache_len = tag_array2->width + data_array2->width; + } + area = cache_ht * cache_len; +} + +void uca_org_t::adjust_area() +{ + double area_adjust; + if (g_ip->pure_ram|| g_ip->pure_cam || g_ip->fully_assoc) + { + if (data_array2->area_efficiency/100.0<0.2) + { + //area_adjust = sqrt(area/(area*(data_array2->area_efficiency/100.0)/0.2)); + area_adjust = sqrt(0.2/(data_array2->area_efficiency/100.0)); + cache_ht = cache_ht/area_adjust; + cache_len = cache_len/area_adjust; + } + } + area = cache_ht * cache_len; +} + +void uca_org_t::find_cyc() +{ + if ((g_ip->pure_ram|| g_ip->pure_cam || g_ip->fully_assoc))//(g_ip->is_cache == false) + { + cycle_time = data_array2->cycle_time; + } + else + { + cycle_time = MAX(tag_array2->cycle_time, + data_array2->cycle_time); + } +} + +uca_org_t :: uca_org_t() +:tag_array2(0), + data_array2(0) +{ + +} + +void uca_org_t :: cleanup() +{ + if (data_array2!=0) + delete data_array2; + if (tag_array2!=0) + delete tag_array2; +} diff --git a/Project_FARSI/cacti_for_FARSI/cacti_interface.h b/Project_FARSI/cacti_for_FARSI/cacti_interface.h new file mode 100644 index 00000000..a2b8e2d2 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/cacti_interface.h @@ -0,0 +1,904 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + + +#ifndef __CACTI_INTERFACE_H__ +#define __CACTI_INTERFACE_H__ + +#include +#include +#include +#include +#include +#include "const.h" + +using namespace std; + + +class min_values_t; +class mem_array; +class uca_org_t; + + +class powerComponents +{ + public: + double dynamic; + double leakage; + double gate_leakage; + double short_circuit; + double longer_channel_leakage; + + powerComponents() : dynamic(0), leakage(0), gate_leakage(0), short_circuit(0), longer_channel_leakage(0) { } + powerComponents(const powerComponents & obj) { *this = obj; } + powerComponents & operator=(const powerComponents & rhs) + { + dynamic = rhs.dynamic; + leakage = rhs.leakage; + gate_leakage = rhs.gate_leakage; + short_circuit = rhs.short_circuit; + longer_channel_leakage = rhs.longer_channel_leakage; + return *this; + } + void reset() { dynamic = 0; leakage = 0; gate_leakage = 0; short_circuit = 0;longer_channel_leakage = 0;} + + friend powerComponents operator+(const powerComponents & x, const powerComponents & y); + friend powerComponents operator*(const powerComponents & x, double const * const y); +}; + + + +class powerDef +{ + public: + powerComponents readOp; + powerComponents writeOp; + powerComponents searchOp;//: for CAM and FA + + powerDef() : readOp(), writeOp(), searchOp() { } + void reset() { readOp.reset(); writeOp.reset(); searchOp.reset();} + + friend powerDef operator+(const powerDef & x, const powerDef & y); + friend powerDef operator*(const powerDef & x, double const * const y); +}; + +enum Wire_type +{ + Global /* gloabl wires with repeaters */, + Global_5 /* 5% delay penalty */, + Global_10 /* 10% delay penalty */, + Global_20 /* 20% delay penalty */, + Global_30 /* 30% delay penalty */, + Low_swing /* differential low power wires with high area overhead */, + Semi_global /* mid-level wires with repeaters*/, + Full_swing /* models all global wires with different latencies (Global_x )*/, + Transmission /* tranmission lines with high area overhead */, + Optical /* optical wires */, + Invalid_wtype +}; + +enum TSV_type +{ + Fine, /*ITRS high density*/ + Coarse /*Industry reported in 2010*/ +}; + +// ali + +enum Mem_IO_type +{ + DDR3, + DDR4, + LPDDR2, + WideIO, + Low_Swing_Diff, + Serial +}; + +enum Mem_DIMM +{ + UDIMM, + RDIMM, + LRDIMM +}; + +enum Mem_state +{ + READ, + WRITE, + IDLE, + SLEEP +}; + +enum Mem_ECC +{ + NO_ECC, + SECDED, // single error correction, double error detection + CHIP_KILL +}; + +enum DIMM_Model +{ + JUST_UDIMM,JUST_RDIMM,JUST_LRDIMM,ALL +}; + +enum MemCad_metrics +{ + Bandwidth, Energy, Cost +}; + +/** +enum BoB_LINK +{ + PARALLEL, // e.g. Intel SMB c104 + SERIAL // e.g. Intel SMB 7510, IBM AMB +}; +**/ +// end ali + + +class InputParameter +{ + public: + + InputParameter(); + void parse_cfg(const string & infile); + + bool error_checking(); // return false if the input parameters are problematic + void display_ip(); + + unsigned int cache_sz; // in bytes + unsigned int line_sz; + unsigned int assoc; + unsigned int nbanks; + unsigned int out_w;// == nr_bits_out + bool specific_tag; + unsigned int tag_w; + unsigned int access_mode; + unsigned int obj_func_dyn_energy; + unsigned int obj_func_dyn_power; + unsigned int obj_func_leak_power; + unsigned int obj_func_cycle_t; + + double F_sz_nm; // feature size in nm + double F_sz_um; // feature size in um + unsigned int num_rw_ports; + unsigned int num_rd_ports; + unsigned int num_wr_ports; + unsigned int num_se_rd_ports; // number of single ended read ports + unsigned int num_search_ports; // : number of search ports for CAM + bool is_main_mem; + bool is_3d_mem; + bool print_detail_debug; + bool is_cache; + bool pure_ram; + bool pure_cam; + bool rpters_in_htree; // if there are repeaters in htree segment + unsigned int ver_htree_wires_over_array; + unsigned int broadcast_addr_din_over_ver_htrees; + unsigned int temp; + + unsigned int ram_cell_tech_type; + unsigned int peri_global_tech_type; + unsigned int data_arr_ram_cell_tech_type; + unsigned int data_arr_peri_global_tech_type; + unsigned int tag_arr_ram_cell_tech_type; + unsigned int tag_arr_peri_global_tech_type; + + unsigned int burst_len; + unsigned int int_prefetch_w; + unsigned int page_sz_bits; + + unsigned int num_die_3d; + unsigned int burst_depth; + unsigned int io_width; + unsigned int sys_freq_MHz; + + unsigned int tsv_is_subarray_type; + unsigned int tsv_os_bank_type; + unsigned int TSV_proj_type; + + int partition_gran; + unsigned int num_tier_row_sprd; + unsigned int num_tier_col_sprd; + bool fine_gran_bank_lvl; + + unsigned int ic_proj_type; // interconnect_projection_type + unsigned int wire_is_mat_type; // wire_inside_mat_type + unsigned int wire_os_mat_type; // wire_outside_mat_type + enum Wire_type wt; + int force_wiretype; + bool print_input_args; + unsigned int nuca_cache_sz; // TODO + int ndbl, ndwl, nspd, ndsam1, ndsam2, ndcm; + bool force_cache_config; + + int cache_level; + int cores; + int nuca_bank_count; + int force_nuca_bank; + + int delay_wt, dynamic_power_wt, leakage_power_wt, + cycle_time_wt, area_wt; + int delay_wt_nuca, dynamic_power_wt_nuca, leakage_power_wt_nuca, + cycle_time_wt_nuca, area_wt_nuca; + + int delay_dev, dynamic_power_dev, leakage_power_dev, + cycle_time_dev, area_dev; + int delay_dev_nuca, dynamic_power_dev_nuca, leakage_power_dev_nuca, + cycle_time_dev_nuca, area_dev_nuca; + int ed; //ED or ED2 optimization + int nuca; + + bool fast_access; + unsigned int block_sz; // bytes + unsigned int tag_assoc; + unsigned int data_assoc; + bool is_seq_acc; + bool fully_assoc; + unsigned int nsets; // == number_of_sets + int print_detail; + + + bool add_ecc_b_; + //parameters for design constraint + double throughput; + double latency; + bool pipelinable; + int pipeline_stages; + int per_stage_vector; + bool with_clock_grid; + + bool array_power_gated; + bool bitline_floating; + bool wl_power_gated; + bool cl_power_gated; + bool interconect_power_gated; + bool power_gating; + + double perfloss; + + bool cl_vertical; + + // Parameters related to off-chip I/O + + double addr_timing, duty_cycle, mem_density, bus_bw, activity_dq, activity_ca, bus_freq; + int mem_data_width, num_mem_dq, num_clk, num_ca, num_dqs, num_dq; + + double rtt_value, ron_value, tflight_value; //FIXME + + Mem_state iostate; + + ///char iostate, dram_ecc, io_type; + + Mem_ECC dram_ecc; + Mem_IO_type io_type; + Mem_DIMM dram_dimm; + + int num_bobs; // BoB is buffer-on-board such as Intel SMB c102 + + int capacity; // in GB + + int num_channels_per_bob; // 1 means no bob + + MemCad_metrics first_metric; + + MemCad_metrics second_metric; + + MemCad_metrics third_metric; + + DIMM_Model dimm_model; + + bool low_power_permitted; // Not yet implemented. It determines acceptable VDDs. + + double load; // between 0 to 1 + + double row_buffer_hit_rate; + + double rd_2_wr_ratio; + + bool same_bw_in_bob; // true if all the channels in the bob have the same bandwidth. + + bool mirror_in_bob;// true if all the channels in the bob have the same configs + + bool total_power; // false means just considering I/O Power + + bool verbose; + + + +}; + + +typedef struct{ + int Ndwl; + int Ndbl; + double Nspd; + int deg_bl_muxing; + int Ndsam_lev_1; + int Ndsam_lev_2; + int number_activated_mats_horizontal_direction; + int number_subbanks; + int page_size_in_bits; + double delay_route_to_bank; + double delay_crossbar; + double delay_addr_din_horizontal_htree; + double delay_addr_din_vertical_htree; + double delay_row_predecode_driver_and_block; + double delay_row_decoder; + double delay_bitlines; + double delay_sense_amp; + double delay_subarray_output_driver; + double delay_bit_mux_predecode_driver_and_block; + double delay_bit_mux_decoder; + double delay_senseamp_mux_lev_1_predecode_driver_and_block; + double delay_senseamp_mux_lev_1_decoder; + double delay_senseamp_mux_lev_2_predecode_driver_and_block; + double delay_senseamp_mux_lev_2_decoder; + double delay_input_htree; + double delay_output_htree; + double delay_dout_vertical_htree; + double delay_dout_horizontal_htree; + double delay_comparator; + double access_time; + double cycle_time; + double multisubbank_interleave_cycle_time; + double delay_request_network; + double delay_inside_mat; + double delay_reply_network; + double trcd; + double cas_latency; + double precharge_delay; + powerDef power_routing_to_bank; + powerDef power_addr_input_htree; + powerDef power_data_input_htree; + powerDef power_data_output_htree; + powerDef power_addr_horizontal_htree; + powerDef power_datain_horizontal_htree; + powerDef power_dataout_horizontal_htree; + powerDef power_addr_vertical_htree; + powerDef power_datain_vertical_htree; + powerDef power_row_predecoder_drivers; + powerDef power_row_predecoder_blocks; + powerDef power_row_decoders; + powerDef power_bit_mux_predecoder_drivers; + powerDef power_bit_mux_predecoder_blocks; + powerDef power_bit_mux_decoders; + powerDef power_senseamp_mux_lev_1_predecoder_drivers; + powerDef power_senseamp_mux_lev_1_predecoder_blocks; + powerDef power_senseamp_mux_lev_1_decoders; + powerDef power_senseamp_mux_lev_2_predecoder_drivers; + powerDef power_senseamp_mux_lev_2_predecoder_blocks; + powerDef power_senseamp_mux_lev_2_decoders; + powerDef power_bitlines; + powerDef power_sense_amps; + powerDef power_prechg_eq_drivers; + powerDef power_output_drivers_at_subarray; + powerDef power_dataout_vertical_htree; + powerDef power_comparators; + powerDef power_crossbar; + powerDef total_power; + double area; + double all_banks_height; + double all_banks_width; + double bank_height; + double bank_width; + double subarray_memory_cell_area_height; + double subarray_memory_cell_area_width; + double mat_height; + double mat_width; + double routing_area_height_within_bank; + double routing_area_width_within_bank; + double area_efficiency; +// double perc_power_dyn_routing_to_bank; +// double perc_power_dyn_addr_horizontal_htree; +// double perc_power_dyn_datain_horizontal_htree; +// double perc_power_dyn_dataout_horizontal_htree; +// double perc_power_dyn_addr_vertical_htree; +// double perc_power_dyn_datain_vertical_htree; +// double perc_power_dyn_row_predecoder_drivers; +// double perc_power_dyn_row_predecoder_blocks; +// double perc_power_dyn_row_decoders; +// double perc_power_dyn_bit_mux_predecoder_drivers; +// double perc_power_dyn_bit_mux_predecoder_blocks; +// double perc_power_dyn_bit_mux_decoders; +// double perc_power_dyn_senseamp_mux_lev_1_predecoder_drivers; +// double perc_power_dyn_senseamp_mux_lev_1_predecoder_blocks; +// double perc_power_dyn_senseamp_mux_lev_1_decoders; +// double perc_power_dyn_senseamp_mux_lev_2_predecoder_drivers; +// double perc_power_dyn_senseamp_mux_lev_2_predecoder_blocks; +// double perc_power_dyn_senseamp_mux_lev_2_decoders; +// double perc_power_dyn_bitlines; +// double perc_power_dyn_sense_amps; +// double perc_power_dyn_prechg_eq_drivers; +// double perc_power_dyn_subarray_output_drivers; +// double perc_power_dyn_dataout_vertical_htree; +// double perc_power_dyn_comparators; +// double perc_power_dyn_crossbar; +// double perc_power_dyn_spent_outside_mats; +// double perc_power_leak_routing_to_bank; +// double perc_power_leak_addr_horizontal_htree; +// double perc_power_leak_datain_horizontal_htree; +// double perc_power_leak_dataout_horizontal_htree; +// double perc_power_leak_addr_vertical_htree; +// double perc_power_leak_datain_vertical_htree; +// double perc_power_leak_row_predecoder_drivers; +// double perc_power_leak_row_predecoder_blocks; +// double perc_power_leak_row_decoders; +// double perc_power_leak_bit_mux_predecoder_drivers; +// double perc_power_leak_bit_mux_predecoder_blocks; +// double perc_power_leak_bit_mux_decoders; +// double perc_power_leak_senseamp_mux_lev_1_predecoder_drivers; +// double perc_power_leak_senseamp_mux_lev_1_predecoder_blocks; +// double perc_power_leak_senseamp_mux_lev_1_decoders; +// double perc_power_leak_senseamp_mux_lev_2_predecoder_drivers; +// double perc_power_leak_senseamp_mux_lev_2_predecoder_blocks; +// double perc_power_leak_senseamp_mux_lev_2_decoders; +// double perc_power_leak_bitlines; +// double perc_power_leak_sense_amps; +// double perc_power_leak_prechg_eq_drivers; +// double perc_power_leak_subarray_output_drivers; +// double perc_power_leak_dataout_vertical_htree; +// double perc_power_leak_comparators; +// double perc_power_leak_crossbar; +// double perc_leak_mats; +// double perc_active_mats; + double refresh_power; + double dram_refresh_period; + double dram_array_availability; + double dyn_read_energy_from_closed_page; + double dyn_read_energy_from_open_page; + double leak_power_subbank_closed_page; + double leak_power_subbank_open_page; + double leak_power_request_and_reply_networks; + double activate_energy; + double read_energy; + double write_energy; + double precharge_energy; +} results_mem_array; + + +class uca_org_t +{ + public: + mem_array * tag_array2; + mem_array * data_array2; + double access_time; + double cycle_time; + double area; + double area_efficiency; + powerDef power; + double leak_power_with_sleep_transistors_in_mats; + double cache_ht; + double cache_len; + char file_n[100]; + double vdd_periph_global; + bool valid; + results_mem_array tag_array; + results_mem_array data_array; + + uca_org_t(); + void find_delay(); + void find_energy(); + void find_area(); + void find_cyc(); + void adjust_area();//for McPAT only to adjust routing overhead + void cleanup(); + ~uca_org_t(){}; +}; + + +class IO_org_t +{ + public: + double io_area; + double io_timing_margin; + double io_voltage_margin; + double io_dynamic_power; + double io_phy_power; + double io_wakeup_time; + double io_termination_power; + IO_org_t():io_area(0),io_timing_margin(0),io_voltage_margin(0) + ,io_dynamic_power(0),io_phy_power(0),io_wakeup_time(0),io_termination_power(0) + {} +}; + + +void reconfigure(InputParameter *local_interface, uca_org_t *fin_res); + +uca_org_t cacti_interface(const string & infile_name); +//McPAT's plain interface, please keep !!! +uca_org_t cacti_interface(InputParameter * const local_interface); +//McPAT's plain interface, please keep !!! +uca_org_t init_interface(InputParameter * const local_interface); +//McPAT's plain interface, please keep !!! +uca_org_t cacti_interface( + int cache_size, + int line_size, + int associativity, + int rw_ports, + int excl_read_ports, + int excl_write_ports, + int single_ended_read_ports, + int search_ports, + int banks, + double tech_node, + int output_width, + int specific_tag, + int tag_width, + int access_mode, + int cache, + int main_mem, + int obj_func_delay, + int obj_func_dynamic_power, + int obj_func_leakage_power, + int obj_func_cycle_time, + int obj_func_area, + int dev_func_delay, + int dev_func_dynamic_power, + int dev_func_leakage_power, + int dev_func_area, + int dev_func_cycle_time, + int ed_ed2_none, // 0 - ED, 1 - ED^2, 2 - use weight and deviate + int temp, + int wt, //0 - default(search across everything), 1 - global, 2 - 5% delay penalty, 3 - 10%, 4 - 20 %, 5 - 30%, 6 - low-swing + int data_arr_ram_cell_tech_flavor_in, + int data_arr_peri_global_tech_flavor_in, + int tag_arr_ram_cell_tech_flavor_in, + int tag_arr_peri_global_tech_flavor_in, + int interconnect_projection_type_in, + int wire_inside_mat_type_in, + int wire_outside_mat_type_in, + int REPEATERS_IN_HTREE_SEGMENTS_in, + int VERTICAL_HTREE_WIRES_OVER_THE_ARRAY_in, + int BROADCAST_ADDR_DATAIN_OVER_VERTICAL_HTREES_in, + int PAGE_SIZE_BITS_in, + int BURST_LENGTH_in, + int INTERNAL_PREFETCH_WIDTH_in, + int force_wiretype, + int wiretype, + int force_config, + int ndwl, + int ndbl, + int nspd, + int ndcm, + int ndsam1, + int ndsam2, + int ecc); +// int cache_size, +// int line_size, +// int associativity, +// int rw_ports, +// int excl_read_ports, +// int excl_write_ports, +// int single_ended_read_ports, +// int banks, +// double tech_node, +// int output_width, +// int specific_tag, +// int tag_width, +// int access_mode, +// int cache, +// int main_mem, +// int obj_func_delay, +// int obj_func_dynamic_power, +// int obj_func_leakage_power, +// int obj_func_area, +// int obj_func_cycle_time, +// int dev_func_delay, +// int dev_func_dynamic_power, +// int dev_func_leakage_power, +// int dev_func_area, +// int dev_func_cycle_time, +// int temp, +// int data_arr_ram_cell_tech_flavor_in, +// int data_arr_peri_global_tech_flavor_in, +// int tag_arr_ram_cell_tech_flavor_in, +// int tag_arr_peri_global_tech_flavor_in, +// int interconnect_projection_type_in, +// int wire_inside_mat_type_in, +// int wire_outside_mat_type_in, +// int REPEATERS_IN_HTREE_SEGMENTS_in, +// int VERTICAL_HTREE_WIRES_OVER_THE_ARRAY_in, +// int BROADCAST_ADDR_DATAIN_OVER_VERTICAL_HTREES_in, +//// double MAXAREACONSTRAINT_PERC_in, +//// double MAXACCTIMECONSTRAINT_PERC_in, +//// double MAX_PERC_DIFF_IN_DELAY_FROM_BEST_DELAY_REPEATER_SOLUTION_in, +// int PAGE_SIZE_BITS_in, +// int BURST_LENGTH_in, +// int INTERNAL_PREFETCH_WIDTH_in); + +//Naveen's interface +uca_org_t cacti_interface( + int cache_size, + int line_size, + int associativity, + int rw_ports, + int excl_read_ports, + int excl_write_ports, + int single_ended_read_ports, + int banks, + double tech_node, + int page_sz, + int burst_length, + int pre_width, + int output_width, + int specific_tag, + int tag_width, + int access_mode, //0 normal, 1 seq, 2 fast + int cache, //scratch ram or cache + int main_mem, + int obj_func_delay, + int obj_func_dynamic_power, + int obj_func_leakage_power, + int obj_func_area, + int obj_func_cycle_time, + int dev_func_delay, + int dev_func_dynamic_power, + int dev_func_leakage_power, + int dev_func_area, + int dev_func_cycle_time, + int ed_ed2_none, // 0 - ED, 1 - ED^2, 2 - use weight and deviate + int temp, + int wt, //0 - default(search across everything), 1 - global, 2 - 5% delay penalty, 3 - 10%, 4 - 20 %, 5 - 30%, 6 - low-swing + int data_arr_ram_cell_tech_flavor_in, + int data_arr_peri_global_tech_flavor_in, + int tag_arr_ram_cell_tech_flavor_in, + int tag_arr_peri_global_tech_flavor_in, + int interconnect_projection_type_in, // 0 - aggressive, 1 - normal + int wire_inside_mat_type_in, + int wire_outside_mat_type_in, + int is_nuca, // 0 - UCA, 1 - NUCA + int core_count, + int cache_level, // 0 - L2, 1 - L3 + int nuca_bank_count, + int nuca_obj_func_delay, + int nuca_obj_func_dynamic_power, + int nuca_obj_func_leakage_power, + int nuca_obj_func_area, + int nuca_obj_func_cycle_time, + int nuca_dev_func_delay, + int nuca_dev_func_dynamic_power, + int nuca_dev_func_leakage_power, + int nuca_dev_func_area, + int nuca_dev_func_cycle_time, + int REPEATERS_IN_HTREE_SEGMENTS_in,//TODO for now only wires with repeaters are supported + int p_input); + + +//CACTI3DD interface +uca_org_t cacti_interface( + int cache_size, + int line_size, + int associativity, + int rw_ports, + int excl_read_ports,// para5 + int excl_write_ports, + int single_ended_read_ports, + int search_ports, + int banks, + double tech_node,//para10 + int output_width, + int specific_tag, + int tag_width, + int access_mode, + int cache, //para15 + int main_mem, + int obj_func_delay, + int obj_func_dynamic_power, + int obj_func_leakage_power, + int obj_func_cycle_time, //para20 + int obj_func_area, + int dev_func_delay, + int dev_func_dynamic_power, + int dev_func_leakage_power, + int dev_func_area, //para25 + int dev_func_cycle_time, + int ed_ed2_none, // 0 - ED, 1 - ED^2, 2 - use weight and deviate + int temp, + int wt, //0 - default(search across everything), 1 - global, 2 - 5% delay penalty, 3 - 10%, 4 - 20 %, 5 - 30%, 6 - low-swing + int data_arr_ram_cell_tech_flavor_in,//para30 + int data_arr_peri_global_tech_flavor_in, + int tag_arr_ram_cell_tech_flavor_in, + int tag_arr_peri_global_tech_flavor_in, + int interconnect_projection_type_in, + int wire_inside_mat_type_in,//para35 + int wire_outside_mat_type_in, + int REPEATERS_IN_HTREE_SEGMENTS_in, + int VERTICAL_HTREE_WIRES_OVER_THE_ARRAY_in, + int BROADCAST_ADDR_DATAIN_OVER_VERTICAL_HTREES_in, + int PAGE_SIZE_BITS_in,//para40 + int BURST_LENGTH_in, + int INTERNAL_PREFETCH_WIDTH_in, + int force_wiretype, + int wiretype, + int force_config,//para45 + int ndwl, + int ndbl, + int nspd, + int ndcm, + int ndsam1,//para50 + int ndsam2, + int ecc, + int is_3d_dram, + int burst_depth, + int IO_width, + int sys_freq, + int debug_detail, + int num_dies, + int tsv_gran_is_subarray, + int tsv_gran_os_bank, + int num_tier_row_sprd, + int num_tier_col_sprd, + int partition_level); + +class mem_array +{ + public: + int Ndcm; + int Ndwl; + int Ndbl; + double Nspd; + int deg_bl_muxing; + int Ndsam_lev_1; + int Ndsam_lev_2; + double access_time; + double cycle_time; + double multisubbank_interleave_cycle_time; + double area_ram_cells; + double area; + powerDef power; + double delay_senseamp_mux_decoder; + double delay_before_subarray_output_driver; + double delay_from_subarray_output_driver_to_output; + double height; + double width; + + double mat_height; + double mat_length; + double subarray_length; + double subarray_height; + + double delay_route_to_bank, + delay_input_htree, + delay_row_predecode_driver_and_block, + delay_row_decoder, + delay_bitlines, + delay_sense_amp, + delay_subarray_output_driver, + delay_dout_htree, + delay_comparator, + delay_matchlines; + //CACTI3DD 3d stats + double delay_row_activate_net, + delay_local_wordline, + + delay_column_access_net, + delay_column_predecoder, + delay_column_decoder, + delay_column_selectline, + delay_datapath_net, + delay_global_data, + delay_local_data_and_drv, + delay_data_buffer; + + double energy_row_activate_net, + energy_row_predecode_driver_and_block, + energy_row_decoder, + energy_local_wordline, + energy_bitlines, + energy_sense_amp, + energy_column_access_net, + energy_column_predecoder, + energy_column_decoder, + energy_column_selectline, + energy_datapath_net, + energy_global_data, + energy_local_data_and_drv, + energy_data_buffer, + energy_subarray_output_driver; + + double all_banks_height, + all_banks_width, + area_efficiency; + + powerDef power_routing_to_bank; + powerDef power_addr_input_htree; + powerDef power_data_input_htree; + powerDef power_data_output_htree; + powerDef power_htree_in_search; + powerDef power_htree_out_search; + powerDef power_row_predecoder_drivers; + powerDef power_row_predecoder_blocks; + powerDef power_row_decoders; + powerDef power_bit_mux_predecoder_drivers; + powerDef power_bit_mux_predecoder_blocks; + powerDef power_bit_mux_decoders; + powerDef power_senseamp_mux_lev_1_predecoder_drivers; + powerDef power_senseamp_mux_lev_1_predecoder_blocks; + powerDef power_senseamp_mux_lev_1_decoders; + powerDef power_senseamp_mux_lev_2_predecoder_drivers; + powerDef power_senseamp_mux_lev_2_predecoder_blocks; + powerDef power_senseamp_mux_lev_2_decoders; + powerDef power_bitlines; + powerDef power_sense_amps; + powerDef power_prechg_eq_drivers; + powerDef power_output_drivers_at_subarray; + powerDef power_dataout_vertical_htree; + powerDef power_comparators; + + powerDef power_cam_bitline_precharge_eq_drv; + powerDef power_searchline; + powerDef power_searchline_precharge; + powerDef power_matchlines; + powerDef power_matchline_precharge; + powerDef power_matchline_to_wordline_drv; + + min_values_t *arr_min; + enum Wire_type wt; + + // dram stats + double activate_energy, read_energy, write_energy, precharge_energy, + refresh_power, leak_power_subbank_closed_page, leak_power_subbank_open_page, + leak_power_request_and_reply_networks; + + double precharge_delay; + + //Power-gating stats + double array_leakage; + double wl_leakage; + double cl_leakage; + + double sram_sleep_tx_width, wl_sleep_tx_width, cl_sleep_tx_width; + double sram_sleep_tx_area, wl_sleep_tx_area, cl_sleep_tx_area; + double sram_sleep_wakeup_latency, wl_sleep_wakeup_latency, cl_sleep_wakeup_latency, bl_floating_wakeup_latency; + double sram_sleep_wakeup_energy, wl_sleep_wakeup_energy, cl_sleep_wakeup_energy, bl_floating_wakeup_energy; + + int num_active_mats; + int num_submarray_mats; + + static bool lt(const mem_array * m1, const mem_array * m2); + + //CACTI3DD 3d dram stats + double t_RCD, t_RAS, t_RC, t_CAS, t_RP, t_RRD; + double activate_power, read_power, write_power, peak_read_power; + int num_row_subarray, num_col_subarray; + double delay_TSV_tot, area_TSV_tot, dyn_pow_TSV_tot, dyn_pow_TSV_per_access; + unsigned int num_TSV_tot; + double area_lwl_drv, area_row_predec_dec, area_col_predec_dec, + area_subarray, area_bus, area_address_bus, area_data_bus, area_data_drv, area_IOSA, area_sense_amp; + +}; + + +#endif + diff --git a/Project_FARSI/cacti_for_FARSI/component.cc b/Project_FARSI/cacti_for_FARSI/component.cc new file mode 100644 index 00000000..ea486597 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/component.cc @@ -0,0 +1,237 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + + + +#include +#include +#include + +#include "bank.h" +#include "component.h" +#include "decoder.h" + +using namespace std; + + + +Component::Component() + :area(), power(), rt_power(),delay(0) +{ +} + + + +Component::~Component() +{ +} + + + +double Component::compute_diffusion_width(int num_stacked_in, int num_folded_tr) +{ + double w_poly = g_ip->F_sz_um; + double spacing_poly_to_poly = g_tp.w_poly_contact + 2 * g_tp.spacing_poly_to_contact; + double total_diff_w = 2 * spacing_poly_to_poly + // for both source and drain + num_stacked_in * w_poly + + (num_stacked_in - 1) * g_tp.spacing_poly_to_poly; + + if (num_folded_tr > 1) + { + total_diff_w += (num_folded_tr - 2) * 2 * spacing_poly_to_poly + + (num_folded_tr - 1) * num_stacked_in * w_poly + + (num_folded_tr - 1) * (num_stacked_in - 1) * g_tp.spacing_poly_to_poly; + } + + return total_diff_w; +} + + + +double Component::compute_gate_area( + int gate_type, + int num_inputs, + double w_pmos, + double w_nmos, + double h_gate) +{ + if (w_pmos <= 0.0 || w_nmos <= 0.0) + { + return 0.0; + } + + double w_folded_pmos, w_folded_nmos; + int num_folded_pmos, num_folded_nmos; + double total_ndiff_w, total_pdiff_w; + Area gate; + + double h_tr_region = h_gate - 2 * g_tp.HPOWERRAIL; + double ratio_p_to_n = w_pmos / (w_pmos + w_nmos); + + if (ratio_p_to_n >= 1 || ratio_p_to_n <= 0) + { + return 0.0; + } + + w_folded_pmos = (h_tr_region - g_tp.MIN_GAP_BET_P_AND_N_DIFFS) * ratio_p_to_n; + w_folded_nmos = (h_tr_region - g_tp.MIN_GAP_BET_P_AND_N_DIFFS) * (1 - ratio_p_to_n); + + assert(w_folded_pmos > 0); + + num_folded_pmos = (int) (ceil(w_pmos / w_folded_pmos)); + num_folded_nmos = (int) (ceil(w_nmos / w_folded_nmos)); + + switch (gate_type) + { + case INV: + total_ndiff_w = compute_diffusion_width(1, num_folded_nmos); + total_pdiff_w = compute_diffusion_width(1, num_folded_pmos); + break; + + case NOR: + total_ndiff_w = compute_diffusion_width(1, num_inputs * num_folded_nmos); + total_pdiff_w = compute_diffusion_width(num_inputs, num_folded_pmos); + break; + + case NAND: + total_ndiff_w = compute_diffusion_width(num_inputs, num_folded_nmos); + total_pdiff_w = compute_diffusion_width(1, num_inputs * num_folded_pmos); + break; + default: + cout << "Unknown gate type: " << gate_type << endl; + exit(1); + } + + gate.w = MAX(total_ndiff_w, total_pdiff_w); + + if (w_folded_nmos > w_nmos) + { + //means that the height of the gate can + //be made smaller than the input height specified, so calculate the height of the gate. + gate.h = w_nmos + w_pmos + g_tp.MIN_GAP_BET_P_AND_N_DIFFS + 2 * g_tp.HPOWERRAIL; + } + else + { + gate.h = h_gate; + } + return gate.get_area(); +} + + + +double Component::compute_tr_width_after_folding( + double input_width, + double threshold_folding_width) +{//This is actually the width of the cell not the width of a device. +//The width of a cell and the width of a device is orthogonal. + if (input_width <= 0) + { + return 0; + } + + int num_folded_tr = (int) (ceil(input_width / threshold_folding_width)); + double spacing_poly_to_poly = g_tp.w_poly_contact + 2 * g_tp.spacing_poly_to_contact; + double width_poly = g_ip->F_sz_um; + double total_diff_width = num_folded_tr * width_poly + (num_folded_tr + 1) * spacing_poly_to_poly; + + return total_diff_width; +} + + + +double Component::height_sense_amplifier(double pitch_sense_amp) +{ + // compute the height occupied by all PMOS transistors + double h_pmos_tr = compute_tr_width_after_folding(g_tp.w_sense_p, pitch_sense_amp) * 2 + + compute_tr_width_after_folding(g_tp.w_iso, pitch_sense_amp) + + 2 * g_tp.MIN_GAP_BET_SAME_TYPE_DIFFS; + + // compute the height occupied by all NMOS transistors + double h_nmos_tr = compute_tr_width_after_folding(g_tp.w_sense_n, pitch_sense_amp) * 2 + + compute_tr_width_after_folding(g_tp.w_sense_en, pitch_sense_amp) + + 2 * g_tp.MIN_GAP_BET_SAME_TYPE_DIFFS; + + // compute total height by considering gap between the p and n diffusion areas + return h_pmos_tr + h_nmos_tr + g_tp.MIN_GAP_BET_P_AND_N_DIFFS; +} + + + +int Component::logical_effort( + int num_gates_min, + double g, + double F, + double * w_n, + double * w_p, + double C_load, + double p_to_n_sz_ratio, + bool is_dram_, + bool is_wl_tr_, + double max_w_nmos) +{ + int num_gates = (int) (log(F) / log(fopt)); + + // check if num_gates is odd. if so, add 1 to make it even + num_gates+= (num_gates % 2) ? 1 : 0; + num_gates = MAX(num_gates, num_gates_min); + + // recalculate the effective fanout of each stage + double f = pow(F, 1.0 / num_gates); + int i = num_gates - 1; + double C_in = C_load / f; + w_n[i] = (1.0 / (1.0 + p_to_n_sz_ratio)) * C_in / gate_C(1, 0, is_dram_, false, is_wl_tr_); + w_n[i] = MAX(w_n[i], g_tp.min_w_nmos_); + w_p[i] = p_to_n_sz_ratio * w_n[i]; + + if (w_n[i] > max_w_nmos) // && !g_ip->is_3d_mem) + { + double C_ld = gate_C((1 + p_to_n_sz_ratio) * max_w_nmos, 0, is_dram_, false, is_wl_tr_); + F = g * C_ld / gate_C(w_n[0] + w_p[0], 0, is_dram_, false, is_wl_tr_); + num_gates = (int) (log(F) / log(fopt)) + 1; + num_gates+= (num_gates % 2) ? 1 : 0; + num_gates = MAX(num_gates, num_gates_min); + f = pow(F, 1.0 / (num_gates - 1)); + i = num_gates - 1; + w_n[i] = max_w_nmos; + w_p[i] = p_to_n_sz_ratio * w_n[i]; + } + + for (i = num_gates - 2; i >= 1; i--) + { + w_n[i] = MAX(w_n[i+1] / f, g_tp.min_w_nmos_); + w_p[i] = p_to_n_sz_ratio * w_n[i]; + } + + assert(num_gates <= MAX_NUMBER_GATES_STAGE); + return num_gates; +} + diff --git a/Project_FARSI/cacti_for_FARSI/component.h b/Project_FARSI/cacti_for_FARSI/component.h new file mode 100644 index 00000000..7d6dbf8f --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/component.h @@ -0,0 +1,84 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + + +#ifndef __COMPONENT_H__ +#define __COMPONENT_H__ + +#include "parameter.h" +#include "area.h" + +using namespace std; + +class Crossbar; +class Bank; + +class Component +{ + public: + Component(); + ~Component(); + + Area area; + powerDef power,rt_power; + double delay; + double cycle_time; + + double compute_gate_area( + int gate_type, + int num_inputs, + double w_pmos, + double w_nmos, + double h_gate); + + double compute_tr_width_after_folding(double input_width, double threshold_folding_width); + double height_sense_amplifier(double pitch_sense_amp); + + protected: + int logical_effort( + int num_gates_min, + double g, + double F, + double * w_n, + double * w_p, + double C_load, + double p_to_n_sz_ratio, + bool is_dram_, + bool is_wl_tr_, + double max_w_nmos); + + private: + double compute_diffusion_width(int num_stacked_in, int num_folded_tr); +}; + +#endif + diff --git a/Project_FARSI/cacti_for_FARSI/const.h b/Project_FARSI/cacti_for_FARSI/const.h new file mode 100644 index 00000000..a2851d70 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/const.h @@ -0,0 +1,273 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + +#ifndef __CONST_H__ +#define __CONST_H__ + +#include +#include +#include +#include +#include + +/* The following are things you might want to change + * when compiling + */ + +/* + * Address bits in a word, and number of output bits from the cache + */ + +/* +was: #define ADDRESS_BITS 32 +now: 42 bits as in the Power4 +This is 36 bits in Pentium 4 +and 40 bits in Opteron. +*/ +const int ADDRESS_BITS = 42; + +/*dt: In addition to the tag bits, the tags also include 1 valid bit, 1 dirty bit, 2 bits for a 4-state + cache coherency protocoll (MESI), 1 bit for MRU (change this to log(ways) for full LRU). + So in total we have 1 + 1 + 2 + 1 = 5 */ +const int EXTRA_TAG_BITS = 5; + +/* limits on the various N parameters */ + +const unsigned int MAXDATAN = 512; // maximum for Ndwl and Ndbl +const unsigned int MAXSUBARRAYS = 1048576; // maximum subarrays for data and tag arrays +const unsigned int MAXDATASPD = 256; // maximum for Nspd +const unsigned int MAX_COL_MUX = 256; + + + +#define ROUTER_TYPES 3 +#define WIRE_TYPES 6 + +const double Cpolywire = 0; + + +/* Threshold voltages (as a proportion of Vdd) + If you don't know them, set all values to 0.5 */ +#define VTHFA1 0.452 +#define VTHFA2 0.304 +#define VTHFA3 0.420 +#define VTHFA4 0.413 +#define VTHFA5 0.405 +#define VTHFA6 0.452 +#define VSINV 0.452 +#define VTHCOMPINV 0.437 +#define VTHMUXNAND 0.548 // TODO : this constant must be revisited +#define VTHEVALINV 0.452 +#define VTHSENSEEXTDRV 0.438 + + +//WmuxdrvNANDn and WmuxdrvNANDp are no longer being used but it's part of the old +//delay_comparator function which we are using exactly as it used to be, so just setting these to 0 +const double WmuxdrvNANDn = 0; +const double WmuxdrvNANDp = 0; + + +/*===================================================================*/ +/* + * The following are things you probably wouldn't want to change. + */ + +#define BIGNUM 1e30 +#define INF 9999999 +#define MAX(a,b) (((a)>(b))?(a):(b)) +#define MIN(a,b) (((a)<(b))?(a):(b)) + +/* Used to communicate with the horowitz model */ +#define RISE 1 +#define FALL 0 +#define NCH 1 +#define PCH 0 + + +#define EPSILON 0.5 //v4.1: This constant is being used in order to fix floating point -> integer +//conversion problems that were occuring within CACTI. Typical problem that was occuring was +//that with different compilers a floating point number like 3.0 would get represented as either +//2.9999....or 3.00000001 and then the integer part of the floating point number (3.0) would +//be computed differently depending on the compiler. What we are doing now is to replace +//int (x) with (int) (x+EPSILON) where EPSILON is 0.5. This would fix such problems. Note that +//this works only when x is an integer >= 0. +/* + * thinks this is more a solution to solve the simple truncate problem + * (http://www.cs.tut.fi/~jkorpela/round.html) rather than the problem mentioned above. + * Unfortunately, this solution causes nasty bugs (different results when using O0 and O3). + * Moreover, round is not correct in CACTI since when an extra fraction of bit/line is needed, + * we need to provide a complete bit/line even the fraction is just 0.01. + * So, in later version than 6.5 we use (int)ceil() to get double to int conversion. + */ + +#define EPSILON2 0.1 +#define EPSILON3 0.6 + + +#define MINSUBARRAYROWS 16 //For simplicity in modeling, for the row decoding structure, we assume +//that each row predecode block is composed of at least one 2-4 decoder. When the outputs from the +//row predecode blocks are combined this means that there are at least 4*4=16 row decode outputs +#define MAXSUBARRAYROWS 262144 //Each row predecode block produces a max of 2^9 outputs. So +//the maximum number of row decode outputs will be 2^9*2^9 +#define MINSUBARRAYCOLS 2 +#define MAXSUBARRAYCOLS 262144 + + +#define INV 0 +#define NOR 1 +#define NAND 2 + + +#define NUMBER_TECH_FLAVORS 4 + +#define NUMBER_INTERCONNECT_PROJECTION_TYPES 2 //aggressive and conservative +//0 = Aggressive projections, 1 = Conservative projections +#define NUMBER_WIRE_TYPES 4 //local, semi-global and global +//1 = 'Semi-global' wire type, 2 = 'Global' wire type +#define NUMBER_TSV_TYPES 3 +//0 = ITRS projected fine TSV type, 1 = Industrial reported large TSV type, 2 = TBD + +const int dram_cell_tech_flavor = 3; + + +#define VBITSENSEMIN 0.08 //minimum bitline sense voltage is fixed to be 80 mV. + +#define fopt 4.0 + +#define INPUT_WIRE_TO_INPUT_GATE_CAP_RATIO 0 +#define BUFFER_SEPARATION_LENGTH_MULTIPLIER 1 +#define NUMBER_MATS_PER_REDUNDANT_MAT 8 + +#define NUMBER_STACKED_DIE_LAYERS 1 + +// this variable can be set to carry out solution optimization for +// a maximum area allocation. +#define STACKED_DIE_LAYER_ALLOTED_AREA_mm2 0 //6.24 //6.21//71.5 + +// this variable can also be employed when solution optimization +// with maximum area allocation is carried out. +#define MAX_PERCENT_AWAY_FROM_ALLOTED_AREA 50 + +// this variable can also be employed when solution optimization +// with maximum area allocation is carried out. +#define MIN_AREA_EFFICIENCY 20 + +// this variable can be employed when solution with a desired +// aspect ratio is required. +#define STACKED_DIE_LAYER_ASPECT_RATIO 1 + +// this variable can be employed when solution with a desired +// aspect ratio is required. +#define MAX_PERCENT_AWAY_FROM_ASPECT_RATIO 101 + +// this variable can be employed to carry out solution optimization +// for a certain target random cycle time. +#define TARGET_CYCLE_TIME_ns 1000000000 + +#define NUMBER_PIPELINE_STAGES 4 + +// this can be used to model the length of interconnect +// between a bank and a crossbar +#define LENGTH_INTERCONNECT_FROM_BANK_TO_CROSSBAR 0 //3791 // 2880//micron + +#define IS_CROSSBAR 0 +#define NUMBER_INPUT_PORTS_CROSSBAR 8 +#define NUMBER_OUTPUT_PORTS_CROSSBAR 8 +#define NUMBER_SIGNALS_PER_PORT_CROSSBAR 256 + + +#define MAT_LEAKAGE_REDUCTION_DUE_TO_SLEEP_TRANSISTORS_FACTOR 1 +#define LEAKAGE_REDUCTION_DUE_TO_LONG_CHANNEL_HP_TRANSISTORS_FACTOR 1 + +#define PAGE_MODE 0 + +#define MAIN_MEM_PER_CHIP_STANDBY_CURRENT_mA 60 +// We are actually not using this variable in the CACTI code. We just want to acknowledge that +// this current should be multiplied by the DDR(n) system VDD value to compute the standby power +// consumed during precharge. + + +const double VDD_STORAGE_LOSS_FRACTION_WORST = 0.125; +const double CU_RESISTIVITY = 0.022; //ohm-micron +const double BULK_CU_RESISTIVITY = 0.018; //ohm-micron +const double PERMITTIVITY_FREE_SPACE = 8.854e-18; //F/micron + +const static uint32_t sram_num_cells_wl_stitching_ = 16; +const static uint32_t dram_num_cells_wl_stitching_ = 64; +const static uint32_t comm_dram_num_cells_wl_stitching_ = 256; +const static double num_bits_per_ecc_b_ = 8.0; + +const double bit_to_byte = 8.0; + +#define MAX_NUMBER_GATES_STAGE 20 +#define MAX_NUMBER_HTREE_NODES 20 +#define NAND2_LEAK_STACK_FACTOR 0.2 +#define NAND3_LEAK_STACK_FACTOR 0.2 +#define NOR2_LEAK_STACK_FACTOR 0.2 +#define INV_LEAK_STACK_FACTOR 0.5 +#define MAX_NUMBER_ARRAY_PARTITIONS 1000000 + +// abbreviations used in this project +// ---------------------------------- +// +// num : number +// rw : read/write +// rd : read +// wr : write +// se : single-ended +// sz : size +// F : feature +// w : width +// h : height or horizontal +// v : vertical or velocity + + +enum ram_cell_tech_type_num +{ + itrs_hp = 0, + itrs_lstp = 1, + itrs_lop = 2, + lp_dram = 3, + comm_dram = 4 +}; + +const double pppm[4] = {1,1,1,1}; +const double pppm_lkg[4] = {0,1,1,0}; +const double pppm_dyn[4] = {1,0,0,0}; +const double pppm_Isub[4] = {0,1,0,0}; +const double pppm_Ig[4] = {0,0,1,0}; +const double pppm_sc[4] = {0,0,0,1}; + +const double Ilinear_to_Isat_ratio =2.0; + + + +#endif diff --git a/Project_FARSI/cacti_for_FARSI/contention.dat b/Project_FARSI/cacti_for_FARSI/contention.dat new file mode 100644 index 00000000..826553e7 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/contention.dat @@ -0,0 +1,126 @@ +l34c64l1b: 1000 1000 1000 1000 1000 1000 1000 1000 +l34c64l2b: 9 11 19 29 43 62 81 102 +l34c64l4b: 6 8 12 17 24 29 39 47 +l34c64l8b: 7 8 10 14 18 22 25 30 +l34c64l16b: 7 7 9 12 14 17 20 24 +l34c64l32b: 7 7 9 12 14 17 20 24 -r +l34c64l64b: 7 7 9 12 14 17 20 24 -r +l34c128l1b: 1000 1000 1000 1000 1000 1000 1000 1000 +l34c128l2b: 4 10 19 30 44 64 82 103 +l34c128l4b: 3 6 11 17 24 31 38 47 +l34c128l8b: 3 5 9 13 17 21 25 29 +l34c128l16b: 4 5 7 10 13 16 19 22 +l34c128l32b: 4 5 7 10 13 16 19 22 -r +l34c128l64b: 4 5 7 10 13 16 19 22 -r +l34c256l1b: 1000 1000 1000 1000 1000 1000 1000 1000 +l34c256l2b: 3 10 19 30 44 63 82 103 +l34c256l4b: 3 6 11 17 24 31 38 47 +l34c256l8b: 2 5 8 12 16 20 24 29 +l34c256l16b: 2 4 7 9 12 15 18 21 +l34c256l32b: 2 4 7 9 12 15 18 21 -r +l34c256l64b: 2 4 7 9 12 15 18 21 -r +l38c64l1b: 1000 1000 1000 1000 1000 1000 1000 1000 +l38c64l2b: 57 59 77 90 137 187 219 245 +l38c64l4b: 35 40 48 56 43 61 80 101 +l38c64l8b: 18 27 41 45 52 58 58 58 -r +l38c64l16b: 16 17 19 35 40 49 53 53 -r +l38c64l32b: 15 15 17 19 22 25 30 30 -r +l38c64l64b: 15 15 17 19 22 25 30 30 -r +l38c128l1b: 1000 1000 1000 1000 1000 1000 1000 1000 +l38c128l2b: 38 50 78 93 139 188 220 245 +l38c128l4b: 29 37 46 56 43 61 81 102 +l38c128l8b: 16 30 39 44 50 57 57 57 -r +l38c128l16b: 14 16 19 33 40 47 52 52 -r +l38c128l32b: 14 15 17 20 23 27 31 31 -r +l38c128l64b: 14 15 17 20 23 27 31 31 -r +l38c256l1b: 1000 1000 1000 1000 1000 1000 1000 1000 +l38c256l2b: 35 50 78 94 139 188 220 246 +l38c256l4b: 28 36 45 55 55 61 81 102 +l38c256l8b: 17 30 38 43 50 57 57 57 -r +l38c256l16b: 15 17 21 32 40 47 51 51 +l38c256l32b: 15 17 19 21 24 29 33 33 +l38c256l64b: 15 17 19 21 24 29 33 33 -r +l316c64l1b: 1000 1000 1000 1000 1000 1000 1000 1000 +l316c64l2b: 1000 1000 1000 1000 1000 1000 1000 1000 +l316c64l4b: 34 35 78 126 178 220 252 274 +l316c64l8b: 9 11 23 43 62 87 105 130 +l316c64l16b: 7 9 13 23 33 45 56 67 +l316c64l32b: 5 6 7 10 13 19 25 30 +l316c64l64b: 4 5 6 8 10 14 18 21 +l316c128l1b: 1000 1000 1000 1000 1000 1000 1000 1000 +l316c128l2b: 25 131 243 1000 1000 1000 1000 1000 +l316c128l4b: 8 28 79 127 179 221 253 274 +l316c128l8b: 4 9 22 43 62 88 106 131 +l316c128l16b: 4 6 11 21 32 44 55 67 +l316c128l32b: 4 6 11 12 12 18 24 29 +l316c128l64b: 2 3 5 7 9 13 17 21 +l316c256l1b: 1000 1000 1000 1000 1000 1000 1000 1000 +l316c256l2b: 1000 1000 1000 1000 1000 1000 1000 1000 +l316c256l4b: 5 28 80 128 180 221 253 274 +l316c256l8b: 3 8 22 43 63 88 107 131 +l316c256l16b: 2 5 11 21 32 44 55 67 +l316c256l32b: 2 3 5 8 12 18 24 29 +l316c256l64b: 2 3 4 6 9 13 17 21 +l24c64l1b: 1000 1000 1000 1000 1000 1000 1000 1000 +l24c64l2b: 10 12 24 41 60 86 105 122 +l24c64l4b: 5 7 13 20 29 38 47 56 +l24c64l8b: 5 6 9 14 18 24 29 35 +l24c64l16b: 4 5 7 10 12 16 19 22 +l24c64l32b: 5 5 6 8 10 12 14 17 +l24c64l64b: 5 5 6 8 10 12 14 16 +l24c128l1b: 1000 1000 1000 1000 1000 1000 1000 1000 +l24c128l2b: 1000 1000 1000 1000 1000 1000 1000 1000 +l24c128l4b: 3 7 13 20 29 38 47 57 +l24c128l8b: 3 5 9 13 18 23 29 35 +l24c128l16b: 3 4 6 9 12 15 19 22 +l24c128l32b: 3 4 5 7 9 11 14 16 +l24c128l64b: 1000 1000 1000 1000 1000 1000 1000 1000 +l24c256l1b: 1000 1000 1000 1000 1000 1000 1000 1000 +l24c256l2b: 1000 1000 1000 1000 1000 1000 1000 1000 +l24c256l4b: 2 6 13 20 29 38 47 57 +l24c256l8b: 2 4 8 13 18 23 28 35 +l24c256l16b: 2 3 6 8 11 15 18 22 +l24c256l32b: 2 3 5 6 8 11 14 16 +l24c256l64b: 1000 1000 1000 1000 1000 1000 1000 1000 +l28c64l1b: 1000 1000 1000 1000 1000 1000 1000 1000 +l28c64l2b: 46 52 117 157 188 225 246 261 +l28c64l4b: 19 25 39 54 96 107 120 150 +l28c64l8b: 9 12 21 30 39 47 58 79 +l28c64l16b: 8 9 11 16 25 32 37 42 +l28c64l32b: 7 8 9 11 14 19 23 28 +l28c64l64b: 7 7 8 10 12 14 18 22 +l28c128l1b: 1000 1000 1000 1000 1000 1000 1000 1000 +l28c128l2b: 1000 1000 1000 1000 1000 1000 1000 1000 +l28c128l4b: 12 22 39 54 98 108 130 151 +l28c128l8b: 7 12 21 30 39 48 59 80 +l28c128l16b: 6 8 11 16 24 31 37 42 +l28c128l32b: 6 7 9 11 14 19 24 28 +l28c128l64b: 6 7 9 11 14 19 24 28 +l28c256l1b: 1000 1000 1000 1000 1000 1000 1000 1000 +l28c256l2b: 1000 1000 1000 1000 1000 1000 1000 1000 +l28c256l4b: 12 22 39 54 100 108 130 152 +l28c256l8b: 7 12 21 30 39 48 59 81 +l28c256l16b: 6 8 11 16 24 31 37 42 +l28c256l32b: 6 7 9 11 14 19 24 28 +l28c256l64b: 6 7 9 11 14 19 24 28 +l216c64l1b: 1000 1000 1000 1000 1000 1000 1000 1000 +l216c64l2b: 1000 1000 1000 1000 1000 1000 1000 1000 +l216c64l4b: 34 35 78 126 178 220 252 274 +l216c64l8b: 9 11 23 43 62 87 105 130 +l216c64l16b: 7 9 13 23 33 45 56 67 +l216c64l32b: 5 6 7 10 13 19 25 30 +l216c64l64b: 4 5 6 8 10 14 18 21 +l216c128l1b: 1000 1000 1000 1000 1000 1000 1000 1000 +l216c128l2b: 25 131 243 1000 1000 1000 1000 1000 +l216c128l4b: 8 28 79 127 179 221 253 274 +l216c128l8b: 4 9 22 43 62 88 106 131 +l216c128l16b: 4 6 11 21 32 44 55 67 +l216c128l32b: 4 6 11 12 12 18 24 29 +l216c128l64b: 2 3 5 7 9 13 17 21 +l216c256l1b: 1000 1000 1000 1000 1000 1000 1000 1000 +l216c256l2b: 1000 1000 1000 1000 1000 1000 1000 1000 +l216c256l4b: 5 28 80 128 180 221 253 274 +l216c256l8b: 3 8 22 43 63 88 107 131 +l216c256l16b: 2 5 11 21 32 44 55 67 +l216c256l32b: 2 3 5 8 12 18 24 29 +l216c256l64b: 2 3 4 6 9 13 17 21 diff --git a/Project_FARSI/cacti_for_FARSI/crossbar.cc b/Project_FARSI/cacti_for_FARSI/crossbar.cc new file mode 100644 index 00000000..be32736f --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/crossbar.cc @@ -0,0 +1,161 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + +#include "crossbar.h" + +#define ASPECT_THRESHOLD .8 +#define ADJ 1 + +Crossbar::Crossbar( + double n_inp_, + double n_out_, + double flit_size_, + /*TechnologyParameter::*/DeviceType *dt + ):n_inp(n_inp_), n_out(n_out_), flit_size(flit_size_), deviceType(dt) +{ + min_w_pmos = deviceType->n_to_p_eff_curr_drv_ratio*g_tp.min_w_nmos_; + Vdd = dt->Vdd; + CB_ADJ = 1; +} + +Crossbar::~Crossbar(){} + +double Crossbar::output_buffer() +{ + + //Wire winit(4, 4); + double l_eff = n_inp*flit_size*g_tp.wire_outside_mat.pitch; + Wire w1(g_ip->wt, l_eff); + //double s1 = w1.repeater_size *l_eff*ADJ/w1.repeater_spacing; + double s1 = w1.repeater_size * (l_eff n_to_p_eff_curr_drv_ratio; + // the model assumes input capacitance of the wire driver = input capacitance of nand + nor = input cap of the driver transistor + TriS1 = s1*(1 + pton_size)/(2 + pton_size + 1 + 2*pton_size); + TriS2 = s1; //driver transistor + + if (TriS1 < 1) + TriS1 = 1; + + double input_cap = gate_C(TriS1*(2*min_w_pmos + g_tp.min_w_nmos_), 0) + + gate_C(TriS1*(min_w_pmos + 2*g_tp.min_w_nmos_), 0); +// input_cap += drain_C_(TriS1*g_tp.min_w_nmos_, NCH, 1, 1, g_tp.cell_h_def) + +// drain_C_(TriS1*min_w_pmos, PCH, 1, 1, g_tp.cell_h_def)*2 + +// gate_C(TriS2*g_tp.min_w_nmos_, 0)+ +// drain_C_(TriS1*min_w_pmos, NCH, 1, 1, g_tp.cell_h_def)*2 + +// drain_C_(TriS1*min_w_pmos, PCH, 1, 1, g_tp.cell_h_def) + +// gate_C(TriS2*min_w_pmos, 0); + tri_int_cap = drain_C_(TriS1*g_tp.min_w_nmos_, NCH, 1, 1, g_tp.cell_h_def) + + drain_C_(TriS1*min_w_pmos, PCH, 1, 1, g_tp.cell_h_def)*2 + + gate_C(TriS2*g_tp.min_w_nmos_, 0)+ + drain_C_(TriS1*min_w_pmos, NCH, 1, 1, g_tp.cell_h_def)*2 + + drain_C_(TriS1*min_w_pmos, PCH, 1, 1, g_tp.cell_h_def) + + gate_C(TriS2*min_w_pmos, 0); + double output_cap = drain_C_(TriS2*g_tp.min_w_nmos_, NCH, 1, 1, g_tp.cell_h_def) + + drain_C_(TriS2*min_w_pmos, PCH, 1, 1, g_tp.cell_h_def); + double ctr_cap = gate_C(TriS2 *(min_w_pmos + g_tp.min_w_nmos_), 0); + + tri_inp_cap = input_cap; + tri_out_cap = output_cap; + tri_ctr_cap = ctr_cap; + return input_cap + output_cap + ctr_cap; +} + +void Crossbar::compute_power() +{ + + Wire winit(4, 4); + double tri_cap = output_buffer(); + assert(tri_cap > 0); + //area of a tristate logic + double g_area = compute_gate_area(INV, 1, TriS2*g_tp.min_w_nmos_, TriS2*min_w_pmos, g_tp.cell_h_def); + g_area *= 2; // to model area of output transistors + g_area += compute_gate_area (NAND, 2, TriS1*2*g_tp.min_w_nmos_, TriS1*min_w_pmos, g_tp.cell_h_def); + g_area += compute_gate_area (NOR, 2, TriS1*g_tp.min_w_nmos_, TriS1*2*min_w_pmos, g_tp.cell_h_def); + double width /*per tristate*/ = g_area/(CB_ADJ * g_tp.cell_h_def); + // effective no. of tristate buffers that need to be laid side by side + int ntri = (int)ceil(g_tp.cell_h_def/(g_tp.wire_outside_mat.pitch)); + double wire_len = MAX(width*ntri*n_out, flit_size*g_tp.wire_outside_mat.pitch*n_out); + Wire w1(g_ip->wt, wire_len); + + area.w = wire_len; + area.h = g_tp.wire_outside_mat.pitch*n_inp*flit_size * CB_ADJ; + Wire w2(g_ip->wt, area.h); + + double aspect_ratio_cb = (area.h/area.w)*(n_out/n_inp); + if (aspect_ratio_cb > 1) aspect_ratio_cb = 1/aspect_ratio_cb; + + if (aspect_ratio_cb < ASPECT_THRESHOLD) { + if (n_out > 2 && n_inp > 2) { + CB_ADJ+=0.2; + //cout << "CB ADJ " << CB_ADJ << endl; + if (CB_ADJ < 4) { + this->compute_power(); + } + } + } + + + + power.readOp.dynamic = (w1.power.readOp.dynamic + w2.power.readOp.dynamic + (tri_inp_cap * n_out + tri_out_cap * n_inp + tri_ctr_cap + tri_int_cap) * Vdd*Vdd)*flit_size; + power.readOp.leakage = n_inp * n_out * flit_size * ( + cmos_Isub_leakage(g_tp.min_w_nmos_*TriS2*2, min_w_pmos*TriS2*2, 1, inv) *Vdd+ + cmos_Isub_leakage(g_tp.min_w_nmos_*TriS1*3, min_w_pmos*TriS1*3, 2, nand)*Vdd+ + cmos_Isub_leakage(g_tp.min_w_nmos_*TriS1*3, min_w_pmos*TriS1*3, 2, nor) *Vdd+ + w1.power.readOp.leakage + w2.power.readOp.leakage); + power.readOp.gate_leakage = n_inp * n_out * flit_size * ( + cmos_Ig_leakage(g_tp.min_w_nmos_*TriS2*2, min_w_pmos*TriS2*2, 1, inv) *Vdd+ + cmos_Ig_leakage(g_tp.min_w_nmos_*TriS1*3, min_w_pmos*TriS1*3, 2, nand)*Vdd+ + cmos_Ig_leakage(g_tp.min_w_nmos_*TriS1*3, min_w_pmos*TriS1*3, 2, nor) *Vdd+ + w1.power.readOp.gate_leakage + w2.power.readOp.gate_leakage); + + // delay calculation + double l_eff = n_inp*flit_size*g_tp.wire_outside_mat.pitch; + Wire wdriver(g_ip->wt, l_eff); + double res = g_tp.wire_outside_mat.R_per_um * (area.w+area.h) + tr_R_on(g_tp.min_w_nmos_*wdriver.repeater_size, NCH, 1); + double cap = g_tp.wire_outside_mat.C_per_um * (area.w + area.h) + n_out*tri_inp_cap + n_inp*tri_out_cap; + delay = horowitz(w1.signal_rise_time(), res*cap, deviceType->Vth/deviceType->Vdd, deviceType->Vth/deviceType->Vdd, RISE); + + Wire wreset(); +} + +void Crossbar::print_crossbar() +{ + cout << "\nCrossbar Stats (" << n_inp << "x" << n_out << ")\n\n"; + cout << "Flit size : " << flit_size << " bits" << endl; + cout << "Width : " << area.w << " u" << endl; + cout << "Height : " << area.h << " u" << endl; + cout << "Dynamic Power : " << power.readOp.dynamic*1e9 * MIN(n_inp, n_out) << " (nJ)" << endl; + cout << "Leakage Power : " << power.readOp.leakage*1e3 << " (mW)" << endl; + cout << "Gate Leakage Power : " << power.readOp.gate_leakage*1e3 << " (mW)" << endl; + cout << "Crossbar Delay : " << delay*1e12 << " ps\n"; +} + + diff --git a/Project_FARSI/cacti_for_FARSI/crossbar.h b/Project_FARSI/cacti_for_FARSI/crossbar.h new file mode 100644 index 00000000..529db9c6 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/crossbar.h @@ -0,0 +1,83 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + +#ifndef __CROSSBAR__ +#define __CROSSBAR__ + +#include +#include +#include "basic_circuit.h" +#include "cacti_interface.h" +#include "component.h" +#include "parameter.h" +#include "mat.h" +#include "wire.h" + +class Crossbar : public Component +{ + public: + Crossbar( + double in, + double out, + double flit_sz, + /*TechnologyParameter::*/DeviceType *dt = &(g_tp.peri_global)); + ~Crossbar(); + + void print_crossbar(); + double output_buffer(); + void compute_power(); + + double n_inp, n_out; + double flit_size; + double tri_inp_cap, tri_out_cap, tri_ctr_cap, tri_int_cap; + + private: + double CB_ADJ; + /* + * Adjust factor of the height of the cross-point (tri-state buffer) cell (layout) in crossbar + * buffer is adjusted to get an aspect ratio of whole cross bar close to one; + * when adjust the ratio, the number of wires route over the tri-state buffers does not change, + * however, the effective wiring pitch changes. Specifically, since CB_ADJ will increase + * during the adjust, the tri-state buffer will become taller and thiner, and the effective wiring pitch + * will increase. As a result, the height of the crossbar (area.h) will increase. + */ + + /*TechnologyParameter::*/DeviceType *deviceType; + double TriS1, TriS2; + double min_w_pmos, Vdd; + +}; + + + + +#endif diff --git a/Project_FARSI/cacti_for_FARSI/ddr3.cfg b/Project_FARSI/cacti_for_FARSI/ddr3.cfg new file mode 100644 index 00000000..89086127 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/ddr3.cfg @@ -0,0 +1,254 @@ +# Cache size +//-size (bytes) 2048 +//-size (bytes) 4096 +//-size (bytes) 32768 +//-size (bytes) 131072 +//-size (bytes) 262144 +//-size (bytes) 1048576 +//-size (bytes) 2097152 +//-size (bytes) 4194304 +-size (bytes) 8388608 +//-size (bytes) 16777216 +//-size (bytes) 33554432 +//-size (bytes) 134217728 +//-size (bytes) 67108864 +//-size (bytes) 1073741824 + +# power gating +-Array Power Gating - "false" +-WL Power Gating - "false" +-CL Power Gating - "false" +-Bitline floating - "false" +-Interconnect Power Gating - "false" +-Power Gating Performance Loss 0.01 + +# Line size +//-block size (bytes) 8 +-block size (bytes) 64 + +# To model Fully Associative cache, set associativity to zero +//-associativity 0 +//-associativity 2 +//-associativity 4 +//-associativity 8 +-associativity 8 + +-read-write port 1 +-exclusive read port 0 +-exclusive write port 0 +-single ended read ports 0 + +# Multiple banks connected using a bus +-UCA bank count 1 +-technology (u) 0.022 +//-technology (u) 0.040 +//-technology (u) 0.032 +//-technology (u) 0.090 + +# following three parameters are meaningful only for main memories + +-page size (bits) 8192 +-burst length 8 +-internal prefetch width 8 + +# following parameter can have one of five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +-Data array cell type - "itrs-hp" +//-Data array cell type - "itrs-lstp" +//-Data array cell type - "itrs-lop" + +# following parameter can have one of three values -- (itrs-hp, itrs-lstp, itrs-lop) +-Data array peripheral type - "itrs-hp" +//-Data array peripheral type - "itrs-lstp" +//-Data array peripheral type - "itrs-lop" + +# following parameter can have one of five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +-Tag array cell type - "itrs-hp" +//-Tag array cell type - "itrs-lstp" +//-Tag array cell type - "itrs-lop" + +# following parameter can have one of three values -- (itrs-hp, itrs-lstp, itrs-lop) +-Tag array peripheral type - "itrs-hp" +//-Tag array peripheral type - "itrs-lstp" +//-Tag array peripheral type - "itrs-lop + +# Bus width include data bits and address bits required by the decoder +//-output/input bus width 16 +-output/input bus width 512 + +// 300-400 in steps of 10 +-operating temperature (K) 360 + +# Type of memory - cache (with a tag array) or ram (scratch ram similar to a register file) +# or main memory (no tag array and every access will happen at a page granularity Ref: CACTI 5.3 report) +-cache type "cache" +//-cache type "ram" +//-cache type "main memory" + +# to model special structure like branch target buffers, directory, etc. +# change the tag size parameter +# if you want cacti to calculate the tagbits, set the tag size to "default" +-tag size (b) "default" +//-tag size (b) 22 + +# fast - data and tag access happen in parallel +# sequential - data array is accessed after accessing the tag array +# normal - data array lookup and tag access happen in parallel +# final data block is broadcasted in data array h-tree +# after getting the signal from the tag array +//-access mode (normal, sequential, fast) - "fast" +-access mode (normal, sequential, fast) - "normal" +//-access mode (normal, sequential, fast) - "sequential" + + +# DESIGN OBJECTIVE for UCA (or banks in NUCA) +-design objective (weight delay, dynamic power, leakage power, cycle time, area) 0:0:0:100:0 + +# Percentage deviation from the minimum value +# Ex: A deviation value of 10:1000:1000:1000:1000 will try to find an organization +# that compromises at most 10% delay. +# NOTE: Try reasonable values for % deviation. Inconsistent deviation +# percentage values will not produce any valid organizations. For example, +# 0:0:100:100:100 will try to identify an organization that has both +# least delay and dynamic power. Since such an organization is not possible, CACTI will +# throw an error. Refer CACTI-6 Technical report for more details +-deviate (delay, dynamic power, leakage power, cycle time, area) 20:100000:100000:100000:100000 + +# Objective for NUCA +-NUCAdesign objective (weight delay, dynamic power, leakage power, cycle time, area) 100:100:0:0:100 +-NUCAdeviate (delay, dynamic power, leakage power, cycle time, area) 10:10000:10000:10000:10000 + +# Set optimize tag to ED or ED^2 to obtain a cache configuration optimized for +# energy-delay or energy-delay sq. product +# Note: Optimize tag will disable weight or deviate values mentioned above +# Set it to NONE to let weight and deviate values determine the +# appropriate cache configuration +//-Optimize ED or ED^2 (ED, ED^2, NONE): "ED" +-Optimize ED or ED^2 (ED, ED^2, NONE): "ED^2" +//-Optimize ED or ED^2 (ED, ED^2, NONE): "NONE" + +-Cache model (NUCA, UCA) - "UCA" +//-Cache model (NUCA, UCA) - "NUCA" + +# In order for CACTI to find the optimal NUCA bank value the following +# variable should be assigned 0. +-NUCA bank count 0 + +# NOTE: for nuca network frequency is set to a default value of +# 5GHz in time.c. CACTI automatically +# calculates the maximum possible frequency and downgrades this value if necessary + +# By default CACTI considers both full-swing and low-swing +# wires to find an optimal configuration. However, it is possible to +# restrict the search space by changing the signaling from "default" to +# "fullswing" or "lowswing" type. +-Wire signaling (fullswing, lowswing, default) - "Global_30" +//-Wire signaling (fullswing, lowswing, default) - "default" +//-Wire signaling (fullswing, lowswing, default) - "lowswing" + +//-Wire inside mat - "global" +-Wire inside mat - "semi-global" +//-Wire outside mat - "global" +-Wire outside mat - "semi-global" + +-Interconnect projection - "conservative" +//-Interconnect projection - "aggressive" + +# Contention in network (which is a function of core count and cache level) is one of +# the critical factor used for deciding the optimal bank count value +# core count can be 4, 8, or 16 +//-Core count 4 +-Core count 8 +//-Core count 16 +-Cache level (L2/L3) - "L3" + +-Add ECC - "true" + +//-Print level (DETAILED, CONCISE) - "CONCISE" +-Print level (DETAILED, CONCISE) - "DETAILED" + +# for debugging +//-Print input parameters - "true" +-Print input parameters - "false" +# force CACTI to model the cache with the +# following Ndbl, Ndwl, Nspd, Ndsam, +# and Ndcm values +//-Force cache config - "true" +-Force cache config - "false" +-Ndwl 1 +-Ndbl 1 +-Nspd 0 +-Ndcm 1 +-Ndsam1 0 +-Ndsam2 0 + + + +#### Default CONFIGURATION values for baseline external IO parameters to DRAM. More details can be found in the CACTI-IO technical report (), especially Chapters 2 and 3. + +# Memory Type (D=DDR3, L=LPDDR2, W=WideIO). Additional memory types can be defined by the user in extio_technology.cc, along with their technology and configuration parameters. + +-dram_type "D" +//-dram_type "L" +//-dram_type "W" +//-dram_type "S" + +# Memory State (R=Read, W=Write, I=Idle or S=Sleep) + +//-iostate "R" +-iostate "W" +//-iostate "I" +//-iostate "S" + +#Address bus timing. To alleviate the timing on the command and address bus due to high loading (shared across all memories on the channel), the interface allows for multi-cycle timing options. + +-addr_timing 0.5 //DDR +//-addr_timing 1.0 //SDR (half of DQ rate) +//-addr_timing 2.0 //2T timing (One fourth of DQ rate) +//-addr_timing 3.0 // 3T timing (One sixth of DQ rate) + +# Memory Density (Gbit per memory/DRAM die) + +-mem_density 8 Gb //Valid values 2^n Gb + +# IO frequency (MHz) (frequency of the external memory interface). + +-bus_freq 800 MHz //As of current memory standards (2013), valid range 0 to 1.5 GHz for DDR3, 0 to 533 MHz for LPDDR2, 0 - 800 MHz for WideIO and 0 - 3 GHz for Low-swing differential. However this can change, and the user is free to define valid ranges based on new memory types or extending beyond existing standards for existing dram types. + +# Duty Cycle (fraction of time in the Memory State defined above) + +-duty_cycle 1.0 //Valid range 0 to 1.0 + +# Activity factor for Data (0->1 transitions) per cycle (for DDR, need to account for the higher activity in this parameter. E.g. max. activity factor for DDR is 1.0, for SDR is 0.5) + +-activity_dq 1.0 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR +#-activity_dq .50 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR + +# Activity factor for Control/Address (0->1 transitions) per cycle (for DDR, need to account for the higher activity in this parameter. E.g. max. activity factor for DDR is 1.0, for SDR is 0.5) + +-activity_ca 1.0 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR, 0 to 0.25 for 2T, and 0 to 0.17 for 3T +#-activity_ca 0.25 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR, 0 to 0.25 for 2T, and 0 to 0.17 for 3T + +# Number of DQ pins + +-num_dq 72 //Number of DQ pins. Includes ECC pins. + +# Number of DQS pins. DQS is a data strobe that is sent along with a small number of data-lanes so the source synchronous timing is local to these DQ bits. Typically, 1 DQS per byte (8 DQ bits) is used. The DQS is also typucally differential, just like the CLK pin. + +-num_dqs 36 //2 x differential pairs. Include ECC pins as well. Valid range 0 to 18. For x4 memories, could have 36 DQS pins. + +# Number of CA pins + +-num_ca 35 //Valid range 0 to 35 pins. +#-num_ca 25 //Valid range 0 to 35 pins. + +# Number of CLK pins. CLK is typically a differential pair. In some cases additional CLK pairs may be used to limit the loading on the CLK pin. + +-num_clk 2 //2 x differential pair. Valid values: 0/2/4. + +# Number of Physical Ranks + +-num_mem_dq 2 //Number of ranks (loads on DQ and DQS) per buffer/register. If multiple LRDIMMs or buffer chips exist, the analysis for capacity and power is reported per buffer/register. + +# Width of the Memory Data Bus + +-mem_data_width 4 //x4 or x8 or x16 or x32 memories. For WideIO upto x128. diff --git a/Project_FARSI/cacti_for_FARSI/decoder.cc b/Project_FARSI/cacti_for_FARSI/decoder.cc new file mode 100644 index 00000000..6ab9bb5a --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/decoder.cc @@ -0,0 +1,1673 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + + +#include "area.h" +#include "decoder.h" +#include "parameter.h" +#include +#include +#include + +using namespace std; + + +Decoder::Decoder( + int _num_dec_signals, + bool flag_way_select, + double _C_ld_dec_out, + double _R_wire_dec_out, + bool fully_assoc_, + bool is_dram_, + bool is_wl_tr_, + const Area & cell_) +:exist(false), + C_ld_dec_out(_C_ld_dec_out), + R_wire_dec_out(_R_wire_dec_out), + num_gates(0), num_gates_min(2), + delay(0), + //power(), + fully_assoc(fully_assoc_), is_dram(is_dram_), + is_wl_tr(is_wl_tr_), + total_driver_nwidth(0), + total_driver_pwidth(0), + cell(cell_), + nodes_DSTN(1) +{ + + for (int i = 0; i < MAX_NUMBER_GATES_STAGE; i++) + { + w_dec_n[i] = 0; + w_dec_p[i] = 0; + } + + /* + * _num_dec_signals is the number of decoded signal as output + * num_addr_bits_dec is the number of signal to be decoded + * as the decoders input. + */ + int num_addr_bits_dec = _log2(_num_dec_signals); + + if (num_addr_bits_dec < 4) + { + if (flag_way_select) + { + exist = true; + num_in_signals = 2; + } + else + { + num_in_signals = 0; + } + } + else + { + exist = true; + + if (flag_way_select) + { + num_in_signals = 3; + } + else + { + num_in_signals = 2; + } + } + + assert(cell.h>0); + assert(cell.w>0); + // the height of a row-decoder-driver cell is fixed to be 4 * cell.h; + //area.h = 4 * cell.h; + area.h = g_tp.h_dec * cell.h; + + compute_widths(); + compute_area(); + +} + + + +void Decoder::compute_widths() +{ + double F; + double p_to_n_sz_ratio = pmos_to_nmos_sz_ratio(is_dram, is_wl_tr); + double gnand2 = (2 + p_to_n_sz_ratio) / (1 + p_to_n_sz_ratio); + double gnand3 = (3 + p_to_n_sz_ratio) / (1 + p_to_n_sz_ratio); + + if (exist) + { + if (num_in_signals == 2 || fully_assoc) + { + w_dec_n[0] = 2 * g_tp.min_w_nmos_; + w_dec_p[0] = p_to_n_sz_ratio * g_tp.min_w_nmos_; + F = gnand2; + } + else + { + w_dec_n[0] = 3 * g_tp.min_w_nmos_; + w_dec_p[0] = p_to_n_sz_ratio * g_tp.min_w_nmos_; + F = gnand3; + } + + F *= C_ld_dec_out / (gate_C(w_dec_n[0], 0, is_dram, false, is_wl_tr) + + gate_C(w_dec_p[0], 0, is_dram, false, is_wl_tr)); + num_gates = logical_effort( + num_gates_min, + num_in_signals == 2 ? gnand2 : gnand3, + F, + w_dec_n, + w_dec_p, + C_ld_dec_out, + p_to_n_sz_ratio, + is_dram, + is_wl_tr, + g_tp.max_w_nmos_dec); + + } +} + + + +void Decoder::compute_area() +{ + double cumulative_area = 0; + double cumulative_curr = 0; // cumulative leakage current + double cumulative_curr_Ig = 0; // cumulative leakage current + + if (exist) + { // First check if this decoder exists + if (num_in_signals == 2) + { + cumulative_area = compute_gate_area(NAND, 2, w_dec_p[0], w_dec_n[0], area.h); + cumulative_curr = cmos_Isub_leakage(w_dec_n[0], w_dec_p[0], 2, nand,is_dram); + cumulative_curr_Ig = cmos_Ig_leakage(w_dec_n[0], w_dec_p[0], 2, nand,is_dram); + } + else if (num_in_signals == 3) + { + cumulative_area = compute_gate_area(NAND, 3, w_dec_p[0], w_dec_n[0], area.h); + cumulative_curr = cmos_Isub_leakage(w_dec_n[0], w_dec_p[0], 3, nand, is_dram);; + cumulative_curr_Ig = cmos_Ig_leakage(w_dec_n[0], w_dec_p[0], 3, nand, is_dram); + } + + for (int i = 1; i < num_gates; i++) + { + cumulative_area += compute_gate_area(INV, 1, w_dec_p[i], w_dec_n[i], area.h); + cumulative_curr += cmos_Isub_leakage(w_dec_n[i], w_dec_p[i], 1, inv, is_dram); + cumulative_curr_Ig = cmos_Ig_leakage(w_dec_n[i], w_dec_p[i], 1, inv, is_dram); + } + power.readOp.leakage = cumulative_curr * g_tp.peri_global.Vdd; + power.readOp.gate_leakage = cumulative_curr_Ig * g_tp.peri_global.Vdd; + + area.w = (cumulative_area / area.h); + } +} + +void Decoder::compute_power_gating() +{ + //For all driver change there is only one sleep transistors to save area + //Total transistor width for sleep tx calculation + for (int i = 1; i <=num_gates; i++) + { + total_driver_nwidth += w_dec_n[i]; + total_driver_pwidth += w_dec_p[i]; + } + + //compute sleep tx + bool is_footer = false; + double Isat_subarray = simplified_nmos_Isat(total_driver_nwidth); + double detalV; + double c_wakeup; + + c_wakeup = drain_C_(total_driver_pwidth, PCH, 1, 1, cell.h);//Psleep tx + detalV = g_tp.peri_global.Vdd-g_tp.peri_global.Vcc_min; + if (g_ip->power_gating) + sleeptx = new Sleep_tx (g_ip->perfloss, + Isat_subarray, + is_footer, + c_wakeup, + detalV, + nodes_DSTN, + area); +} + +double Decoder::compute_delays(double inrisetime) +{ + if (exist) + { + double ret_val = 0; // outrisetime + int i; + double rd, tf, this_delay, c_load, c_intrinsic, Vpp; + double Vdd = g_tp.peri_global.Vdd; + + if ((is_wl_tr) && (is_dram)) + { + Vpp = g_tp.vpp; + } + else if (is_wl_tr) + { + Vpp = g_tp.sram_cell.Vdd; + } + else + { + Vpp = g_tp.peri_global.Vdd; + } + + // first check whether a decoder is required at all + rd = tr_R_on(w_dec_n[0], NCH, num_in_signals, is_dram, false, is_wl_tr); + c_load = gate_C(w_dec_n[1] + w_dec_p[1], 0.0, is_dram, false, is_wl_tr); + c_intrinsic = drain_C_(w_dec_p[0], PCH, 1, 1, area.h, is_dram, false, is_wl_tr) * num_in_signals + + drain_C_(w_dec_n[0], NCH, num_in_signals, 1, area.h, is_dram, false, is_wl_tr); + tf = rd * (c_intrinsic + c_load); + this_delay = horowitz(inrisetime, tf, 0.5, 0.5, RISE); + delay += this_delay; + inrisetime = this_delay / (1.0 - 0.5); + power.readOp.dynamic += (c_load + c_intrinsic) * Vdd * Vdd; + + for (i = 1; i < num_gates - 1; ++i) + { + rd = tr_R_on(w_dec_n[i], NCH, 1, is_dram, false, is_wl_tr); + c_load = gate_C(w_dec_p[i+1] + w_dec_n[i+1], 0.0, is_dram, false, is_wl_tr); + c_intrinsic = drain_C_(w_dec_p[i], PCH, 1, 1, area.h, is_dram, false, is_wl_tr) + + drain_C_(w_dec_n[i], NCH, 1, 1, area.h, is_dram, false, is_wl_tr); + tf = rd * (c_intrinsic + c_load); + this_delay = horowitz(inrisetime, tf, 0.5, 0.5, RISE); + delay += this_delay; + inrisetime = this_delay / (1.0 - 0.5); + power.readOp.dynamic += (c_load + c_intrinsic) * Vdd * Vdd; + } + + // add delay of final inverter that drives the wordline + i = num_gates - 1; + c_load = C_ld_dec_out; + rd = tr_R_on(w_dec_n[i], NCH, 1, is_dram, false, is_wl_tr); + c_intrinsic = drain_C_(w_dec_p[i], PCH, 1, 1, area.h, is_dram, false, is_wl_tr) + + drain_C_(w_dec_n[i], NCH, 1, 1, area.h, is_dram, false, is_wl_tr); + tf = rd * (c_intrinsic + c_load) + R_wire_dec_out * c_load / 2; + this_delay = horowitz(inrisetime, tf, 0.5, 0.5, RISE); + delay += this_delay; + ret_val = this_delay / (1.0 - 0.5); + power.readOp.dynamic += c_load * Vpp * Vpp + c_intrinsic * Vdd * Vdd; + + compute_power_gating(); + return ret_val; + } + else + { + return 0.0; + } +} + +void Decoder::leakage_feedback(double temperature) +{ + double cumulative_curr = 0; // cumulative leakage current + double cumulative_curr_Ig = 0; // cumulative leakage current + + if (exist) + { // First check if this decoder exists + if (num_in_signals == 2) + { + cumulative_curr = cmos_Isub_leakage(w_dec_n[0], w_dec_p[0], 2, nand,is_dram); + cumulative_curr_Ig = cmos_Ig_leakage(w_dec_n[0], w_dec_p[0], 2, nand,is_dram); + } + else if (num_in_signals == 3) + { + cumulative_curr = cmos_Isub_leakage(w_dec_n[0], w_dec_p[0], 3, nand, is_dram);; + cumulative_curr_Ig = cmos_Ig_leakage(w_dec_n[0], w_dec_p[0], 3, nand, is_dram); + } + + for (int i = 1; i < num_gates; i++) + { + cumulative_curr += cmos_Isub_leakage(w_dec_n[i], w_dec_p[i], 1, inv, is_dram); + cumulative_curr_Ig = cmos_Ig_leakage(w_dec_n[i], w_dec_p[i], 1, inv, is_dram); + } + + power.readOp.leakage = cumulative_curr * g_tp.peri_global.Vdd; + power.readOp.gate_leakage = cumulative_curr_Ig * g_tp.peri_global.Vdd; + } +} + +PredecBlk::PredecBlk( + int num_dec_signals, + Decoder * dec_, + double C_wire_predec_blk_out, + double R_wire_predec_blk_out_, + int num_dec_per_predec, + bool is_dram, + bool is_blk1) + :dec(dec_), + exist(false), + number_input_addr_bits(0), + C_ld_predec_blk_out(0), + R_wire_predec_blk_out(0), + branch_effort_nand2_gate_output(1), + branch_effort_nand3_gate_output(1), + flag_two_unique_paths(false), + flag_L2_gate(0), + number_inputs_L1_gate(0), + number_gates_L1_nand2_path(0), + number_gates_L1_nand3_path(0), + number_gates_L2(0), + min_number_gates_L1(2), + min_number_gates_L2(2), + num_L1_active_nand2_path(0), + num_L1_active_nand3_path(0), + delay_nand2_path(0), + delay_nand3_path(0), + power_nand2_path(), + power_nand3_path(), + power_L2(), + is_dram_(is_dram) +{ + int branch_effort_predec_out; + double C_ld_dec_gate; + int num_addr_bits_dec = _log2(num_dec_signals); + int blk1_num_input_addr_bits = (num_addr_bits_dec + 1) / 2; + int blk2_num_input_addr_bits = num_addr_bits_dec - blk1_num_input_addr_bits; + + w_L1_nand2_n[0] = 0; + w_L1_nand2_p[0] = 0; + w_L1_nand3_n[0] = 0; + w_L1_nand3_p[0] = 0; + + if (is_blk1 == true) + { + if (num_addr_bits_dec <= 0) + { + return; + } + else if (num_addr_bits_dec < 4) + { + // Just one predecoder block is required with NAND2 gates. No decoder required. + // The first level of predecoding directly drives the decoder output load + exist = true; + number_input_addr_bits = num_addr_bits_dec; + R_wire_predec_blk_out = dec->R_wire_dec_out; + C_ld_predec_blk_out = dec->C_ld_dec_out; + } + else + { + exist = true; + number_input_addr_bits = blk1_num_input_addr_bits; + branch_effort_predec_out = (1 << blk2_num_input_addr_bits); + C_ld_dec_gate = num_dec_per_predec * gate_C(dec->w_dec_n[0] + dec->w_dec_p[0], 0, is_dram_, false, false); + R_wire_predec_blk_out = R_wire_predec_blk_out_; + C_ld_predec_blk_out = branch_effort_predec_out * C_ld_dec_gate + C_wire_predec_blk_out; + } + } + else + { + if (num_addr_bits_dec >= 4) + { + exist = true; + number_input_addr_bits = blk2_num_input_addr_bits; + branch_effort_predec_out = (1 << blk1_num_input_addr_bits); + C_ld_dec_gate = num_dec_per_predec * gate_C(dec->w_dec_n[0] + dec->w_dec_p[0], 0, is_dram_, false, false); + R_wire_predec_blk_out = R_wire_predec_blk_out_; + C_ld_predec_blk_out = branch_effort_predec_out * C_ld_dec_gate + C_wire_predec_blk_out; + } + } + + compute_widths(); + compute_area(); +} + + + +void PredecBlk::compute_widths() +{ + double F, c_load_nand3_path, c_load_nand2_path; + double p_to_n_sz_ratio = pmos_to_nmos_sz_ratio(is_dram_); + double gnand2 = (2 + p_to_n_sz_ratio) / (1 + p_to_n_sz_ratio); + double gnand3 = (3 + p_to_n_sz_ratio) / (1 + p_to_n_sz_ratio); + + if (exist == false) return; + + + switch (number_input_addr_bits) + { + case 1: + flag_two_unique_paths = false; + number_inputs_L1_gate = 2; + flag_L2_gate = 0; + break; + case 2: + flag_two_unique_paths = false; + number_inputs_L1_gate = 2; + flag_L2_gate = 0; + break; + case 3: + flag_two_unique_paths = false; + number_inputs_L1_gate = 3; + flag_L2_gate = 0; + break; + case 4: + flag_two_unique_paths = false; + number_inputs_L1_gate = 2; + flag_L2_gate = 2; + branch_effort_nand2_gate_output = 4; + break; + case 5: + flag_two_unique_paths = true; + flag_L2_gate = 2; + branch_effort_nand2_gate_output = 8; + branch_effort_nand3_gate_output = 4; + break; + case 6: + flag_two_unique_paths = false; + number_inputs_L1_gate = 3; + flag_L2_gate = 2; + branch_effort_nand3_gate_output = 8; + break; + case 7: + flag_two_unique_paths = true; + flag_L2_gate = 3; + branch_effort_nand2_gate_output = 32; + branch_effort_nand3_gate_output = 16; + break; + case 8: + flag_two_unique_paths = true; + flag_L2_gate = 3; + branch_effort_nand2_gate_output = 64; + branch_effort_nand3_gate_output = 32; + break; + case 9: + flag_two_unique_paths = false; + number_inputs_L1_gate = 3; + flag_L2_gate = 3; + branch_effort_nand3_gate_output = 64; + break; + default: + assert(0); + break; + } + + // find the number of gates and sizing in second level of predecoder (if there is a second level) + if (flag_L2_gate) + { + if (flag_L2_gate == 2) + { // 2nd level is a NAND2 gate + w_L2_n[0] = 2 * g_tp.min_w_nmos_; + F = gnand2; + } + else + { // 2nd level is a NAND3 gate + w_L2_n[0] = 3 * g_tp.min_w_nmos_; + F = gnand3; + } + w_L2_p[0] = p_to_n_sz_ratio * g_tp.min_w_nmos_; + F *= C_ld_predec_blk_out / (gate_C(w_L2_n[0], 0, is_dram_) + gate_C(w_L2_p[0], 0, is_dram_)); + number_gates_L2 = logical_effort( + min_number_gates_L2, + flag_L2_gate == 2 ? gnand2 : gnand3, + F, + w_L2_n, + w_L2_p, + C_ld_predec_blk_out, + p_to_n_sz_ratio, + is_dram_, false, + g_tp.max_w_nmos_); + + // Now find the number of gates and widths in first level of predecoder + if ((flag_two_unique_paths)||(number_inputs_L1_gate == 2)) + { // Whenever flag_two_unique_paths is true, it means first level of decoder employs + // both NAND2 and NAND3 gates. Or when number_inputs_L1_gate is 2, it means + // a NAND2 gate is used in the first level of the predecoder + c_load_nand2_path = branch_effort_nand2_gate_output * + (gate_C(w_L2_n[0], 0, is_dram_) + + gate_C(w_L2_p[0], 0, is_dram_)); + w_L1_nand2_n[0] = 2 * g_tp.min_w_nmos_; + w_L1_nand2_p[0] = p_to_n_sz_ratio * g_tp.min_w_nmos_; + F = gnand2 * c_load_nand2_path / + (gate_C(w_L1_nand2_n[0], 0, is_dram_) + + gate_C(w_L1_nand2_p[0], 0, is_dram_)); + number_gates_L1_nand2_path = logical_effort( + min_number_gates_L1, + gnand2, + F, + w_L1_nand2_n, + w_L1_nand2_p, + c_load_nand2_path, + p_to_n_sz_ratio, + is_dram_, false, + g_tp.max_w_nmos_); + } + + //Now find widths of gates along path in which first gate is a NAND3 + if ((flag_two_unique_paths)||(number_inputs_L1_gate == 3)) + { // Whenever flag_two_unique_paths is TRUE, it means first level of decoder employs + // both NAND2 and NAND3 gates. Or when number_inputs_L1_gate is 3, it means + // a NAND3 gate is used in the first level of the predecoder + c_load_nand3_path = branch_effort_nand3_gate_output * + (gate_C(w_L2_n[0], 0, is_dram_) + + gate_C(w_L2_p[0], 0, is_dram_)); + w_L1_nand3_n[0] = 3 * g_tp.min_w_nmos_; + w_L1_nand3_p[0] = p_to_n_sz_ratio * g_tp.min_w_nmos_; + F = gnand3 * c_load_nand3_path / + (gate_C(w_L1_nand3_n[0], 0, is_dram_) + + gate_C(w_L1_nand3_p[0], 0, is_dram_)); + number_gates_L1_nand3_path = logical_effort( + min_number_gates_L1, + gnand3, + F, + w_L1_nand3_n, + w_L1_nand3_p, + c_load_nand3_path, + p_to_n_sz_ratio, + is_dram_, false, + g_tp.max_w_nmos_); + } + } + else + { // find number of gates and widths in first level of predecoder block when there is no second level + if (number_inputs_L1_gate == 2) + { + w_L1_nand2_n[0] = 2 * g_tp.min_w_nmos_; + w_L1_nand2_p[0] = p_to_n_sz_ratio * g_tp.min_w_nmos_; + F = gnand2*C_ld_predec_blk_out / + (gate_C(w_L1_nand2_n[0], 0, is_dram_) + + gate_C(w_L1_nand2_p[0], 0, is_dram_)); + number_gates_L1_nand2_path = logical_effort( + min_number_gates_L1, + gnand2, + F, + w_L1_nand2_n, + w_L1_nand2_p, + C_ld_predec_blk_out, + p_to_n_sz_ratio, + is_dram_, false, + g_tp.max_w_nmos_); + } + else if (number_inputs_L1_gate == 3) + { + w_L1_nand3_n[0] = 3 * g_tp.min_w_nmos_; + w_L1_nand3_p[0] = p_to_n_sz_ratio * g_tp.min_w_nmos_; + F = gnand3*C_ld_predec_blk_out / + (gate_C(w_L1_nand3_n[0], 0, is_dram_) + + gate_C(w_L1_nand3_p[0], 0, is_dram_)); + number_gates_L1_nand3_path = logical_effort( + min_number_gates_L1, + gnand3, + F, + w_L1_nand3_n, + w_L1_nand3_p, + C_ld_predec_blk_out, + p_to_n_sz_ratio, + is_dram_, false, + g_tp.max_w_nmos_); + } + } +} + + + +void PredecBlk::compute_area() +{ + if (exist) + { // First check whether a predecoder block is needed + int num_L1_nand2 = 0; + int num_L1_nand3 = 0; + int num_L2 = 0; + double tot_area_L1_nand3 =0; + double leak_L1_nand3 =0; + double gate_leak_L1_nand3 =0; + + double tot_area_L1_nand2 = compute_gate_area(NAND, 2, w_L1_nand2_p[0], w_L1_nand2_n[0], g_tp.cell_h_def); + double leak_L1_nand2 = cmos_Isub_leakage(w_L1_nand2_n[0], w_L1_nand2_p[0], 2, nand, is_dram_); + double gate_leak_L1_nand2 = cmos_Ig_leakage(w_L1_nand2_n[0], w_L1_nand2_p[0], 2, nand, is_dram_); + if (number_inputs_L1_gate != 3) { + tot_area_L1_nand3 = 0; + leak_L1_nand3 = 0; + gate_leak_L1_nand3 =0; + } + else { + tot_area_L1_nand3 = compute_gate_area(NAND, 3, w_L1_nand3_p[0], w_L1_nand3_n[0], g_tp.cell_h_def); + leak_L1_nand3 = cmos_Isub_leakage(w_L1_nand3_n[0], w_L1_nand3_p[0], 3, nand); + gate_leak_L1_nand3 = cmos_Ig_leakage(w_L1_nand3_n[0], w_L1_nand3_p[0], 3, nand); + } + + switch (number_input_addr_bits) + { + case 1: //2 NAND2 gates + num_L1_nand2 = 2; + num_L2 = 0; + num_L1_active_nand2_path =1; + num_L1_active_nand3_path =0; + break; + case 2: //4 NAND2 gates + num_L1_nand2 = 4; + num_L2 = 0; + num_L1_active_nand2_path =1; + num_L1_active_nand3_path =0; + break; + case 3: //8 NAND3 gates + num_L1_nand3 = 8; + num_L2 = 0; + num_L1_active_nand2_path =0; + num_L1_active_nand3_path =1; + break; + case 4: //4 + 4 NAND2 gates + num_L1_nand2 = 8; + num_L2 = 16; + num_L1_active_nand2_path =2; + num_L1_active_nand3_path =0; + break; + case 5: //4 NAND2 gates, 8 NAND3 gates + num_L1_nand2 = 4; + num_L1_nand3 = 8; + num_L2 = 32; + num_L1_active_nand2_path =1; + num_L1_active_nand3_path =1; + break; + case 6: //8 + 8 NAND3 gates + num_L1_nand3 = 16; + num_L2 = 64; + num_L1_active_nand2_path =0; + num_L1_active_nand3_path =2; + break; + case 7: //4 + 4 NAND2 gates, 8 NAND3 gates + num_L1_nand2 = 8; + num_L1_nand3 = 8; + num_L2 = 128; + num_L1_active_nand2_path =2; + num_L1_active_nand3_path =1; + break; + case 8: //4 NAND2 gates, 8 + 8 NAND3 gates + num_L1_nand2 = 4; + num_L1_nand3 = 16; + num_L2 = 256; + num_L1_active_nand2_path =2; + num_L1_active_nand3_path =2; + break; + case 9: //8 + 8 + 8 NAND3 gates + num_L1_nand3 = 24; + num_L2 = 512; + num_L1_active_nand2_path =0; + num_L1_active_nand3_path =3; + break; + default: + break; + } + + for (int i = 1; i < number_gates_L1_nand2_path; ++i) + { + tot_area_L1_nand2 += compute_gate_area(INV, 1, w_L1_nand2_p[i], w_L1_nand2_n[i], g_tp.cell_h_def); + leak_L1_nand2 += cmos_Isub_leakage(w_L1_nand2_n[i], w_L1_nand2_p[i], 2, nand, is_dram_); + gate_leak_L1_nand2 += cmos_Ig_leakage(w_L1_nand2_n[i], w_L1_nand2_p[i], 2, nand, is_dram_); + } + tot_area_L1_nand2 *= num_L1_nand2; + leak_L1_nand2 *= num_L1_nand2; + gate_leak_L1_nand2 *= num_L1_nand2; + + for (int i = 1; i < number_gates_L1_nand3_path; ++i) + { + tot_area_L1_nand3 += compute_gate_area(INV, 1, w_L1_nand3_p[i], w_L1_nand3_n[i], g_tp.cell_h_def); + leak_L1_nand3 += cmos_Isub_leakage(w_L1_nand3_n[i], w_L1_nand3_p[i], 3, nand, is_dram_); + gate_leak_L1_nand3 += cmos_Ig_leakage(w_L1_nand3_n[i], w_L1_nand3_p[i], 3, nand, is_dram_); + } + tot_area_L1_nand3 *= num_L1_nand3; + leak_L1_nand3 *= num_L1_nand3; + gate_leak_L1_nand3 *= num_L1_nand3; + + double cumulative_area_L1 = tot_area_L1_nand2 + tot_area_L1_nand3; + double cumulative_area_L2 = 0.0; + double leakage_L2 = 0.0; + double gate_leakage_L2 = 0.0; + + if (flag_L2_gate == 2) + { + cumulative_area_L2 = compute_gate_area(NAND, 2, w_L2_p[0], w_L2_n[0], g_tp.cell_h_def); + leakage_L2 = cmos_Isub_leakage(w_L2_n[0], w_L2_p[0], 2, nand, is_dram_); + gate_leakage_L2 = cmos_Ig_leakage(w_L2_n[0], w_L2_p[0], 2, nand, is_dram_); + } + else if (flag_L2_gate == 3) + { + cumulative_area_L2 = compute_gate_area(NAND, 3, w_L2_p[0], w_L2_n[0], g_tp.cell_h_def); + leakage_L2 = cmos_Isub_leakage(w_L2_n[0], w_L2_p[0], 3, nand, is_dram_); + gate_leakage_L2 = cmos_Ig_leakage(w_L2_n[0], w_L2_p[0], 3, nand, is_dram_); + } + + for (int i = 1; i < number_gates_L2; ++i) + { + cumulative_area_L2 += compute_gate_area(INV, 1, w_L2_p[i], w_L2_n[i], g_tp.cell_h_def); + leakage_L2 += cmos_Isub_leakage(w_L2_n[i], w_L2_p[i], 2, inv, is_dram_); + gate_leakage_L2 += cmos_Ig_leakage(w_L2_n[i], w_L2_p[i], 2, inv, is_dram_); + } + cumulative_area_L2 *= num_L2; + leakage_L2 *= num_L2; + gate_leakage_L2 *= num_L2; + + power_nand2_path.readOp.leakage = leak_L1_nand2 * g_tp.peri_global.Vdd; + power_nand3_path.readOp.leakage = leak_L1_nand3 * g_tp.peri_global.Vdd; + power_L2.readOp.leakage = leakage_L2 * g_tp.peri_global.Vdd; + area.set_area(cumulative_area_L1 + cumulative_area_L2); + power_nand2_path.readOp.gate_leakage = gate_leak_L1_nand2 * g_tp.peri_global.Vdd; + power_nand3_path.readOp.gate_leakage = gate_leak_L1_nand3 * g_tp.peri_global.Vdd; + power_L2.readOp.gate_leakage = gate_leakage_L2 * g_tp.peri_global.Vdd; + } +} + + + +pair PredecBlk::compute_delays( + pair inrisetime) // +{ + pair ret_val; + ret_val.first = 0; // outrisetime_nand2_path + ret_val.second = 0; // outrisetime_nand3_path + + double inrisetime_nand2_path = inrisetime.first; + double inrisetime_nand3_path = inrisetime.second; + int i; + double rd, c_load, c_intrinsic, tf, this_delay; + double Vdd = g_tp.peri_global.Vdd; + + // TODO: following delay calculation part can be greatly simplified. + // first check whether a predecoder block is required + if (exist) + { + //Find delay in first level of predecoder block + //First find delay in path + if ((flag_two_unique_paths) || (number_inputs_L1_gate == 2)) + { + //First gate is a NAND2 gate + rd = tr_R_on(w_L1_nand2_n[0], NCH, 2, is_dram_); + c_load = gate_C(w_L1_nand2_n[1] + w_L1_nand2_p[1], 0.0, is_dram_); + c_intrinsic = 2 * drain_C_(w_L1_nand2_p[0], PCH, 1, 1, g_tp.cell_h_def, is_dram_) + + drain_C_(w_L1_nand2_n[0], NCH, 2, 1, g_tp.cell_h_def, is_dram_); + tf = rd * (c_intrinsic + c_load); + this_delay = horowitz(inrisetime_nand2_path, tf, 0.5, 0.5, RISE); + delay_nand2_path += this_delay; + inrisetime_nand2_path = this_delay / (1.0 - 0.5); + power_nand2_path.readOp.dynamic += (c_load + c_intrinsic) * Vdd * Vdd; + + //Add delays of all but the last inverter in the chain + for (i = 1; i < number_gates_L1_nand2_path - 1; ++i) + { + rd = tr_R_on(w_L1_nand2_n[i], NCH, 1, is_dram_); + c_load = gate_C(w_L1_nand2_n[i+1] + w_L1_nand2_p[i+1], 0.0, is_dram_); + c_intrinsic = drain_C_(w_L1_nand2_p[i], PCH, 1, 1, g_tp.cell_h_def, is_dram_) + + drain_C_(w_L1_nand2_n[i], NCH, 1, 1, g_tp.cell_h_def, is_dram_); + tf = rd * (c_intrinsic + c_load); + this_delay = horowitz(inrisetime_nand2_path, tf, 0.5, 0.5, RISE); + delay_nand2_path += this_delay; + inrisetime_nand2_path = this_delay / (1.0 - 0.5); + power_nand2_path.readOp.dynamic += (c_intrinsic + c_load) * Vdd * Vdd; + } + + //Add delay of the last inverter + i = number_gates_L1_nand2_path - 1; + rd = tr_R_on(w_L1_nand2_n[i], NCH, 1, is_dram_); + if (flag_L2_gate) + { + c_load = branch_effort_nand2_gate_output*(gate_C(w_L2_n[0], 0, is_dram_) + gate_C(w_L2_p[0], 0, is_dram_)); + c_intrinsic = drain_C_(w_L1_nand2_p[i], PCH, 1, 1, g_tp.cell_h_def, is_dram_) + + drain_C_(w_L1_nand2_n[i], NCH, 1, 1, g_tp.cell_h_def, is_dram_); + tf = rd * (c_intrinsic + c_load); + this_delay = horowitz(inrisetime_nand2_path, tf, 0.5, 0.5, RISE); + delay_nand2_path += this_delay; + inrisetime_nand2_path = this_delay / (1.0 - 0.5); + power_nand2_path.readOp.dynamic += (c_intrinsic + c_load) * Vdd * Vdd; + } + else + { //First level directly drives decoder output load + c_load = C_ld_predec_blk_out; + c_intrinsic = drain_C_(w_L1_nand2_p[i], PCH, 1, 1, g_tp.cell_h_def, is_dram_) + + drain_C_(w_L1_nand2_n[i], NCH, 1, 1, g_tp.cell_h_def, is_dram_); + tf = rd * (c_intrinsic + c_load) + R_wire_predec_blk_out * c_load / 2; + this_delay = horowitz(inrisetime_nand2_path, tf, 0.5, 0.5, RISE); + delay_nand2_path += this_delay; + ret_val.first = this_delay / (1.0 - 0.5); + power_nand2_path.readOp.dynamic += (c_intrinsic + c_load) * Vdd * Vdd; + } + } + + if ((flag_two_unique_paths) || (number_inputs_L1_gate == 3)) + { //Check if the number of gates in the first level is more than 1. + //First gate is a NAND3 gate + rd = tr_R_on(w_L1_nand3_n[0], NCH, 3, is_dram_); + c_load = gate_C(w_L1_nand3_n[1] + w_L1_nand3_p[1], 0.0, is_dram_); + c_intrinsic = 3 * drain_C_(w_L1_nand3_p[0], PCH, 1, 1, g_tp.cell_h_def, is_dram_) + + drain_C_(w_L1_nand3_n[0], NCH, 3, 1, g_tp.cell_h_def, is_dram_); + tf = rd * (c_intrinsic + c_load); + this_delay = horowitz(inrisetime_nand3_path, tf, 0.5, 0.5, RISE); + delay_nand3_path += this_delay; + inrisetime_nand3_path = this_delay / (1.0 - 0.5); + power_nand3_path.readOp.dynamic += (c_intrinsic + c_load) * Vdd * Vdd; + + //Add delays of all but the last inverter in the chain + for (i = 1; i < number_gates_L1_nand3_path - 1; ++i) + { + rd = tr_R_on(w_L1_nand3_n[i], NCH, 1, is_dram_); + c_load = gate_C(w_L1_nand3_n[i+1] + w_L1_nand3_p[i+1], 0.0, is_dram_); + c_intrinsic = drain_C_(w_L1_nand3_p[i], PCH, 1, 1, g_tp.cell_h_def, is_dram_) + + drain_C_(w_L1_nand3_n[i], NCH, 1, 1, g_tp.cell_h_def, is_dram_); + tf = rd * (c_intrinsic + c_load); + this_delay = horowitz(inrisetime_nand3_path, tf, 0.5, 0.5, RISE); + delay_nand3_path += this_delay; + inrisetime_nand3_path = this_delay / (1.0 - 0.5); + power_nand3_path.readOp.dynamic += (c_intrinsic + c_load) * Vdd * Vdd; + } + + //Add delay of the last inverter + i = number_gates_L1_nand3_path - 1; + rd = tr_R_on(w_L1_nand3_n[i], NCH, 1, is_dram_); + if (flag_L2_gate) + { + c_load = branch_effort_nand3_gate_output*(gate_C(w_L2_n[0], 0, is_dram_) + gate_C(w_L2_p[0], 0, is_dram_)); + c_intrinsic = drain_C_(w_L1_nand3_p[i], PCH, 1, 1, g_tp.cell_h_def, is_dram_) + + drain_C_(w_L1_nand3_n[i], NCH, 1, 1, g_tp.cell_h_def, is_dram_); + tf = rd * (c_intrinsic + c_load); + this_delay = horowitz(inrisetime_nand3_path, tf, 0.5, 0.5, RISE); + delay_nand3_path += this_delay; + inrisetime_nand3_path = this_delay / (1.0 - 0.5); + power_nand3_path.readOp.dynamic += (c_intrinsic + c_load) * Vdd * Vdd; + } + else + { //First level directly drives decoder output load + c_load = C_ld_predec_blk_out; + c_intrinsic = drain_C_(w_L1_nand3_p[i], PCH, 1, 1, g_tp.cell_h_def, is_dram_) + + drain_C_(w_L1_nand3_n[i], NCH, 1, 1, g_tp.cell_h_def, is_dram_); + tf = rd * (c_intrinsic + c_load) + R_wire_predec_blk_out * c_load / 2; + this_delay = horowitz(inrisetime_nand3_path, tf, 0.5, 0.5, RISE); + delay_nand3_path += this_delay; + ret_val.second = this_delay / (1.0 - 0.5); + power_nand3_path.readOp.dynamic += (c_intrinsic + c_load) * Vdd * Vdd; + } + } + + // Find delay through second level + if (flag_L2_gate) + { + if (flag_L2_gate == 2) + { + rd = tr_R_on(w_L2_n[0], NCH, 2, is_dram_); + c_load = gate_C(w_L2_n[1] + w_L2_p[1], 0.0, is_dram_); + c_intrinsic = 2 * drain_C_(w_L2_p[0], PCH, 1, 1, g_tp.cell_h_def, is_dram_) + + drain_C_(w_L2_n[0], NCH, 2, 1, g_tp.cell_h_def, is_dram_); + tf = rd * (c_intrinsic + c_load); + this_delay = horowitz(inrisetime_nand2_path, tf, 0.5, 0.5, RISE); + delay_nand2_path += this_delay; + inrisetime_nand2_path = this_delay / (1.0 - 0.5); + power_L2.readOp.dynamic += (c_intrinsic + c_load) * Vdd * Vdd; + } + else + { // flag_L2_gate = 3 + rd = tr_R_on(w_L2_n[0], NCH, 3, is_dram_); + c_load = gate_C(w_L2_n[1] + w_L2_p[1], 0.0, is_dram_); + c_intrinsic = 3 * drain_C_(w_L2_p[0], PCH, 1, 1, g_tp.cell_h_def, is_dram_) + + drain_C_(w_L2_n[0], NCH, 3, 1, g_tp.cell_h_def, is_dram_); + tf = rd * (c_intrinsic + c_load); + this_delay = horowitz(inrisetime_nand3_path, tf, 0.5, 0.5, RISE); + delay_nand3_path += this_delay; + inrisetime_nand3_path = this_delay / (1.0 - 0.5); + power_L2.readOp.dynamic += (c_intrinsic + c_load) * Vdd * Vdd; + } + + for (i = 1; i < number_gates_L2 - 1; ++i) + { + rd = tr_R_on(w_L2_n[i], NCH, 1, is_dram_); + c_load = gate_C(w_L2_n[i+1] + w_L2_p[i+1], 0.0, is_dram_); + c_intrinsic = drain_C_(w_L2_p[i], PCH, 1, 1, g_tp.cell_h_def, is_dram_) + + drain_C_(w_L2_n[i], NCH, 1, 1, g_tp.cell_h_def, is_dram_); + tf = rd * (c_intrinsic + c_load); + this_delay = horowitz(inrisetime_nand2_path, tf, 0.5, 0.5, RISE); + delay_nand2_path += this_delay; + inrisetime_nand2_path = this_delay / (1.0 - 0.5); + this_delay = horowitz(inrisetime_nand3_path, tf, 0.5, 0.5, RISE); + delay_nand3_path += this_delay; + inrisetime_nand3_path = this_delay / (1.0 - 0.5); + power_L2.readOp.dynamic += (c_intrinsic + c_load) * Vdd * Vdd; + } + + //Add delay of final inverter that drives the wordline decoders + i = number_gates_L2 - 1; + c_load = C_ld_predec_blk_out; + rd = tr_R_on(w_L2_n[i], NCH, 1, is_dram_); + c_intrinsic = drain_C_(w_L2_p[i], PCH, 1, 1, g_tp.cell_h_def, is_dram_) + + drain_C_(w_L2_n[i], NCH, 1, 1, g_tp.cell_h_def, is_dram_); + tf = rd * (c_intrinsic + c_load) + R_wire_predec_blk_out * c_load / 2; + this_delay = horowitz(inrisetime_nand2_path, tf, 0.5, 0.5, RISE); + delay_nand2_path += this_delay; + ret_val.first = this_delay / (1.0 - 0.5); + this_delay = horowitz(inrisetime_nand3_path, tf, 0.5, 0.5, RISE); + delay_nand3_path += this_delay; + ret_val.second = this_delay / (1.0 - 0.5); + power_L2.readOp.dynamic += (c_intrinsic + c_load) * Vdd * Vdd; + } + } + + delay = (ret_val.first > ret_val.second) ? ret_val.first : ret_val.second; + return ret_val; +} + +void PredecBlk::leakage_feedback(double temperature) +{ + if (exist) + { // First check whether a predecoder block is needed + int num_L1_nand2 = 0; + int num_L1_nand3 = 0; + int num_L2 = 0; + double leak_L1_nand3 =0; + double gate_leak_L1_nand3 =0; + + double leak_L1_nand2 = cmos_Isub_leakage(w_L1_nand2_n[0], w_L1_nand2_p[0], 2, nand, is_dram_); + double gate_leak_L1_nand2 = cmos_Ig_leakage(w_L1_nand2_n[0], w_L1_nand2_p[0], 2, nand, is_dram_); + if (number_inputs_L1_gate != 3) { + leak_L1_nand3 = 0; + gate_leak_L1_nand3 =0; + } + else { + leak_L1_nand3 = cmos_Isub_leakage(w_L1_nand3_n[0], w_L1_nand3_p[0], 3, nand); + gate_leak_L1_nand3 = cmos_Ig_leakage(w_L1_nand3_n[0], w_L1_nand3_p[0], 3, nand); + } + + switch (number_input_addr_bits) + { + case 1: //2 NAND2 gates + num_L1_nand2 = 2; + num_L2 = 0; + num_L1_active_nand2_path =1; + num_L1_active_nand3_path =0; + break; + case 2: //4 NAND2 gates + num_L1_nand2 = 4; + num_L2 = 0; + num_L1_active_nand2_path =1; + num_L1_active_nand3_path =0; + break; + case 3: //8 NAND3 gates + num_L1_nand3 = 8; + num_L2 = 0; + num_L1_active_nand2_path =0; + num_L1_active_nand3_path =1; + break; + case 4: //4 + 4 NAND2 gates + num_L1_nand2 = 8; + num_L2 = 16; + num_L1_active_nand2_path =2; + num_L1_active_nand3_path =0; + break; + case 5: //4 NAND2 gates, 8 NAND3 gates + num_L1_nand2 = 4; + num_L1_nand3 = 8; + num_L2 = 32; + num_L1_active_nand2_path =1; + num_L1_active_nand3_path =1; + break; + case 6: //8 + 8 NAND3 gates + num_L1_nand3 = 16; + num_L2 = 64; + num_L1_active_nand2_path =0; + num_L1_active_nand3_path =2; + break; + case 7: //4 + 4 NAND2 gates, 8 NAND3 gates + num_L1_nand2 = 8; + num_L1_nand3 = 8; + num_L2 = 128; + num_L1_active_nand2_path =2; + num_L1_active_nand3_path =1; + break; + case 8: //4 NAND2 gates, 8 + 8 NAND3 gates + num_L1_nand2 = 4; + num_L1_nand3 = 16; + num_L2 = 256; + num_L1_active_nand2_path =2; + num_L1_active_nand3_path =2; + break; + case 9: //8 + 8 + 8 NAND3 gates + num_L1_nand3 = 24; + num_L2 = 512; + num_L1_active_nand2_path =0; + num_L1_active_nand3_path =3; + break; + default: + break; + } + + for (int i = 1; i < number_gates_L1_nand2_path; ++i) + { + leak_L1_nand2 += cmos_Isub_leakage(w_L1_nand2_n[i], w_L1_nand2_p[i], 2, nand, is_dram_); + gate_leak_L1_nand2 += cmos_Ig_leakage(w_L1_nand2_n[i], w_L1_nand2_p[i], 2, nand, is_dram_); + } + leak_L1_nand2 *= num_L1_nand2; + gate_leak_L1_nand2 *= num_L1_nand2; + + for (int i = 1; i < number_gates_L1_nand3_path; ++i) + { + leak_L1_nand3 += cmos_Isub_leakage(w_L1_nand3_n[i], w_L1_nand3_p[i], 3, nand, is_dram_); + gate_leak_L1_nand3 += cmos_Ig_leakage(w_L1_nand3_n[i], w_L1_nand3_p[i], 3, nand, is_dram_); + } + leak_L1_nand3 *= num_L1_nand3; + gate_leak_L1_nand3 *= num_L1_nand3; + + double leakage_L2 = 0.0; + double gate_leakage_L2 = 0.0; + + if (flag_L2_gate == 2) + { + leakage_L2 = cmos_Isub_leakage(w_L2_n[0], w_L2_p[0], 2, nand, is_dram_); + gate_leakage_L2 = cmos_Ig_leakage(w_L2_n[0], w_L2_p[0], 2, nand, is_dram_); + } + else if (flag_L2_gate == 3) + { + leakage_L2 = cmos_Isub_leakage(w_L2_n[0], w_L2_p[0], 3, nand, is_dram_); + gate_leakage_L2 = cmos_Ig_leakage(w_L2_n[0], w_L2_p[0], 3, nand, is_dram_); + } + + for (int i = 1; i < number_gates_L2; ++i) + { + leakage_L2 += cmos_Isub_leakage(w_L2_n[i], w_L2_p[i], 2, inv, is_dram_); + gate_leakage_L2 += cmos_Ig_leakage(w_L2_n[i], w_L2_p[i], 2, inv, is_dram_); + } + leakage_L2 *= num_L2; + gate_leakage_L2 *= num_L2; + + power_nand2_path.readOp.leakage = leak_L1_nand2 * g_tp.peri_global.Vdd; + power_nand3_path.readOp.leakage = leak_L1_nand3 * g_tp.peri_global.Vdd; + power_L2.readOp.leakage = leakage_L2 * g_tp.peri_global.Vdd; + + power_nand2_path.readOp.gate_leakage = gate_leak_L1_nand2 * g_tp.peri_global.Vdd; + power_nand3_path.readOp.gate_leakage = gate_leak_L1_nand3 * g_tp.peri_global.Vdd; + power_L2.readOp.gate_leakage = gate_leakage_L2 * g_tp.peri_global.Vdd; + } +} + +PredecBlkDrv::PredecBlkDrv( + int way_select_, + PredecBlk * blk_, + bool is_dram) + :flag_driver_exists(0), + number_gates_nand2_path(0), + number_gates_nand3_path(0), + min_number_gates(2), + num_buffers_driving_1_nand2_load(0), + num_buffers_driving_2_nand2_load(0), + num_buffers_driving_4_nand2_load(0), + num_buffers_driving_2_nand3_load(0), + num_buffers_driving_8_nand3_load(0), + num_buffers_nand3_path(0), + c_load_nand2_path_out(0), + c_load_nand3_path_out(0), + r_load_nand2_path_out(0), + r_load_nand3_path_out(0), + delay_nand2_path(0), + delay_nand3_path(0), + power_nand2_path(), + power_nand3_path(), + blk(blk_), dec(blk->dec), + is_dram_(is_dram), + way_select(way_select_) +{ + for (int i = 0; i < MAX_NUMBER_GATES_STAGE; i++) + { + width_nand2_path_n[i] = 0; + width_nand2_path_p[i] = 0; + width_nand3_path_n[i] = 0; + width_nand3_path_p[i] = 0; + } + + number_input_addr_bits = blk->number_input_addr_bits; + + if (way_select > 1) + { + flag_driver_exists = 1; + number_input_addr_bits = way_select; + if (dec->num_in_signals == 2) + { + c_load_nand2_path_out = gate_C(dec->w_dec_n[0] + dec->w_dec_p[0], 0, is_dram_); + num_buffers_driving_2_nand2_load = number_input_addr_bits; + } + else if (dec->num_in_signals == 3) + { + c_load_nand3_path_out = gate_C(dec->w_dec_n[0] + dec->w_dec_p[0], 0, is_dram_); + num_buffers_driving_2_nand3_load = number_input_addr_bits; + } + } + else if (way_select == 0) + { + if (blk->exist) + { + flag_driver_exists = 1; + } + } + + compute_widths(); + compute_area(); +} + + + +void PredecBlkDrv::compute_widths() +{ + // The predecode block driver accepts as input the address bits from the h-tree network. For + // each addr bit it then generates addr and addrbar as outputs. For now ignore the effect of + // inversion to generate addrbar and simply treat addrbar as addr. + + double F; + double p_to_n_sz_ratio = pmos_to_nmos_sz_ratio(is_dram_); + + if (flag_driver_exists) + { + double C_nand2_gate_blk = gate_C(blk->w_L1_nand2_n[0] + blk->w_L1_nand2_p[0], 0, is_dram_); + double C_nand3_gate_blk = gate_C(blk->w_L1_nand3_n[0] + blk->w_L1_nand3_p[0], 0, is_dram_); + + if (way_select == 0) + { + if (blk->number_input_addr_bits == 1) + { //2 NAND2 gates + num_buffers_driving_2_nand2_load = 1; + c_load_nand2_path_out = 2 * C_nand2_gate_blk; + } + else if (blk->number_input_addr_bits == 2) + { //4 NAND2 gates one 2-4 decoder + num_buffers_driving_4_nand2_load = 2; + c_load_nand2_path_out = 4 * C_nand2_gate_blk; + } + else if (blk->number_input_addr_bits == 3) + { //8 NAND3 gates one 3-8 decoder + num_buffers_driving_8_nand3_load = 3; + c_load_nand3_path_out = 8 * C_nand3_gate_blk; + } + else if (blk->number_input_addr_bits == 4) + { //4 + 4 NAND2 gates two 2-4 decoder + num_buffers_driving_4_nand2_load = 4; + c_load_nand2_path_out = 4 * C_nand2_gate_blk; + } + else if (blk->number_input_addr_bits == 5) + { //4 NAND2 gates, 8 NAND3 gates one 2-4 decoder and one 3-8 decoder + num_buffers_driving_4_nand2_load = 2; + num_buffers_driving_8_nand3_load = 3; + c_load_nand2_path_out = 4 * C_nand2_gate_blk; + c_load_nand3_path_out = 8 * C_nand3_gate_blk; + } + else if (blk->number_input_addr_bits == 6) + { //8 + 8 NAND3 gates two 3-8 decoder + num_buffers_driving_8_nand3_load = 6; + c_load_nand3_path_out = 8 * C_nand3_gate_blk; + } + else if (blk->number_input_addr_bits == 7) + { //4 + 4 NAND2 gates, 8 NAND3 gates two 2-4 decoder and one 3-8 decoder + num_buffers_driving_4_nand2_load = 4; + num_buffers_driving_8_nand3_load = 3; + c_load_nand2_path_out = 4 * C_nand2_gate_blk; + c_load_nand3_path_out = 8 * C_nand3_gate_blk; + } + else if (blk->number_input_addr_bits == 8) + { //4 NAND2 gates, 8 + 8 NAND3 gates one 2-4 decoder and two 3-8 decoder + num_buffers_driving_4_nand2_load = 2; + num_buffers_driving_8_nand3_load = 6; + c_load_nand2_path_out = 4 * C_nand2_gate_blk; + c_load_nand3_path_out = 8 * C_nand3_gate_blk; + } + else if (blk->number_input_addr_bits == 9) + { //8 + 8 + 8 NAND3 gates three 3-8 decoder + num_buffers_driving_8_nand3_load = 9; + c_load_nand3_path_out = 8 * C_nand3_gate_blk; + } + } + + if ((blk->flag_two_unique_paths) || + (blk->number_inputs_L1_gate == 2) || + (number_input_addr_bits == 0) || + ((way_select)&&(dec->num_in_signals == 2))) + { //this means that way_select is driving NAND2 in decoder. + width_nand2_path_n[0] = g_tp.min_w_nmos_; + width_nand2_path_p[0] = p_to_n_sz_ratio * width_nand2_path_n[0]; + F = c_load_nand2_path_out / gate_C(width_nand2_path_n[0] + width_nand2_path_p[0], 0, is_dram_); + number_gates_nand2_path = logical_effort( + min_number_gates, + 1, + F, + width_nand2_path_n, + width_nand2_path_p, + c_load_nand2_path_out, + p_to_n_sz_ratio, + is_dram_, false, g_tp.max_w_nmos_); + } + + if ((blk->flag_two_unique_paths) || + (blk->number_inputs_L1_gate == 3) || + ((way_select)&&(dec->num_in_signals == 3))) + { //this means that way_select is driving NAND3 in decoder. + width_nand3_path_n[0] = g_tp.min_w_nmos_; + width_nand3_path_p[0] = p_to_n_sz_ratio * width_nand3_path_n[0]; + F = c_load_nand3_path_out / gate_C(width_nand3_path_n[0] + width_nand3_path_p[0], 0, is_dram_); + number_gates_nand3_path = logical_effort( + min_number_gates, + 1, + F, + width_nand3_path_n, + width_nand3_path_p, + c_load_nand3_path_out, + p_to_n_sz_ratio, + is_dram_, false, g_tp.max_w_nmos_); + } + } +} + + + +void PredecBlkDrv::compute_area() +{ + double area_nand2_path = 0; + double area_nand3_path = 0; + double leak_nand2_path = 0; + double leak_nand3_path = 0; + double gate_leak_nand2_path = 0; + double gate_leak_nand3_path = 0; + + if (flag_driver_exists) + { // first check whether a predecoder block driver is needed + for (int i = 0; i < number_gates_nand2_path; ++i) + { + area_nand2_path += compute_gate_area(INV, 1, width_nand2_path_p[i], width_nand2_path_n[i], g_tp.cell_h_def); + leak_nand2_path += cmos_Isub_leakage(width_nand2_path_n[i], width_nand2_path_p[i], 1, inv,is_dram_); + gate_leak_nand2_path += cmos_Ig_leakage(width_nand2_path_n[i], width_nand2_path_p[i], 1, inv,is_dram_); + } + area_nand2_path *= (num_buffers_driving_1_nand2_load + + num_buffers_driving_2_nand2_load + + num_buffers_driving_4_nand2_load); + leak_nand2_path *= (num_buffers_driving_1_nand2_load + + num_buffers_driving_2_nand2_load + + num_buffers_driving_4_nand2_load); + gate_leak_nand2_path *= (num_buffers_driving_1_nand2_load + + num_buffers_driving_2_nand2_load + + num_buffers_driving_4_nand2_load); + + for (int i = 0; i < number_gates_nand3_path; ++i) + { + area_nand3_path += compute_gate_area(INV, 1, width_nand3_path_p[i], width_nand3_path_n[i], g_tp.cell_h_def); + leak_nand3_path += cmos_Isub_leakage(width_nand3_path_n[i], width_nand3_path_p[i], 1, inv,is_dram_); + gate_leak_nand3_path += cmos_Ig_leakage(width_nand3_path_n[i], width_nand3_path_p[i], 1, inv,is_dram_); + } + area_nand3_path *= (num_buffers_driving_2_nand3_load + num_buffers_driving_8_nand3_load); + leak_nand3_path *= (num_buffers_driving_2_nand3_load + num_buffers_driving_8_nand3_load); + gate_leak_nand3_path *= (num_buffers_driving_2_nand3_load + num_buffers_driving_8_nand3_load); + + power_nand2_path.readOp.leakage = leak_nand2_path * g_tp.peri_global.Vdd; + power_nand3_path.readOp.leakage = leak_nand3_path * g_tp.peri_global.Vdd; + power_nand2_path.readOp.gate_leakage = gate_leak_nand2_path * g_tp.peri_global.Vdd; + power_nand3_path.readOp.gate_leakage = gate_leak_nand3_path * g_tp.peri_global.Vdd; + area.set_area(area_nand2_path + area_nand3_path); + } +} + + + +pair PredecBlkDrv::compute_delays( + double inrisetime_nand2_path, + double inrisetime_nand3_path) +{ + pair ret_val; + ret_val.first = 0; // outrisetime_nand2_path + ret_val.second = 0; // outrisetime_nand3_path + int i; + double rd, c_gate_load, c_load, c_intrinsic, tf, this_delay; + double Vdd = g_tp.peri_global.Vdd; + + if (flag_driver_exists) + { + for (i = 0; i < number_gates_nand2_path - 1; ++i) + { + rd = tr_R_on(width_nand2_path_n[i], NCH, 1, is_dram_); + c_gate_load = gate_C(width_nand2_path_p[i+1] + width_nand2_path_n[i+1], 0.0, is_dram_); + c_intrinsic = drain_C_(width_nand2_path_p[i], PCH, 1, 1, g_tp.cell_h_def, is_dram_) + + drain_C_(width_nand2_path_n[i], NCH, 1, 1, g_tp.cell_h_def, is_dram_); + tf = rd * (c_intrinsic + c_gate_load); + this_delay = horowitz(inrisetime_nand2_path, tf, 0.5, 0.5, RISE); + delay_nand2_path += this_delay; + inrisetime_nand2_path = this_delay / (1.0 - 0.5); + power_nand2_path.readOp.dynamic += (c_gate_load + c_intrinsic) * 0.5 * Vdd * Vdd; + } + + // Final inverter drives the predecoder block or the decoder output load + if (number_gates_nand2_path != 0) + { + i = number_gates_nand2_path - 1; + rd = tr_R_on(width_nand2_path_n[i], NCH, 1, is_dram_); + c_intrinsic = drain_C_(width_nand2_path_p[i], PCH, 1, 1, g_tp.cell_h_def, is_dram_) + + drain_C_(width_nand2_path_n[i], NCH, 1, 1, g_tp.cell_h_def, is_dram_); + c_load = c_load_nand2_path_out; + tf = rd * (c_intrinsic + c_load) + r_load_nand2_path_out*c_load/ 2; + this_delay = horowitz(inrisetime_nand2_path, tf, 0.5, 0.5, RISE); + delay_nand2_path += this_delay; + ret_val.first = this_delay / (1.0 - 0.5); + power_nand2_path.readOp.dynamic += (c_intrinsic + c_load) * 0.5 * Vdd * Vdd; +// cout<< "c_intrinsic = " << c_intrinsic << "c_load" << c_load <blk), blk2(drv2_->blk), drv1(drv1_), drv2(drv2_) +{ + driver_power.readOp.leakage = drv1->power_nand2_path.readOp.leakage + + drv1->power_nand3_path.readOp.leakage + + drv2->power_nand2_path.readOp.leakage + + drv2->power_nand3_path.readOp.leakage; + block_power.readOp.leakage = blk1->power_nand2_path.readOp.leakage + + blk1->power_nand3_path.readOp.leakage + + blk1->power_L2.readOp.leakage + + blk2->power_nand2_path.readOp.leakage + + blk2->power_nand3_path.readOp.leakage + + blk2->power_L2.readOp.leakage; + power.readOp.leakage = driver_power.readOp.leakage + block_power.readOp.leakage; + + driver_power.readOp.gate_leakage = drv1->power_nand2_path.readOp.gate_leakage + + drv1->power_nand3_path.readOp.gate_leakage + + drv2->power_nand2_path.readOp.gate_leakage + + drv2->power_nand3_path.readOp.gate_leakage; + block_power.readOp.gate_leakage = blk1->power_nand2_path.readOp.gate_leakage + + blk1->power_nand3_path.readOp.gate_leakage + + blk1->power_L2.readOp.gate_leakage + + blk2->power_nand2_path.readOp.gate_leakage + + blk2->power_nand3_path.readOp.gate_leakage + + blk2->power_L2.readOp.gate_leakage; + power.readOp.gate_leakage = driver_power.readOp.gate_leakage + block_power.readOp.gate_leakage; +} + +void PredecBlkDrv::leakage_feedback(double temperature) +{ + double leak_nand2_path = 0; + double leak_nand3_path = 0; + double gate_leak_nand2_path = 0; + double gate_leak_nand3_path = 0; + + if (flag_driver_exists) + { // first check whether a predecoder block driver is needed + for (int i = 0; i < number_gates_nand2_path; ++i) + { + leak_nand2_path += cmos_Isub_leakage(width_nand2_path_n[i], width_nand2_path_p[i], 1, inv,is_dram_); + gate_leak_nand2_path += cmos_Ig_leakage(width_nand2_path_n[i], width_nand2_path_p[i], 1, inv,is_dram_); + } + leak_nand2_path *= (num_buffers_driving_1_nand2_load + + num_buffers_driving_2_nand2_load + + num_buffers_driving_4_nand2_load); + gate_leak_nand2_path *= (num_buffers_driving_1_nand2_load + + num_buffers_driving_2_nand2_load + + num_buffers_driving_4_nand2_load); + + for (int i = 0; i < number_gates_nand3_path; ++i) + { + leak_nand3_path += cmos_Isub_leakage(width_nand3_path_n[i], width_nand3_path_p[i], 1, inv,is_dram_); + gate_leak_nand3_path += cmos_Ig_leakage(width_nand3_path_n[i], width_nand3_path_p[i], 1, inv,is_dram_); + } + leak_nand3_path *= (num_buffers_driving_2_nand3_load + num_buffers_driving_8_nand3_load); + gate_leak_nand3_path *= (num_buffers_driving_2_nand3_load + num_buffers_driving_8_nand3_load); + + power_nand2_path.readOp.leakage = leak_nand2_path * g_tp.peri_global.Vdd; + power_nand3_path.readOp.leakage = leak_nand3_path * g_tp.peri_global.Vdd; + power_nand2_path.readOp.gate_leakage = gate_leak_nand2_path * g_tp.peri_global.Vdd; + power_nand3_path.readOp.gate_leakage = gate_leak_nand3_path * g_tp.peri_global.Vdd; + } +} + +double Predec::compute_delays(double inrisetime) +{ + // TODO: Jung Ho thinks that predecoder block driver locates between decoder and predecoder block. + pair tmp_pair1, tmp_pair2; + tmp_pair1 = drv1->compute_delays(inrisetime, inrisetime); + tmp_pair1 = blk1->compute_delays(tmp_pair1); + tmp_pair2 = drv2->compute_delays(inrisetime, inrisetime); + tmp_pair2 = blk2->compute_delays(tmp_pair2); + tmp_pair1 = get_max_delay_before_decoder(tmp_pair1, tmp_pair2); + + driver_power.readOp.dynamic = + drv1->num_addr_bits_nand2_path() * drv1->power_nand2_path.readOp.dynamic + + drv1->num_addr_bits_nand3_path() * drv1->power_nand3_path.readOp.dynamic + + drv2->num_addr_bits_nand2_path() * drv2->power_nand2_path.readOp.dynamic + + drv2->num_addr_bits_nand3_path() * drv2->power_nand3_path.readOp.dynamic; + + block_power.readOp.dynamic = + blk1->power_nand2_path.readOp.dynamic*blk1->num_L1_active_nand2_path + + blk1->power_nand3_path.readOp.dynamic*blk1->num_L1_active_nand3_path + + blk1->power_L2.readOp.dynamic + + blk2->power_nand2_path.readOp.dynamic*blk1->num_L1_active_nand2_path + + blk2->power_nand3_path.readOp.dynamic*blk1->num_L1_active_nand3_path + + blk2->power_L2.readOp.dynamic; + + power.readOp.dynamic = driver_power.readOp.dynamic + block_power.readOp.dynamic; + + delay = tmp_pair1.first; + return tmp_pair1.second; +} + + +void Predec::leakage_feedback(double temperature) +{ + drv1->leakage_feedback(temperature); + drv2->leakage_feedback(temperature); + blk1->leakage_feedback(temperature); + blk2->leakage_feedback(temperature); + + driver_power.readOp.leakage = drv1->power_nand2_path.readOp.leakage + + drv1->power_nand3_path.readOp.leakage + + drv2->power_nand2_path.readOp.leakage + + drv2->power_nand3_path.readOp.leakage; + block_power.readOp.leakage = blk1->power_nand2_path.readOp.leakage + + blk1->power_nand3_path.readOp.leakage + + blk1->power_L2.readOp.leakage + + blk2->power_nand2_path.readOp.leakage + + blk2->power_nand3_path.readOp.leakage + + blk2->power_L2.readOp.leakage; + power.readOp.leakage = driver_power.readOp.leakage + block_power.readOp.leakage; + + driver_power.readOp.gate_leakage = drv1->power_nand2_path.readOp.gate_leakage + + drv1->power_nand3_path.readOp.gate_leakage + + drv2->power_nand2_path.readOp.gate_leakage + + drv2->power_nand3_path.readOp.gate_leakage; + block_power.readOp.gate_leakage = blk1->power_nand2_path.readOp.gate_leakage + + blk1->power_nand3_path.readOp.gate_leakage + + blk1->power_L2.readOp.gate_leakage + + blk2->power_nand2_path.readOp.gate_leakage + + blk2->power_nand3_path.readOp.gate_leakage + + blk2->power_L2.readOp.gate_leakage; + power.readOp.gate_leakage = driver_power.readOp.gate_leakage + block_power.readOp.gate_leakage; +} + +// returns +pair Predec::get_max_delay_before_decoder( + pair input_pair1, + pair input_pair2) +{ + pair ret_val; + double delay; + + delay = drv1->delay_nand2_path + blk1->delay_nand2_path; + ret_val.first = delay; + ret_val.second = input_pair1.first; + delay = drv1->delay_nand3_path + blk1->delay_nand3_path; + if (ret_val.first < delay) + { + ret_val.first = delay; + ret_val.second = input_pair1.second; + } + delay = drv2->delay_nand2_path + blk2->delay_nand2_path; + if (ret_val.first < delay) + { + ret_val.first = delay; + ret_val.second = input_pair2.first; + } + delay = drv2->delay_nand3_path + blk2->delay_nand3_path; + if (ret_val.first < delay) + { + ret_val.first = delay; + ret_val.second = input_pair2.second; + } + + return ret_val; +} + + + +Driver::Driver(double c_gate_load_, double c_wire_load_, double r_wire_load_, bool is_dram) +:number_gates(0), + min_number_gates(2), + c_gate_load(c_gate_load_), + c_wire_load(c_wire_load_), + r_wire_load(r_wire_load_), + delay(0), +// power(), + is_dram_(is_dram), + total_driver_nwidth(0), + total_driver_pwidth(0) +{ + for (int i = 0; i < MAX_NUMBER_GATES_STAGE; i++) + { + width_n[i] = 0; + width_p[i] = 0; + } + + compute_widths(); + compute_area(); +} + + +void Driver::compute_widths() +{ + double p_to_n_sz_ratio = pmos_to_nmos_sz_ratio(is_dram_); + double c_load = c_gate_load + c_wire_load; + width_n[0] = g_tp.min_w_nmos_; + width_p[0] = p_to_n_sz_ratio * g_tp.min_w_nmos_; + + double F = c_load / gate_C(width_n[0] + width_p[0], 0, is_dram_); + number_gates = logical_effort( + min_number_gates, + 1, + F, + width_n, + width_p, + c_load, + p_to_n_sz_ratio, + is_dram_, false, + g_tp.max_w_nmos_); +} + +void Driver::compute_area() +{ + double cumulative_area = 0; + ///double cumulative_curr = 0; // cumulative leakage current + ///double cumulative_curr_Ig = 0; // cumulative leakage current + area.h = g_tp.cell_h_def; + for (int i = 0; i < number_gates; i++) + { + cumulative_area += compute_gate_area(INV, 1, width_p[i], width_n[i], area.h); + ///cumulative_curr += cmos_Isub_leakage(width_n[i], width_p[i], 1, inv, is_dram_); + ///cumulative_curr_Ig = cmos_Ig_leakage(width_n[i], width_p[i], 1, inv, is_dram_); + + } + area.w = (cumulative_area / area.h); +} + +void Driver::compute_power_gating() +{ + //For all driver change there is only one sleep transistors to save area + //Total transistor width for sleep tx calculation + for (int i = 0; i <=number_gates; i++) + { + total_driver_nwidth += width_n[i]; + total_driver_pwidth += width_p[i]; + } + + //compute sleep tx + bool is_footer = false; + double Isat_subarray = simplified_nmos_Isat(total_driver_nwidth); + double detalV; + double c_wakeup; + + c_wakeup = drain_C_(total_driver_pwidth, PCH, 1, 1, area.h);//Psleep tx + detalV = g_tp.peri_global.Vdd-g_tp.peri_global.Vcc_min; + if (g_ip->power_gating) + sleeptx = new Sleep_tx (g_ip->perfloss, + Isat_subarray, + is_footer, + c_wakeup, + detalV, + 1, + area); +} + + +double Driver::compute_delay(double inrisetime) +{ + int i; + double rd, c_load, c_intrinsic, tf; + double this_delay = 0; + + for (i = 0; i < number_gates - 1; ++i) + { + rd = tr_R_on(width_n[i], NCH, 1, is_dram_); + c_load = gate_C(width_n[i+1] + width_p[i+1], 0.0, is_dram_); + c_intrinsic = drain_C_(width_p[i], PCH, 1, 1, g_tp.cell_h_def, is_dram_) + + drain_C_(width_n[i], NCH, 1, 1, g_tp.cell_h_def, is_dram_); + tf = rd * (c_intrinsic + c_load); + this_delay = horowitz(inrisetime, tf, 0.5, 0.5, RISE); + delay += this_delay; + inrisetime = this_delay / (1.0 - 0.5); + power.readOp.dynamic += (c_intrinsic + c_load) * g_tp.peri_global.Vdd * g_tp.peri_global.Vdd; + power.readOp.leakage += cmos_Isub_leakage(width_n[i], width_p[i], 1, inv, is_dram_) *g_tp.peri_global.Vdd; + power.readOp.gate_leakage += cmos_Ig_leakage(width_n[i], width_p[i], 1, inv, is_dram_)* g_tp.peri_global.Vdd; + } + + i = number_gates - 1; + c_load = c_gate_load + c_wire_load; + rd = tr_R_on(width_n[i], NCH, 1, is_dram_); + c_intrinsic = drain_C_(width_p[i], PCH, 1, 1, g_tp.cell_h_def, is_dram_) + + drain_C_(width_n[i], NCH, 1, 1, g_tp.cell_h_def, is_dram_); + tf = rd * (c_intrinsic + c_load) + r_wire_load * (c_wire_load / 2 + c_gate_load); + this_delay = horowitz(inrisetime, tf, 0.5, 0.5, RISE); + delay += this_delay; + power.readOp.dynamic += (c_intrinsic + c_load) * g_tp.peri_global.Vdd * g_tp.peri_global.Vdd; + power.readOp.leakage += cmos_Isub_leakage(width_n[i], width_p[i], 1, inv, is_dram_) * g_tp.peri_global.Vdd; + power.readOp.gate_leakage += cmos_Ig_leakage(width_n[i], width_p[i], 1, inv, is_dram_)* g_tp.peri_global.Vdd; + + return this_delay / (1.0 - 0.5); +} + +/* +void Driver::compute_area() +{ + double cumulative_area = 0; + double cumulative_curr = 0; // cumulative leakage current + double cumulative_curr_Ig = 0; // cumulative leakage current + + area.h = g_tp.h_dec * g_tp.dram.b_h; + for (int i = 1; i < number_gates; i++) + { + cumulative_area += compute_gate_area(INV, 1, width_p[i], width_n[i], area.h); + cumulative_curr += cmos_Isub_leakage(width_n[i], width_p[i], 1, inv, is_dram_); + cumulative_curr_Ig = cmos_Ig_leakage(width_n[i], width_p[i], 1, inv, is_dram_); + } + area.w = (cumulative_area / area.h); + +} +*/ diff --git a/Project_FARSI/cacti_for_FARSI/decoder.h b/Project_FARSI/cacti_for_FARSI/decoder.h new file mode 100644 index 00000000..bd74c644 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/decoder.h @@ -0,0 +1,272 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + +#ifndef __DECODER_H__ +#define __DECODER_H__ + +#include "area.h" +#include "component.h" +#include "parameter.h" +#include "powergating.h" +#include + +using namespace std; + + +class Decoder : public Component +{ + public: + Decoder( + int _num_dec_signals, + bool flag_way_select, + double _C_ld_dec_out, + double _R_wire_dec_out, + bool fully_assoc_, + bool is_dram_, + bool is_wl_tr_, + const Area & cell_); + + bool exist; + int num_in_signals; + double C_ld_dec_out; + double R_wire_dec_out; + int num_gates; + int num_gates_min; + double w_dec_n[MAX_NUMBER_GATES_STAGE]; + double w_dec_p[MAX_NUMBER_GATES_STAGE]; + double delay; + //powerDef power; + bool fully_assoc; + bool is_dram; + bool is_wl_tr; + + double total_driver_nwidth; + double total_driver_pwidth; + Sleep_tx * sleeptx; + + const Area & cell; + int nodes_DSTN; + + void compute_widths(); + void compute_area(); + double compute_delays(double inrisetime); // return outrisetime + void compute_power_gating(); + + void leakage_feedback(double temperature); + + ~Decoder() + { + if (!sleeptx) + delete sleeptx; + }; +}; + + + +class PredecBlk : public Component +{ + public: + PredecBlk( + int num_dec_signals, + Decoder * dec, + double C_wire_predec_blk_out, + double R_wire_predec_blk_out, + int num_dec_per_predec, + bool is_dram_, + bool is_blk1); + + Decoder * dec; + bool exist; + int number_input_addr_bits; + double C_ld_predec_blk_out; + double R_wire_predec_blk_out; + int branch_effort_nand2_gate_output; + int branch_effort_nand3_gate_output; + bool flag_two_unique_paths; + int flag_L2_gate; + int number_inputs_L1_gate; + int number_gates_L1_nand2_path; + int number_gates_L1_nand3_path; + int number_gates_L2; + int min_number_gates_L1; + int min_number_gates_L2; + int num_L1_active_nand2_path; + int num_L1_active_nand3_path; + double w_L1_nand2_n[MAX_NUMBER_GATES_STAGE]; + double w_L1_nand2_p[MAX_NUMBER_GATES_STAGE]; + double w_L1_nand3_n[MAX_NUMBER_GATES_STAGE]; + double w_L1_nand3_p[MAX_NUMBER_GATES_STAGE]; + double w_L2_n[MAX_NUMBER_GATES_STAGE]; + double w_L2_p[MAX_NUMBER_GATES_STAGE]; + double delay_nand2_path; + double delay_nand3_path; + powerDef power_nand2_path; + powerDef power_nand3_path; + powerDef power_L2; + + bool is_dram_; + + void compute_widths(); + void compute_area(); + + void leakage_feedback(double temperature); + + pair compute_delays(pair inrisetime); // + // return +}; + + +class PredecBlkDrv : public Component +{ + public: + PredecBlkDrv( + int way_select, + PredecBlk * blk_, + bool is_dram); + + int flag_driver_exists; + int number_input_addr_bits; + int number_gates_nand2_path; + int number_gates_nand3_path; + int min_number_gates; + int num_buffers_driving_1_nand2_load; + int num_buffers_driving_2_nand2_load; + int num_buffers_driving_4_nand2_load; + int num_buffers_driving_2_nand3_load; + int num_buffers_driving_8_nand3_load; + int num_buffers_nand3_path; + double c_load_nand2_path_out; + double c_load_nand3_path_out; + double r_load_nand2_path_out; + double r_load_nand3_path_out; + double width_nand2_path_n[MAX_NUMBER_GATES_STAGE]; + double width_nand2_path_p[MAX_NUMBER_GATES_STAGE]; + double width_nand3_path_n[MAX_NUMBER_GATES_STAGE]; + double width_nand3_path_p[MAX_NUMBER_GATES_STAGE]; + double delay_nand2_path; + double delay_nand3_path; + powerDef power_nand2_path; + powerDef power_nand3_path; + + PredecBlk * blk; + Decoder * dec; + bool is_dram_; + int way_select; + + void compute_widths(); + void compute_area(); + + void leakage_feedback(double temperature); + + + pair compute_delays( + double inrisetime_nand2_path, + double inrisetime_nand3_path); // return + + inline int num_addr_bits_nand2_path() + { + return num_buffers_driving_1_nand2_load + + num_buffers_driving_2_nand2_load + + num_buffers_driving_4_nand2_load; + } + inline int num_addr_bits_nand3_path() + { + return num_buffers_driving_2_nand3_load + + num_buffers_driving_8_nand3_load; + } + double get_rdOp_dynamic_E(int num_act_mats_hor_dir); +}; + + + +class Predec : public Component +{ + public: + Predec( + PredecBlkDrv * drv1, + PredecBlkDrv * drv2); + + double compute_delays(double inrisetime); // return outrisetime + + void leakage_feedback(double temperature); + PredecBlk * blk1; + PredecBlk * blk2; + PredecBlkDrv * drv1; + PredecBlkDrv * drv2; + + powerDef block_power; + powerDef driver_power; + + private: + // returns + pair get_max_delay_before_decoder( + pair input_pair1, + pair input_pair2); +}; + + + +class Driver : public Component +{ + public: + Driver(double c_gate_load_, double c_wire_load_, double r_wire_load_, bool is_dram); + + int number_gates; + int min_number_gates; + double width_n[MAX_NUMBER_GATES_STAGE]; + double width_p[MAX_NUMBER_GATES_STAGE]; + double c_gate_load; + double c_wire_load; + double r_wire_load; + double delay; +// powerDef power; + bool is_dram_; + + double total_driver_nwidth; + double total_driver_pwidth; + Sleep_tx * sleeptx; + + void compute_widths(); + void compute_area(); + double compute_delay(double inrisetime); + + void compute_power_gating(); + + ~Driver() + { + if (!sleeptx) + delete sleeptx; + }; +}; + + +#endif diff --git a/Project_FARSI/cacti_for_FARSI/dram.cfg b/Project_FARSI/cacti_for_FARSI/dram.cfg new file mode 100644 index 00000000..f55b5b37 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/dram.cfg @@ -0,0 +1,114 @@ +//-size (bytes) 16777216 +//-size (bytes) 33554432 +-size (bytes) 134217728 +//-size (bytes) 67108864 +//-size (bytes) 1073741824 + +-block size (bytes) 64 +-associativity 1 +-read-write port 1 +-exclusive read port 0 +-exclusive write port 0 +-single ended read ports 0 +-UCA bank count 1 +//-technology (u) 0.032 +//-technology (u) 0.045 +-technology (u) 0.068 +//-technology (u) 0.078 + +# following three parameters are meaningful only for main memories +-page size (bits) 8192 +-burst length 8 +-internal prefetch width 8 + +# following parameter can have one of the five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +-Data array cell type - "comm-dram" + +# following parameter can have one of the three values -- (itrs-hp, itrs-lstp, itrs-lop) +-Data array peripheral type - "itrs-hp" + +# following parameter can have one of the five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +-Tag array cell type - "itrs-hp" + +# following parameter can have one of the three values -- (itrs-hp, itrs-lstp, itrs-lop) +-Tag array peripheral type - "itrs-hp" + +# Bus width include data bits and address bits required by the decoder +//-output/input bus width 512 +-output/input bus width 64 + +-operating temperature (K) 350 + +-cache type "main memory" + +# to model special structure like branch target buffers, directory, etc. +# change the tag size parameter +# if you want cacti to calculate the tagbits, set the tag size to "default" +-tag size (b) "default" +//-tag size (b) 45 + +# fast - data and tag access happen in parallel +# sequential - data array is accessed after accessing the tag array +# normal - data array lookup and tag access happen in parallel +# final data block is broadcasted in data array h-tree +# after getting the signal from the tag array +//-access mode (normal, sequential, fast) - "fast" +-access mode (normal, sequential, fast) - "normal" +//-access mode (normal, sequential, fast) - "sequential" + +# DESIGN OBJECTIVE for UCA (or banks in NUCA) +//-design objective (weight delay, dynamic power, leakage power, cycle time, area) 100:100:0:0:0 +-design objective (weight delay, dynamic power, leakage power, cycle time, area) 0:0:0:100:0 +-deviate (delay, dynamic power, leakage power, cycle time, area) 20:100000:100000:100000:1000000 +//-deviate (delay, dynamic power, leakage power, cycle time, area) 200:100000:100000:100000:20 + +-Optimize ED or ED^2 (ED, ED^2, NONE): "NONE" + +-Cache model (NUCA, UCA) - "UCA" + +//-Wire signalling (fullswing, lowswing, default) - "default" +-Wire signalling (fullswing, lowswing, default) - "Global_10" + +-Wire inside mat - "global" +//-Wire inside mat - "semi-global" +-Wire outside mat - "global" + +-Interconnect projection - "conservative" +//-Interconnect projection - "aggressive" + +-Add ECC - "true" + +-Print level (DETAILED, CONCISE) - "DETAILED" + +# for debugging +-Print input parameters - "true" +# force CACTI to model the cache with the +# following Ndbl, Ndwl, Nspd, Ndsam, +# and Ndcm values +//-Force cache config - "true" +-Force cache config - "false" +-Ndwl 1 +-Ndbl 1 +-Nspd 0 +-Ndcm 1 +-Ndsam1 0 +-Ndsam2 0 + +########### NUCA Params ############ + +# Objective for NUCA +-NUCAdesign objective (weight delay, dynamic power, leakage power, cycle time, area) 100:100:0:0:100 +-NUCAdeviate (delay, dynamic power, leakage power, cycle time, area) 10:10000:10000:10000:10000 + +# Contention in network (which is a function of core count and cache level) is one of +# the critical factor used for deciding the optimal bank count value +# core count can be 4, 8, or 16 +//-Core count 4 +-Core count 8 +//-Core count 16 +-Cache level (L2/L3) - "L3" + +# In order for CACTI to find the optimal NUCA bank value the following +# variable should be assigned 0. +-NUCA bank count 0 + diff --git a/Project_FARSI/cacti_for_FARSI/extio.cc b/Project_FARSI/cacti_for_FARSI/extio.cc new file mode 100644 index 00000000..f09be49f --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/extio.cc @@ -0,0 +1,506 @@ +#include "extio.h" +#include + + +Extio::Extio(IOTechParam *iot): +io_param(iot){} + + +//External IO AREA. Does not include PHY or decap, includes only IO active circuit. More details can be found in the CACTI-IO technical report (), Chapter 2.3. + +void Extio::extio_area() +{ + + //Area per IO, assuming drive stage and ODT are shared + double single_io_area = io_param->ioarea_c + + (io_param->ioarea_k0/io_param->r_on)+(1/io_param->r_on)* + (io_param->ioarea_k1*io_param->frequency + + io_param->ioarea_k2*io_param->frequency*io_param->frequency + + io_param->ioarea_k3*io_param->frequency* + io_param->frequency*io_param->frequency); // IO Area in sq.mm. + + //Area per IO if ODT requirements are more stringent than the Ron + //requirements in determining size of driver + if (2*io_param->rtt1_dq_read < io_param->r_on) { + single_io_area = io_param->ioarea_c + + (io_param->ioarea_k0/(2*io_param->rtt1_dq_read))+ + (1/io_param->r_on)*(io_param->ioarea_k1*io_param->frequency + + io_param->ioarea_k2*io_param->frequency*io_param->frequency + + io_param->ioarea_k3*io_param->frequency*io_param->frequency*io_param->frequency); + } + + //Total IO area + io_area = (g_ip->num_dq + g_ip->num_dqs + g_ip->num_ca + g_ip->num_clk) * + single_io_area; + + //printf("IO Area (sq.mm) = "); + //cout << io_area << endl; + +} + +//External IO Termination Power. More details can be found in the CACTI-IO technical report (), Chapter 2.1. + +void Extio::extio_power_term() +{ + + //IO Termination and Bias Power + + //Bias and Leakage Power + power_bias = io_param->i_bias * io_param->vdd_io + + io_param->i_leak * (g_ip->num_dq + + g_ip->num_dqs + + g_ip->num_clk + + g_ip->num_ca) * io_param->vdd_io/1000000; + + + //Termination Power + power_termination_read = 1000 * (g_ip->num_dq + g_ip->num_dqs) * + io_param->vdd_io * io_param->vdd_io * 0.25 * + (1/(io_param->r_on + io_param->rpar_read + io_param->rs1_dq) + + 1/(io_param->rtt1_dq_read) + 1/(io_param->rtt2_dq_read)) + + 1000 * g_ip->num_ca * io_param->vdd_io * io_param->vdd_io * + (0.5 / (2 * (io_param->r_on_ca + io_param->rtt_ca))); + + power_termination_write = 1000 * (g_ip->num_dq + g_ip->num_dqs) * + io_param->vdd_io * io_param->vdd_io * 0.25 * + (1/(io_param->r_on + io_param->rpar_write) + + 1/(io_param->rtt1_dq_write) + 1/(io_param->rtt2_dq_write)) + + 1000 * g_ip->num_ca * io_param->vdd_io * io_param->vdd_io * + (0.5 / (2 * (io_param->r_on_ca + io_param->rtt_ca))); + + power_clk_bias = io_param->vdd_io * io_param->v_sw_clk / io_param->r_diff_term * 1000; + + + if (io_param->io_type == Serial) + { power_termination_read= 1000*(g_ip->num_dq)*io_param->vdd_io*io_param->v_sw_clk/io_param->r_diff_term; + power_termination_write= 1000*(g_ip->num_dq)*io_param->vdd_io*io_param->v_sw_clk/io_param->r_diff_term; + power_clk_bias=0; + } + + if (io_param->io_type == DDR4) + { + power_termination_read=1000 * (g_ip->num_dq + g_ip->num_dqs) * + io_param->vdd_io * io_param->vdd_io *0.5 * (1/(io_param->r_on + io_param->rpar_read + io_param->rs1_dq)) + + 1000 * g_ip->num_ca * io_param->vdd_io * io_param->vdd_io * + (0.5 / (2 * (io_param->r_on_ca + io_param->rtt_ca))); + + + + power_termination_write = 1000 * (g_ip->num_dq + g_ip->num_dqs) * + io_param->vdd_io * io_param->vdd_io * 0.5 * + (1/(io_param->r_on + io_param->rpar_write)) + + 1000 * g_ip->num_ca * io_param->vdd_io * io_param->vdd_io * + (0.5 / (2 * (io_param->r_on_ca + io_param->rtt_ca))); + + + + } + + + //Combining the power terms based on STATE (READ/WRITE/IDLE/SLEEP) + if (g_ip->iostate == READ) + { + io_power_term = g_ip->duty_cycle * + (power_termination_read + power_bias + power_clk_bias); + } + else if (g_ip->iostate == WRITE) + { + io_power_term = g_ip->duty_cycle * + (power_termination_write + power_bias + power_clk_bias); + } + else if (g_ip->iostate == IDLE) + { + io_power_term = g_ip->duty_cycle * + (power_termination_write + power_bias + power_clk_bias); + if (io_param->io_type == DDR4) + { io_power_term = 1e-6*io_param->i_leak*io_param->vdd_io; // IDLE IO power for DDR4 is leakage since bus can be parked at VDDQ + } + } + else if (g_ip->iostate == SLEEP) + { + io_power_term = 1e-6*io_param->i_leak*io_param->vdd_io; //nA to mW + } + else + { + io_power_term = 0; + } + + + //printf("IO Termination and Bias Power (mW) = "); + //cout << io_power_term << endl; +} + + +//External PHY Power and Wakeup Times. More details can be found in the CACTI-IO technical report (), Chapter 2.1. + +void Extio::extio_power_phy () +{ + + + phy_static_power = io_param->phy_datapath_s + io_param->phy_phase_rotator_s + + io_param->phy_clock_tree_s + io_param->phy_rx_s + io_param->phy_dcc_s + + io_param->phy_deskew_s + io_param->phy_leveling_s + io_param->phy_pll_s; // in mW + + phy_dynamic_power = io_param->phy_datapath_d + io_param->phy_phase_rotator_d + + io_param->phy_clock_tree_d + io_param->phy_rx_d + io_param->phy_dcc_d + + io_param->phy_deskew_d + io_param->phy_leveling_d + + io_param->phy_pll_d; // in mW/Gbps + + + +//Combining the power terms based on STATE (READ/WRITE/IDLE/SLEEP) + if (g_ip->iostate == READ) + { + phy_power = phy_static_power + 2 * io_param->frequency * g_ip->num_dq * phy_dynamic_power / 1000; // Total PHY power in mW + } + else if (g_ip->iostate == WRITE) + { + phy_power = phy_static_power + 2 * io_param->frequency * g_ip->num_dq * phy_dynamic_power / 1000; // Total PHY power in mW + } + else if (g_ip->iostate == IDLE) + { + phy_power = phy_static_power; // Total PHY power in mW + + } + else if (g_ip->iostate == SLEEP) + { + phy_power = 0; // Total PHY power in mW; + } + else + { + phy_power = 0; // Total PHY power in mW; + } + + + phy_wtime = io_param->phy_pll_wtime + io_param->phy_phase_rotator_wtime + io_param->phy_rx_wtime + io_param->phy_bandgap_wtime + io_param->phy_deskew_wtime + io_param->phy_vrefgen_wtime; // Total Wakeup time from SLEEP to ACTIVE. Some of the Wakeup time can be hidden if all components do not need to be serially brought out of SLEEP. This depends on the implementation and user can modify the Wakeup times accordingly. + + + //printf("PHY Power (mW) = "); + //cout << phy_power << " "; + //printf("PHY Wakeup Time (us) = "); + //cout << phy_wtime << endl; + +} + + +//External IO Dynamic Power. Does not include termination or PHY. More details can be found in the CACTI-IO technical report (), Chapter 2.1. + +void Extio::extio_power_dynamic() +{ + + if (io_param->io_type == Serial) + { + power_dq_write = 0; + + power_dqs_write = 0; + + power_ca_write = 0; + + power_dq_read = 0; + + power_dqs_read = 0; + + power_ca_read = 0; + + power_clk = 0; + + } + else + { + + + //Line capacitance calculations for effective c_line + + double c_line =1e6/(io_param->z0*2*io_param->frequency); //For DDR signals: DQ, DQS, CLK + double c_line_ca=c_line; //For DDR CA + double c_line_sdr=1e6/(io_param->z0*io_param->frequency); //For SDR CA + double c_line_2T=1e6*2/(io_param->z0*io_param->frequency); //For 2T timing + double c_line_3T=1e6*3/(io_param->z0*io_param->frequency); //For 3T timing + + //Line capacitance if flight time is less than half the bit period + + if (io_param->t_flight < 1e3/(4*io_param->frequency)){ + c_line = 1e3*io_param->t_flight/io_param->z0; + } + + if (io_param->t_flight_ca < 1e3/(4*io_param->frequency)){ + c_line_ca = 1e3*io_param->t_flight/io_param->z0; + } + + if (io_param->t_flight_ca < 1e3/(2*io_param->frequency)){ + c_line_sdr = 1e3*io_param->t_flight/io_param->z0; + } + + if (io_param->t_flight_ca < 1e3*2/(2*io_param->frequency)){ + c_line_2T = 1e3*io_param->t_flight/io_param->z0; + } + + if (io_param->t_flight_ca < 1e3*3/(2*io_param->frequency)){ + c_line_3T = 1e3*io_param->t_flight/io_param->z0; + } + + //Line capacitance calculation for the address bus, depending on what address timing is chosen (DDR/SDR/2T/3T) + + if (g_ip->addr_timing==1.0) { + c_line_ca = c_line_sdr; + } + else if (g_ip->addr_timing==2.0){ + c_line_ca = c_line_2T; + } + else if (g_ip->addr_timing==3.0){ + c_line_ca = c_line_3T; + } + + //Dynamic power per signal group for WRITE and READ modes + + power_dq_write = g_ip->num_dq * g_ip->activity_dq * + (io_param->c_tx + c_line) * io_param->vdd_io * + io_param->v_sw_data_write_line * io_param->frequency / 1000 + + g_ip->num_dq * g_ip->activity_dq * io_param->c_data * + io_param->vdd_io * io_param->v_sw_data_write_load1 * + io_param->frequency / 1000 + + g_ip->num_dq * g_ip->activity_dq * ((g_ip->num_mem_dq-1) * + io_param->c_data) * io_param->vdd_io * + io_param->v_sw_data_write_load2 * io_param->frequency / 1000 + + g_ip->num_dq * g_ip->activity_dq * io_param->c_int * + io_param->vdd_io * io_param->vdd_io * io_param->frequency / 1000; + + power_dqs_write = g_ip->num_dqs * (io_param->c_tx + c_line) * + io_param->vdd_io * io_param->v_sw_data_write_line * + io_param->frequency / 1000 + + g_ip->num_dqs * io_param->c_data * io_param->vdd_io * + io_param->v_sw_data_write_load1 * io_param->frequency / 1000 + + g_ip->num_dqs * ((g_ip->num_mem_dq-1) * io_param->c_data) * + io_param->vdd_io * io_param->v_sw_data_write_load2 * + io_param->frequency / 1000 + + g_ip->num_dqs * io_param->c_int * io_param->vdd_io * + io_param->vdd_io * io_param->frequency / 1000; + + power_ca_write = g_ip->num_ca * g_ip->activity_ca * + (io_param->c_tx + io_param->num_mem_ca * io_param->c_addr + + c_line_ca) * + io_param->vdd_io * io_param->v_sw_addr * io_param->frequency / 1000 + + g_ip->num_ca * g_ip->activity_ca * io_param->c_int * + io_param->vdd_io * io_param->vdd_io * io_param->frequency / 1000; + + power_dq_read = g_ip->num_dq * g_ip->activity_dq * + (io_param->c_tx + c_line) * io_param->vdd_io * + io_param->v_sw_data_read_line * io_param->frequency / 1000.0 + + g_ip->num_dq * g_ip->activity_dq * io_param->c_data * + io_param->vdd_io * io_param->v_sw_data_read_load1 * io_param->frequency / 1000.0 + + g_ip->num_dq *g_ip->activity_dq * ((g_ip->num_mem_dq-1) * io_param->c_data) * + io_param->vdd_io * io_param->v_sw_data_read_load2 * io_param->frequency / 1000.0 + + g_ip->num_dq * g_ip->activity_dq * io_param->c_int * io_param->vdd_io * + io_param->vdd_io * io_param->frequency / 1000.0; + + power_dqs_read = g_ip->num_dqs * (io_param->c_tx + c_line) * + io_param->vdd_io * io_param->v_sw_data_read_line * + io_param->frequency / 1000.0 + + g_ip->num_dqs * io_param->c_data * io_param->vdd_io * + io_param->v_sw_data_read_load1 * io_param->frequency / 1000.0 + + g_ip->num_dqs * ((g_ip->num_mem_dq-1) * io_param->c_data) * + io_param->vdd_io * io_param->v_sw_data_read_load2 * io_param->frequency / 1000.0 + + g_ip->num_dqs * io_param->c_int * io_param->vdd_io * io_param->vdd_io * + io_param->frequency / 1000.0; + + power_ca_read = g_ip->num_ca * g_ip->activity_ca * + (io_param->c_tx + io_param->num_mem_ca * + io_param->c_addr + c_line_ca) * + io_param->vdd_io * io_param->v_sw_addr * io_param->frequency / 1000 + + g_ip->num_ca * g_ip->activity_ca * io_param->c_int * + io_param->vdd_io * io_param->vdd_io * io_param->frequency / 1000; + + power_clk = g_ip->num_clk * + (io_param->c_tx + io_param->num_mem_clk * + io_param->c_data + c_line) * + io_param->vdd_io * io_param->v_sw_clk *io_param->frequency / 1000 + + g_ip->num_clk * io_param->c_int * io_param->vdd_io * + io_param->vdd_io * io_param->frequency / 1000; + + + + } + + //Combining the power terms based on STATE (READ/WRITE/IDLE/SLEEP) + + if (g_ip->iostate == READ) { + io_power_dynamic = g_ip->duty_cycle * (power_dq_read + + power_ca_read + power_dqs_read + power_clk); + + } + else if (g_ip->iostate == WRITE) { + io_power_dynamic = g_ip->duty_cycle * + (power_dq_write + power_ca_write + power_dqs_write + power_clk); + } + else if (g_ip->iostate == IDLE) { + io_power_dynamic = g_ip->duty_cycle * (power_clk); + } + else if (g_ip->iostate == SLEEP) { + io_power_dynamic = 0; + } + else { + io_power_dynamic = 0; + } + + + //printf("IO Dynamic Power (mW) = "); + //cout << io_power_dynamic << " "; +} + + +//External IO Timing and Voltage Margins. More details can be found in the CACTI-IO technical report (), Chapter 2.2. + +void Extio::extio_eye() +{ + + if (io_param->io_type == Serial) + {io_vmargin=0; + } + else + { + + //VOLTAGE MARGINS + //Voltage noise calculations based on proportional and independent noise + //sources for WRITE, READ and CA + double v_noise_write = io_param->k_noise_write_sen * io_param->v_sw_data_write_line + + io_param->v_noise_independent_write; + double v_noise_read = io_param->k_noise_read_sen * io_param->v_sw_data_read_line + + io_param->v_noise_independent_read; + double v_noise_addr = io_param->k_noise_addr_sen * io_param->v_sw_addr + + io_param->v_noise_independent_addr; + + + //Worst-case voltage margin (Swing/2 - Voltage noise) calculations per state + //depending on DQ voltage margin and CA voltage margin (lesser or the two is + //reported) + if (g_ip->iostate == READ) + { + if ((io_param->v_sw_data_read_line/2 - v_noise_read) < + (io_param->v_sw_addr/2 - v_noise_addr)) { + io_vmargin = io_param->v_sw_data_read_line/2 - v_noise_read; + } + else { + io_vmargin = io_param->v_sw_addr/2 - v_noise_addr; + } + } + else if (g_ip->iostate == WRITE) { + if ((io_param->v_sw_data_write_line/2 - v_noise_write) < + (io_param->v_sw_addr/2 - v_noise_addr)) { + io_vmargin = io_param->v_sw_data_write_line/2 - v_noise_write; + } + else { + io_vmargin = io_param->v_sw_addr/2 - v_noise_addr; + } + } + else { + io_vmargin = 0; + } + + } + + //TIMING MARGINS + + double t_margin_write_setup,t_margin_write_hold,t_margin_read_setup + ,t_margin_read_hold,t_margin_addr_setup,t_margin_addr_hold; + + if (io_param->io_type == Serial) + { + + t_margin_write_setup = (1e6/(4*io_param->frequency)) - + io_param->t_ds - + io_param->t_jitter_setup_sen; + + t_margin_write_hold = (1e6/(4*io_param->frequency)) - + io_param->t_dh - io_param->t_dcd_soc - + io_param->t_jitter_hold_sen; + + t_margin_read_setup = (1e6/(4*io_param->frequency)) - + io_param->t_soc_setup - + io_param->t_jitter_setup_sen; + + t_margin_read_hold = (1e6/(4*io_param->frequency)) - + io_param->t_soc_hold - io_param->t_dcd_dram - + io_param->t_dcd_soc - + io_param->t_jitter_hold_sen; + + + + t_margin_addr_setup = (1e6*g_ip->addr_timing/(2*io_param->frequency)); + + + t_margin_addr_hold = (1e6*g_ip->addr_timing/(2*io_param->frequency)); + + + + } + else + { + + + + //Setup and Hold timing margins for DQ WRITE, DQ READ and CA based on timing + //budget + t_margin_write_setup = (1e6/(4*io_param->frequency)) - + io_param->t_ds - io_param->t_error_soc - + io_param->t_jitter_setup_sen - io_param->t_skew_setup + io_param->t_cor_margin; + + t_margin_write_hold = (1e6/(4*io_param->frequency)) - + io_param->t_dh - io_param->t_dcd_soc - io_param->t_error_soc - + io_param->t_jitter_hold_sen - io_param->t_skew_hold + io_param->t_cor_margin; + + t_margin_read_setup = (1e6/(4*io_param->frequency)) - + io_param->t_soc_setup - io_param->t_error_soc - + io_param->t_jitter_setup_sen - io_param->t_skew_setup - + io_param->t_dqsq + io_param->t_cor_margin; + + t_margin_read_hold = (1e6/(4*io_param->frequency)) - + io_param->t_soc_hold - io_param->t_dcd_dram - + io_param->t_dcd_soc - io_param->t_error_soc - + io_param->t_jitter_hold_sen - io_param->t_skew_hold + io_param->t_cor_margin; + + + + t_margin_addr_setup = (1e6*g_ip->addr_timing/(2*io_param->frequency)) - + io_param->t_is - io_param->t_error_soc - + io_param->t_jitter_addr_setup_sen - io_param->t_skew_setup + io_param->t_cor_margin; + + + t_margin_addr_hold = (1e6*g_ip->addr_timing/(2*io_param->frequency)) - + io_param->t_ih - io_param->t_dcd_soc - io_param->t_error_soc - + io_param->t_jitter_addr_hold_sen - io_param->t_skew_hold + io_param->t_cor_margin; + } + + //Worst-case timing margin per state depending on DQ and CA timing margins + if (g_ip->iostate == READ) { + io_tmargin = t_margin_read_setup < t_margin_read_hold ? + t_margin_read_setup : t_margin_read_hold; + io_tmargin = io_tmargin < t_margin_addr_setup ? + io_tmargin : t_margin_addr_setup; + io_tmargin = io_tmargin < t_margin_addr_hold ? + io_tmargin : t_margin_addr_hold; + } + else if (g_ip->iostate == WRITE) { + io_tmargin = t_margin_write_setup < t_margin_write_hold ? + t_margin_write_setup : t_margin_write_hold; + io_tmargin = io_tmargin < t_margin_addr_setup ? + io_tmargin : t_margin_addr_setup; + io_tmargin = io_tmargin < t_margin_addr_hold ? + io_tmargin : t_margin_addr_hold; + } + else { + io_tmargin = 0; + } + + + + + + //OUTPUTS + + + //printf("IO Timing Margin (ps) = "); + //cout << io_tmargin < + +/* This file contains configuration parameters, including + * default configuration for DDR3, LPDDR2 and WIDEIO. The configuration + * parameters include technology parameters - voltage, load capacitances, IO + * area coefficients, timing parameters, as well as external io configuration parameters - + * termination values, voltage noise coefficients and voltage/timing noise + * sensitivity parameters. More details can be found in the CACTI-IO technical + * report (), especially Chapters 2 and 3. The user can define new dram types here. */ + + + +///////////// DDR3 /////////////////// + + const double rtt1_wr_lrdimm_ddr3[8][4] = +{ + {INF,INF,120,120}, + {INF,INF,120,120}, + {INF,120,120,80}, + {120,120,120,60}, + {120,120,120,60}, + {120,80,80,60}, + {120,80,80,60}, + {120,80,60,40} +}; + + const double rtt2_wr_lrdimm_ddr3[8][4] = +{ + {INF,INF,INF,INF},//1 + {INF,INF,120,120},//2 + {120,120,120,80}, //3 + {120,120,80,60}, //4 + {120,120,80,60}, + {120,80,60,40}, //6 + {120,80,60,40}, + {80,80,40,30}//8 +}; + + const double rtt1_rd_lrdimm_ddr3[8][4] = +{ + {INF,INF,120,120},//1 + {INF,INF,120,120},//2 + {INF,120,120,80}, //3 + {120,120,120,60}, //4 + {120,120,120,60}, + {120,80,80,60}, //6 + {120,80,80,60}, + {120,80,60,40}//8 +}; + + const double rtt2_rd_lrdimm_ddr3[8][4] = +{ + {INF,INF,INF,INF},//1 + {INF,120,80,60},//2 + {120,80,80,40}, //3 + {120,80,60,40}, //4 + {120,80,60,40}, + {80,60,60,30}, //6 + {80,60,60,30}, + {80,60,40,20}//8 +}; + + + const double rtt1_wr_host_dimm_ddr3[3][4]= +{ + {120,120,120,60}, + {120,80,80,60}, + {120,80,60,40} +}; + +const double rtt2_wr_host_dimm_ddr3[3][4]= +{ + {120,120,80,60}, + {120,80,60,40}, + {80,80,40,30} +}; + + const double rtt1_rd_host_dimm_ddr3[3][4]= +{ + {120,120,120,60}, + {120,80,80,60}, + {120,80,60,40} +}; + + const double rtt2_rd_host_dimm_ddr3[3][4]= +{ + {120,80,60,40}, + {80,60,60,30}, + {80,60,40,20} +}; + + + const double rtt1_wr_bob_dimm_ddr3[3][4]= +{ + {INF,120,120,80}, + {120,120,120,60}, + {120,80,80,60} +}; + + const double rtt2_wr_bob_dimm_ddr3[3][4]= +{ + {120,120,120,80}, + {120,120,80,60}, + {120,80,60,40} +}; + + const double rtt1_rd_bob_dimm_ddr3[3][4]= +{ + {INF,120,120,80}, + {120,120,120,60}, + {120,80,80,60} +}; + + const double rtt2_rd_bob_dimm_ddr3[3][4]= +{ + {120,80,80,40}, + {120,80,60,40}, + {80,60,60,30} +}; + + +///////////// DDR4 /////////////////// + + const double rtt1_wr_lrdimm_ddr4[8][4] = +{ + {120,120,80,80},//1 + {120,120,80,80},//2 + {120,80,80,60}, //3 + {80,60,60,60}, //4 + {80,60,60,60}, + {60,60,60,40}, //6 + {60,60,60,40}, + {40,40,40,40}//8 +}; + + const double rtt2_wr_lrdimm_ddr4[8][4] = +{ + {INF,INF,INF,INF},//1 + {120,120,120,80},//2 + {120,80,80,80},//3 + {80,80,80,60},//4 + {80,80,80,60}, + {60,60,60,40},//6 + {60,60,60,40}, + {60,40,40,30}//8 +}; + + const double rtt1_rd_lrdimm_ddr4[8][4] = +{ + {120,120,80,80},//1 + {120,120,80,60},//2 + {120,80,80,60}, //3 + {120,60,60,60}, //4 + {120,60,60,60}, + {80,60,60,40}, //6 + {80,60,60,40}, + {60,40,40,30}//8 +}; + + const double rtt2_rd_lrdimm_ddr4[8][4] = +{ + {INF,INF,INF,INF},//1 + {80,60,60,60},//2 + {60,60,40,40}, //3 + {60,40,40,40}, //4 + {60,40,40,40}, + {40,40,40,30}, //6 + {40,40,40,30}, + {40,30,30,20}//8 +}; + + + + const double rtt1_wr_host_dimm_ddr4[3][4]= +{ + {80,60,60,60}, + {60,60,60,60}, + {40,40,40,40} +}; + + const double rtt2_wr_host_dimm_ddr4[3][4]= +{ + {80,80,80,60}, + {60,60,60,40}, + {60,40,40,30} +}; + + const double rtt1_rd_host_dimm_ddr4[3][4]= +{ + {120,60,60,60}, + {80,60,60,40}, + {60,40,40,30} +}; + + const double rtt2_rd_host_dimm_ddr4[3][4]= +{ + {60,40,40,40}, + {40,40,40,30}, + {40,30,30,20} +}; + + + const double rtt1_wr_bob_dimm_ddr4[3][4]= +{ + {120,80,80,60}, + {80,60,60,60}, + {60,60,60,40} +}; + + const double rtt2_wr_bob_dimm_ddr4[3][4]= +{ + {120,80,80,80}, + {80,80,80,60}, + {60,60,60,40} +}; + + const double rtt1_rd_bob_dimm_ddr4[3][4]= +{ + {120,80,80,60}, + {120,60,60,60}, + {80,60,60,40} +}; + + const double rtt2_rd_bob_dimm_ddr4[3][4]= +{ + {60,60,40,40}, + {60,40,40,40}, + {40,40,40,30} +}; + + +///////////////////////////////////////////// + +int IOTechParam::frequnecy_index(Mem_IO_type type) +{ + if(type==DDR3) + { + if(frequency<=400) + return 0; + else if(frequency<=533) + return 1; + else if(frequency<=667) + return 2; + else + return 3; + } + else if(type==DDR4) + { + if(frequency<=800) + return 0; + else if(frequency<=933) + return 1; + else if(frequency<=1066) + return 2; + else + return 3; + } + else + { + assert(false); + } + return 0; +} + + + +IOTechParam::IOTechParam(InputParameter * g_ip) +{ + num_mem_ca = g_ip->num_mem_dq * (g_ip->num_dq/g_ip->mem_data_width); + num_mem_clk = g_ip->num_mem_dq * + (g_ip->num_dq/g_ip->mem_data_width)/(g_ip->num_clk/2); + + + if (g_ip->io_type == LPDDR2) { //LPDDR + //Technology Parameters + + vdd_io = 1.2; + v_sw_clk = 1; + + // Loading paramters + c_int = 1.5; + c_tx = 2; + c_data = 1.5; + c_addr = 0.75; + i_bias = 5; + i_leak = 1000; + + // IO Area coefficients + + ioarea_c = 0.01; + ioarea_k0 = 0.5; + ioarea_k1 = 0.00008; + ioarea_k2 = 0.000000030; + ioarea_k3 = 0.000000000008; + + // Timing parameters (ps) + t_ds = 250; + t_is = 250; + t_dh = 250; + t_ih = 250; + t_dcd_soc = 50; + t_dcd_dram = 50; + t_error_soc = 50; + t_skew_setup = 50; + t_skew_hold = 50; + t_dqsq = 250; + t_soc_setup = 50; + t_soc_hold = 50; + t_jitter_setup = 200; + t_jitter_hold = 200; + t_jitter_addr_setup = 200; + t_jitter_addr_hold = 200; + t_cor_margin = 40; + + //External IO Configuration Parameters + + r_diff_term = 480; + rtt1_dq_read = 100000; + rtt2_dq_read = 100000; + rtt1_dq_write = 100000; + rtt2_dq_write = 100000; + rtt_ca = 240; + rs1_dq = 0; + rs2_dq = 0; + r_stub_ca = 0; + r_on = 50; + r_on_ca = 50; + z0 = 50; + t_flight = 0.5; + t_flight_ca = 0.5; + + // Voltage noise coeffecients + k_noise_write = 0.2; + k_noise_read = 0.2; + k_noise_addr = 0.2; + v_noise_independent_write = 0.1; + v_noise_independent_read = 0.1; + v_noise_independent_addr = 0.1; + + //SENSITIVITY INPUTS FOR TIMING AND VOLTAGE NOISE + +/* This is a user-defined section that depends on the channel sensitivity + * to IO and DRAM parameters. The t_jitter_* and k_noise_* are the + * parameters that are impacted based on the channel analysis. The user + * can define any relationship between the termination, loading and + * configuration parameters AND the t_jitter/k_noise parameters. E.g. a + * linear relationship, a non-linear analytical relationship or a lookup + * table. The sensitivity coefficients are based on channel analysis + * performed on the channel of interest.Given below is an example of such + * a sensitivity relationship. + * Such a linear fit can be found efficiently using an orthogonal design + * of experiments method shown in the technical report (), in Chapter 2.2. */ + + k_noise_write_sen = k_noise_write * (1 + 0.2*(r_on/34 - 1) + + 0.2*(g_ip->num_mem_dq/2 - 1)); + k_noise_read_sen = k_noise_read * (1 + 0.2*(r_on/34 - 1) + + 0.2*(g_ip->num_mem_dq/2 - 1)); + k_noise_addr_sen = k_noise_addr * (1 + 0.1*(rtt_ca/100 - 1) + + 0.2*(r_on/34 - 1) + 0.2*(num_mem_ca/16 - 1)); + + t_jitter_setup_sen = t_jitter_setup * (1 + 0.1*(r_on/34 - 1) + + 0.3*(g_ip->num_mem_dq/2 - 1)); + t_jitter_hold_sen = t_jitter_hold * (1 + 0.1*(r_on/34 - 1) + + 0.3*(g_ip->num_mem_dq/2 - 1)); + t_jitter_addr_setup_sen = t_jitter_addr_setup * (1 + 0.2*(rtt_ca/100 - 1) + + 0.1*(r_on/34 - 1) + 0.4*(num_mem_ca/16 - 1)); + t_jitter_addr_hold_sen = t_jitter_addr_hold * (1 + 0.2*(rtt_ca/100 - 1) + + 0.1*(r_on/34 - 1) + 0.4*(num_mem_ca/16 - 1)); + + // PHY Static Power Coefficients (mW) + + phy_datapath_s = 0; + phy_phase_rotator_s = 5; + phy_clock_tree_s = 0; + phy_rx_s = 3; + phy_dcc_s = 0; + phy_deskew_s = 0; + phy_leveling_s = 0; + phy_pll_s = 2; + + + // PHY Dynamic Power Coefficients (mW/Gbps) + + phy_datapath_d = 0.3; + phy_phase_rotator_d = 0.01; + phy_clock_tree_d = 0.4; + phy_rx_d = 0.2; + phy_dcc_d = 0; + phy_deskew_d = 0; + phy_leveling_d = 0; + phy_pll_d = 0.05; + + + //PHY Wakeup Times (Sleep to Active) (microseconds) + + phy_pll_wtime = 10; + phy_phase_rotator_wtime = 5; + phy_rx_wtime = 2; + phy_bandgap_wtime = 10; + phy_deskew_wtime = 0; + phy_vrefgen_wtime = 0; + + + } + else if (g_ip->io_type == WideIO) { //WIDEIO + //Technology Parameters + vdd_io = 1.2; + v_sw_clk = 1.2; + + // Loading parameters + c_int = 0.5; + c_tx = 0.5; + c_data = 0.5; + c_addr = 0.35; + i_bias = 0; + i_leak = 500; + + // IO Area coefficients + ioarea_c = 0.003; + ioarea_k0 = 0.2; + ioarea_k1 = 0.00004; + ioarea_k2 = 0.000000020; + ioarea_k3 = 0.000000000004; + + // Timing parameters (ps) + t_ds = 250; + t_is = 250; + t_dh = 250; + t_ih = 250; + t_dcd_soc = 50; + t_dcd_dram = 50; + t_error_soc = 50; + t_skew_setup = 50; + t_skew_hold = 50; + t_dqsq = 250; + t_soc_setup = 50; + t_soc_hold = 50; + t_jitter_setup = 200; + t_jitter_hold = 200; + t_jitter_addr_setup = 200; + t_jitter_addr_hold = 200; + t_cor_margin = 50; + + //External IO Configuration Parameters + + r_diff_term = 100000; + rtt1_dq_read = 100000; + rtt2_dq_read = 100000; + rtt1_dq_write = 100000; + rtt2_dq_write = 100000; + rtt_ca = 100000; + rs1_dq = 0; + rs2_dq = 0; + r_stub_ca = 0; + r_on = 75; + r_on_ca = 75; + z0 = 50; + t_flight = 0.05; + t_flight_ca = 0.05; + + // Voltage noise coeffecients + k_noise_write = 0.2; + k_noise_read = 0.2; + k_noise_addr = 0.2; + v_noise_independent_write = 0.1; + v_noise_independent_read = 0.1; + v_noise_independent_addr = 0.1; + + //SENSITIVITY INPUTS FOR TIMING AND VOLTAGE NOISE + + /* This is a user-defined section that depends on the channel sensitivity + * to IO and DRAM parameters. The t_jitter_* and k_noise_* are the + * parameters that are impacted based on the channel analysis. The user + * can define any relationship between the termination, loading and + * configuration parameters AND the t_jitter/k_noise parameters. E.g. a + * linear relationship, a non-linear analytical relationship or a lookup + * table. The sensitivity coefficients are based on channel analysis + * performed on the channel of interest.Given below is an example of such + * a sensitivity relationship. + * Such a linear fit can be found efficiently using an orthogonal design + * of experiments method shown in the technical report (), in Chapter 2.2. */ + + k_noise_write_sen = k_noise_write * (1 + 0.2*(r_on/50 - 1) + + 0.2*(g_ip->num_mem_dq/2 - 1)); + k_noise_read_sen = k_noise_read * (1 + 0.2*(r_on/50 - 1) + + 0.2*(g_ip->num_mem_dq/2 - 1)); + k_noise_addr_sen = k_noise_addr * (1 + 0.2*(r_on/50 - 1) + + 0.2*(num_mem_ca/16 - 1)); + + + t_jitter_setup_sen = t_jitter_setup * (1 + 0.1*(r_on/50 - 1) + + 0.3*(g_ip->num_mem_dq/2 - 1)); + t_jitter_hold_sen = t_jitter_hold * (1 + 0.1*(r_on/50 - 1) + + 0.3*(g_ip->num_mem_dq/2 - 1)); + t_jitter_addr_setup_sen = t_jitter_addr_setup * (1 + 0.1*(r_on/50 - 1) + + 0.4*(num_mem_ca/16 - 1)); + t_jitter_addr_hold_sen = t_jitter_addr_hold * (1 + 0.1*(r_on/50 - 1) + + 0.4*(num_mem_ca/16 - 1)); + + // PHY Static Power Coefficients (mW) + phy_datapath_s = 0; + phy_phase_rotator_s = 1; + phy_clock_tree_s = 0; + phy_rx_s = 0; + phy_dcc_s = 0; + phy_deskew_s = 0; + phy_leveling_s = 0; + phy_pll_s = 0; + + + // PHY Dynamic Power Coefficients (mW/Gbps) + phy_datapath_d = 0.3; + phy_phase_rotator_d = 0.01; + phy_clock_tree_d = 0.2; + phy_rx_d = 0.1; + phy_dcc_d = 0; + phy_deskew_d = 0; + phy_leveling_d = 0; + phy_pll_d = 0; + + //PHY Wakeup Times (Sleep to Active) (microseconds) + + phy_pll_wtime = 10; + phy_phase_rotator_wtime = 0; + phy_rx_wtime = 0; + phy_bandgap_wtime = 0; + phy_deskew_wtime = 0; + phy_vrefgen_wtime = 0; + + + } + else if (g_ip->io_type == DDR3) + { //Default parameters for DDR3 + // IO Supply voltage (V) + vdd_io = 1.5; + v_sw_clk = 0.75; + + // Loading parameters + c_int = 1.5; + c_tx = 2; + c_data = 1.5; + c_addr = 0.75; + i_bias = 15; + i_leak = 1000; + + // IO Area coefficients + ioarea_c = 0.01; + ioarea_k0 = 0.5; + ioarea_k1 = 0.00015; + ioarea_k2 = 0.000000045; + ioarea_k3 = 0.000000000015; + + // Timing parameters (ps) + t_ds = 150; + t_is = 150; + t_dh = 150; + t_ih = 150; + t_dcd_soc = 50; + t_dcd_dram = 50; + t_error_soc = 25; + t_skew_setup = 25; + t_skew_hold = 25; + t_dqsq = 100; + t_soc_setup = 50; + t_soc_hold = 50; + t_jitter_setup = 100; + t_jitter_hold = 100; + t_jitter_addr_setup = 100; + t_jitter_addr_hold = 100; + t_cor_margin = 30; + + + //External IO Configuration Parameters + + r_diff_term = 100; + rtt1_dq_read = g_ip->rtt_value; + rtt2_dq_read = g_ip->rtt_value; + rtt1_dq_write = g_ip->rtt_value; + rtt2_dq_write = g_ip->rtt_value; + rtt_ca = 50; + rs1_dq = 15; + rs2_dq = 15; + r_stub_ca = 0; + r_on = g_ip->ron_value; + r_on_ca = 50; + z0 = 50; + t_flight = g_ip->tflight_value; + t_flight_ca = 2; + + // Voltage noise coeffecients + + k_noise_write = 0.2; + k_noise_read = 0.2; + k_noise_addr = 0.2; + v_noise_independent_write = 0.1; + v_noise_independent_read = 0.1; + v_noise_independent_addr = 0.1; + + //SENSITIVITY INPUTS FOR TIMING AND VOLTAGE NOISE + + /* This is a user-defined section that depends on the channel sensitivity + * to IO and DRAM parameters. The t_jitter_* and k_noise_* are the + * parameters that are impacted based on the channel analysis. The user + * can define any relationship between the termination, loading and + * configuration parameters AND the t_jitter/k_noise parameters. E.g. a + * linear relationship, a non-linear analytical relationship or a lookup + * table. The sensitivity coefficients are based on channel analysis + * performed on the channel of interest.Given below is an example of such + * a sensitivity relationship. + * Such a linear fit can be found efficiently using an orthogonal design + * of experiments method shown in the technical report (), in Chapter 2.2. */ + + k_noise_write_sen = k_noise_write * (1 + 0.1*(rtt1_dq_write/60 - 1) + + 0.2*(rtt2_dq_write/60 - 1) + 0.2*(r_on/34 - 1) + + 0.2*(g_ip->num_mem_dq/2 - 1)); + + k_noise_read_sen = k_noise_read * (1 + 0.1*(rtt1_dq_read/60 - 1) + + 0.2*(rtt2_dq_read/60 - 1) + 0.2*(r_on/34 - 1) + + 0.2*(g_ip->num_mem_dq/2 - 1)); + + k_noise_addr_sen = k_noise_addr * (1 + 0.1*(rtt_ca/50 - 1) + + 0.2*(r_on/34 - 1) + 0.2*(num_mem_ca/16 - 1)); + + + t_jitter_setup_sen = t_jitter_setup * (1 + 0.2*(rtt1_dq_write/60 - 1) + + 0.3*(rtt2_dq_write/60 - 1) + 0.1*(r_on/34 - 1) + + 0.3*(g_ip->num_mem_dq/2 - 1)); + + t_jitter_hold_sen = t_jitter_hold * (1 + 0.2*(rtt1_dq_write/60 - 1) + + 0.3*(rtt2_dq_write/60 - 1) + + 0.1*(r_on/34 - 1) + 0.3*(g_ip->num_mem_dq/2 - 1)); + + t_jitter_addr_setup_sen = t_jitter_addr_setup * (1 + 0.2*(rtt_ca/50 - 1) + + 0.1*(r_on/34 - 1) + 0.4*(num_mem_ca/16 - 1)); + + t_jitter_addr_hold_sen = t_jitter_addr_hold * (1 + 0.2*(rtt_ca/50 - 1) + + 0.1*(r_on/34 - 1) + 0.4*(num_mem_ca/16 - 1)); + + // PHY Static Power Coefficients (mW) + phy_datapath_s = 0; + phy_phase_rotator_s = 10; + phy_clock_tree_s = 0; + phy_rx_s = 10; + phy_dcc_s = 0; + phy_deskew_s = 0; + phy_leveling_s = 0; + phy_pll_s = 10; + + // PHY Dynamic Power Coefficients (mW/Gbps) + phy_datapath_d = 0.5; + phy_phase_rotator_d = 0.01; + phy_clock_tree_d = 0.5; + phy_rx_d = 0.5; + phy_dcc_d = 0.05; + phy_deskew_d = 0.1; + phy_leveling_d = 0.05; + phy_pll_d = 0.05; + + //PHY Wakeup Times (Sleep to Active) (microseconds) + + phy_pll_wtime = 10; + phy_phase_rotator_wtime = 5; + phy_rx_wtime = 2; + phy_bandgap_wtime = 10; + phy_deskew_wtime = 0.003; + phy_vrefgen_wtime = 0.5; + + + } + else if (g_ip->io_type == DDR4) + { //Default parameters for DDR4 + // IO Supply voltage (V) + vdd_io = 1.2; + v_sw_clk = 0.6; + + // Loading parameters + c_int = 1.5; + c_tx = 2; + c_data = 1; + c_addr = 0.75; + i_bias = 15; + i_leak = 1000; + + // IO Area coefficients + ioarea_c = 0.01; + ioarea_k0 = 0.35; + ioarea_k1 = 0.00008; + ioarea_k2 = 0.000000035; + ioarea_k3 = 0.000000000010; + + // Timing parameters (ps) + t_ds = 30; + t_is = 60; + t_dh = 30; + t_ih = 60; + t_dcd_soc = 20; + t_dcd_dram = 20; + t_error_soc = 15; + t_skew_setup = 15; + t_skew_hold = 15; + t_dqsq = 50; + t_soc_setup = 20; + t_soc_hold = 10; + t_jitter_setup = 30; + t_jitter_hold = 30; + t_jitter_addr_setup = 60; + t_jitter_addr_hold = 60; + t_cor_margin = 10; + + + //External IO Configuration Parameters + + r_diff_term = 100; + rtt1_dq_read = g_ip->rtt_value; + rtt2_dq_read = g_ip->rtt_value; + rtt1_dq_write = g_ip->rtt_value; + rtt2_dq_write = g_ip->rtt_value; + rtt_ca = 50; + rs1_dq = 15; + rs2_dq = 15; + r_stub_ca = 0; + r_on = g_ip->ron_value; + r_on_ca = 50; + z0 = 50; + t_flight = g_ip->tflight_value; + t_flight_ca = 2; + + // Voltage noise coeffecients + + k_noise_write = 0.2; + k_noise_read = 0.2; + k_noise_addr = 0.2; + v_noise_independent_write = 0.1; + v_noise_independent_read = 0.1; + v_noise_independent_addr = 0.1; + + //SENSITIVITY INPUTS FOR TIMING AND VOLTAGE NOISE + + /* This is a user-defined section that depends on the channel sensitivity + * to IO and DRAM parameters. The t_jitter_* and k_noise_* are the + * parameters that are impacted based on the channel analysis. The user + * can define any relationship between the termination, loading and + * configuration parameters AND the t_jitter/k_noise parameters. E.g. a + * linear relationship, a non-linear analytical relationship or a lookup + * table. The sensitivity coefficients are based on channel analysis + * performed on the channel of interest.Given below is an example of such + * a sensitivity relationship. + * Such a linear fit can be found efficiently using an orthogonal design + * of experiments method shown in the technical report (), in Chapter 2.2. */ + + k_noise_write_sen = k_noise_write * (1 + 0.1*(rtt1_dq_write/60 - 1) + + 0.2*(rtt2_dq_write/60 - 1) + 0.2*(r_on/34 - 1) + + 0.2*(g_ip->num_mem_dq/2 - 1)); + + k_noise_read_sen = k_noise_read * (1 + 0.1*(rtt1_dq_read/60 - 1) + + 0.2*(rtt2_dq_read/60 - 1) + 0.2*(r_on/34 - 1) + + 0.2*(g_ip->num_mem_dq/2 - 1)); + + k_noise_addr_sen = k_noise_addr * (1 + 0.1*(rtt_ca/50 - 1) + + 0.2*(r_on/34 - 1) + 0.2*(num_mem_ca/16 - 1)); + + + t_jitter_setup_sen = t_jitter_setup * (1 + 0.2*(rtt1_dq_write/60 - 1) + + 0.3*(rtt2_dq_write/60 - 1) + 0.1*(r_on/34 - 1) + + 0.3*(g_ip->num_mem_dq/2 - 1)); + + t_jitter_hold_sen = t_jitter_hold * (1 + 0.2*(rtt1_dq_write/60 - 1) + + 0.3*(rtt2_dq_write/60 - 1) + + 0.1*(r_on/34 - 1) + 0.3*(g_ip->num_mem_dq/2 - 1)); + + t_jitter_addr_setup_sen = t_jitter_addr_setup * (1 + 0.2*(rtt_ca/50 - 1) + + 0.1*(r_on/34 - 1) + 0.4*(num_mem_ca/16 - 1)); + + t_jitter_addr_hold_sen = t_jitter_addr_hold * (1 + 0.2*(rtt_ca/50 - 1) + + 0.1*(r_on/34 - 1) + 0.4*(num_mem_ca/16 - 1)); + + // PHY Static Power Coefficients (mW) + phy_datapath_s = 0; + phy_phase_rotator_s = 10; + phy_clock_tree_s = 0; + phy_rx_s = 10; + phy_dcc_s = 0; + phy_deskew_s = 0; + phy_leveling_s = 0; + phy_pll_s = 10; + + // PHY Dynamic Power Coefficients (mW/Gbps) + phy_datapath_d = 0.5; + phy_phase_rotator_d = 0.01; + phy_clock_tree_d = 0.5; + phy_rx_d = 0.5; + phy_dcc_d = 0.05; + phy_deskew_d = 0.1; + phy_leveling_d = 0.05; + phy_pll_d = 0.05; + + //PHY Wakeup Times (Sleep to Active) (microseconds) + + phy_pll_wtime = 10; + phy_phase_rotator_wtime = 5; + phy_rx_wtime = 2; + phy_bandgap_wtime = 10; + phy_deskew_wtime = 0.003; + phy_vrefgen_wtime = 0.5; + + + } + else if (g_ip->io_type == Serial) + { //Default parameters for Serial + // IO Supply voltage (V) + vdd_io = 1.2; + v_sw_clk = 0.75; + + // IO Area coefficients + ioarea_c = 0.01; + ioarea_k0 = 0.15; + ioarea_k1 = 0.00005; + ioarea_k2 = 0.000000025; + ioarea_k3 = 0.000000000005; + + // Timing parameters (ps) + t_ds = 15; + t_dh = 15; + t_dcd_soc = 10; + t_dcd_dram = 10; + t_soc_setup = 10; + t_soc_hold = 10; + t_jitter_setup = 20; + t_jitter_hold = 20; + + //External IO Configuration Parameters + + r_diff_term = 100; + + + t_jitter_setup_sen = t_jitter_setup; + + t_jitter_hold_sen = t_jitter_hold; + + t_jitter_addr_setup_sen = t_jitter_addr_setup; + + t_jitter_addr_hold_sen = t_jitter_addr_hold; + + // PHY Static Power Coefficients (mW) + phy_datapath_s = 0; + phy_phase_rotator_s = 10; + phy_clock_tree_s = 0; + phy_rx_s = 10; + phy_dcc_s = 0; + phy_deskew_s = 0; + phy_leveling_s = 0; + phy_pll_s = 10; + + // PHY Dynamic Power Coefficients (mW/Gbps) + phy_datapath_d = 0.5; + phy_phase_rotator_d = 0.01; + phy_clock_tree_d = 0.5; + phy_rx_d = 0.5; + phy_dcc_d = 0.05; + phy_deskew_d = 0.1; + phy_leveling_d = 0.05; + phy_pll_d = 0.05; + + //PHY Wakeup Times (Sleep to Active) (microseconds) + + phy_pll_wtime = 10; + phy_phase_rotator_wtime = 5; + phy_rx_wtime = 2; + phy_bandgap_wtime = 10; + phy_deskew_wtime = 0.003; + phy_vrefgen_wtime = 0.5; + + + } + else + { + cout << "Not Yet supported" << endl; + exit(1); + } + + + //SWING AND TERMINATION CALCULATIONS + + //R|| calculation + rpar_write =(rtt1_dq_write + rs1_dq)*(rtt2_dq_write + rs2_dq)/ + (rtt1_dq_write + rs1_dq + rtt2_dq_write + rs2_dq); + rpar_read =(rtt1_dq_read)*(rtt2_dq_read + rs2_dq)/ + (rtt1_dq_read + rtt2_dq_read + rs2_dq); + + //Swing calculation + v_sw_data_read_load1 =vdd_io * (rtt1_dq_read)*(rtt2_dq_read + rs2_dq) / + ((rtt1_dq_read + rtt2_dq_read + rs2_dq)*(r_on + rs1_dq + rpar_read)); + v_sw_data_read_load2 =vdd_io * (rtt1_dq_read)*(rtt2_dq_read) / + ((rtt1_dq_read + rtt2_dq_read + rs2_dq)*(r_on + rs1_dq + rpar_read)); + v_sw_data_read_line =vdd_io * rpar_read / (r_on + rs1_dq + rpar_read); + v_sw_addr =vdd_io * rtt_ca / (50 + rtt_ca); + v_sw_data_write_load1 =vdd_io * (rtt1_dq_write)*(rtt2_dq_write + rs2_dq) / + ((rtt1_dq_write + rs1_dq + rtt2_dq_write + rs2_dq)*(r_on + rpar_write)); + v_sw_data_write_load2 =vdd_io * (rtt2_dq_write)*(rtt1_dq_write + rs1_dq) / + ((rtt1_dq_write + rs1_dq + rtt2_dq_write + rs2_dq)*(r_on + rpar_write)); + v_sw_data_write_line =vdd_io * rpar_write / (r_on + rpar_write); + +} + +// This constructor recieves most of the input from g_ip. +// however it is possible to customize other some of the paremeters, +// that are mentioned as inputs. +// connection: 0 bob-dimm, 1 host-dimm, 2 lrdimm + + +IOTechParam::IOTechParam(InputParameter * g_ip, Mem_IO_type io_type1, int num_mem_dq, int mem_data_width + , int num_dq, int connection, int num_loads, double freq) +{ + num_mem_ca = num_mem_dq * (mem_data_width); + num_mem_clk = num_mem_dq * + (num_dq/mem_data_width)/(g_ip->num_clk/2); + + io_type = io_type1; + frequency = freq; + + + + + if (io_type == LPDDR2) { //LPDDR + //Technology Parameters + + vdd_io = 1.2; + v_sw_clk = 1; + + // Loading paramters + c_int = 1.5; + c_tx = 2; + c_data = 1.5; + c_addr = 0.75; + i_bias = 5; + i_leak = 1000; + + // IO Area coefficients + + ioarea_c = 0.01; + ioarea_k0 = 0.5; + ioarea_k1 = 0.00008; + ioarea_k2 = 0.000000030; + ioarea_k3 = 0.000000000008; + + // Timing parameters (ps) + t_ds = 250; + t_is = 250; + t_dh = 250; + t_ih = 250; + t_dcd_soc = 50; + t_dcd_dram = 50; + t_error_soc = 50; + t_skew_setup = 50; + t_skew_hold = 50; + t_dqsq = 250; + t_soc_setup = 50; + t_soc_hold = 50; + t_jitter_setup = 200; + t_jitter_hold = 200; + t_jitter_addr_setup = 200; + t_jitter_addr_hold = 200; + t_cor_margin = 40; + + //External IO Configuration Parameters + + r_diff_term = 480; + rtt1_dq_read = 100000; + rtt2_dq_read = 100000; + rtt1_dq_write = 100000; + rtt2_dq_write = 100000; + rtt_ca = 240; + rs1_dq = 0; + rs2_dq = 0; + r_stub_ca = 0; + r_on = 50; + r_on_ca = 50; + z0 = 50; + t_flight = 0.5; + t_flight_ca = 0.5; + + // Voltage noise coeffecients + k_noise_write = 0.2; + k_noise_read = 0.2; + k_noise_addr = 0.2; + v_noise_independent_write = 0.1; + v_noise_independent_read = 0.1; + v_noise_independent_addr = 0.1; + + //SENSITIVITY INPUTS FOR TIMING AND VOLTAGE NOISE + +/* This is a user-defined section that depends on the channel sensitivity + * to IO and DRAM parameters. The t_jitter_* and k_noise_* are the + * parameters that are impacted based on the channel analysis. The user + * can define any relationship between the termination, loading and + * configuration parameters AND the t_jitter/k_noise parameters. E.g. a + * linear relationship, a non-linear analytical relationship or a lookup + * table. The sensitivity coefficients are based on channel analysis + * performed on the channel of interest.Given below is an example of such + * a sensitivity relationship. + * Such a linear fit can be found efficiently using an orthogonal design + * of experiments method shown in the technical report (), in Chapter 2.2. */ + + k_noise_write_sen = k_noise_write * (1 + 0.2*(r_on/34 - 1) + + 0.2*(num_mem_dq/2 - 1)); + k_noise_read_sen = k_noise_read * (1 + 0.2*(r_on/34 - 1) + + 0.2*(num_mem_dq/2 - 1)); + k_noise_addr_sen = k_noise_addr * (1 + 0.1*(rtt_ca/100 - 1) + + 0.2*(r_on/34 - 1) + 0.2*(num_mem_ca/16 - 1)); + + t_jitter_setup_sen = t_jitter_setup * (1 + 0.1*(r_on/34 - 1) + + 0.3*(num_mem_dq/2 - 1)); + t_jitter_hold_sen = t_jitter_hold * (1 + 0.1*(r_on/34 - 1) + + 0.3*(num_mem_dq/2 - 1)); + t_jitter_addr_setup_sen = t_jitter_addr_setup * (1 + 0.2*(rtt_ca/100 - 1) + + 0.1*(r_on/34 - 1) + 0.4*(num_mem_ca/16 - 1)); + t_jitter_addr_hold_sen = t_jitter_addr_hold * (1 + 0.2*(rtt_ca/100 - 1) + + 0.1*(r_on/34 - 1) + 0.4*(num_mem_ca/16 - 1)); + + // PHY Static Power Coefficients (mW) + + phy_datapath_s = 0; + phy_phase_rotator_s = 5; + phy_clock_tree_s = 0; + phy_rx_s = 3; + phy_dcc_s = 0; + phy_deskew_s = 0; + phy_leveling_s = 0; + phy_pll_s = 2; + + + // PHY Dynamic Power Coefficients (mW/Gbps) + + phy_datapath_d = 0.3; + phy_phase_rotator_d = 0.01; + phy_clock_tree_d = 0.4; + phy_rx_d = 0.2; + phy_dcc_d = 0; + phy_deskew_d = 0; + phy_leveling_d = 0; + phy_pll_d = 0.05; + + + //PHY Wakeup Times (Sleep to Active) (microseconds) + + phy_pll_wtime = 10; + phy_phase_rotator_wtime = 5; + phy_rx_wtime = 2; + phy_bandgap_wtime = 10; + phy_deskew_wtime = 0; + phy_vrefgen_wtime = 0; + + + } + else if (io_type == WideIO) { //WIDEIO + //Technology Parameters + vdd_io = 1.2; + v_sw_clk = 1.2; + + // Loading parameters + c_int = 0.5; + c_tx = 0.5; + c_data = 0.5; + c_addr = 0.35; + i_bias = 0; + i_leak = 500; + + // IO Area coefficients + ioarea_c = 0.003; + ioarea_k0 = 0.2; + ioarea_k1 = 0.00004; + ioarea_k2 = 0.000000020; + ioarea_k3 = 0.000000000004; + + // Timing parameters (ps) + t_ds = 250; + t_is = 250; + t_dh = 250; + t_ih = 250; + t_dcd_soc = 50; + t_dcd_dram = 50; + t_error_soc = 50; + t_skew_setup = 50; + t_skew_hold = 50; + t_dqsq = 250; + t_soc_setup = 50; + t_soc_hold = 50; + t_jitter_setup = 200; + t_jitter_hold = 200; + t_jitter_addr_setup = 200; + t_jitter_addr_hold = 200; + t_cor_margin = 50; + + //External IO Configuration Parameters + + r_diff_term = 100000; + rtt1_dq_read = 100000; + rtt2_dq_read = 100000; + rtt1_dq_write = 100000; + rtt2_dq_write = 100000; + rtt_ca = 100000; + rs1_dq = 0; + rs2_dq = 0; + r_stub_ca = 0; + r_on = 75; + r_on_ca = 75; + z0 = 50; + t_flight = 0.05; + t_flight_ca = 0.05; + + // Voltage noise coeffecients + k_noise_write = 0.2; + k_noise_read = 0.2; + k_noise_addr = 0.2; + v_noise_independent_write = 0.1; + v_noise_independent_read = 0.1; + v_noise_independent_addr = 0.1; + + //SENSITIVITY INPUTS FOR TIMING AND VOLTAGE NOISE + + /* This is a user-defined section that depends on the channel sensitivity + * to IO and DRAM parameters. The t_jitter_* and k_noise_* are the + * parameters that are impacted based on the channel analysis. The user + * can define any relationship between the termination, loading and + * configuration parameters AND the t_jitter/k_noise parameters. E.g. a + * linear relationship, a non-linear analytical relationship or a lookup + * table. The sensitivity coefficients are based on channel analysis + * performed on the channel of interest.Given below is an example of such + * a sensitivity relationship. + * Such a linear fit can be found efficiently using an orthogonal design + * of experiments method shown in the technical report (), in Chapter 2.2. */ + + k_noise_write_sen = k_noise_write * (1 + 0.2*(r_on/50 - 1) + + 0.2*(num_mem_dq/2 - 1)); + k_noise_read_sen = k_noise_read * (1 + 0.2*(r_on/50 - 1) + + 0.2*(num_mem_dq/2 - 1)); + k_noise_addr_sen = k_noise_addr * (1 + 0.2*(r_on/50 - 1) + + 0.2*(num_mem_ca/16 - 1)); + + + t_jitter_setup_sen = t_jitter_setup * (1 + 0.1*(r_on/50 - 1) + + 0.3*(num_mem_dq/2 - 1)); + t_jitter_hold_sen = t_jitter_hold * (1 + 0.1*(r_on/50 - 1) + + 0.3*(num_mem_dq/2 - 1)); + t_jitter_addr_setup_sen = t_jitter_addr_setup * (1 + 0.1*(r_on/50 - 1) + + 0.4*(num_mem_ca/16 - 1)); + t_jitter_addr_hold_sen = t_jitter_addr_hold * (1 + 0.1*(r_on/50 - 1) + + 0.4*(num_mem_ca/16 - 1)); + + // PHY Static Power Coefficients (mW) + phy_datapath_s = 0; + phy_phase_rotator_s = 1; + phy_clock_tree_s = 0; + phy_rx_s = 0; + phy_dcc_s = 0; + phy_deskew_s = 0; + phy_leveling_s = 0; + phy_pll_s = 0; + + + // PHY Dynamic Power Coefficients (mW/Gbps) + phy_datapath_d = 0.3; + phy_phase_rotator_d = 0.01; + phy_clock_tree_d = 0.2; + phy_rx_d = 0.1; + phy_dcc_d = 0; + phy_deskew_d = 0; + phy_leveling_d = 0; + phy_pll_d = 0; + + //PHY Wakeup Times (Sleep to Active) (microseconds) + + phy_pll_wtime = 10; + phy_phase_rotator_wtime = 0; + phy_rx_wtime = 0; + phy_bandgap_wtime = 0; + phy_deskew_wtime = 0; + phy_vrefgen_wtime = 0; + + + } + else if (io_type == DDR3) + { //Default parameters for DDR3 + // IO Supply voltage (V) + vdd_io = 1.5; + v_sw_clk = 0.75; + + // Loading parameters + c_int = 1.5; + c_tx = 2; + c_data = 1.5; + c_addr = 0.75; + i_bias = 15; + i_leak = 1000; + + // IO Area coefficients + ioarea_c = 0.01; + ioarea_k0 = 0.5; + ioarea_k1 = 0.00015; + ioarea_k2 = 0.000000045; + ioarea_k3 = 0.000000000015; + + // Timing parameters (ps) + t_ds = 150; + t_is = 150; + t_dh = 150; + t_ih = 150; + t_dcd_soc = 50; + t_dcd_dram = 50; + t_error_soc = 25; + t_skew_setup = 25; + t_skew_hold = 25; + t_dqsq = 100; + t_soc_setup = 50; + t_soc_hold = 50; + t_jitter_setup = 100; + t_jitter_hold = 100; + t_jitter_addr_setup = 100; + t_jitter_addr_hold = 100; + t_cor_margin = 30; + + + //External IO Configuration Parameters + + r_diff_term = 100; + + /* + rtt1_dq_read = g_ip->rtt_value; + rtt2_dq_read = g_ip->rtt_value; + rtt1_dq_write = g_ip->rtt_value; + rtt2_dq_write = g_ip->rtt_value; + */ + switch(connection) + { + case(0): + rtt1_dq_write = rtt1_wr_bob_dimm_ddr3[num_loads-1][frequnecy_index(io_type)]; + rtt2_dq_write = rtt2_wr_bob_dimm_ddr3[num_loads-1][frequnecy_index(io_type)]; + rtt1_dq_read = rtt1_rd_bob_dimm_ddr3[num_loads-1][frequnecy_index(io_type)]; + rtt2_dq_read = rtt2_rd_bob_dimm_ddr3[num_loads-1][frequnecy_index(io_type)]; + break; + case(1): + rtt1_dq_write = rtt1_wr_host_dimm_ddr3[num_loads-1][frequnecy_index(io_type)]; + rtt2_dq_write = rtt2_wr_host_dimm_ddr3[num_loads-1][frequnecy_index(io_type)]; + rtt1_dq_read = rtt1_rd_host_dimm_ddr3[num_loads-1][frequnecy_index(io_type)]; + rtt2_dq_read = rtt2_rd_host_dimm_ddr3[num_loads-1][frequnecy_index(io_type)]; + break; + case(2): + rtt1_dq_write = rtt1_wr_lrdimm_ddr3[num_loads-1][frequnecy_index(io_type)]; + rtt2_dq_write = rtt2_wr_lrdimm_ddr3[num_loads-1][frequnecy_index(io_type)]; + rtt1_dq_read = rtt1_rd_lrdimm_ddr3[num_loads-1][frequnecy_index(io_type)]; + rtt2_dq_read = rtt2_rd_lrdimm_ddr3[num_loads-1][frequnecy_index(io_type)]; + break; + default: + break; + } + + + rtt_ca = 50; + rs1_dq = 15; + rs2_dq = 15; + r_stub_ca = 0; + r_on = g_ip->ron_value; + r_on_ca = 50; + z0 = 50; + t_flight = g_ip->tflight_value; + t_flight_ca = 2; + + // Voltage noise coeffecients + + k_noise_write = 0.2; + k_noise_read = 0.2; + k_noise_addr = 0.2; + v_noise_independent_write = 0.1; + v_noise_independent_read = 0.1; + v_noise_independent_addr = 0.1; + + //SENSITIVITY INPUTS FOR TIMING AND VOLTAGE NOISE + + /* This is a user-defined section that depends on the channel sensitivity + * to IO and DRAM parameters. The t_jitter_* and k_noise_* are the + * parameters that are impacted based on the channel analysis. The user + * can define any relationship between the termination, loading and + * configuration parameters AND the t_jitter/k_noise parameters. E.g. a + * linear relationship, a non-linear analytical relationship or a lookup + * table. The sensitivity coefficients are based on channel analysis + * performed on the channel of interest.Given below is an example of such + * a sensitivity relationship. + * Such a linear fit can be found efficiently using an orthogonal design + * of experiments method shown in the technical report (), in Chapter 2.2. */ + + k_noise_write_sen = k_noise_write * (1 + 0.1*(rtt1_dq_write/60 - 1) + + 0.2*(rtt2_dq_write/60 - 1) + 0.2*(r_on/34 - 1) + + 0.2*(num_mem_dq/2 - 1)); + + k_noise_read_sen = k_noise_read * (1 + 0.1*(rtt1_dq_read/60 - 1) + + 0.2*(rtt2_dq_read/60 - 1) + 0.2*(r_on/34 - 1) + + 0.2*(num_mem_dq/2 - 1)); + + k_noise_addr_sen = k_noise_addr * (1 + 0.1*(rtt_ca/50 - 1) + + 0.2*(r_on/34 - 1) + 0.2*(num_mem_ca/16 - 1)); + + + t_jitter_setup_sen = t_jitter_setup * (1 + 0.2*(rtt1_dq_write/60 - 1) + + 0.3*(rtt2_dq_write/60 - 1) + 0.1*(r_on/34 - 1) + + 0.3*(num_mem_dq/2 - 1)); + + t_jitter_hold_sen = t_jitter_hold * (1 + 0.2*(rtt1_dq_write/60 - 1) + + 0.3*(rtt2_dq_write/60 - 1) + + 0.1*(r_on/34 - 1) + 0.3*(num_mem_dq/2 - 1)); + + t_jitter_addr_setup_sen = t_jitter_addr_setup * (1 + 0.2*(rtt_ca/50 - 1) + + 0.1*(r_on/34 - 1) + 0.4*(num_mem_ca/16 - 1)); + + t_jitter_addr_hold_sen = t_jitter_addr_hold * (1 + 0.2*(rtt_ca/50 - 1) + + 0.1*(r_on/34 - 1) + 0.4*(num_mem_ca/16 - 1)); + + // PHY Static Power Coefficients (mW) + phy_datapath_s = 0; + phy_phase_rotator_s = 10; + phy_clock_tree_s = 0; + phy_rx_s = 10; + phy_dcc_s = 0; + phy_deskew_s = 0; + phy_leveling_s = 0; + phy_pll_s = 10; + + // PHY Dynamic Power Coefficients (mW/Gbps) + phy_datapath_d = 0.5; + phy_phase_rotator_d = 0.01; + phy_clock_tree_d = 0.5; + phy_rx_d = 0.5; + phy_dcc_d = 0.05; + phy_deskew_d = 0.1; + phy_leveling_d = 0.05; + phy_pll_d = 0.05; + + //PHY Wakeup Times (Sleep to Active) (microseconds) + + phy_pll_wtime = 10; + phy_phase_rotator_wtime = 5; + phy_rx_wtime = 2; + phy_bandgap_wtime = 10; + phy_deskew_wtime = 0.003; + phy_vrefgen_wtime = 0.5; + + + } + else if (io_type == DDR4) + { //Default parameters for DDR4 + // IO Supply voltage (V) + vdd_io = 1.2; + v_sw_clk = 0.6; + + // Loading parameters + c_int = 1.5; + c_tx = 2; + c_data = 1; + c_addr = 0.75; + i_bias = 15; + i_leak = 1000; + + // IO Area coefficients + ioarea_c = 0.01; + ioarea_k0 = 0.35; + ioarea_k1 = 0.00008; + ioarea_k2 = 0.000000035; + ioarea_k3 = 0.000000000010; + + // Timing parameters (ps) + t_ds = 30; + t_is = 60; + t_dh = 30; + t_ih = 60; + t_dcd_soc = 20; + t_dcd_dram = 20; + t_error_soc = 15; + t_skew_setup = 15; + t_skew_hold = 15; + t_dqsq = 50; + t_soc_setup = 20; + t_soc_hold = 10; + t_jitter_setup = 30; + t_jitter_hold = 30; + t_jitter_addr_setup = 60; + t_jitter_addr_hold = 60; + t_cor_margin = 10; + + + //External IO Configuration Parameters + + r_diff_term = 100; + /* + rtt1_dq_read = g_ip->rtt_value; + rtt2_dq_read = g_ip->rtt_value; + rtt1_dq_write = g_ip->rtt_value; + rtt2_dq_write = g_ip->rtt_value; + */ + + switch(connection) + { + case(0): + rtt1_dq_write = rtt1_wr_bob_dimm_ddr4[num_loads-1][frequnecy_index(io_type)]; + rtt2_dq_write = rtt2_wr_bob_dimm_ddr4[num_loads-1][frequnecy_index(io_type)]; + rtt1_dq_read = rtt1_rd_bob_dimm_ddr4[num_loads-1][frequnecy_index(io_type)]; + rtt2_dq_read = rtt2_rd_bob_dimm_ddr4[num_loads-1][frequnecy_index(io_type)]; + break; + case(1): + rtt1_dq_write = rtt1_wr_host_dimm_ddr4[num_loads-1][frequnecy_index(io_type)]; + rtt2_dq_write = rtt2_wr_host_dimm_ddr4[num_loads-1][frequnecy_index(io_type)]; + rtt1_dq_read = rtt1_rd_host_dimm_ddr4[num_loads-1][frequnecy_index(io_type)]; + rtt2_dq_read = rtt2_rd_host_dimm_ddr4[num_loads-1][frequnecy_index(io_type)]; + break; + case(2): + rtt1_dq_write = rtt1_wr_lrdimm_ddr4[num_loads-1][frequnecy_index(io_type)]; + rtt2_dq_write = rtt2_wr_lrdimm_ddr4[num_loads-1][frequnecy_index(io_type)]; + rtt1_dq_read = rtt1_rd_lrdimm_ddr4[num_loads-1][frequnecy_index(io_type)]; + rtt2_dq_read = rtt2_rd_lrdimm_ddr4[num_loads-1][frequnecy_index(io_type)]; + break; + default: + break; + } + + rtt_ca = 50; + rs1_dq = 15; + rs2_dq = 15; + r_stub_ca = 0; + r_on = g_ip->ron_value; + r_on_ca = 50; + z0 = 50; + t_flight = g_ip->tflight_value; + t_flight_ca = 2; + + // Voltage noise coeffecients + + k_noise_write = 0.2; + k_noise_read = 0.2; + k_noise_addr = 0.2; + v_noise_independent_write = 0.1; + v_noise_independent_read = 0.1; + v_noise_independent_addr = 0.1; + + //SENSITIVITY INPUTS FOR TIMING AND VOLTAGE NOISE + + /* This is a user-defined section that depends on the channel sensitivity + * to IO and DRAM parameters. The t_jitter_* and k_noise_* are the + * parameters that are impacted based on the channel analysis. The user + * can define any relationship between the termination, loading and + * configuration parameters AND the t_jitter/k_noise parameters. E.g. a + * linear relationship, a non-linear analytical relationship or a lookup + * table. The sensitivity coefficients are based on channel analysis + * performed on the channel of interest.Given below is an example of such + * a sensitivity relationship. + * Such a linear fit can be found efficiently using an orthogonal design + * of experiments method shown in the technical report (), in Chapter 2.2. */ + + k_noise_write_sen = k_noise_write * (1 + 0.1*(rtt1_dq_write/60 - 1) + + 0.2*(rtt2_dq_write/60 - 1) + 0.2*(r_on/34 - 1) + + 0.2*(num_mem_dq/2 - 1)); + + k_noise_read_sen = k_noise_read * (1 + 0.1*(rtt1_dq_read/60 - 1) + + 0.2*(rtt2_dq_read/60 - 1) + 0.2*(r_on/34 - 1) + + 0.2*(num_mem_dq/2 - 1)); + + k_noise_addr_sen = k_noise_addr * (1 + 0.1*(rtt_ca/50 - 1) + + 0.2*(r_on/34 - 1) + 0.2*(num_mem_ca/16 - 1)); + + + t_jitter_setup_sen = t_jitter_setup * (1 + 0.2*(rtt1_dq_write/60 - 1) + + 0.3*(rtt2_dq_write/60 - 1) + 0.1*(r_on/34 - 1) + + 0.3*(num_mem_dq/2 - 1)); + + t_jitter_hold_sen = t_jitter_hold * (1 + 0.2*(rtt1_dq_write/60 - 1) + + 0.3*(rtt2_dq_write/60 - 1) + + 0.1*(r_on/34 - 1) + 0.3*(num_mem_dq/2 - 1)); + + t_jitter_addr_setup_sen = t_jitter_addr_setup * (1 + 0.2*(rtt_ca/50 - 1) + + 0.1*(r_on/34 - 1) + 0.4*(num_mem_ca/16 - 1)); + + t_jitter_addr_hold_sen = t_jitter_addr_hold * (1 + 0.2*(rtt_ca/50 - 1) + + 0.1*(r_on/34 - 1) + 0.4*(num_mem_ca/16 - 1)); + + // PHY Static Power Coefficients (mW) + phy_datapath_s = 0; + phy_phase_rotator_s = 10; + phy_clock_tree_s = 0; + phy_rx_s = 10; + phy_dcc_s = 0; + phy_deskew_s = 0; + phy_leveling_s = 0; + phy_pll_s = 10; + + // PHY Dynamic Power Coefficients (mW/Gbps) + phy_datapath_d = 0.5; + phy_phase_rotator_d = 0.01; + phy_clock_tree_d = 0.5; + phy_rx_d = 0.5; + phy_dcc_d = 0.05; + phy_deskew_d = 0.1; + phy_leveling_d = 0.05; + phy_pll_d = 0.05; + + //PHY Wakeup Times (Sleep to Active) (microseconds) + + phy_pll_wtime = 10; + phy_phase_rotator_wtime = 5; + phy_rx_wtime = 2; + phy_bandgap_wtime = 10; + phy_deskew_wtime = 0.003; + phy_vrefgen_wtime = 0.5; + + + } + else if (io_type == Serial) + { //Default parameters for Serial + // IO Supply voltage (V) + vdd_io = 1.2; + v_sw_clk = 0.75; + + // IO Area coefficients + ioarea_c = 0.01; + ioarea_k0 = 0.15; + ioarea_k1 = 0.00005; + ioarea_k2 = 0.000000025; + ioarea_k3 = 0.000000000005; + + // Timing parameters (ps) + t_ds = 15; + t_dh = 15; + t_dcd_soc = 10; + t_dcd_dram = 10; + t_soc_setup = 10; + t_soc_hold = 10; + t_jitter_setup = 20; + t_jitter_hold = 20; + + //External IO Configuration Parameters + + r_diff_term = 100; + + + t_jitter_setup_sen = t_jitter_setup; + + t_jitter_hold_sen = t_jitter_hold; + + t_jitter_addr_setup_sen = t_jitter_addr_setup; + + t_jitter_addr_hold_sen = t_jitter_addr_hold; + + // PHY Static Power Coefficients (mW) + phy_datapath_s = 0; + phy_phase_rotator_s = 10; + phy_clock_tree_s = 0; + phy_rx_s = 10; + phy_dcc_s = 0; + phy_deskew_s = 0; + phy_leveling_s = 0; + phy_pll_s = 10; + + // PHY Dynamic Power Coefficients (mW/Gbps) + phy_datapath_d = 0.5; + phy_phase_rotator_d = 0.01; + phy_clock_tree_d = 0.5; + phy_rx_d = 0.5; + phy_dcc_d = 0.05; + phy_deskew_d = 0.1; + phy_leveling_d = 0.05; + phy_pll_d = 0.05; + + //PHY Wakeup Times (Sleep to Active) (microseconds) + + phy_pll_wtime = 10; + phy_phase_rotator_wtime = 5; + phy_rx_wtime = 2; + phy_bandgap_wtime = 10; + phy_deskew_wtime = 0.003; + phy_vrefgen_wtime = 0.5; + + + } + else + { + cout << "Not Yet supported" << endl; + exit(1); + } + + + //SWING AND TERMINATION CALCULATIONS + + //R|| calculation + rpar_write =(rtt1_dq_write + rs1_dq)*(rtt2_dq_write + rs2_dq)/ + (rtt1_dq_write + rs1_dq + rtt2_dq_write + rs2_dq); + rpar_read =(rtt1_dq_read)*(rtt2_dq_read + rs2_dq)/ + (rtt1_dq_read + rtt2_dq_read + rs2_dq); + + + + //Swing calculation + v_sw_data_read_load1 =vdd_io * (rtt1_dq_read)*(rtt2_dq_read + rs2_dq) / + ((rtt1_dq_read + rtt2_dq_read + rs2_dq)*(r_on + rs1_dq + rpar_read)); + v_sw_data_read_load2 =vdd_io * (rtt1_dq_read)*(rtt2_dq_read) / + ((rtt1_dq_read + rtt2_dq_read + rs2_dq)*(r_on + rs1_dq + rpar_read)); + v_sw_data_read_line =vdd_io * rpar_read / (r_on + rs1_dq + rpar_read); + v_sw_addr =vdd_io * rtt_ca / (50 + rtt_ca); + v_sw_data_write_load1 =vdd_io * (rtt1_dq_write)*(rtt2_dq_write + rs2_dq) / + ((rtt1_dq_write + rs1_dq + rtt2_dq_write + rs2_dq)*(r_on + rpar_write)); + v_sw_data_write_load2 =vdd_io * (rtt2_dq_write)*(rtt1_dq_write + rs1_dq) / + ((rtt1_dq_write + rs1_dq + rtt2_dq_write + rs2_dq)*(r_on + rpar_write)); + v_sw_data_write_line =vdd_io * rpar_write / (r_on + rpar_write); + +} + + + +IOTechParam::~IOTechParam() +{} diff --git a/Project_FARSI/cacti_for_FARSI/extio_technology.h b/Project_FARSI/cacti_for_FARSI/extio_technology.h new file mode 100644 index 00000000..2f3d3087 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/extio_technology.h @@ -0,0 +1,225 @@ +#ifndef __EXTIO_TECH__ +#define __EXTIO_TECH__ + +#include +#include "parameter.h" +#include "const.h" + +#define NUM_DIMM 1 + + +extern const double rtt1_wr_lrdimm_ddr3[8][4]; +extern const double rtt2_wr_lrdimm_ddr3[8][4]; +extern const double rtt1_rd_lrdimm_ddr3[8][4]; +extern const double rtt2_rd_lrdimm_ddr3[8][4]; + +extern const double rtt1_wr_host_dimm_ddr3[3][4]; +extern const double rtt2_wr_host_dimm_ddr3[3][4]; +extern const double rtt1_rd_host_dimm_ddr3[3][4]; +extern const double rtt2_rd_host_dimm_ddr3[3][4]; + +extern const double rtt1_wr_bob_dimm_ddr3[3][4]; +extern const double rtt2_wr_bob_dimm_ddr3[3][4]; +extern const double rtt1_rd_bob_dimm_ddr3[3][4]; +extern const double rtt2_rd_bob_dimm_ddr3[3][4]; + + +extern const double rtt1_wr_lrdimm_ddr4[8][4]; +extern const double rtt2_wr_lrdimm_ddr4[8][4]; +extern const double rtt1_rd_lrdimm_ddr4[8][4]; +extern const double rtt2_rd_lrdimm_ddr4[8][4]; + +extern const double rtt1_wr_host_dimm_ddr4[3][4]; +extern const double rtt2_wr_host_dimm_ddr4[3][4]; +extern const double rtt1_rd_host_dimm_ddr4[3][4]; +extern const double rtt2_rd_host_dimm_ddr4[3][4]; + +extern const double rtt1_wr_bob_dimm_ddr4[3][4]; +extern const double rtt2_wr_bob_dimm_ddr4[3][4]; +extern const double rtt1_rd_bob_dimm_ddr4[3][4]; +extern const double rtt2_rd_bob_dimm_ddr4[3][4]; + +class IOTechParam +{ + public: + IOTechParam(InputParameter *); + // connection : 0(bob-dimm), 1(host-dimm), 2(on-dimm) + IOTechParam(InputParameter *, Mem_IO_type io_type, int num_mem_dq, int mem_data_width, int num_dq, int connection, int num_loads, double freq) ; + ~IOTechParam(); + double num_mem_ca; /* Number of loads on the address bus + based on total number of memories in the channel.For + registered or buffered configurations, the num_mem_dq and num_mem_ca is per buffer. */ + + double num_mem_clk; /* Number of loads on the clock as total + memories in the channel / number of clock lines available */ + + //Technology Parameters + // IO Supply voltage (V) + double vdd_io; /* Voltage swing on CLK/CLKB (V) (swing on the CLK pin if it + is differentially terminated) */ + double v_sw_clk; + + // Loading parameters + + double c_int; /*Internal IO loading (pF) (loading within the IO, due to + predriver nets) */ + double c_tx; /* IO TX self-load including package (pF) (loading at the + CPU TX pin) */ + double c_data; /* Device loading per memory data pin (pF) (DRAM device + load for DQ per die) */ + double c_addr; /* Device loading per memory address pin (pF) (DRAM + device load for CA per die) */ + double i_bias; /* Bias current (mA) (includes bias current for the whole memory + bus due to RX Vref based receivers */ + double i_leak; // Active leakage current per pin (nA) + + + + // IO Area coefficients + + double ioarea_c; /* sq.mm. (IO Area baseline coeeficient for control + circuitry and overhead) */ + double ioarea_k0; /* sq.mm * ohms (IO Area coefficient for the driver, for + unit drive strength or output impedance) */ + double ioarea_k1; /* sq.mm * ohms / MHz (IO Area coefficient for the + predriver final stage, based on fanout needed) */ + double ioarea_k2; /* sq.mm * ohms / MHz^2 (IO Area coefficient for + predriver middle stage, based on fanout needed) */ + double ioarea_k3; /* sq.mm * ohms / MHz^3 (IO Area coefficient for + predriver first stage, based on fanout needed) */ + + + // Timing parameters (ps) + + double t_ds; //DQ setup time at DRAM + double t_is; //CA setup time at DRAM + double t_dh; //DQ hold time at DRAM + double t_ih; //CA hold time at DRAM + double t_dcd_soc; //Duty-cycle distortion at the CPU/SOC + double t_dcd_dram; //Duty-cycle distortion at the DRAM + double t_error_soc; //Timing error due to edge placement uncertainty of the DLL + double t_skew_setup;//Setup skew between DQ/DQS or CA/CLK after deskewing the lines + double t_skew_hold; //Hold skew between DQ/DQS or CA/CLK after deskewing the lines + double t_dqsq; //DQ-DQS skew at the DRAM output during Read + //double t_qhs; //DQ-DQS hold factor at the DRAM output during Read FIXME: I am commenting it as the variable is never used. + double t_soc_setup; //Setup time at SOC input dueing Read + double t_soc_hold; //Hold time at SOC input during Read + double t_jitter_setup; /* Half-cycle jitter on the DQS at DRAM input + affecting setup time */ + double t_jitter_hold; /* Half-cycle jitter on the DQS at the DRAM input + affecting hold time */ + double t_jitter_addr_setup; /* Half-cycle jitter on the CLK at DRAM input + affecting setup time */ + double t_jitter_addr_hold; /* Half-cycle jitter on the CLK at the DRAM + input affecting hold time */ + double t_cor_margin; // Statistical correlation margin + + + //Termination Parameters + + double r_diff_term; /* Differential termination resister if + used for CLK (Ohm) */ + + + // ODT related termination resistor values (Ohm) + + double rtt1_dq_read; //DQ Read termination at CPU + double rtt2_dq_read; //DQ Read termination at inactive DRAM + double rtt1_dq_write; //DQ Write termination at active DRAM + double rtt2_dq_write; //DQ Write termination at inactive DRAM + double rtt_ca; //CA fly-by termination + double rs1_dq; //Series resistor at active DRAM + double rs2_dq; //Series resistor at inactive DRAM + double r_stub_ca; //Series resistor for the fly-by channel + double r_on; //Driver impedance + double r_on_ca; //CA driver impedance + + double z0; //Line impedance (ohms): Characteristic impedance of the route. + double t_flight; /* Flight time of the interconnect (ns) (approximately + 180ps/inch for FR4) */ + double t_flight_ca; /* Flight time of the Control/Address (CA) + interconnect (ns) (approximately 180ps/inch for FR4) */ + + // Voltage noise coeffecients + + double k_noise_write; //Proportional noise coefficient for Write mode + double k_noise_read; //Proportional noise coefficient for Read mode + double k_noise_addr; //Proportional noise coefficient for Address bus + double v_noise_independent_write; //Independent noise voltage for Write mode + double v_noise_independent_read; //Independent noise voltage for Read mode + double v_noise_independent_addr; //Independent noise voltage for Address bus + + + //SENSITIVITY INPUTS FOR TIMING AND VOLTAGE NOISE + + /* This is a user-defined section that depends on the channel sensitivity + * to IO and DRAM parameters. The t_jitter_* and k_noise_* are the + * parameters that are impacted based on the channel analysis. The user + * can define any relationship between the termination, loading and + * configuration parameters AND the t_jitter/k_noise parameters. */ + + double k_noise_write_sen; + double k_noise_read_sen; + double k_noise_addr_sen; + double t_jitter_setup_sen; + double t_jitter_hold_sen; + double t_jitter_addr_setup_sen; + double t_jitter_addr_hold_sen; + + //SWING AND TERMINATION CALCULATIONS + //R|| calculation + + double rpar_write; + double rpar_read; + + //Swing calculation + + double v_sw_data_read_load1; //Swing for DQ at dram1 during READ + double v_sw_data_read_load2; //Swing for DQ at dram2 during READ + double v_sw_data_read_line; //Swing for DQ on the line during READ + double v_sw_addr; //Swing for the address bus + double v_sw_data_write_load1; //Swing for DQ at dram1 during WRITE + double v_sw_data_write_load2; //Swing for DQ at dram2 during WRITE + double v_sw_data_write_line; //Swing for DQ on the line during WRITE + + // PHY Static Power Coefficients (mW) + + double phy_datapath_s; // Datapath Static Power + double phy_phase_rotator_s; // Phase Rotator Static Power + double phy_clock_tree_s; // Clock Tree Static Power + double phy_rx_s; // Receiver Static Power + double phy_dcc_s; // Duty Cycle Correction Static Power + double phy_deskew_s; // Deskewing Static Power + double phy_leveling_s; // Write and Read Leveling Static Power + double phy_pll_s; // PHY PLL Static Power + + + // PHY Dynamic Power Coefficients (mW/Gbps) + + double phy_datapath_d; // Datapath Dynamic Power + double phy_phase_rotator_d; // Phase Rotator Dynamic Power + double phy_clock_tree_d; // Clock Tree Dynamic Power + double phy_rx_d; // Receiver Dynamic Power + double phy_dcc_d; // Duty Cycle Correction Dynamic Power + double phy_deskew_d; // Deskewing Dynamic Power + double phy_leveling_d; // Write and Read Leveling Dynamic Power + double phy_pll_d; // PHY PLL Dynamic Power + + + //PHY Wakeup Times (Sleep to Active) (microseconds) + + double phy_pll_wtime; // PHY PLL Wakeup Time + double phy_phase_rotator_wtime; // Phase Rotator Wakeup Time + double phy_rx_wtime; // Receiver Wakeup Time + double phy_bandgap_wtime; // Bandgap Wakeup Time + double phy_deskew_wtime; // Deskewing Wakeup Time + double phy_vrefgen_wtime; // VREF Generator Wakeup Time + + + // RTT values depends on the number of loads, frequency, and link_type + double frequency; + Mem_IO_type io_type; + int frequnecy_index(Mem_IO_type type); +}; + +#endif diff --git a/Project_FARSI/cacti_for_FARSI/farsi_gen.cfg b/Project_FARSI/cacti_for_FARSI/farsi_gen.cfg new file mode 100644 index 00000000..50bd9ad9 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/farsi_gen.cfg @@ -0,0 +1,254 @@ +# Cache size +//-size (bytes) 2048 +//-size (bytes) 4096 +//-size (bytes) 32768 +//-size (bytes) 131072 +//-size (bytes) 262144 +//-size (bytes) 1048576 +//-size (bytes) 2097152 +//-size (bytes) 4194304 +//-size (bytes) 8388608 +//-size (bytes) 16777216 +//-size (bytes) 33554432 +//-size (bytes) 134217728 +//-size (bytes) 67108864 +//-size (bytes) 1073741824 + +# power gating +-Array Power Gating - "false" +-WL Power Gating - "false" +-CL Power Gating - "false" +-Bitline floating - "false" +-Interconnect Power Gating - "false" +-Power Gating Performance Loss 0.01 + +# Line size +-block size (bytes) 32 +//-block size (bytes) 64 + +# To model Fully Associative cache, set associativity to zero +//-associativity 0 +-associativity 2 +//-associativity 4 +//-associativity 8 +//-associativity 8 + +-read-write port 1 +-exclusive read port 0 +-exclusive write port 0 +-single ended read ports 0 + +# Multiple banks connected using a bus +-UCA bank count 1 +//-technology (u) 0.022 +-technology (u) 0.040 +//-technology (u) 0.032 +//-technology (u) 0.090 + +# following three parameters are meaningful only for main memories + +-page size (bits) 1024 +-burst length 4 +-internal prefetch width 8 + +# following parameter can have one of five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +//-Data array cell type - "itrs-hp" +//-Data array cell type - "itrs-lstp" +//-Data array cell type - "itrs-lop" + +# following parameter can have one of three values -- (itrs-hp, itrs-lstp, itrs-lop) +-Data array peripheral type - "itrs-hp" +//-Data array peripheral type - "itrs-lstp" +//-Data array peripheral type - "itrs-lop" + +# following parameter can have one of five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +-Tag array cell type - "itrs-hp" +//-Tag array cell type - "itrs-lstp" +//-Tag array cell type - "itrs-lop" + +# following parameter can have one of three values -- (itrs-hp, itrs-lstp, itrs-lop) +-Tag array peripheral type - "itrs-hp" +//-Tag array peripheral type - "itrs-lstp" +//-Tag array peripheral type - "itrs-lop + +# Bus width include data bits and address bits required by the decoder +//-output/input bus width 16 +-output/input bus width 128 + +// 300-400 in steps of 10 +-operating temperature (K) 360 + +# Type of memory - cache (with a tag array) or ram (scratch ram similar to a register file) +# or main memory (no tag array and every access will happen at a page granularity Ref: CACTI 5.3 report) +//-cache type "cache" +//-cache type "ram" +//-cache type "main memory" + +# to model special structure like branch target buffers, directory, etc. +# change the tag size parameter +# if you want cacti to calculate the tagbits, set the tag size to "default" +-tag size (b) "default" +//-tag size (b) 22 + +# fast - data and tag access happen in parallel +# sequential - data array is accessed after accessing the tag array +# normal - data array lookup and tag access happen in parallel +# final data block is broadcasted in data array h-tree +# after getting the signal from the tag array +//-access mode (normal, sequential, fast) - "fast" +-access mode (normal, sequential, fast) - "normal" +//-access mode (normal, sequential, fast) - "sequential" + + +# DESIGN OBJECTIVE for UCA (or banks in NUCA) +-design objective (weight delay, dynamic power, leakage power, cycle time, area) 0:0:0:100:0 + +# Percentage deviation from the minimum value +# Ex: A deviation value of 10:1000:1000:1000:1000 will try to find an organization +# that compromises at most 10% delay. +# NOTE: Try reasonable values for % deviation. Inconsistent deviation +# percentage values will not produce any valid organizations. For example, +# 0:0:100:100:100 will try to identify an organization that has both +# least delay and dynamic power. Since such an organization is not possible, CACTI will +# throw an error. Refer CACTI-6 Technical report for more details +-deviate (delay, dynamic power, leakage power, cycle time, area) 20:100000:100000:100000:100000 + +# Objective for NUCA +-NUCAdesign objective (weight delay, dynamic power, leakage power, cycle time, area) 100:100:0:0:100 +-NUCAdeviate (delay, dynamic power, leakage power, cycle time, area) 10:10000:10000:10000:10000 + +# Set optimize tag to ED or ED^2 to obtain a cache configuration optimized for +# energy-delay or energy-delay sq. product +# Note: Optimize tag will disable weight or deviate values mentioned above +# Set it to NONE to let weight and deviate values determine the +# appropriate cache configuration +//-Optimize ED or ED^2 (ED, ED^2, NONE): "ED" +-Optimize ED or ED^2 (ED, ED^2, NONE): "ED^2" +//-Optimize ED or ED^2 (ED, ED^2, NONE): "NONE" + +-Cache model (NUCA, UCA) - "UCA" +//-Cache model (NUCA, UCA) - "NUCA" + +# In order for CACTI to find the optimal NUCA bank value the following +# variable should be assigned 0. +-NUCA bank count 0 + +# NOTE: for nuca network frequency is set to a default value of +# 5GHz in time.c. CACTI automatically +# calculates the maximum possible frequency and downgrades this value if necessary + +# By default CACTI considers both full-swing and low-swing +# wires to find an optimal configuration. However, it is possible to +# restrict the search space by changing the signaling from "default" to +# "fullswing" or "lowswing" type. +-Wire signaling (fullswing, lowswing, default) - "Global_30" +//-Wire signaling (fullswing, lowswing, default) - "default" +//-Wire signaling (fullswing, lowswing, default) - "lowswing" + +//-Wire inside mat - "global" +-Wire inside mat - "semi-global" +//-Wire outside mat - "global" +-Wire outside mat - "semi-global" + +-Interconnect projection - "conservative" +//-Interconnect projection - "aggressive" + +# Contention in network (which is a function of core count and cache level) is one of +# the critical factor used for deciding the optimal bank count value +# core count can be 4, 8, or 16 +//-Core count 4 +-Core count 8 +//-Core count 16 +-Cache level (L2/L3) - "L3" + +-Add ECC - "true" + +//-Print level (DETAILED, CONCISE) - "CONCISE" +-Print level (DETAILED, CONCISE) - "DETAILED" + +# for debugging +//-Print input parameters - "true" +-Print input parameters - "false" +# force CACTI to model the cache with the +# following Ndbl, Ndwl, Nspd, Ndsam, +# and Ndcm values +//-Force cache config - "true" +-Force cache config - "false" +-Ndwl 1 +-Ndbl 1 +-Nspd 0 +-Ndcm 1 +-Ndsam1 0 +-Ndsam2 0 + + + +#### Default CONFIGURATION values for baseline external IO parameters to DRAM. More details can be found in the CACTI-IO technical report (), especially Chapters 2 and 3. + +# Memory Type (D=DDR3, L=LPDDR2, W=WideIO). Additional memory types can be defined by the user in extio_technology.cc, along with their technology and configuration parameters. + +-dram_type "D" +//-dram_type "L" +//-dram_type "W" +//-dram_type "S" + +# Memory State (R=Read, W=Write, I=Idle or S=Sleep) + +//-iostate "R" +-iostate "W" +//-iostate "I" +//-iostate "S" + +#Address bus timing. To alleviate the timing on the command and address bus due to high loading (shared across all memories on the channel), the interface allows for multi-cycle timing options. + +-addr_timing 0.5 //DDR +//-addr_timing 1.0 //SDR (half of DQ rate) +//-addr_timing 2.0 //2T timing (One fourth of DQ rate) +//-addr_timing 3.0 // 3T timing (One sixth of DQ rate) + +# Memory Density (Gbit per memory/DRAM die) + +-mem_density 8 Gb //Valid values 2^n Gb + +# IO frequency (MHz) (frequency of the external memory interface). + +-bus_freq 100 MHz //As of current memory standards (2013), valid range 0 to 1.5 GHz for DDR3, 0 to 533 MHz for LPDDR2, 0 - 800 MHz for WideIO and 0 - 3 GHz for Low-swing differential. However this can change, and the user is free to define valid ranges based on new memory types or extending beyond existing standards for existing dram types. + +# Duty Cycle (fraction of time in the Memory State defined above) + +-duty_cycle 1.0 //Valid range 0 to 1.0 + +# Activity factor for Data (0->1 transitions) per cycle (for DDR, need to account for the higher activity in this parameter. E.g. max. activity factor for DDR is 1.0, for SDR is 0.5) + +-activity_dq 1.0 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR +#-activity_dq .50 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR + +# Activity factor for Control/Address (0->1 transitions) per cycle (for DDR, need to account for the higher activity in this parameter. E.g. max. activity factor for DDR is 1.0, for SDR is 0.5) + +-activity_ca 1.0 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR, 0 to 0.25 for 2T, and 0 to 0.17 for 3T +#-activity_ca 0.25 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR, 0 to 0.25 for 2T, and 0 to 0.17 for 3T + +# Number of DQ pins + +-num_dq 72 //Number of DQ pins. Includes ECC pins. + +# Number of DQS pins. DQS is a data strobe that is sent along with a small number of data-lanes so the source synchronous timing is local to these DQ bits. Typically, 1 DQS per byte (8 DQ bits) is used. The DQS is also typucally differential, just like the CLK pin. + +-num_dqs 36 //2 x differential pairs. Include ECC pins as well. Valid range 0 to 18. For x4 memories, could have 36 DQS pins. + +# Number of CA pins + +-num_ca 35 //Valid range 0 to 35 pins. +#-num_ca 25 //Valid range 0 to 35 pins. + +# Number of CLK pins. CLK is typically a differential pair. In some cases additional CLK pairs may be used to limit the loading on the CLK pin. + +-num_clk 2 //2 x differential pair. Valid values: 0/2/4. + +# Number of Physical Ranks + +-num_mem_dq 2 //Number of ranks (loads on DQ and DQS) per buffer/register. If multiple LRDIMMs or buffer chips exist, the analysis for capacity and power is reported per buffer/register. + +# Width of the Memory Data Bus + +-mem_data_width 32 //x4 or x8 or x16 or x32 memories. For WideIO upto x128. diff --git a/Project_FARSI/cacti_for_FARSI/htree2.cc b/Project_FARSI/cacti_for_FARSI/htree2.cc new file mode 100644 index 00000000..3077820b --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/htree2.cc @@ -0,0 +1,640 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + + +#include "htree2.h" +#include "wire.h" +#include +#include + +Htree2::Htree2( + enum Wire_type wire_model, double mat_w, double mat_h, + int a_bits, int d_inbits, int search_data_in, int d_outbits, int search_data_out, int bl, int wl, enum Htree_type htree_type, + bool uca_tree_, bool search_tree_, /*TechnologyParameter::*/DeviceType *dt) + :in_rise_time(0), out_rise_time(0), + tree_type(htree_type), mat_width(mat_w), mat_height(mat_h), + add_bits(a_bits), data_in_bits(d_inbits), search_data_in_bits(search_data_in),data_out_bits(d_outbits), + search_data_out_bits(search_data_out), ndbl(bl), ndwl(wl), + uca_tree(uca_tree_), search_tree(search_tree_), wt(wire_model), deviceType(dt) +{ + assert(ndbl >= 2 && ndwl >= 2); + +// if (ndbl == 1 && ndwl == 1) +// { +// delay = 0; +// power.readOp.dynamic = 0; +// power.readOp.leakage = 0; +// area.w = mat_w; +// area.h = mat_h; +// return; +// } +// if (ndwl == 1) ndwl++; +// if (ndbl == 1) ndbl++; + + max_unpipelined_link_delay = 0; //TODO + min_w_nmos = g_tp.min_w_nmos_; + min_w_pmos = deviceType->n_to_p_eff_curr_drv_ratio * min_w_nmos; + + switch (htree_type) + { + case Add_htree: + wire_bw = init_wire_bw = add_bits; + in_htree(); + break; + case Data_in_htree: + wire_bw = init_wire_bw = data_in_bits; + in_htree(); + break; + case Data_out_htree: + wire_bw = init_wire_bw = data_out_bits; + out_htree(); + break; + case Search_in_htree: + wire_bw = init_wire_bw = search_data_in_bits;//in_search_tree is broad cast, out_htree is not. + in_htree(); + break; + case Search_out_htree: + wire_bw = init_wire_bw = search_data_out_bits; + out_htree(); + break; + default: + assert(0); + break; + } + + power_bit = power; + power.readOp.dynamic *= init_wire_bw; + + assert(power.readOp.dynamic >= 0); + assert(power.readOp.leakage >= 0); +} + + + +// nand gate sizing calculation +void Htree2::input_nand(double s1, double s2, double l_eff) +{ + Wire w1(wt, l_eff); + double pton_size = deviceType->n_to_p_eff_curr_drv_ratio; + // input capacitance of a repeater = input capacitance of nand. + double nsize = s1*(1 + pton_size)/(2 + pton_size); + nsize = (nsize < 1) ? 1 : nsize; + + double tc = 2*tr_R_on(nsize*min_w_nmos, NCH, 1) * + (drain_C_(nsize*min_w_nmos, NCH, 1, 1, g_tp.cell_h_def)*2 + + 2 * gate_C(s2*(min_w_nmos + min_w_pmos), 0)); + delay+= horowitz (w1.out_rise_time, tc, + deviceType->Vth/deviceType->Vdd, deviceType->Vth/deviceType->Vdd, RISE); + power.readOp.dynamic += 0.5 * + (2*drain_C_(pton_size * nsize*min_w_pmos, PCH, 1, 1, g_tp.cell_h_def) + + drain_C_(nsize*min_w_nmos, NCH, 1, 1, g_tp.cell_h_def) + + 2*gate_C(s2*(min_w_nmos + min_w_pmos), 0)) * + deviceType->Vdd * deviceType->Vdd; + + power.searchOp.dynamic += 0.5 * + (2*drain_C_(pton_size * nsize*min_w_pmos, PCH, 1, 1, g_tp.cell_h_def) + + drain_C_(nsize*min_w_nmos, NCH, 1, 1, g_tp.cell_h_def) + + 2*gate_C(s2*(min_w_nmos + min_w_pmos), 0)) * + deviceType->Vdd * deviceType->Vdd * wire_bw ; + power.readOp.leakage += (wire_bw*cmos_Isub_leakage(min_w_nmos*(nsize*2), min_w_pmos * nsize * 2, 2, nand))*deviceType->Vdd; + power.readOp.gate_leakage += (wire_bw*cmos_Ig_leakage(min_w_nmos*(nsize*2), min_w_pmos * nsize * 2, 2, nand))*deviceType->Vdd; +} + + + +// tristate buffer model consisting of not, nand, nor, and driver transistors +void Htree2::output_buffer(double s1, double s2, double l_eff) +{ + Wire w1(wt, l_eff); + double pton_size = deviceType->n_to_p_eff_curr_drv_ratio; + // input capacitance of repeater = input capacitance of nand + nor. + double size = s1*(1 + pton_size)/(2 + pton_size + 1 + 2*pton_size); + double s_eff = //stage eff of a repeater in a wire + (gate_C(s2*(min_w_nmos + min_w_pmos), 0) + w1.wire_cap(l_eff*1e-6,true))/ + gate_C(s2*(min_w_nmos + min_w_pmos), 0); + double tr_size = gate_C(s1*(min_w_nmos + min_w_pmos), 0) * 1/2/(s_eff*gate_C(min_w_pmos, 0)); + size = (size < 1) ? 1 : size; + + double res_nor = 2*tr_R_on(size*min_w_pmos, PCH, 1); + double res_ptrans = tr_R_on(tr_size*min_w_nmos, NCH, 1); + double cap_nand_out = drain_C_(size*min_w_nmos, NCH, 1, 1, g_tp.cell_h_def) + + drain_C_(size*min_w_pmos, PCH, 1, 1, g_tp.cell_h_def)*2 + + gate_C(tr_size*min_w_pmos, 0); + double cap_ptrans_out = 2 *(drain_C_(tr_size*min_w_pmos, PCH, 1, 1, g_tp.cell_h_def) + + drain_C_(tr_size*min_w_nmos, NCH, 1, 1, g_tp.cell_h_def)) + + gate_C(s1*(min_w_nmos + min_w_pmos), 0); + + double tc = res_nor * cap_nand_out + (res_nor + res_ptrans) * cap_ptrans_out; + + + delay += horowitz (w1.out_rise_time, tc, + deviceType->Vth/deviceType->Vdd, deviceType->Vth/deviceType->Vdd, RISE); + + //nand + power.readOp.dynamic += 0.5 * + (2*drain_C_(size*min_w_pmos, PCH, 1, 1, g_tp.cell_h_def) + + drain_C_(size*min_w_nmos, NCH, 1, 1, g_tp.cell_h_def) + + gate_C(tr_size*(min_w_pmos), 0)) * + deviceType->Vdd * deviceType->Vdd; + + power.searchOp.dynamic += 0.5 * + (2*drain_C_(size*min_w_pmos, PCH, 1, 1, g_tp.cell_h_def) + + drain_C_(size*min_w_nmos, NCH, 1, 1, g_tp.cell_h_def) + + gate_C(tr_size*(min_w_pmos), 0)) * + deviceType->Vdd * deviceType->Vdd*init_wire_bw; + + //not + power.readOp.dynamic += 0.5 * + (drain_C_(size*min_w_pmos, PCH, 1, 1, g_tp.cell_h_def) + +drain_C_(size*min_w_nmos, NCH, 1, 1, g_tp.cell_h_def) + +gate_C(size*(min_w_nmos + min_w_pmos), 0)) * + deviceType->Vdd * deviceType->Vdd; + + power.searchOp.dynamic += 0.5 * + (drain_C_(size*min_w_pmos, PCH, 1, 1, g_tp.cell_h_def) + +drain_C_(size*min_w_nmos, NCH, 1, 1, g_tp.cell_h_def) + +gate_C(size*(min_w_nmos + min_w_pmos), 0)) * + deviceType->Vdd * deviceType->Vdd*init_wire_bw; + + //nor + power.readOp.dynamic += 0.5 * + (drain_C_(size*min_w_pmos, PCH, 1, 1, g_tp.cell_h_def) + + 2*drain_C_(size*min_w_nmos, NCH, 1, 1, g_tp.cell_h_def) + +gate_C(tr_size*(min_w_nmos + min_w_pmos), 0)) * + deviceType->Vdd * deviceType->Vdd; + + power.searchOp.dynamic += 0.5 * + (drain_C_(size*min_w_pmos, PCH, 1, 1, g_tp.cell_h_def) + + 2*drain_C_(size*min_w_nmos, NCH, 1, 1, g_tp.cell_h_def) + +gate_C(tr_size*(min_w_nmos + min_w_pmos), 0)) * + deviceType->Vdd * deviceType->Vdd*init_wire_bw; + + //output transistor + power.readOp.dynamic += 0.5 * + ((drain_C_(tr_size*min_w_pmos, PCH, 1, 1, g_tp.cell_h_def) + +drain_C_(tr_size*min_w_nmos, NCH, 1, 1, g_tp.cell_h_def))*2 + + gate_C(s1*(min_w_nmos + min_w_pmos), 0)) * + deviceType->Vdd * deviceType->Vdd; + + power.searchOp.dynamic += 0.5 * + ((drain_C_(tr_size*min_w_pmos, PCH, 1, 1, g_tp.cell_h_def) + +drain_C_(tr_size*min_w_nmos, NCH, 1, 1, g_tp.cell_h_def))*2 + + gate_C(s1*(min_w_nmos + min_w_pmos), 0)) * + deviceType->Vdd * deviceType->Vdd*init_wire_bw; + + if(uca_tree) { + power.readOp.leakage += cmos_Isub_leakage(min_w_nmos*tr_size*2, min_w_pmos*tr_size*2, 1, inv)*deviceType->Vdd*wire_bw;/*inverter + output tr*/ + power.readOp.leakage += cmos_Isub_leakage(min_w_nmos*size*3, min_w_pmos*size*3, 2, nand)*deviceType->Vdd*wire_bw;//nand + power.readOp.leakage += cmos_Isub_leakage(min_w_nmos*size*3, min_w_pmos*size*3, 2, nor)*deviceType->Vdd*wire_bw;//nor + + power.readOp.gate_leakage += cmos_Ig_leakage(min_w_nmos*tr_size*2, min_w_pmos*tr_size*2, 1, inv)*deviceType->Vdd*wire_bw;/*inverter + output tr*/ + power.readOp.gate_leakage += cmos_Ig_leakage(min_w_nmos*size*3, min_w_pmos*size*3, 2, nand)*deviceType->Vdd*wire_bw;//nand + power.readOp.gate_leakage += cmos_Ig_leakage(min_w_nmos*size*3, min_w_pmos*size*3, 2, nor)*deviceType->Vdd*wire_bw;//nor + //power.readOp.gate_leakage *=; + } + else { + power.readOp.leakage += cmos_Isub_leakage(min_w_nmos*tr_size*2, min_w_pmos*tr_size*2, 1, inv)*deviceType->Vdd*wire_bw;/*inverter + output tr*/ + power.readOp.leakage += cmos_Isub_leakage(min_w_nmos*size*3, min_w_pmos*size*3, 2, nand)*deviceType->Vdd*wire_bw;//nand + power.readOp.leakage += cmos_Isub_leakage(min_w_nmos*size*3, min_w_pmos*size*3, 2, nor)*deviceType->Vdd*wire_bw;//nor + + power.readOp.gate_leakage += cmos_Ig_leakage(min_w_nmos*tr_size*2, min_w_pmos*tr_size*2, 1, inv)*deviceType->Vdd*wire_bw;/*inverter + output tr*/ + power.readOp.gate_leakage += cmos_Ig_leakage(min_w_nmos*size*3, min_w_pmos*size*3, 2, nand)*deviceType->Vdd*wire_bw;//nand + power.readOp.gate_leakage += cmos_Ig_leakage(min_w_nmos*size*3, min_w_pmos*size*3, 2, nor)*deviceType->Vdd*wire_bw;//nor + //power.readOp.gate_leakage *=deviceType->Vdd*wire_bw; + } +} + + + +/* calculates the input h-tree delay/power + * A nand gate is used at each node to + * limit the signal + * The area of an unbalanced htree (rows != columns) + * depends on how data is traversed. + * In the following function, if ( no. of rows < no. of columns), + * then data first traverse in excess hor. links until vertical + * and horizontal nodes are same. + * If no. of rows is bigger, then data traverse in + * a hor. link followed by a ver. link in a repeated + * fashion (similar to a balanced tree) until there are no + * hor. links left. After this it goes through the remaining vertical + * links. + */ + void +Htree2::in_htree() +{ + //temp var + double s1 = 0, s2 = 0, s3 = 0; + double l_eff = 0; + Wire *wtemp1 = 0, *wtemp2 = 0, *wtemp3 = 0; + double len = 0, ht = 0; + int option = 0; + + int h = (int) _log2(ndwl/2); // horizontal nodes + int v = (int) _log2(ndbl/2); // vertical nodes + double len_temp; + double ht_temp; + if (uca_tree) + {//: this computation do not consider the wires that route from edge to middle. + ht_temp = (mat_height*ndbl/2 +/* since uca_tree models interbank tree, mat_height => bank height */ + ((add_bits + data_in_bits + data_out_bits + (search_data_in_bits + search_data_out_bits)) * g_tp.wire_outside_mat.pitch * + 2 * (1-pow(0.5,h))))/2; + len_temp = (mat_width*ndwl/2 + + ((add_bits + data_in_bits + data_out_bits + (search_data_in_bits + search_data_out_bits)) * g_tp.wire_outside_mat.pitch * + 2 * (1-pow(0.5,v))))/2; + } + else + { + if (ndwl == ndbl) { + ht_temp = ((mat_height*ndbl/2) + + ((add_bits + (search_data_in_bits + search_data_out_bits))* (ndbl/2-1) * g_tp.wire_outside_mat.pitch) + + ((data_in_bits + data_out_bits) * g_tp.wire_outside_mat.pitch * h) + )/2; + len_temp = (mat_width*ndwl/2 + + ((add_bits + (search_data_in_bits + search_data_out_bits)) * (ndwl/2-1) * g_tp.wire_outside_mat.pitch) + + ((data_in_bits + data_out_bits) * g_tp.wire_outside_mat.pitch * v))/2; + } + else if (ndwl > ndbl) { + double excess_part = (_log2(ndwl/2) - _log2(ndbl/2)); + ht_temp = ((mat_height*ndbl/2) + + ((add_bits + + (search_data_in_bits + search_data_out_bits)) * ((ndbl/2-1) + excess_part) * g_tp.wire_outside_mat.pitch) + + (data_in_bits + data_out_bits) * g_tp.wire_outside_mat.pitch * + (2*(1 - pow(0.5, h-v)) + pow(0.5, v-h) * v))/2; + len_temp = (mat_width*ndwl/2 + + ((add_bits + (search_data_in_bits + search_data_out_bits))* (ndwl/2-1) * g_tp.wire_outside_mat.pitch) + + ((data_in_bits + data_out_bits) * g_tp.wire_outside_mat.pitch * v))/2; + } + else { + double excess_part = (_log2(ndbl/2) - _log2(ndwl/2)); + ht_temp = ((mat_height*ndbl/2) + + ((add_bits + (search_data_in_bits + search_data_out_bits))* ((ndwl/2-1) + excess_part) * g_tp.wire_outside_mat.pitch) + + ((data_in_bits + data_out_bits) * g_tp.wire_outside_mat.pitch * h) + )/2; + len_temp = (mat_width*ndwl/2 + + ((add_bits + (search_data_in_bits + search_data_out_bits)) * ((ndwl/2-1) + excess_part) * g_tp.wire_outside_mat.pitch) + + (data_in_bits + data_out_bits) * g_tp.wire_outside_mat.pitch * (h + 2*(1-pow(0.5, v-h))))/2; + } + } + + area.h = ht_temp * 2; + area.w = len_temp * 2; + delay = 0; + power.readOp.dynamic = 0; + power.readOp.leakage = 0; + power.searchOp.dynamic =0; + len = len_temp; + ht = ht_temp/2; + + while (v > 0 || h > 0) + { + if (wtemp1) delete wtemp1; + if (wtemp2) delete wtemp2; + if (wtemp3) delete wtemp3; + + if (h > v) + { + //the iteration considers only one horizontal link + wtemp1 = new Wire(wt, len); // hor + wtemp2 = new Wire(wt, len/2); // ver + len_temp = len; + len /= 2; + wtemp3 = 0; + h--; + option = 0; + } + else if (v>0 && h>0) + { + //considers one horizontal link and one vertical link + wtemp1 = new Wire(wt, len); // hor + wtemp2 = new Wire(wt, ht); // ver + wtemp3 = new Wire(wt, len/2); // next hor + len_temp = len; + ht_temp = ht; + len /= 2; + ht /= 2; + v--; + h--; + option = 1; + } + else + { + // considers only one vertical link + assert(h == 0); + wtemp1 = new Wire(wt, ht); // ver + wtemp2 = new Wire(wt, ht/2); // hor + ht_temp = ht; + ht /= 2; + wtemp3 = 0; + v--; + option = 2; + } + + delay += wtemp1->delay; + power.readOp.dynamic += wtemp1->power.readOp.dynamic; + power.searchOp.dynamic += wtemp1->power.readOp.dynamic*wire_bw; + power.readOp.leakage += wtemp1->power.readOp.leakage*wire_bw; + power.readOp.gate_leakage += wtemp1->power.readOp.gate_leakage*wire_bw; + if ((uca_tree == false && option == 2) || search_tree==true) + { + wire_bw*=2; // wire bandwidth doubles only for vertical branches + } + + if (uca_tree == false) + { + if (len_temp > wtemp1->repeater_spacing) + { + s1 = wtemp1->repeater_size; + l_eff = wtemp1->repeater_spacing; + } + else + { + s1 = (len_temp/wtemp1->repeater_spacing) * wtemp1->repeater_size; + l_eff = len_temp; + } + + if (ht_temp > wtemp2->repeater_spacing) + { + s2 = wtemp2->repeater_size; + } + else + { + s2 = (len_temp/wtemp2->repeater_spacing) * wtemp2->repeater_size; + } + // first level + input_nand(s1, s2, l_eff); + } + + + if (option != 1) + { + continue; + } + + // second level + delay += wtemp2->delay; + power.readOp.dynamic += wtemp2->power.readOp.dynamic; + power.searchOp.dynamic += wtemp2->power.readOp.dynamic*wire_bw; + power.readOp.leakage += wtemp2->power.readOp.leakage*wire_bw; + power.readOp.gate_leakage += wtemp2->power.readOp.gate_leakage*wire_bw; + + if (uca_tree) + { + power.readOp.leakage += (wtemp2->power.readOp.leakage*wire_bw); + power.readOp.gate_leakage += wtemp2->power.readOp.gate_leakage*wire_bw; + } + else + { + power.readOp.leakage += (wtemp2->power.readOp.leakage*wire_bw); + power.readOp.gate_leakage += wtemp2->power.readOp.gate_leakage*wire_bw; + wire_bw*=2; + + if (ht_temp > wtemp3->repeater_spacing) + { + s3 = wtemp3->repeater_size; + l_eff = wtemp3->repeater_spacing; + } + else + { + s3 = (len_temp/wtemp3->repeater_spacing) * wtemp3->repeater_size; + l_eff = ht_temp; + } + + input_nand(s2, s3, l_eff); + } + } + + if (wtemp1) delete wtemp1; + if (wtemp2) delete wtemp2; + if (wtemp3) delete wtemp3; +} + + + +/* a tristate buffer is used to handle fan-ins + * The area of an unbalanced htree (rows != columns) + * depends on how data is traversed. + * In the following function, if ( no. of rows < no. of columns), + * then data first traverse in excess hor. links until vertical + * and horizontal nodes are same. + * If no. of rows is bigger, then data traverse in + * a hor. link followed by a ver. link in a repeated + * fashion (similar to a balanced tree) until there are no + * hor. links left. After this it goes through the remaining vertical + * links. + */ +void Htree2::out_htree() +{ + //temp var + double s1 = 0, s2 = 0, s3 = 0; + double l_eff = 0; + Wire *wtemp1 = 0, *wtemp2 = 0, *wtemp3 = 0; + double len = 0, ht = 0; + int option = 0; + + int h = (int) _log2(ndwl/2); + int v = (int) _log2(ndbl/2); + double len_temp; + double ht_temp; + if (uca_tree) + { + ht_temp = (mat_height*ndbl/2 +/* since uca_tree models interbank tree, mat_height => bank height */ + ((add_bits + data_in_bits + data_out_bits + (search_data_in_bits + search_data_out_bits)) * g_tp.wire_outside_mat.pitch * + 2 * (1-pow(0.5,h))))/2; + len_temp = (mat_width*ndwl/2 + + ((add_bits + data_in_bits + data_out_bits + (search_data_in_bits + search_data_out_bits)) * g_tp.wire_outside_mat.pitch * + 2 * (1-pow(0.5,v))))/2; + } + else + { + if (ndwl == ndbl) { + ht_temp = ((mat_height*ndbl/2) + + ((add_bits+ (search_data_in_bits + search_data_out_bits)) * (ndbl/2-1) * g_tp.wire_outside_mat.pitch) + + ((data_in_bits + data_out_bits) * g_tp.wire_outside_mat.pitch * h) + )/2; + len_temp = (mat_width*ndwl/2 + + ((add_bits + (search_data_in_bits + search_data_out_bits)) * (ndwl/2-1) * g_tp.wire_outside_mat.pitch) + + ((data_in_bits + data_out_bits) * g_tp.wire_outside_mat.pitch * v))/2; + + } + else if (ndwl > ndbl) { + double excess_part = (_log2(ndwl/2) - _log2(ndbl/2)); + ht_temp = ((mat_height*ndbl/2) + + ((add_bits + (search_data_in_bits + search_data_out_bits)) * ((ndbl/2-1) + excess_part) * g_tp.wire_outside_mat.pitch) + + (data_in_bits + data_out_bits) * g_tp.wire_outside_mat.pitch * + (2*(1 - pow(0.5, h-v)) + pow(0.5, v-h) * v))/2; + len_temp = (mat_width*ndwl/2 + + ((add_bits + (search_data_in_bits + search_data_out_bits))* (ndwl/2-1) * g_tp.wire_outside_mat.pitch) + + ((data_in_bits + data_out_bits) * g_tp.wire_outside_mat.pitch * v))/2; + } + else { + double excess_part = (_log2(ndbl/2) - _log2(ndwl/2)); + ht_temp = ((mat_height*ndbl/2) + + ((add_bits + (search_data_in_bits + search_data_out_bits))* ((ndwl/2-1) + excess_part) * g_tp.wire_outside_mat.pitch) + + ((data_in_bits + data_out_bits) * g_tp.wire_outside_mat.pitch * h) + )/2; + len_temp = (mat_width*ndwl/2 + + ((add_bits + (search_data_in_bits + search_data_out_bits))* ((ndwl/2-1) + excess_part) * g_tp.wire_outside_mat.pitch) + + (data_in_bits + data_out_bits) * g_tp.wire_outside_mat.pitch * (h + 2*(1-pow(0.5, v-h))))/2; + } + } + area.h = ht_temp * 2; + area.w = len_temp * 2; + delay = 0; + power.readOp.dynamic = 0; + power.readOp.leakage = 0; + power.readOp.gate_leakage = 0; + //cout<<"power.readOp.gate_leakage"< 0 || h > 0) + { //finds delay/power of each link in the tree + if (wtemp1) delete wtemp1; + if (wtemp2) delete wtemp2; + if (wtemp3) delete wtemp3; + + if(h > v) { + //the iteration considers only one horizontal link + wtemp1 = new Wire(wt, len); // hor + wtemp2 = new Wire(wt, len/2); // ver + len_temp = len; + len /= 2; + wtemp3 = 0; + h--; + option = 0; + } + else if (v>0 && h>0) { + //considers one horizontal link and one vertical link + wtemp1 = new Wire(wt, len); // hor + wtemp2 = new Wire(wt, ht); // ver + wtemp3 = new Wire(wt, len/2); // next hor + len_temp = len; + ht_temp = ht; + len /= 2; + ht /= 2; + v--; + h--; + option = 1; + } + else { + // considers only one vertical link + assert(h == 0); + wtemp1 = new Wire(wt, ht); // hor + wtemp2 = new Wire(wt, ht/2); // ver + ht_temp = ht; + ht /= 2; + wtemp3 = 0; + v--; + option = 2; + } + delay += wtemp1->delay; + power.readOp.dynamic += wtemp1->power.readOp.dynamic; + power.searchOp.dynamic += wtemp1->power.readOp.dynamic*init_wire_bw; + power.readOp.leakage += wtemp1->power.readOp.leakage*wire_bw; + power.readOp.gate_leakage += wtemp1->power.readOp.gate_leakage*wire_bw; + //cout<<"power.readOp.gate_leakage"< wtemp1->repeater_spacing) + { + s1 = wtemp1->repeater_size; + l_eff = wtemp1->repeater_spacing; + } + else + { + s1 = (len_temp/wtemp1->repeater_spacing) * wtemp1->repeater_size; + l_eff = len_temp; + } + if (ht_temp > wtemp2->repeater_spacing) + { + s2 = wtemp2->repeater_size; + } + else + { + s2 = (len_temp/wtemp2->repeater_spacing) * wtemp2->repeater_size; + } + // first level + output_buffer(s1, s2, l_eff); + } + + + if (option != 1) + { + continue; + } + + // second level + delay += wtemp2->delay; + power.readOp.dynamic += wtemp2->power.readOp.dynamic; + power.searchOp.dynamic += wtemp2->power.readOp.dynamic*init_wire_bw; + power.readOp.leakage += wtemp2->power.readOp.leakage*wire_bw; + power.readOp.gate_leakage += wtemp2->power.readOp.gate_leakage*wire_bw; + //cout<<"power.readOp.gate_leakage"<power.readOp.leakage*wire_bw); + power.readOp.gate_leakage += wtemp2->power.readOp.gate_leakage*wire_bw; + } + else + { + power.readOp.leakage += (wtemp2->power.readOp.leakage*wire_bw); + power.readOp.gate_leakage += wtemp2->power.readOp.gate_leakage*wire_bw; + wire_bw*=2; + + if (ht_temp > wtemp3->repeater_spacing) + { + s3 = wtemp3->repeater_size; + l_eff = wtemp3->repeater_spacing; + } + else + { + s3 = (len_temp/wtemp3->repeater_spacing) * wtemp3->repeater_size; + l_eff = ht_temp; + } + + output_buffer(s2, s3, l_eff); + } + //cout<<"power.readOp.leakage"<power.readOp.gate_leakage"<power.readOp.gate_leakage< +#include +#include + + +#include "io.h" +#include "area.h" +#include "basic_circuit.h" +#include "parameter.h" +#include "Ucache.h" +#include "nuca.h" +#include "crossbar.h" +#include "arbiter.h" +//#include "highradix.h" +#include "TSV.h" +#include "memorybus.h" +#include "version_cacti.h" + +#include "extio.h" +#include "extio_technology.h" +#include "memcad.h" + +using namespace std; + + +InputParameter::InputParameter() +: array_power_gated(false), + bitline_floating(false), + wl_power_gated(false), + cl_power_gated(false), + interconect_power_gated(false), + power_gating(false), + cl_vertical (true) +{ + +} + +/* Parses "cache.cfg" file */ + void +InputParameter::parse_cfg(const string & in_file) +{ + FILE *fp = fopen(in_file.c_str(), "r"); + char line[5000]; + char jk[5000]; + char temp_var[5000]; + + if(!fp) { + cout << in_file << " is missing!\n"; + exit(-1); + } + + while(fscanf(fp, "%[^\n]\n", line) != EOF) { + + if (!strncmp("-size", line, strlen("-size"))) { + sscanf(line, "-size %[(:-~)*]%u", jk, &(cache_sz)); + if (g_ip->print_detail_debug) + cout << "cache size: " << g_ip->cache_sz << "GB" << endl; + continue; + } + + + + if (!strncmp("-page size", line, strlen("-page size"))) { + sscanf(line, "-page size %[(:-~)*]%u", jk, &(page_sz_bits)); + continue; + } + + if (!strncmp("-burst length", line, strlen("-burst length"))) { + sscanf(line, "-burst %[(:-~)*]%u", jk, &(burst_len)); + continue; + } + + if (!strncmp("-internal prefetch width", line, strlen("-internal prefetch width"))) { + sscanf(line, "-internal prefetch %[(:-~)*]%u", jk, &(int_prefetch_w)); + continue; + } + + if (!strncmp("-block", line, strlen("-block"))) { + sscanf(line, "-block size (bytes) %d", &(line_sz)); + continue; + } + + if (!strncmp("-associativity", line, strlen("-associativity"))) { + sscanf(line, "-associativity %d", &(assoc)); + continue; + } + + if (!strncmp("-read-write", line, strlen("-read-write"))) { + sscanf(line, "-read-write port %d", &(num_rw_ports)); + continue; + } + + if (!strncmp("-exclusive read", line, strlen("exclusive read"))) { + sscanf(line, "-exclusive read port %d", &(num_rd_ports)); + continue; + } + + if(!strncmp("-exclusive write", line, strlen("-exclusive write"))) { + sscanf(line, "-exclusive write port %d", &(num_wr_ports)); + continue; + } + + if (!strncmp("-single ended", line, strlen("-single ended"))) { + sscanf(line, "-single %[(:-~)*]%d", jk, + &(num_se_rd_ports)); + continue; + } + + if (!strncmp("-search", line, strlen("-search"))) { + sscanf(line, "-search port %d", &(num_search_ports)); + continue; + } + + if (!strncmp("-UCA bank", line, strlen("-UCA bank"))) { + sscanf(line, "-UCA bank%[((:-~)| )*]%d", jk, &(nbanks)); + continue; + } + + if (!strncmp("-technology", line, strlen("-technology"))) { + sscanf(line, "-technology (u) %lf", &(F_sz_um)); + F_sz_nm = F_sz_um*1000; + continue; + } + + if (!strncmp("-output/input", line, strlen("-output/input"))) { + sscanf(line, "-output/input bus %[(:-~)*]%d", jk, &(out_w)); + continue; + } + + if (!strncmp("-operating temperature", line, strlen("-operating temperature"))) { + sscanf(line, "-operating temperature %[(:-~)*]%d", jk, &(temp)); + continue; + } + + if (!strncmp("-cache type", line, strlen("-cache type"))) { + sscanf(line, "-cache type%[^\"]\"%[^\"]\"", jk, temp_var); + + if (!strncmp("cache", temp_var, sizeof("cache"))) { + is_cache = true; + } + else + { + is_cache = false; + } + + if (!strncmp("main memory", temp_var, sizeof("main memory"))) { + is_main_mem = true; + } + else { + is_main_mem = false; + } + + if (!strncmp("3D memory or 2D main memory", temp_var, sizeof("3D memory or 2D main memory"))) { + is_3d_mem = true; + is_main_mem = true; + } + else { + is_3d_mem = false; + //is_main_mem = false; + } + + if (g_ip->print_detail_debug) + {cout << "io.cc: is_3d_mem = " << is_3d_mem << endl;} + + if (!strncmp("cam", temp_var, sizeof("cam"))) { + pure_cam = true; + } + else { + pure_cam = false; + } + + if (!strncmp("ram", temp_var, sizeof("ram"))) { + pure_ram = true; + } + else { + if (!is_main_mem) + pure_ram = false; + else + pure_ram = true; + } + + continue; + } + + if (!strncmp("-print option", line, strlen("-print option"))) { + sscanf(line, "-print option%[^\"]\"%[^\"]\"", jk, temp_var); + + if (!strncmp("debug detail", temp_var, sizeof("debug detail"))) { + print_detail_debug = true; + } + else { + print_detail_debug = false; + } + if (g_ip->print_detail_debug) + {cout << "io.cc: print_detail_debug = " << print_detail_debug << endl;} + continue; + } + + if (!strncmp("-burst depth", line, strlen("-burst depth"))) { + sscanf(line, "-burst %[(:-~)*]%u", jk, &(burst_depth)); + continue; + } + + if (!strncmp("-IO width", line, strlen("-IO width"))) { + sscanf(line, "-IO %[(:-~)*]%u", jk, &(io_width)); + continue; + } + + if (!strncmp("-system frequency", line, strlen("-system frequency"))) { + sscanf(line, "-system frequency %[(:-~)*]%u", jk, &(sys_freq_MHz)); + if(g_ip->print_detail_debug) + cout << "system frequency: " << g_ip->sys_freq_MHz << endl; + continue; + } + + + + if (!strncmp("-stacked die", line, strlen("-stacked die"))) { + sscanf(line, "-stacked die %[(:-~)*]%u", jk, &(num_die_3d)); + if(g_ip->print_detail_debug) + cout << "num_die_3d: " << g_ip->num_die_3d << endl; + continue; + } + + if (!strncmp("-partitioning granularity", line, strlen("-partitioning granularity"))) { + sscanf(line, "-partitioning %[(:-~)*]%u", jk, &(partition_gran)); + if(g_ip->print_detail_debug) + cout << "partitioning granularity: " << g_ip->partition_gran << endl; + continue; + } + + if (!strncmp("-TSV projection", line, strlen("-TSV projection"))) { + sscanf(line, "-TSV %[(:-~)*]%u", jk, &(TSV_proj_type)); + if(g_ip->print_detail_debug) + cout << "TSV projection: " << g_ip->TSV_proj_type << endl; + continue; + } + + + //g_ip->print_detail_debug = debug_detail; + + + //g_ip->partition_gran = 1; + + // --- These two parameters are supposed for bank level partitioning, currently not shown to public + g_ip->num_tier_row_sprd = 1; + g_ip->num_tier_col_sprd = 1; + + if (!strncmp("-tag size", line, strlen("-tag size"))) { + sscanf(line, "-tag size%[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("default", temp_var, sizeof("default"))) { + specific_tag = false; + tag_w = 42; /* the actual value is calculated + * later based on the cache size, bank count, and associativity + */ + } + else { + specific_tag = true; + sscanf(line, "-tag size (b) %d", &(tag_w)); + } + continue; + } + + if (!strncmp("-access mode", line, strlen("-access mode"))) { + sscanf(line, "-access %[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("fast", temp_var, strlen("fast"))) { + access_mode = 2; + } + else if (!strncmp("sequential", temp_var, strlen("sequential"))) { + access_mode = 1; + } + else if(!strncmp("normal", temp_var, strlen("normal"))) { + access_mode = 0; + } + else { + cout << "ERROR: Invalid access mode!\n"; + exit(0); + } + continue; + } + + if (!strncmp("-Data array cell type", line, strlen("-Data array cell type"))) { + sscanf(line, "-Data array cell type %[^\"]\"%[^\"]\"", jk, temp_var); + + if(!strncmp("itrs-hp", temp_var, strlen("itrs-hp"))) { + data_arr_ram_cell_tech_type = 0; + } + else if(!strncmp("itrs-lstp", temp_var, strlen("itrs-lstp"))) { + data_arr_ram_cell_tech_type = 1; + } + else if(!strncmp("itrs-lop", temp_var, strlen("itrs-lop"))) { + data_arr_ram_cell_tech_type = 2; + } + else if(!strncmp("lp-dram", temp_var, strlen("lp-dram"))) { + data_arr_ram_cell_tech_type = 3; + } + else if(!strncmp("comm-dram", temp_var, strlen("comm-dram"))) { + data_arr_ram_cell_tech_type = 4; + } + else { + cout << "ERROR: Invalid type!\n"; + exit(0); + } + continue; + } + + if (!strncmp("-Data array peripheral type", line, strlen("-Data array peripheral type"))) { + sscanf(line, "-Data array peripheral type %[^\"]\"%[^\"]\"", jk, temp_var); + + if(!strncmp("itrs-hp", temp_var, strlen("itrs-hp"))) { + data_arr_peri_global_tech_type = 0; + } + else if(!strncmp("itrs-lstp", temp_var, strlen("itrs-lstp"))) { + data_arr_peri_global_tech_type = 1; + } + else if(!strncmp("itrs-lop", temp_var, strlen("itrs-lop"))) { + data_arr_peri_global_tech_type = 2; + } + else { + cout << "ERROR: Invalid type!\n"; + exit(0); + } + continue; + } + + if (!strncmp("-Tag array cell type", line, strlen("-Tag array cell type"))) { + sscanf(line, "-Tag array cell type %[^\"]\"%[^\"]\"", jk, temp_var); + + if(!strncmp("itrs-hp", temp_var, strlen("itrs-hp"))) { + tag_arr_ram_cell_tech_type = 0; + } + else if(!strncmp("itrs-lstp", temp_var, strlen("itrs-lstp"))) { + tag_arr_ram_cell_tech_type = 1; + } + else if(!strncmp("itrs-lop", temp_var, strlen("itrs-lop"))) { + tag_arr_ram_cell_tech_type = 2; + } + else if(!strncmp("lp-dram", temp_var, strlen("lp-dram"))) { + tag_arr_ram_cell_tech_type = 3; + } + else if(!strncmp("comm-dram", temp_var, strlen("comm-dram"))) { + tag_arr_ram_cell_tech_type = 4; + } + else { + cout << "ERROR: Invalid type!\n"; + exit(0); + } + continue; + } + + if (!strncmp("-Tag array peripheral type", line, strlen("-Tag array peripheral type"))) { + sscanf(line, "-Tag array peripheral type %[^\"]\"%[^\"]\"", jk, temp_var); + + if(!strncmp("itrs-hp", temp_var, strlen("itrs-hp"))) { + tag_arr_peri_global_tech_type = 0; + } + else if(!strncmp("itrs-lstp", temp_var, strlen("itrs-lstp"))) { + tag_arr_peri_global_tech_type = 1; + } + else if(!strncmp("itrs-lop", temp_var, strlen("itrs-lop"))) { + tag_arr_peri_global_tech_type = 2; + } + else { + cout << "ERROR: Invalid type!\n"; + exit(0); + } + continue; + } + if(!strncmp("-design", line, strlen("-design"))) { + sscanf(line, "-%[((:-~)| |,)*]%d:%d:%d:%d:%d", jk, + &(delay_wt), &(dynamic_power_wt), + &(leakage_power_wt), + &(cycle_time_wt), &(area_wt)); + continue; + } + + if(!strncmp("-deviate", line, strlen("-deviate"))) { + sscanf(line, "-%[((:-~)| |,)*]%d:%d:%d:%d:%d", jk, + &(delay_dev), &(dynamic_power_dev), + &(leakage_power_dev), + &(cycle_time_dev), &(area_dev)); + continue; + } + + if(!strncmp("-Optimize", line, strlen("-Optimize"))) { + sscanf(line, "-Optimize %[^\"]\"%[^\"]\"", jk, temp_var); + + if(!strncmp("ED^2", temp_var, strlen("ED^2"))) { + ed = 2; + } + else if(!strncmp("ED", temp_var, strlen("ED"))) { + ed = 1; + } + else { + ed = 0; + } + } + + if(!strncmp("-NUCAdesign", line, strlen("-NUCAdesign"))) { + sscanf(line, "-%[((:-~)| |,)*]%d:%d:%d:%d:%d", jk, + &(delay_wt_nuca), &(dynamic_power_wt_nuca), + &(leakage_power_wt_nuca), + &(cycle_time_wt_nuca), &(area_wt_nuca)); + continue; + } + + if(!strncmp("-NUCAdeviate", line, strlen("-NUCAdeviate"))) { + sscanf(line, "-%[((:-~)| |,)*]%d:%d:%d:%d:%d", jk, + &(delay_dev_nuca), &(dynamic_power_dev_nuca), + &(leakage_power_dev_nuca), + &(cycle_time_dev_nuca), &(area_dev_nuca)); + continue; + } + + if(!strncmp("-Cache model", line, strlen("-cache model"))) { + sscanf(line, "-Cache model %[^\"]\"%[^\"]\"", jk, temp_var); + + if (!strncmp("UCA", temp_var, strlen("UCA"))) { + nuca = 0; + } + else { + nuca = 1; + } + continue; + } + + if(!strncmp("-NUCA bank", line, strlen("-NUCA bank"))) { + sscanf(line, "-NUCA bank count %d", &(nuca_bank_count)); + + if (nuca_bank_count != 0) { + force_nuca_bank = 1; + } + continue; + } + + if(!strncmp("-Wire inside mat", line, strlen("-Wire inside mat"))) { + sscanf(line, "-Wire%[^\"]\"%[^\"]\"", jk, temp_var); + + if (!strncmp("global", temp_var, strlen("global"))) { + wire_is_mat_type = 2; + continue; + } + else if (!strncmp("local", temp_var, strlen("local"))) { + wire_is_mat_type = 0; + continue; + } + else { + wire_is_mat_type = 1; + continue; + } + } + + if(!strncmp("-Wire outside mat", line, strlen("-Wire outside mat"))) { + sscanf(line, "-Wire%[^\"]\"%[^\"]\"", jk, temp_var); + + if (!strncmp("global", temp_var, strlen("global"))) { + wire_os_mat_type = 2; + } + else { + wire_os_mat_type = 1; + } + continue; + } + + if(!strncmp("-Interconnect projection", line, strlen("-Interconnect projection"))) { + sscanf(line, "-Interconnect projection%[^\"]\"%[^\"]\"", jk, temp_var); + + if (!strncmp("aggressive", temp_var, strlen("aggressive"))) { + ic_proj_type = 0; + } + else { + ic_proj_type = 1; + } + continue; + } + + if(!strncmp("-Wire signaling", line, strlen("-wire signaling"))) { + sscanf(line, "-Wire%[^\"]\"%[^\"]\"", jk, temp_var); + + if (!strncmp("default", temp_var, strlen("default"))) { + force_wiretype = 0; + wt = Global; + } + else if (!(strncmp("Global_10", temp_var, strlen("Global_10")))) { + force_wiretype = 1; + wt = Global_10; + } + else if (!(strncmp("Global_20", temp_var, strlen("Global_20")))) { + force_wiretype = 1; + wt = Global_20; + } + else if (!(strncmp("Global_30", temp_var, strlen("Global_30")))) { + force_wiretype = 1; + wt = Global_30; + } + else if (!(strncmp("Global_5", temp_var, strlen("Global_5")))) { + force_wiretype = 1; + wt = Global_5; + } + else if (!(strncmp("Global", temp_var, strlen("Global")))) { + force_wiretype = 1; + wt = Global; + } + else if (!(strncmp("fullswing", temp_var, strlen("fullswing")))) { + force_wiretype = 1; + wt = Full_swing; + } + else if (!(strncmp("lowswing", temp_var, strlen("lowswing")))) { + force_wiretype = 1; + wt = Low_swing; + } + else { + cout << "Unknown wire type!\n"; + exit(0); + } + continue; + } + + + + if(!strncmp("-Core", line, strlen("-Core"))) { + sscanf(line, "-Core count %d\n", &(cores)); + if (cores > 16) { + printf("No. of cores should be less than 16!\n"); + } + continue; + } + + if(!strncmp("-Cache level", line, strlen("-Cache level"))) { + sscanf(line, "-Cache l%[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("L2", temp_var, strlen("L2"))) { + cache_level = 0; + } + else { + cache_level = 1; + } + } + + if(!strncmp("-Print level", line, strlen("-Print level"))) { + sscanf(line, "-Print l%[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("DETAILED", temp_var, strlen("DETAILED"))) { + print_detail = 1; + } + else { + print_detail = 0; + } + + } + if(!strncmp("-Add ECC", line, strlen("-Add ECC"))) { + sscanf(line, "-Add ECC %[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("true", temp_var, strlen("true"))) { + add_ecc_b_ = true; + } + else { + add_ecc_b_ = false; + } + } + + if(!strncmp("-CLDriver vertical", line, strlen("-CLDriver vertical"))) { + sscanf(line, "-CLDriver vertical %[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("true", temp_var, strlen("true"))) { + cl_vertical = true; + } + else { + cl_vertical = false; + } + } + + if(!strncmp("-Array Power Gating", line, strlen("-Array Power Gating"))) { + sscanf(line, "-Array Power Gating %[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("true", temp_var, strlen("true"))) { + array_power_gated = true; + } + else { + array_power_gated = false; + } + } + + if(!strncmp("-Bitline floating", line, strlen("-Bitline floating"))) { + sscanf(line, "-Bitline floating %[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("true", temp_var, strlen("true"))) { + bitline_floating = true; + } + else { + bitline_floating = false; + } + } + + if(!strncmp("-WL Power Gating", line, strlen("-WL Power Gating"))) { + sscanf(line, "-WL Power Gating %[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("true", temp_var, strlen("true"))) { + wl_power_gated = true; + } + else { + wl_power_gated = false; + } + } + + if(!strncmp("-CL Power Gating", line, strlen("-CL Power Gating"))) { + sscanf(line, "-CL Power Gating %[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("true", temp_var, strlen("true"))) { + cl_power_gated = true; + } + else { + cl_power_gated = false; + } + } + + if(!strncmp("-Interconnect Power Gating", line, strlen("-Interconnect Power Gating"))) { + sscanf(line, "-Interconnect Power Gating %[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("true", temp_var, strlen("true"))) { + interconect_power_gated = true; + } + else { + interconect_power_gated = false; + } + } + + if(!strncmp("-Power Gating Performance Loss", line, strlen("-Power Gating Performance Loss"))) { + sscanf(line, "-Power Gating Performance Loss %lf", &(perfloss)); + continue; + } + + if(!strncmp("-Print input parameters", line, strlen("-Print input parameters"))) { + sscanf(line, "-Print input %[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("true", temp_var, strlen("true"))) { + print_input_args = true; + } + else { + print_input_args = false; + } + } + + if(!strncmp("-Force cache config", line, strlen("-Force cache config"))) { + sscanf(line, "-Force cache %[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("true", temp_var, strlen("true"))) { + force_cache_config = true; + } + else { + force_cache_config = false; + } + } + + if(!strncmp("-Ndbl", line, strlen("-Ndbl"))) { + sscanf(line, "-Ndbl %d\n", &(ndbl)); + continue; + } + if(!strncmp("-Ndwl", line, strlen("-Ndwl"))) { + sscanf(line, "-Ndwl %d\n", &(ndwl)); + continue; + } + if(!strncmp("-Nspd", line, strlen("-Nspd"))) { + sscanf(line, "-Nspd %d\n", &(nspd)); + continue; + } + if(!strncmp("-Ndsam1", line, strlen("-Ndsam1"))) { + sscanf(line, "-Ndsam1 %d\n", &(ndsam1)); + continue; + } + if(!strncmp("-Ndsam2", line, strlen("-Ndsam2"))) { + sscanf(line, "-Ndsam2 %d\n", &(ndsam2)); + continue; + } + if(!strncmp("-Ndcm", line, strlen("-Ndcm"))) { + sscanf(line, "-Ndcm %d\n", &(ndcm)); + continue; + } + + // Parameters related to off-chip interconnect + + if(!strncmp("-dram type", line, strlen("-dram type"))) { + sscanf(line, "-dram type%[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("DDR3", temp_var, strlen("DDR3"))) + { + io_type = DDR3; + } + else if(!strncmp("DDR4", temp_var, strlen("DDR4"))) + { + io_type = DDR4; + } + else if(!strncmp("LPDDR2", temp_var, strlen("LPDDR2"))) + { + io_type = LPDDR2; + } + else if(!strncmp("WideIO", temp_var, strlen("WideIO"))) + { + io_type = WideIO; + } + else if(!strncmp("Low_Swing_Diff", temp_var, strlen("Low_Swing_Diff"))) + { + io_type = Low_Swing_Diff; + } + else if(!strncmp("Serial", temp_var, strlen("Serial"))) + { + io_type = Serial; + } + else + { + cout << "Invalid Input for dram type!" << endl; + exit(1); + } + // sscanf(line, "-io_type \"%c\"\n", &(io_type)); + } + if(!strncmp("-io state", line, strlen("-io state"))) { + sscanf(line, "-io state%[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("READ", temp_var, strlen("READ"))) + { + iostate = READ; + } + else if(!strncmp("WRITE", temp_var, strlen("WRITE"))) + { + iostate = WRITE; + } + else if(!strncmp("IDLE", temp_var, strlen("IDLE"))) + { + iostate = IDLE; + } + else if(!strncmp("SLEEP", temp_var, strlen("SLEEP"))) + { + iostate = SLEEP; + } + else + { + cout << "Invalid Input for io state!" << endl; + exit(1); + } + //sscanf(line, "-iostate \"%c\"\n", &(iostate)); + } + if(!strncmp("-addr_timing", line, strlen("-addr_timing"))) { + sscanf(line, "-addr_timing %lf", &(addr_timing)); + } + if(!strncmp("-dram ecc", line, strlen("-dram ecc"))) { + sscanf(line, "-dram ecc%[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("NO_ECC", temp_var, strlen("NO_ECC"))) + { + dram_ecc = NO_ECC; + } + else if(!strncmp("SECDED", temp_var, strlen("SECDED"))) + { + dram_ecc = SECDED; + } + else if(!strncmp("CHIP_KILL", temp_var, strlen("CHIP_KILL"))) + { + dram_ecc = CHIP_KILL; + } + else + { + cout << "Invalid Input for dram ecc!" << endl; + exit(1); + } + //sscanf(line, "-dram_ecc \"%c\"\n", &(dram_ecc)); + } + if(!strncmp("-dram dimm", line, strlen("-dram dimm"))) { + sscanf(line, "-dram dimm%[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("UDIMM", temp_var, strlen("UDIMM"))) + { + dram_dimm = UDIMM; + } + else if(!strncmp("RDIMM", temp_var, strlen("RDIMM"))) + { + dram_dimm = RDIMM; + } + else if(!strncmp("LRDIMM", temp_var, strlen("LRDIMM"))) + { + dram_dimm = LRDIMM; + } + else + { + cout << "Invalid Input for dram dimm!" << endl; + exit(1); + } + //sscanf(line, "-dram_ecc \"%c\"\n", &(dram_ecc)); + } + + + if(!strncmp("-bus_bw", line, strlen("-bus_bw"))) { + sscanf(line, "-bus_bw %lf", &(bus_bw)); + } + if(!strncmp("-duty_cycle", line, strlen("-duty_cycle"))) { + sscanf(line, "-duty_cycle %lf", &(duty_cycle)); + } + if(!strncmp("-mem_density", line, strlen("-mem_density"))) { + sscanf(line, "-mem_density %lf", &(mem_density)); + } + if(!strncmp("-activity_dq", line, strlen("-activity_dq"))) { + sscanf(line, "-activity_dq %lf", &activity_dq); + } + if(!strncmp("-activity_ca", line, strlen("-activity_ca"))) { + sscanf(line, "-activity_ca %lf", &activity_ca); + } + if(!strncmp("-bus_freq", line, strlen("-bus_freq"))) { + sscanf(line, "-bus_freq %lf", &bus_freq); + } + if(!strncmp("-num_dq", line, strlen("-num_dq"))) { + sscanf(line, "-num_dq %d", &num_dq); + } + if(!strncmp("-num_dqs", line, strlen("-num_dqs"))) { + sscanf(line, "-num_dqs %d", &num_dqs); + } + if(!strncmp("-num_ca", line, strlen("-num_ca"))) { + sscanf(line, "-num_ca %d", &num_ca); + } + if(!strncmp("-num_clk", line, strlen("-num_clk"))) { + sscanf(line, "-num_clk %d", &num_clk); + if(num_clk<=0) + { + cout << "num_clk should be greater than zero!\n"; + exit(1); + } + } + if(!strncmp("-num_mem_dq", line, strlen("-num_mem_dq"))) { + sscanf(line, "-num_mem_dq %d", &num_mem_dq); + } + if(!strncmp("-mem_data_width", line, strlen("-mem_data_width"))) { + sscanf(line, "-mem_data_width %d", &mem_data_width); + } + + // added just for memcad + + if(!strncmp("-num_bobs", line, strlen("-num_bobs"))) { + sscanf(line, "-num_bobs %d", &num_bobs); + } + if(!strncmp("-capacity", line, strlen("-capacity"))) { + sscanf(line, "-capacity %d", &capacity); + } + if(!strncmp("-num_channels_per_bob", line, strlen("-num_channels_per_bob"))) { + sscanf(line, "-num_channels_per_bob %d", &num_channels_per_bob); + } + if(!strncmp("-first metric", line, strlen("-first metric"))) { + sscanf(line, "-first metric%[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("Cost", temp_var, strlen("Cost"))) + { + first_metric = Cost; + } + else if(!strncmp("Energy", temp_var, strlen("Energy"))) + { + first_metric = Energy; + } + else if(!strncmp("Bandwidth", temp_var, strlen("Bandwidth"))) + { + first_metric = Bandwidth; + } + else + { + cout << "Invalid Input for first metric!" << endl; + exit(1); + } + + } + if(!strncmp("-second metric", line, strlen("-second metric"))) { + sscanf(line, "-second metric%[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("Cost", temp_var, strlen("Cost"))) + { + second_metric = Cost; + } + else if(!strncmp("Energy", temp_var, strlen("Energy"))) + { + second_metric = Energy; + } + else if(!strncmp("Bandwidth", temp_var, strlen("Bandwidth"))) + { + second_metric = Bandwidth; + } + else + { + cout << "Invalid Input for second metric!" << endl; + exit(1); + } + + } + if(!strncmp("-third metric", line, strlen("-third metric"))) { + sscanf(line, "-third metric%[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("Cost", temp_var, strlen("Cost"))) + { + third_metric = Cost; + } + else if(!strncmp("Energy", temp_var, strlen("Energy"))) + { + third_metric = Energy; + } + else if(!strncmp("Bandwidth", temp_var, strlen("Bandwidth"))) + { + third_metric = Bandwidth; + } + else + { + cout << "Invalid Input for third metric!" << endl; + exit(1); + } + + } + if(!strncmp("-DIMM model", line, strlen("-DIMM model"))) { + sscanf(line, "-DIMM model%[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("JUST_UDIMM", temp_var, strlen("JUST_UDIMM"))) + { + dimm_model = JUST_UDIMM; + } + else if(!strncmp("JUST_RDIMM", temp_var, strlen("JUST_RDIMM"))) + { + dimm_model = JUST_RDIMM; + } + else if(!strncmp("JUST_LRDIMM", temp_var, strlen("JUST_LRDIMM"))) + { + dimm_model = JUST_LRDIMM; + } + else if(!strncmp("ALL", temp_var, strlen("ALL"))) + { + dimm_model = ALL; + } + else + { + cout << "Invalid Input for DIMM model!" << endl; + exit(1); + } + + } + if(!strncmp("-Low Power Permitted", line, strlen("-Low Power Permitted"))) { + sscanf(line, "-Low Power Permitted%[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("T", temp_var, strlen("T"))) + { + low_power_permitted = true; + } + else if(!strncmp("F", temp_var, strlen("F"))) + { + low_power_permitted = false; + } + else + { + cout << "Invalid Input for Low Power Permitted!" << endl; + exit(1); + } + + } + if(!strncmp("-load", line, strlen("-load"))) { + sscanf(line, "-load %lf", &(load)); + } + if(!strncmp("-row_buffer_hit_rate", line, strlen("-row_buffer_hit_rate"))) { + sscanf(line, "-row_buffer_hit_rate %lf", &(row_buffer_hit_rate)); + } + if(!strncmp("-rd_2_wr_ratio", line, strlen("-rd_2_wr_ratio"))) { + sscanf(line, "-rd_2_wr_ratio %lf", &(rd_2_wr_ratio)); + } + if(!strncmp("-same_bw_in_bob", line, strlen("-same_bw_in_bob"))) { + sscanf(line, "-same_bw_in_bob%[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("T", temp_var, strlen("T"))) + { + same_bw_in_bob = true; + } + else if(!strncmp("F", temp_var, strlen("F"))) + { + same_bw_in_bob = false; + } + else + { + cout << "Invalid Input for same_bw_in_bob!" << endl; + exit(1); + } + + } + if(!strncmp("-mirror_in_bob", line, strlen("-mirror_in_bob"))) { + sscanf(line, "-mirror_in_bob%[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("T", temp_var, strlen("T"))) + { + mirror_in_bob = true; + } + else if(!strncmp("F", temp_var, strlen("F"))) + { + mirror_in_bob = false; + } + else + { + cout << "Invalid Input for mirror_in_bob!" << endl; + exit(1); + } + + } + if(!strncmp("-total_power", line, strlen("-total_power"))) { + sscanf(line, "-total_power%[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("T", temp_var, strlen("T"))) + { + total_power = true; + } + else if(!strncmp("F", temp_var, strlen("F"))) + { + total_power = false; + } + else + { + cout << "Invalid Input for total_power!" << endl; + exit(1); + } + + } + if(!strncmp("-verbose", line, strlen("-verbose"))) { + sscanf(line, "-verbose%[^\"]\"%[^\"]\"", jk, temp_var); + if (!strncmp("T", temp_var, strlen("T"))) + { + verbose = true; + } + else if(!strncmp("F", temp_var, strlen("F"))) + { + verbose = false; + } + else + { + cout << "Invalid Input for same_bw_in_bob!" << endl; + exit(1); + } + + } + + + + + } + rpters_in_htree = true; + fclose(fp); +} + + void +InputParameter::display_ip() +{ + cout << "Cache size : " << cache_sz << endl; + cout << "Block size : " << line_sz << endl; + cout << "Associativity : " << assoc << endl; + cout << "Read only ports : " << num_rd_ports << endl; + cout << "Write only ports : " << num_wr_ports << endl; + cout << "Read write ports : " << num_rw_ports << endl; + cout << "Single ended read ports : " << num_se_rd_ports << endl; + if (fully_assoc||pure_cam) + { + cout << "Search ports : " << num_search_ports << endl; + } + cout << "Cache banks (UCA) : " << nbanks << endl; + cout << "Technology : " << F_sz_um << endl; + cout << "Temperature : " << temp << endl; + cout << "Tag size : " << tag_w << endl; + if (is_cache) + { + cout << "array type : " << "Cache" << endl; + } + if (pure_ram) + { + cout << "array type : " << "Scratch RAM" << endl; + } + if (pure_cam) + { + cout << "array type : " << "CAM" << endl; + } + cout << "Model as memory : " << is_main_mem << endl; + cout << "Model as 3D memory : " << is_3d_mem << endl; + cout << "Access mode : " << access_mode << endl; + cout << "Data array cell type : " << data_arr_ram_cell_tech_type << endl; + cout << "Data array peripheral type : " << data_arr_peri_global_tech_type << endl; + cout << "Tag array cell type : " << tag_arr_ram_cell_tech_type << endl; + cout << "Tag array peripheral type : " << tag_arr_peri_global_tech_type << endl; + cout << "Optimization target : " << ed << endl; + cout << "Design objective (UCA wt) : " << delay_wt << " " + << dynamic_power_wt << " " << leakage_power_wt << " " << cycle_time_wt + << " " << area_wt << endl; + cout << "Design objective (UCA dev) : " << delay_dev << " " + << dynamic_power_dev << " " << leakage_power_dev << " " << cycle_time_dev + << " " << area_dev << endl; + if (nuca) + { + cout << "Cores : " << cores << endl; + + + cout << "Design objective (NUCA wt) : " << delay_wt_nuca << " " + << dynamic_power_wt_nuca << " " << leakage_power_wt_nuca << " " << cycle_time_wt_nuca + << " " << area_wt_nuca << endl; + cout << "Design objective (NUCA dev) : " << delay_dev_nuca << " " + << dynamic_power_dev_nuca << " " << leakage_power_dev_nuca << " " << cycle_time_dev_nuca + << " " << area_dev_nuca << endl; + } + cout << "Cache model : " << nuca << endl; + cout << "Nuca bank : " << nuca_bank_count << endl; + cout << "Wire inside mat : " << wire_is_mat_type << endl; + cout << "Wire outside mat : " << wire_os_mat_type << endl; + cout << "Interconnect projection : " << ic_proj_type << endl; + cout << "Wire signaling : " << force_wiretype << endl; + cout << "Print level : " << print_detail << endl; + cout << "ECC overhead : " << add_ecc_b_ << endl; + cout << "Page size : " << page_sz_bits << endl; + cout << "Burst length : " << burst_len << endl; + cout << "Internal prefetch width : " << int_prefetch_w << endl; + cout << "Force cache config : " << g_ip->force_cache_config << endl; + if (g_ip->force_cache_config) { + cout << "Ndwl : " << g_ip->ndwl << endl; + cout << "Ndbl : " << g_ip->ndbl << endl; + cout << "Nspd : " << g_ip->nspd << endl; + cout << "Ndcm : " << g_ip->ndcm << endl; + cout << "Ndsam1 : " << g_ip->ndsam1 << endl; + cout << "Ndsam2 : " << g_ip->ndsam2 << endl; + } + cout << "Subarray Driver direction : " << g_ip->cl_vertical << endl; + + // CACTI-I/O + cout << "iostate : " ; + switch(iostate) + { + case(READ): cout << "READ" << endl; break; + case(WRITE): cout << "WRITE" << endl; break; + case(IDLE): cout << "IDLE" << endl; break; + case(SLEEP): cout << "SLEEP" << endl; break; + default: assert(false); + } + cout << "dram_ecc : " ; + switch(dram_ecc) + { + case(NO_ECC): cout << "NO_ECC" << endl; break; + case(SECDED): cout << "SECDED" << endl; break; + case(CHIP_KILL): cout << "CHIP_KILL" << endl; break; + default: assert(false); + } + cout << "io_type : " ; + switch(io_type) + { + case(DDR3): cout << "DDR3" << endl; break; + case(DDR4): cout << "DDR4" << endl; break; + case(LPDDR2): cout << "LPDDR2" << endl; break; + case(WideIO): cout << "WideIO" << endl; break; + case(Low_Swing_Diff): cout << "Low_Swing_Diff" << endl; break; + default: assert(false); + } + cout << "dram_dimm : " ; + switch(dram_dimm) + { + case(UDIMM): cout << "UDIMM" << endl; break; + case(RDIMM): cout << "RDIMM" << endl; break; + case(LRDIMM): cout << "LRDIMM" << endl; break; + default: assert(false); + } + + + +} + + + +powerComponents operator+(const powerComponents & x, const powerComponents & y) +{ + powerComponents z; + + z.dynamic = x.dynamic + y.dynamic; + z.leakage = x.leakage + y.leakage; + z.gate_leakage = x.gate_leakage + y.gate_leakage; + z.short_circuit = x.short_circuit + y.short_circuit; + z.longer_channel_leakage = x.longer_channel_leakage + y.longer_channel_leakage; + + return z; +} + +powerComponents operator*(const powerComponents & x, double const * const y) +{ + powerComponents z; + + z.dynamic = x.dynamic*y[0]; + z.leakage = x.leakage*y[1]; + z.gate_leakage = x.gate_leakage*y[2]; + z.short_circuit = x.short_circuit*y[3]; + z.longer_channel_leakage = x.longer_channel_leakage*y[1];//longer channel leakage has the same behavior as normal leakage + + return z; +} + + +powerDef operator+(const powerDef & x, const powerDef & y) +{ + powerDef z; + + z.readOp = x.readOp + y.readOp; + z.writeOp = x.writeOp + y.writeOp; + z.searchOp = x.searchOp + y.searchOp; + return z; +} + +powerDef operator*(const powerDef & x, double const * const y) +{ + powerDef z; + + z.readOp = x.readOp*y; + z.writeOp = x.writeOp*y; + z.searchOp = x.searchOp*y; + return z; +} + +uca_org_t cacti_interface(const string & infile_name) +{ + + //cout<<"TSV_proj_type: " << g_ip->TSV_proj_type << endl; + uca_org_t fin_res; + //uca_org_t result; + fin_res.valid = false; + + g_ip = new InputParameter(); + g_ip->parse_cfg(infile_name); + if(!g_ip->error_checking()) + exit(0); + // if (g_ip->print_input_args) + //g_ip->display_ip(); + + + init_tech_params(g_ip->F_sz_um, false); + Wire winit; // Do not delete this line. It initializes wires. +// cout << winit.wire_res(256*8*64e-9) << endl; +// exit(0); + + + //CACTI3DD + // --- These two parameters are supposed for two different TSV technologies within one DRAM fabrication, currently assume one individual TSV geometry size for cost efficiency + g_ip->tsv_is_subarray_type = g_ip->TSV_proj_type; + g_ip->tsv_os_bank_type = g_ip->TSV_proj_type; + TSV tsv_test(Coarse);// ********* double len_ /* in um*/, double diam_, double TSV_pitch_, + if(g_ip->print_detail_debug) + { + tsv_test.print_TSV(); + } + +// For HighRadix Only +// //// Wire wirea(g_ip->wt, 1000); +// //// wirea.print_wire(); +// //// cout << "Wire Area " << wirea.area.get_area() << " sq. u" << endl; +// // winit.print_wire(); +// // +// HighRadix *hr; +// hr = new HighRadix(); +// hr->compute_power(); +// hr->print_router(); +// exit(0); +// +// double sub_switch_sz = 2; +// double rows = 32; +// for (int i=0; i<6; i++) { +// sub_switch_sz = pow(2, i); +// rows = 64/sub_switch_sz; +// hr = new HighRadix(sub_switch_sz, rows, .8/* freq */, 64, 2, 64, 0.7); +// hr->compute_power(); +// hr->print_router(); +// delete hr; +// } +// // HighRadix yarc; +// // yarc.compute_power(); +// // yarc.print_router(); +// winit.print_wire(); +// exit(0); +// For HighRadix Only End + + if (g_ip->nuca == 1) + { + Nuca n(&g_tp.peri_global); + n.sim_nuca(); + } + + //g_ip->display_ip(); + + + + IOTechParam iot(g_ip, g_ip->io_type, g_ip->num_mem_dq, g_ip->mem_data_width, g_ip->num_dq,g_ip->dram_dimm, 1,g_ip->bus_freq ); + Extio testextio(&iot); + testextio.extio_area(); + testextio.extio_eye(); + testextio.extio_power_dynamic(); + testextio.extio_power_phy(); + testextio.extio_power_term(); + + + /* + int freq[][4]={{400,533,667,800},{800,933,1066,1200}}; + + Mem_IO_type types[2]={DDR3,DDR4}; + + int max_load[3]={3,3,8}; + + for(int j=0;j<1;j++) + { + for(int connection=0;connection<3;connection++) + { + for(int frq=3;frq<4;frq++) + { + for(int load=1;load<=max_load[connection];load++) + { + IOTechParam iot(g_ip, types[j], load, 8, 72, connection, load, freq[j][frq]); + Extio testextio(&iot); + // testextio.extio_area(); + // testextio.extio_eye(); + testextio.extio_power_dynamic(); + testextio.extio_power_phy(); + testextio.extio_power_term(); + + } + cout << endl; + } + cout << endl; + } + cout << endl; + } + */ + + ///double total_io_p, total_phy_p, total_io_area, total_vmargin, total_tmargin; + //testextio.extio_power_area_timing(total_io_p, total_phy_p, total_io_area, total_vmargin, total_tmargin); + + solve(&fin_res); + + //output_UCA(&fin_res); + output_data_csv(fin_res, infile_name + ".out"); + + + // Memcad Optimization + MemCadParameters memcad_params(g_ip); + solve_memcad(&memcad_params); + + + delete (g_ip); + return fin_res; +} + +//CACTI3DD's plain interface, please keep !!! +uca_org_t cacti_interface( + int dram_cap_tot_byte, + int line_size, + int associativity, + int rw_ports, + int excl_read_ports,// para5 + int excl_write_ports, + int single_ended_read_ports, + int search_ports, + int banks, + double tech_node,//para10 + int output_width, + int specific_tag, + int tag_width, + int access_mode, + int cache, //para15 + int main_mem, + int obj_func_delay, + int obj_func_dynamic_power, + int obj_func_leakage_power, + int obj_func_cycle_time, //para20 + int obj_func_area, + int dev_func_delay, + int dev_func_dynamic_power, + int dev_func_leakage_power, + int dev_func_area, //para25 + int dev_func_cycle_time, + int ed_ed2_none, // 0 - ED, 1 - ED^2, 2 - use weight and deviate + int temp, + int wt, //0 - default(search across everything), 1 - global, 2 - 5% delay penalty, 3 - 10%, 4 - 20 %, 5 - 30%, 6 - low-swing + int data_arr_ram_cell_tech_flavor_in,//para30 + int data_arr_peri_global_tech_flavor_in, + int tag_arr_ram_cell_tech_flavor_in, + int tag_arr_peri_global_tech_flavor_in, + int interconnect_projection_type_in, + int wire_inside_mat_type_in,//para35 + int wire_outside_mat_type_in, + int REPEATERS_IN_HTREE_SEGMENTS_in, + int VERTICAL_HTREE_WIRES_OVER_THE_ARRAY_in, + int BROADCAST_ADDR_DATAIN_OVER_VERTICAL_HTREES_in, + int PAGE_SIZE_BITS_in,//para40 + int BURST_LENGTH_in, + int INTERNAL_PREFETCH_WIDTH_in, + int force_wiretype, + int wiretype, + int force_config,//para45 + int ndwl, + int ndbl, + int nspd, + int ndcm, + int ndsam1,//para50 + int ndsam2, + int ecc, + int is_3d_dram, + int burst_depth, + int IO_width, + int sys_freq, + int debug_detail, + int num_dies, + int tsv_gran_is_subarray, + int tsv_gran_os_bank, + int num_tier_row_sprd, + int num_tier_col_sprd, + int partition_level + ) +{ + g_ip = new InputParameter(); + + uca_org_t fin_res; + fin_res.valid = false; + + g_ip->data_arr_ram_cell_tech_type = data_arr_ram_cell_tech_flavor_in; + g_ip->data_arr_peri_global_tech_type = data_arr_peri_global_tech_flavor_in; + g_ip->tag_arr_ram_cell_tech_type = tag_arr_ram_cell_tech_flavor_in; + g_ip->tag_arr_peri_global_tech_type = tag_arr_peri_global_tech_flavor_in; + + g_ip->ic_proj_type = interconnect_projection_type_in; + g_ip->wire_is_mat_type = wire_inside_mat_type_in; + g_ip->wire_os_mat_type = wire_outside_mat_type_in; + g_ip->burst_len = BURST_LENGTH_in; + g_ip->int_prefetch_w = INTERNAL_PREFETCH_WIDTH_in; + g_ip->page_sz_bits = PAGE_SIZE_BITS_in; + + g_ip->num_die_3d = num_dies; + g_ip->cache_sz = dram_cap_tot_byte; + g_ip->line_sz = line_size; + g_ip->assoc = associativity; + g_ip->nbanks = banks; + g_ip->out_w = output_width; + g_ip->specific_tag = specific_tag; + if (specific_tag == 0) { + g_ip->tag_w = 42; + } + else { + g_ip->tag_w = tag_width; + } + + g_ip->access_mode = access_mode; + g_ip->delay_wt = obj_func_delay; + g_ip->dynamic_power_wt = obj_func_dynamic_power; + g_ip->leakage_power_wt = obj_func_leakage_power; + g_ip->area_wt = obj_func_area; + g_ip->cycle_time_wt = obj_func_cycle_time; + g_ip->delay_dev = dev_func_delay; + g_ip->dynamic_power_dev = dev_func_dynamic_power; + g_ip->leakage_power_dev = dev_func_leakage_power; + g_ip->area_dev = dev_func_area; + g_ip->cycle_time_dev = dev_func_cycle_time; + g_ip->temp = temp; + g_ip->ed = ed_ed2_none; + + g_ip->F_sz_nm = tech_node; + g_ip->F_sz_um = tech_node / 1000; + g_ip->is_main_mem = (main_mem != 0) ? true : false; + g_ip->is_cache = (cache ==1) ? true : false; + g_ip->pure_ram = (cache ==0) ? true : false; + g_ip->pure_cam = (cache ==2) ? true : false; + g_ip->rpters_in_htree = (REPEATERS_IN_HTREE_SEGMENTS_in != 0) ? true : false; + g_ip->ver_htree_wires_over_array = VERTICAL_HTREE_WIRES_OVER_THE_ARRAY_in; + g_ip->broadcast_addr_din_over_ver_htrees = BROADCAST_ADDR_DATAIN_OVER_VERTICAL_HTREES_in; + + g_ip->num_rw_ports = rw_ports; + g_ip->num_rd_ports = excl_read_ports; + g_ip->num_wr_ports = excl_write_ports; + g_ip->num_se_rd_ports = single_ended_read_ports; + g_ip->num_search_ports = search_ports; + + g_ip->print_detail = 1; + g_ip->nuca = 0; + + if (force_wiretype == 0) + { + g_ip->wt = Global; + g_ip->force_wiretype = false; + } + else + { g_ip->force_wiretype = true; + if (wiretype==10) { + g_ip->wt = Global_10; + } + if (wiretype==20) { + g_ip->wt = Global_20; + } + if (wiretype==30) { + g_ip->wt = Global_30; + } + if (wiretype==5) { + g_ip->wt = Global_5; + } + if (wiretype==0) { + g_ip->wt = Low_swing; + } + } + //g_ip->wt = Global_5; + if (force_config == 0) + { + g_ip->force_cache_config = false; + } + else + { + g_ip->force_cache_config = true; + g_ip->ndbl=ndbl; + g_ip->ndwl=ndwl; + g_ip->nspd=nspd; + g_ip->ndcm=ndcm; + g_ip->ndsam1=ndsam1; + g_ip->ndsam2=ndsam2; + + + } + + if (ecc==0){ + g_ip->add_ecc_b_=false; + } + else + { + g_ip->add_ecc_b_=true; + } + + //CACTI3DD + g_ip->is_3d_mem = is_3d_dram; + g_ip->burst_depth = burst_depth; + g_ip->io_width =IO_width; + g_ip->sys_freq_MHz = sys_freq; + g_ip->print_detail_debug = debug_detail; + + g_ip->tsv_is_subarray_type = tsv_gran_is_subarray; + g_ip->tsv_os_bank_type = tsv_gran_os_bank; + + g_ip->partition_gran = partition_level; + g_ip->num_tier_row_sprd = num_tier_row_sprd; + g_ip->num_tier_col_sprd = num_tier_col_sprd; + if(partition_level == 3) + g_ip->fine_gran_bank_lvl = true; + else + g_ip->fine_gran_bank_lvl = false; + + if(!g_ip->error_checking()) + exit(0); + + init_tech_params(g_ip->F_sz_um, false); + Wire winit; // Do not delete this line. It initializes wires. + + //tsv + //TSV tsv_test(Coarse); + //tsv_test.print_TSV(); + + g_ip->display_ip(); + solve(&fin_res); + output_UCA(&fin_res); + output_data_csv_3dd(fin_res); + delete (g_ip); + + return fin_res; +} + +//cacti6.5's plain interface, please keep !!! +uca_org_t cacti_interface( + int cache_size, + int line_size, + int associativity, + int rw_ports, + int excl_read_ports, + int excl_write_ports, + int single_ended_read_ports, + int banks, + double tech_node, // in nm + int page_sz, + int burst_length, + int pre_width, + int output_width, + int specific_tag, + int tag_width, + int access_mode, //0 normal, 1 seq, 2 fast + int cache, //scratch ram or cache + int main_mem, + int obj_func_delay, + int obj_func_dynamic_power, + int obj_func_leakage_power, + int obj_func_area, + int obj_func_cycle_time, + int dev_func_delay, + int dev_func_dynamic_power, + int dev_func_leakage_power, + int dev_func_area, + int dev_func_cycle_time, + int ed_ed2_none, // 0 - ED, 1 - ED^2, 2 - use weight and deviate + int temp, + int wt, //0 - default(search across everything), 1 - global, 2 - 5% delay penalty, 3 - 10%, 4 - 20 %, 5 - 30%, 6 - low-swing + int data_arr_ram_cell_tech_flavor_in, // 0-4 + int data_arr_peri_global_tech_flavor_in, + int tag_arr_ram_cell_tech_flavor_in, + int tag_arr_peri_global_tech_flavor_in, + int interconnect_projection_type_in, // 0 - aggressive, 1 - normal + int wire_inside_mat_type_in, + int wire_outside_mat_type_in, + int is_nuca, // 0 - UCA, 1 - NUCA + int core_count, + int cache_level, // 0 - L2, 1 - L3 + int nuca_bank_count, + int nuca_obj_func_delay, + int nuca_obj_func_dynamic_power, + int nuca_obj_func_leakage_power, + int nuca_obj_func_area, + int nuca_obj_func_cycle_time, + int nuca_dev_func_delay, + int nuca_dev_func_dynamic_power, + int nuca_dev_func_leakage_power, + int nuca_dev_func_area, + int nuca_dev_func_cycle_time, + int REPEATERS_IN_HTREE_SEGMENTS_in,//TODO for now only wires with repeaters are supported + int p_input) +{ + g_ip = new InputParameter(); + g_ip->add_ecc_b_ = true; + + g_ip->data_arr_ram_cell_tech_type = data_arr_ram_cell_tech_flavor_in; + g_ip->data_arr_peri_global_tech_type = data_arr_peri_global_tech_flavor_in; + g_ip->tag_arr_ram_cell_tech_type = tag_arr_ram_cell_tech_flavor_in; + g_ip->tag_arr_peri_global_tech_type = tag_arr_peri_global_tech_flavor_in; + + g_ip->ic_proj_type = interconnect_projection_type_in; + g_ip->wire_is_mat_type = wire_inside_mat_type_in; + g_ip->wire_os_mat_type = wire_outside_mat_type_in; + g_ip->burst_len = burst_length; + g_ip->int_prefetch_w = pre_width; + g_ip->page_sz_bits = page_sz; + + g_ip->cache_sz = cache_size; + g_ip->line_sz = line_size; + g_ip->assoc = associativity; + g_ip->nbanks = banks; + g_ip->out_w = output_width; + g_ip->specific_tag = specific_tag; + if (tag_width == 0) { + g_ip->tag_w = 42; + } + else { + g_ip->tag_w = tag_width; + } + + g_ip->access_mode = access_mode; + g_ip->delay_wt = obj_func_delay; + g_ip->dynamic_power_wt = obj_func_dynamic_power; + g_ip->leakage_power_wt = obj_func_leakage_power; + g_ip->area_wt = obj_func_area; + g_ip->cycle_time_wt = obj_func_cycle_time; + g_ip->delay_dev = dev_func_delay; + g_ip->dynamic_power_dev = dev_func_dynamic_power; + g_ip->leakage_power_dev = dev_func_leakage_power; + g_ip->area_dev = dev_func_area; + g_ip->cycle_time_dev = dev_func_cycle_time; + g_ip->ed = ed_ed2_none; + + switch(wt) { + case (0): + g_ip->force_wiretype = 0; + g_ip->wt = Global; + break; + case (1): + g_ip->force_wiretype = 1; + g_ip->wt = Global; + break; + case (2): + g_ip->force_wiretype = 1; + g_ip->wt = Global_5; + break; + case (3): + g_ip->force_wiretype = 1; + g_ip->wt = Global_10; + break; + case (4): + g_ip->force_wiretype = 1; + g_ip->wt = Global_20; + break; + case (5): + g_ip->force_wiretype = 1; + g_ip->wt = Global_30; + break; + case (6): + g_ip->force_wiretype = 1; + g_ip->wt = Low_swing; + break; + default: + cout << "Unknown wire type!\n"; + exit(0); + } + + g_ip->delay_wt_nuca = nuca_obj_func_delay; + g_ip->dynamic_power_wt_nuca = nuca_obj_func_dynamic_power; + g_ip->leakage_power_wt_nuca = nuca_obj_func_leakage_power; + g_ip->area_wt_nuca = nuca_obj_func_area; + g_ip->cycle_time_wt_nuca = nuca_obj_func_cycle_time; + g_ip->delay_dev_nuca = dev_func_delay; + g_ip->dynamic_power_dev_nuca = nuca_dev_func_dynamic_power; + g_ip->leakage_power_dev_nuca = nuca_dev_func_leakage_power; + g_ip->area_dev_nuca = nuca_dev_func_area; + g_ip->cycle_time_dev_nuca = nuca_dev_func_cycle_time; + g_ip->nuca = is_nuca; + g_ip->nuca_bank_count = nuca_bank_count; + if(nuca_bank_count > 0) { + g_ip->force_nuca_bank = 1; + } + g_ip->cores = core_count; + g_ip->cache_level = cache_level; + + g_ip->temp = temp; + + g_ip->F_sz_nm = tech_node; + g_ip->F_sz_um = tech_node / 1000; + g_ip->is_main_mem = (main_mem != 0) ? true : false; + g_ip->is_cache = (cache != 0) ? true : false; + g_ip->rpters_in_htree = (REPEATERS_IN_HTREE_SEGMENTS_in != 0) ? true : false; + + g_ip->num_rw_ports = rw_ports; + g_ip->num_rd_ports = excl_read_ports; + g_ip->num_wr_ports = excl_write_ports; + g_ip->num_se_rd_ports = single_ended_read_ports; + g_ip->print_detail = 1; + g_ip->nuca = 0; + + g_ip->wt = Global_5; + g_ip->force_cache_config = false; + g_ip->force_wiretype = false; + g_ip->print_input_args = p_input; + + + uca_org_t fin_res; + fin_res.valid = false; + + if (g_ip->error_checking() == false) exit(0); + if (g_ip->print_input_args) + g_ip->display_ip(); + init_tech_params(g_ip->F_sz_um, false); + Wire winit; // Do not delete this line. It initializes wires. + + if (g_ip->nuca == 1) + { + Nuca n(&g_tp.peri_global); + n.sim_nuca(); + } + solve(&fin_res); + + output_UCA(&fin_res); + + delete (g_ip); + return fin_res; +} + +//McPAT's plain interface, please keep !!! +uca_org_t cacti_interface( + int cache_size, + int line_size, + int associativity, + int rw_ports, + int excl_read_ports,// para5 + int excl_write_ports, + int single_ended_read_ports, + int search_ports, + int banks, + double tech_node,//para10 + int output_width, + int specific_tag, + int tag_width, + int access_mode, + int cache, //para15 + int main_mem, + int obj_func_delay, + int obj_func_dynamic_power, + int obj_func_leakage_power, + int obj_func_cycle_time, //para20 + int obj_func_area, + int dev_func_delay, + int dev_func_dynamic_power, + int dev_func_leakage_power, + int dev_func_area, //para25 + int dev_func_cycle_time, + int ed_ed2_none, // 0 - ED, 1 - ED^2, 2 - use weight and deviate + int temp, + int wt, //0 - default(search across everything), 1 - global, 2 - 5% delay penalty, 3 - 10%, 4 - 20 %, 5 - 30%, 6 - low-swing + int data_arr_ram_cell_tech_flavor_in,//para30 + int data_arr_peri_global_tech_flavor_in, + int tag_arr_ram_cell_tech_flavor_in, + int tag_arr_peri_global_tech_flavor_in, + int interconnect_projection_type_in, + int wire_inside_mat_type_in,//para35 + int wire_outside_mat_type_in, + int REPEATERS_IN_HTREE_SEGMENTS_in, + int VERTICAL_HTREE_WIRES_OVER_THE_ARRAY_in, + int BROADCAST_ADDR_DATAIN_OVER_VERTICAL_HTREES_in, + int PAGE_SIZE_BITS_in,//para40 + int BURST_LENGTH_in, + int INTERNAL_PREFETCH_WIDTH_in, + int force_wiretype, + int wiretype, + int force_config,//para45 + int ndwl, + int ndbl, + int nspd, + int ndcm, + int ndsam1,//para50 + int ndsam2, + int ecc) +{ + g_ip = new InputParameter(); + + uca_org_t fin_res; + fin_res.valid = false; + + g_ip->data_arr_ram_cell_tech_type = data_arr_ram_cell_tech_flavor_in; + g_ip->data_arr_peri_global_tech_type = data_arr_peri_global_tech_flavor_in; + g_ip->tag_arr_ram_cell_tech_type = tag_arr_ram_cell_tech_flavor_in; + g_ip->tag_arr_peri_global_tech_type = tag_arr_peri_global_tech_flavor_in; + + g_ip->ic_proj_type = interconnect_projection_type_in; + g_ip->wire_is_mat_type = wire_inside_mat_type_in; + g_ip->wire_os_mat_type = wire_outside_mat_type_in; + g_ip->burst_len = BURST_LENGTH_in; + g_ip->int_prefetch_w = INTERNAL_PREFETCH_WIDTH_in; + g_ip->page_sz_bits = PAGE_SIZE_BITS_in; + + g_ip->cache_sz = cache_size; + g_ip->line_sz = line_size; + g_ip->assoc = associativity; + g_ip->nbanks = banks; + g_ip->out_w = output_width; + g_ip->specific_tag = specific_tag; + if (specific_tag == 0) { + g_ip->tag_w = 42; + } + else { + g_ip->tag_w = tag_width; + } + + g_ip->access_mode = access_mode; + g_ip->delay_wt = obj_func_delay; + g_ip->dynamic_power_wt = obj_func_dynamic_power; + g_ip->leakage_power_wt = obj_func_leakage_power; + g_ip->area_wt = obj_func_area; + g_ip->cycle_time_wt = obj_func_cycle_time; + g_ip->delay_dev = dev_func_delay; + g_ip->dynamic_power_dev = dev_func_dynamic_power; + g_ip->leakage_power_dev = dev_func_leakage_power; + g_ip->area_dev = dev_func_area; + g_ip->cycle_time_dev = dev_func_cycle_time; + g_ip->temp = temp; + g_ip->ed = ed_ed2_none; + + g_ip->F_sz_nm = tech_node; + g_ip->F_sz_um = tech_node / 1000; + g_ip->is_main_mem = (main_mem != 0) ? true : false; + g_ip->is_cache = (cache ==1) ? true : false; + g_ip->pure_ram = (cache ==0) ? true : false; + g_ip->pure_cam = (cache ==2) ? true : false; + g_ip->rpters_in_htree = (REPEATERS_IN_HTREE_SEGMENTS_in != 0) ? true : false; + g_ip->ver_htree_wires_over_array = VERTICAL_HTREE_WIRES_OVER_THE_ARRAY_in; + g_ip->broadcast_addr_din_over_ver_htrees = BROADCAST_ADDR_DATAIN_OVER_VERTICAL_HTREES_in; + + g_ip->num_rw_ports = rw_ports; + g_ip->num_rd_ports = excl_read_ports; + g_ip->num_wr_ports = excl_write_ports; + g_ip->num_se_rd_ports = single_ended_read_ports; + g_ip->num_search_ports = search_ports; + + g_ip->print_detail = 1; + g_ip->nuca = 0; + + if (force_wiretype == 0) + { + g_ip->wt = Global; + g_ip->force_wiretype = false; + } + else + { g_ip->force_wiretype = true; + if (wiretype==10) { + g_ip->wt = Global_10; + } + if (wiretype==20) { + g_ip->wt = Global_20; + } + if (wiretype==30) { + g_ip->wt = Global_30; + } + if (wiretype==5) { + g_ip->wt = Global_5; + } + if (wiretype==0) { + g_ip->wt = Low_swing; + } + } + //g_ip->wt = Global_5; + if (force_config == 0) + { + g_ip->force_cache_config = false; + } + else + { + g_ip->force_cache_config = true; + g_ip->ndbl=ndbl; + g_ip->ndwl=ndwl; + g_ip->nspd=nspd; + g_ip->ndcm=ndcm; + g_ip->ndsam1=ndsam1; + g_ip->ndsam2=ndsam2; + + + } + + if (ecc==0){ + g_ip->add_ecc_b_=false; + } + else + { + g_ip->add_ecc_b_=true; + } + + + if(!g_ip->error_checking()) + exit(0); + + init_tech_params(g_ip->F_sz_um, false); + Wire winit; // Do not delete this line. It initializes wires. + + g_ip->display_ip(); + solve(&fin_res); + output_UCA(&fin_res); + output_data_csv(fin_res); + delete (g_ip); + + return fin_res; +} + + + +bool InputParameter::error_checking() +{ + int A; + bool seq_access = false; + fast_access = true; + + switch (access_mode) + { + case 0: + seq_access = false; + fast_access = false; + break; + case 1: + seq_access = true; + fast_access = false; + break; + case 2: + seq_access = false; + fast_access = true; + break; + } + + if(is_main_mem) + { + if(ic_proj_type == 0 && !g_ip->is_3d_mem) + { + cerr << "DRAM model supports only conservative interconnect projection!\n\n"; + return false; + } + } + + + uint32_t B = line_sz; + + if (B < 1) + { + cerr << "Block size must >= 1" << endl; + return false; + } + else if (B*8 < out_w) + { + cerr << "Block size must be at least " << out_w/8 << endl; + return false; + } + + if (F_sz_um <= 0) + { + cerr << "Feature size must be > 0" << endl; + return false; + } + else if (F_sz_um > 0.091) + { + cerr << "Feature size must be <= 90 nm" << endl; + return false; + } + + + uint32_t RWP = num_rw_ports; + uint32_t ERP = num_rd_ports; + uint32_t EWP = num_wr_ports; + uint32_t NSER = num_se_rd_ports; + uint32_t SCHP = num_search_ports; + +//TODO: revisit this. This is an important feature. thought this should be used +// // If multiple banks and multiple ports are specified, then if number of ports is less than or equal to +// // the number of banks, we assume that the multiple ports are implemented via the multiple banks. +// // In such a case we assume that each bank has 1 RWP port. +// if ((RWP + ERP + EWP) <= nbanks && nbanks>1) +// { +// RWP = 1; +// ERP = 0; +// EWP = 0; +// NSER = 0; +// } +// else if ((RWP < 0) || (EWP < 0) || (ERP < 0)) +// { +// cerr << "Ports must >=0" << endl; +// return false; +// } +// else if (RWP > 2) +// { +// cerr << "Maximum of 2 read/write ports" << endl; +// return false; +// } +// else if ((RWP+ERP+EWP) < 1) + // Changed to new implementation: + // The number of ports specified at input is per bank + if ((RWP+ERP+EWP) < 1) + { + cerr << "Must have at least one port" << endl; + return false; + } + + if (is_pow2(nbanks) == false) + { + cerr << "Number of subbanks should be greater than or equal to 1 and should be a power of 2" << endl; + return false; + } + + int C = cache_sz/nbanks; + if (C < 64 && !g_ip->is_3d_mem) + { + cerr << "Cache size must >=64" << endl; + return false; + } + +//TODO: revisit this +// if (pure_ram==true && assoc!=1) +// { +// cerr << "Pure RAM must have assoc as 1" << endl; +// return false; +// } + + //fully assoc and cam check + if (is_cache && assoc==0) + fully_assoc =true; + else + fully_assoc = false; + + if (pure_cam==true && assoc!=0) + { + cerr << "Pure CAM must have associativity as 0" << endl; + return false; + } + + if (assoc==0 && (pure_cam==false && is_cache ==false)) + { + cerr << "Only CAM or Fully associative cache can have associativity as 0" << endl; + return false; + } + + if ((fully_assoc==true || pure_cam==true) + && (data_arr_ram_cell_tech_type!= tag_arr_ram_cell_tech_type + || data_arr_peri_global_tech_type != tag_arr_peri_global_tech_type )) + { + cerr << "CAM and fully associative cache must have same device type for both data and tag array" << endl; + return false; + } + + if ((fully_assoc==true || pure_cam==true) + && (data_arr_ram_cell_tech_type== lp_dram || data_arr_ram_cell_tech_type== comm_dram)) + { + cerr << "DRAM based CAM and fully associative cache are not supported" << endl; + return false; + } + + if ((fully_assoc==true || pure_cam==true) + && (is_main_mem==true)) + { + cerr << "CAM and fully associative cache cannot be as main memory" << endl; + return false; + } + + if ((fully_assoc || pure_cam) && SCHP<1) + { + cerr << "CAM and fully associative must have at least 1 search port" << endl; + return false; + } + + if (RWP==0 && ERP==0 && SCHP>0 && ((fully_assoc || pure_cam))) + { + ERP=SCHP; + } + +// if ((!(fully_assoc || pure_cam)) && SCHP>=1) +// { +// cerr << "None CAM and fully associative cannot have search ports" << endl; +// return false; +// } + + if (assoc == 0) + { + A = C/B; + //fully_assoc = true; + } + else + { + if (assoc == 1) + { + A = 1; + //fully_assoc = false; + } + else + { + //fully_assoc = false; + A = assoc; + if (is_pow2(A) == false) + { + cerr << "Associativity must be a power of 2" << endl; + return false; + } + } + } + + if (C/(B*A) <= 1 && assoc!=0 && !g_ip->is_3d_mem) + { + cerr << "Number of sets is too small: " << endl; + cerr << " Need to either increase cache size, or decrease associativity or block size" << endl; + cerr << " (or use fully associative cache)" << endl; + return false; + } + + block_sz = B; + + /*dt: testing sequential access mode*/ + if(seq_access) + { + tag_assoc = A; + data_assoc = 1; + is_seq_acc = true; + } + else + { + tag_assoc = A; + data_assoc = A; + is_seq_acc = false; + } + + if (assoc==0) + { + data_assoc = 1; + } + num_rw_ports = RWP; + num_rd_ports = ERP; + num_wr_ports = EWP; + num_se_rd_ports = NSER; + if (!(fully_assoc || pure_cam)) + num_search_ports = 0; + nsets = C/(B*A); + + if (temp < 300 || temp > 400 || temp%10 != 0) + { + cerr << temp << " Temperature must be between 300 and 400 Kelvin and multiple of 10." << endl; + return false; + } + + if (nsets < 1 && !g_ip->is_3d_mem) + { + cerr << "Less than one set..." << endl; + return false; + } + + power_gating = (array_power_gated + || bitline_floating + || wl_power_gated + || cl_power_gated + || interconect_power_gated)?true:false; + + return true; +} + +void output_data_csv_3dd(const uca_org_t & fin_res) +{ + //TODO: the csv output should remain + fstream file("out.csv", ios::in); + bool print_index = file.fail(); + file.close(); + + file.open("out.csv", ios::out|ios::app); + if (file.fail() == true) + { + cerr << "File out.csv could not be opened successfully" << endl; + } + else + { + //print_index = false; + if (print_index == true) + { + file << "Tech node (nm), "; + file << "Number of tiers, "; + file << "Capacity (MB) per die, "; + file << "Number of banks, "; + file << "Page size in bits, "; + //file << "Output width (bits), "; + file << "Burst depth, "; + file << "IO width, "; + file << "Ndwl, "; + file << "Ndbl, "; + file << "N rows in subarray, "; + file << "N cols in subarray, "; +// file << "Access time (ns), "; +// file << "Random cycle time (ns), "; +// file << "Multisubbank interleave cycle time (ns), "; + +// file << "Delay request network (ns), "; +// file << "Delay inside mat (ns), "; +// file << "Delay reply network (ns), "; +// file << "Tag array access time (ns), "; +// file << "Data array access time (ns), "; +// file << "Refresh period (microsec), "; +// file << "DRAM array availability (%), "; + + + +// file << "Dynamic search energy (nJ), "; +// file << "Dynamic read energy (nJ), "; +// file << "Dynamic write energy (nJ), "; +// file << "Tag Dynamic read energy (nJ), "; +// file << "Data Dynamic read energy (nJ), "; +// file << "Dynamic read power (mW), "; +// file << "Standby leakage per bank(mW), "; +// file << "Leakage per bank with leak power management (mW), "; +// file << "Leakage per bank with leak power management (mW), "; +// file << "Refresh power as percentage of standby leakage, "; + file << "Area (mm2), "; + +// file << "Nspd, "; +// file << "Ndcm, "; +// file << "Ndsam_level_1, "; +// file << "Ndsam_level_2, "; + file << "Data arrary area efficiency %, "; +// file << "Ntwl, "; +// file << "Ntbl, "; +// file << "Ntspd, "; +// file << "Ntcm, "; +// file << "Ntsam_level_1, "; +// file << "Ntsam_level_2, "; +// file << "Tag arrary area efficiency %, "; + +// file << "Resistance per unit micron (ohm-micron), "; +// file << "Capacitance per unit micron (fF per micron), "; +// file << "Unit-length wire delay (ps), "; +// file << "FO4 delay (ps), "; +// file << "delay route to bank (including crossb delay) (ps), "; +// file << "Crossbar delay (ps), "; +// file << "Dyn read energy per access from closed page (nJ), "; +// file << "Dyn read energy per access from open page (nJ), "; +// file << "Leak power of an subbank with page closed (mW), "; +// file << "Leak power of a subbank with page open (mW), "; +// file << "Leak power of request and reply networks (mW), "; +// file << "Number of subbanks, "; + + file << "Number of TSVs in total, "; + file << "Delay of TSVs (ns) worst case, "; + file << "Area of TSVs (mm2) in total, "; + file << "Energy of TSVs (nJ) per access, "; + + file << "t_RCD (ns), "; + file << "t_RAS (ns), "; + file << "t_RC (ns), "; + file << "t_CAS (ns), "; + file << "t_RP (ns), "; + + + file << "Activate energy (nJ), "; + file << "Read energy (nJ), "; + file << "Write energy (nJ), "; + file << "Precharge energy (nJ), "; + //file << "tRCD, "; + //file << "CAS latency, "; + //file << "Precharge delay, "; +// file << "Perc dyn energy bitlines, "; +// file << "perc dyn energy wordlines, "; +// file << "perc dyn energy outside mat, "; +// file << "Area opt (perc), "; +// file << "Delay opt (perc), "; +// file << "Repeater opt (perc), "; + //file << "Aspect ratio"; + file << "t_RRD (ns), "; + file << "Number tiers for a row, "; + file << "Number tiers for a column, "; + file << "delay_row_activate_net, " ; + file << "delay_row_predecode_driver_and_block, " ; + file << "delay_row_decoder, " ; + file << "delay_local_wordline , " ; + file << "delay_bitlines, " ; + file << "delay_sense_amp, " ; + + file << "delay_column_access_net, " ; + file << "delay_column_predecoder, " ; + file << "delay_column_decoder, " ; + file << "delay_column_selectline, " ; + file << "delay_datapath_net, " ; + file << "delay_global_data, " ; + file << "delay_local_data_and_drv, " ; + file << "delay_data_buffer, " ; + file << "delay_subarray_output_driver, " ; + + file << "energy_row_activate_net, "; + file << "energy_row_predecode_driver_and_block, "; + file << "energy_row_decoder, "; + file << "energy_local_wordline, "; + file << "energy_bitlines, "; + file << "energy_sense_amp, "; + + file << "energy_column_access_net, "; + file << "energy_column_predecoder, "; + file << "energy_column_decoder, "; + file << "energy_column_selectline, "; + file << "energy_datapath_net, "; + file << "energy_global_data, "; + file << "energy_local_data_and_drv, "; + file << "energy_subarray_output_driver, "; + file << "energy_data_buffer, "; + + file << "area_subarray, "; + file << "area_lwl_drv, "; + file << "area_row_predec_dec, "; + file << "area_col_predec_dec, "; + file << "area_bus, "; + file << "area_address_bus, "; + file << "area_data_bus, "; + file << "area_data_drv, "; + file << "area_IOSA, "; + file << endl; + } + file << g_ip->F_sz_nm << ", "; + file << g_ip->num_die_3d << ", "; + file << g_ip->cache_sz * 1024 / g_ip->num_die_3d << ", "; + file << g_ip->nbanks << ", "; + file << g_ip->page_sz_bits << ", " ; +// file << g_ip->tag_assoc << ", "; + //file << g_ip->out_w << ", "; + file << g_ip->burst_depth << ", "; + file << g_ip->io_width << ", "; + + file << fin_res.data_array2->Ndwl << ", "; + file << fin_res.data_array2->Ndbl << ", "; + file << fin_res.data_array2->num_row_subarray << ", "; + file << fin_res.data_array2->num_col_subarray << ", "; +// file << fin_res.access_time*1e+9 << ", "; +// file << fin_res.cycle_time*1e+9 << ", "; +// file << fin_res.data_array2->multisubbank_interleave_cycle_time*1e+9 << ", "; +// file << fin_res.data_array2->delay_request_network*1e+9 << ", "; +// file << fin_res.data_array2->delay_inside_mat*1e+9 << ", "; +// file << fin_res.data_array2.delay_reply_network*1e+9 << ", "; + +// if (!(g_ip->fully_assoc || g_ip->pure_cam || g_ip->pure_ram)) +// { +// file << fin_res.tag_array2->access_time*1e+9 << ", "; +// } +// else +// { +// file << 0 << ", "; +// } +// file << fin_res.data_array2->access_time*1e+9 << ", "; +// file << fin_res.data_array2->dram_refresh_period*1e+6 << ", "; +// file << fin_res.data_array2->dram_array_availability << ", "; +/* if (g_ip->fully_assoc || g_ip->pure_cam) + { + file << fin_res.power.searchOp.dynamic*1e+9 << ", "; + } + else + { + file << "N/A" << ", "; + } + */ +// file << fin_res.power.readOp.dynamic*1e+9 << ", "; +// file << fin_res.power.writeOp.dynamic*1e+9 << ", "; +// if (!(g_ip->fully_assoc || g_ip->pure_cam || g_ip->pure_ram)) +// { +// file << fin_res.tag_array2->power.readOp.dynamic*1e+9 << ", "; +// } +// else +// { +// file << "NA" << ", "; +// } +// file << fin_res.data_array2->power.readOp.dynamic*1e+9 << ", "; +// if (g_ip->fully_assoc || g_ip->pure_cam) +// { +// file << fin_res.power.searchOp.dynamic*1000/fin_res.cycle_time << ", "; +// } +// else +// { +// file << fin_res.power.readOp.dynamic*1000/fin_res.cycle_time << ", "; +// } + +// file <<( fin_res.power.readOp.leakage + fin_res.power.readOp.gate_leakage )*1000 << ", "; +// file << fin_res.leak_power_with_sleep_transistors_in_mats*1000 << ", "; +// file << fin_res.data_array.refresh_power / fin_res.data_array.total_power.readOp.leakage << ", "; + file << fin_res.data_array2->area *1e-6 << ", "; + +// file << fin_res.data_array2->Nspd << ", "; +// file << fin_res.data_array2->deg_bl_muxing << ", "; +// file << fin_res.data_array2->Ndsam_lev_1 << ", "; +// file << fin_res.data_array2->Ndsam_lev_2 << ", "; + file << fin_res.data_array2->area_efficiency << ", "; +/* if (!(g_ip->fully_assoc || g_ip->pure_cam || g_ip->pure_ram)) + { + file << fin_res.tag_array2->Ndwl << ", "; + file << fin_res.tag_array2->Ndbl << ", "; + file << fin_res.tag_array2->Nspd << ", "; + file << fin_res.tag_array2->deg_bl_muxing << ", "; + file << fin_res.tag_array2->Ndsam_lev_1 << ", "; + file << fin_res.tag_array2->Ndsam_lev_2 << ", "; + file << fin_res.tag_array2->area_efficiency << ", "; + } + else + { + file << "N/A" << ", "; + file << "N/A"<< ", "; + file << "N/A" << ", "; + file << "N/A" << ", "; + file << "N/A" << ", "; + file << "N/A" << ", "; + file << "N/A" << ", "; + } +*/ + file << fin_res.data_array2->num_TSV_tot << ", "; + file << fin_res.data_array2->delay_TSV_tot *1e9 << ", "; + file << fin_res.data_array2->area_TSV_tot *1e-6 << ", "; + file << fin_res.data_array2->dyn_pow_TSV_per_access *1e9 << ", "; + + file << fin_res.data_array2->t_RCD *1e9 << ", "; + file << fin_res.data_array2->t_RAS *1e9 << ", "; + file << fin_res.data_array2->t_RC *1e9 << ", "; + file << fin_res.data_array2->t_CAS *1e9 << ", "; + file << fin_res.data_array2->t_RP *1e9 << ", "; + + + +// file << g_tp.wire_inside_mat.R_per_um << ", "; +// file << g_tp.wire_inside_mat.C_per_um / 1e-15 << ", "; +// file << g_tp.unit_len_wire_del / 1e-12 << ", "; +// file << g_tp.FO4 / 1e-12 << ", "; +// file << fin_res.data_array.delay_route_to_bank / 1e-9 << ", "; +// file << fin_res.data_array.delay_crossbar / 1e-9 << ", "; +// file << fin_res.data_array.dyn_read_energy_from_closed_page / 1e-9 << ", "; +// file << fin_res.data_array.dyn_read_energy_from_open_page / 1e-9 << ", "; +// file << fin_res.data_array.leak_power_subbank_closed_page / 1e-3 << ", "; +// file << fin_res.data_array.leak_power_subbank_open_page / 1e-3 << ", "; +// file << fin_res.data_array.leak_power_request_and_reply_networks / 1e-3 << ", "; +// file << fin_res.data_array.number_subbanks << ", " ; + //file << fin_res.data_array.page_size_in_bits << ", " ; + + file << fin_res.data_array2->activate_energy * 1e9 << ", " ; + file << fin_res.data_array2->read_energy * 1e9 << ", " ; + file << fin_res.data_array2->write_energy * 1e9 << ", " ; + file << fin_res.data_array2->precharge_energy * 1e9 << ", " ; + //file << fin_res.data_array.trcd * 1e9 << ", " ; + //file << fin_res.data_array.cas_latency * 1e9 << ", " ; + //file << fin_res.data_array.precharge_delay * 1e9 << ", " ; + //file << fin_res.data_array.all_banks_height / fin_res.data_array.all_banks_width; + + file << fin_res.data_array2->t_RRD * 1e9 << ", " ; + file << g_ip->num_tier_row_sprd << ", " ; + file << g_ip->num_tier_col_sprd << ", " ; + + file << fin_res.data_array2->delay_row_activate_net * 1e9 << ", " ; + file << fin_res.data_array2->delay_row_predecode_driver_and_block * 1e9 << ", " ; + file << fin_res.data_array2->delay_row_decoder * 1e9 << ", " ; + file << fin_res.data_array2->delay_local_wordline * 1e9 << ", " ; + file << fin_res.data_array2->delay_bitlines * 1e9 << ", " ; + file << fin_res.data_array2->delay_sense_amp * 1e9 << ", " ; + file << fin_res.data_array2->delay_column_access_net * 1e9 << ", " ; + file << fin_res.data_array2->delay_column_predecoder * 1e9 << ", " ; + file << fin_res.data_array2->delay_column_decoder * 1e9 << ", " ; + file << fin_res.data_array2->delay_column_selectline * 1e9 << ", " ; + file << fin_res.data_array2->delay_datapath_net * 1e9 << ", " ; + file << fin_res.data_array2->delay_global_data * 1e9 << ", " ; + file << fin_res.data_array2->delay_local_data_and_drv * 1e9 << ", " ; + file << fin_res.data_array2->delay_data_buffer * 1e9 << ", " ; + file << fin_res.data_array2->delay_subarray_output_driver * 1e9 << ", " ; + + file << fin_res.data_array2->energy_row_activate_net * 1e9 << ", " ; + file << fin_res.data_array2->energy_row_predecode_driver_and_block * 1e9 << ", " ; + file << fin_res.data_array2->energy_row_decoder * 1e9 << ", " ; + file << fin_res.data_array2->energy_local_wordline * 1e9 << ", " ; + file << fin_res.data_array2->energy_bitlines * 1e9 << ", " ; + file << fin_res.data_array2->energy_sense_amp * 1e9 << ", " ; + + file << fin_res.data_array2->energy_column_access_net * 1e9 << ", " ; + file << fin_res.data_array2->energy_column_predecoder * 1e9 << ", " ; + file << fin_res.data_array2->energy_column_decoder * 1e9 << ", " ; + file << fin_res.data_array2->energy_column_selectline * 1e9 << ", " ; + file << fin_res.data_array2->energy_datapath_net * 1e9 << ", " ; + file << fin_res.data_array2->energy_global_data * 1e9 << ", " ; + file << fin_res.data_array2->energy_local_data_and_drv * 1e9 << ", " ; + file << fin_res.data_array2->energy_subarray_output_driver * 1e9 << ", " ; + file << fin_res.data_array2->energy_data_buffer * 1e9 << ", " ; + + file << fin_res.data_array2->area_subarray / 1e6 << ", " ; + file << fin_res.data_array2->area_lwl_drv / 1e6 << ", " ; + file << fin_res.data_array2->area_row_predec_dec / 1e6 << ", " ; + file << fin_res.data_array2->area_col_predec_dec / 1e6 << ", " ; + file << fin_res.data_array2->area_bus / 1e6 << ", " ; + file << fin_res.data_array2->area_address_bus / 1e6 << ", " ; + file << fin_res.data_array2->area_data_bus / 1e6 << ", " ; + file << fin_res.data_array2->area_data_drv / 1e6 << ", " ; + file << fin_res.data_array2->area_IOSA / 1e6 << ", " ; + file << fin_res.data_array2->area_sense_amp / 1e6 << ", " ; + file<F_sz_nm << ", "; + file << g_ip->cache_sz << ", "; + file << g_ip->nbanks << ", "; + file << g_ip->tag_assoc << ", "; + file << g_ip->out_w << ", "; + file << fin_res.access_time*1e+9 << ", "; + file << fin_res.cycle_time*1e+9 << ", "; +// file << fin_res.data_array2->multisubbank_interleave_cycle_time*1e+9 << ", "; +// file << fin_res.data_array2->delay_request_network*1e+9 << ", "; +// file << fin_res.data_array2->delay_inside_mat*1e+9 << ", "; +// file << fin_res.data_array2.delay_reply_network*1e+9 << ", "; + +// if (!(g_ip->fully_assoc || g_ip->pure_cam || g_ip->pure_ram)) +// { +// file << fin_res.tag_array2->access_time*1e+9 << ", "; +// } +// else +// { +// file << 0 << ", "; +// } +// file << fin_res.data_array2->access_time*1e+9 << ", "; +// file << fin_res.data_array2->dram_refresh_period*1e+6 << ", "; +// file << fin_res.data_array2->dram_array_availability << ", "; + if (g_ip->fully_assoc || g_ip->pure_cam) + { + file << fin_res.power.searchOp.dynamic*1e+9 << ", "; + } + else + { + file << "N/A" << ", "; + } + file << fin_res.power.readOp.dynamic*1e+9 << ", "; + file << fin_res.power.writeOp.dynamic*1e+9 << ", "; +// if (!(g_ip->fully_assoc || g_ip->pure_cam || g_ip->pure_ram)) +// { +// file << fin_res.tag_array2->power.readOp.dynamic*1e+9 << ", "; +// } +// else +// { +// file << "NA" << ", "; +// } +// file << fin_res.data_array2->power.readOp.dynamic*1e+9 << ", "; +// if (g_ip->fully_assoc || g_ip->pure_cam) +// { +// file << fin_res.power.searchOp.dynamic*1000/fin_res.cycle_time << ", "; +// } +// else +// { +// file << fin_res.power.readOp.dynamic*1000/fin_res.cycle_time << ", "; +// } + + file <<( fin_res.power.readOp.leakage + fin_res.power.readOp.gate_leakage )*1000 << ", "; +// file << fin_res.leak_power_with_sleep_transistors_in_mats*1000 << ", "; +// file << fin_res.data_array.refresh_power / fin_res.data_array.total_power.readOp.leakage << ", "; + file << fin_res.area*1e-6 << ", "; + + file << fin_res.data_array2->Ndwl << ", "; + file << fin_res.data_array2->Ndbl << ", "; + file << fin_res.data_array2->Nspd << ", "; + file << fin_res.data_array2->deg_bl_muxing << ", "; + file << fin_res.data_array2->Ndsam_lev_1 << ", "; + file << fin_res.data_array2->Ndsam_lev_2 << ", "; + file << fin_res.data_array2->area_efficiency << ", "; + if (!(g_ip->fully_assoc || g_ip->pure_cam || g_ip->pure_ram)) + { + file << fin_res.tag_array2->Ndwl << ", "; + file << fin_res.tag_array2->Ndbl << ", "; + file << fin_res.tag_array2->Nspd << ", "; + file << fin_res.tag_array2->deg_bl_muxing << ", "; + file << fin_res.tag_array2->Ndsam_lev_1 << ", "; + file << fin_res.tag_array2->Ndsam_lev_2 << ", "; + file << fin_res.tag_array2->area_efficiency << ", "; + } + else + { + file << "N/A" << ", "; + file << "N/A"<< ", "; + file << "N/A" << ", "; + file << "N/A" << ", "; + file << "N/A" << ", "; + file << "N/A" << ", "; + file << "N/A" << ", "; + } + +// file << g_tp.wire_inside_mat.R_per_um << ", "; +// file << g_tp.wire_inside_mat.C_per_um / 1e-15 << ", "; +// file << g_tp.unit_len_wire_del / 1e-12 << ", "; +// file << g_tp.FO4 / 1e-12 << ", "; +// file << fin_res.data_array.delay_route_to_bank / 1e-9 << ", "; +// file << fin_res.data_array.delay_crossbar / 1e-9 << ", "; +// file << fin_res.data_array.dyn_read_energy_from_closed_page / 1e-9 << ", "; +// file << fin_res.data_array.dyn_read_energy_from_open_page / 1e-9 << ", "; +// file << fin_res.data_array.leak_power_subbank_closed_page / 1e-3 << ", "; +// file << fin_res.data_array.leak_power_subbank_open_page / 1e-3 << ", "; +// file << fin_res.data_array.leak_power_request_and_reply_networks / 1e-3 << ", "; +// file << fin_res.data_array.number_subbanks << ", " ; +// file << fin_res.data_array.page_size_in_bits << ", " ; +// file << fin_res.data_array.activate_energy * 1e9 << ", " ; +// file << fin_res.data_array.read_energy * 1e9 << ", " ; +// file << fin_res.data_array.write_energy * 1e9 << ", " ; +// file << fin_res.data_array.precharge_energy * 1e9 << ", " ; +// file << fin_res.data_array.trcd * 1e9 << ", " ; +// file << fin_res.data_array.cas_latency * 1e9 << ", " ; +// file << fin_res.data_array.precharge_delay * 1e9 << ", " ; +// file << fin_res.data_array.all_banks_height / fin_res.data_array.all_banks_width; + file<is_3d_mem) + { + + cout<<"------- CACTI (version "<< VER_MAJOR_CACTI <<"."<< VER_MINOR_CACTI<<"."VER_COMMENT_CACTI + << " of " << VER_UPDATE_CACTI << ") 3D DRAM Main Memory -------"<cache_sz) << endl; + if(g_ip->num_die_3d>1) + { + cout << " Stacked die count: " << (int) g_ip->num_die_3d << endl; + if(g_ip->TSV_proj_type == 1) + cout << " TSV projection: industrial conservative" << endl; + else + cout << " TSV projection: ITRS aggressive" << endl; + } + cout << " Number of banks: " << (int) g_ip->nbanks << endl; + cout << " Technology size (nm): " << g_ip->F_sz_nm << endl; + cout << " Page size (bits): " << g_ip->page_sz_bits << endl; + cout << " Burst depth: " << g_ip->burst_depth << endl; + cout << " Chip IO width: " << g_ip->io_width << endl; + cout << " Best Ndwl: " << fr->data_array2->Ndwl << endl; + cout << " Best Ndbl: " << fr->data_array2->Ndbl << endl; + cout << " # rows in subarray: " << fr->data_array2->num_row_subarray << endl; + cout << " # columns in subarray: " << fr->data_array2->num_col_subarray << endl; + + cout <<"\nResults:\n"; + cout<<"Timing Components:"<data_array2->t_RCD * 1e9 << " ns" <data_array2->t_RAS * 1e9 << " ns" <data_array2->t_RC * 1e9 << " ns" <data_array2->t_CAS * 1e9 << " ns" <data_array2->t_RP* 1e9 << " ns" <data_array2->t_RRD* 1e9 << " ns" <data_array2->t_RRD * 1e9 << " ns"<data_array2->activate_energy * 1e9 << " nJ" <data_array2->read_energy * 1e9 << " nJ" <data_array2->write_energy * 1e9 << " nJ" <data_array2->precharge_energy * 1e9 << " nJ" <data_array2->activate_power * 1e3 << " mW" <data_array2->read_power * 1e3 << " mW" <data_array2->write_power * 1e3 << " mW" <burst_depth)/(g_ip->sys_freq_MHz*1e6)/2) * 1e3 << " mW" <data_array2->area/1e6<<" mm2"<partition_gran>0) ? fr->data_array2->area : (fr->data_array2->area/0.5); + double DRAM_area_per_die = (g_ip->partition_gran>0) ? fr->data_array2->area : (fr->data_array2->area + fr->data_array2->area_ram_cells*0.65); + //double DRAM_area_per_die = (g_ip->partition_gran>0) ? fr->data_array2->area : (fr->data_array2->area + 2.5e9*(double)(g_ip->F_sz_um)*(g_ip->F_sz_um)); + double area_efficiency_per_die = (g_ip->partition_gran>0) ? fr->data_array2->area_efficiency : (fr->data_array2->area_ram_cells / DRAM_area_per_die *100); + double DRAM_width = (g_ip->partition_gran>0) ? fr->data_array2->all_banks_width : (fr->data_array2->all_banks_width + (DRAM_area_per_die-fr->data_array2->area)/fr->data_array2->all_banks_height); + cout<<" DRAM core area: "<< fr->data_array2->area/1e6 <<" mm2"<partition_gran == 0) + cout<<" DRAM area per die: "<< DRAM_area_per_die/1e6 <<" mm2"<data_array2->all_banks_height/1e3 <<" mm"<num_die_3d>1) + { + cout<<"TSV Components:"<data_array2->area_TSV_tot /1e6 <<" mm2"<data_array2->delay_TSV_tot * 1e9 <<" ns"<data_array2->dyn_pow_TSV_per_access * 1e9 <<" nJ"<is_3d_mem) + { + // if (NUCA) + if (0) { + cout << "\n\n Detailed Bank Stats:\n"; + cout << " Bank Size (bytes): %d\n" << + (int) (g_ip->cache_sz); + } + else { + if (g_ip->data_arr_ram_cell_tech_type == 3) { + cout << "\n---------- CACTI (version "<< VER_MAJOR_CACTI <<"."<< VER_MINOR_CACTI<<"."VER_COMMENT_CACTI + << " of " << VER_UPDATE_CACTI << "), Uniform Cache Access " << + "Logic Process Based DRAM Model ----------\n"; + } + else if (g_ip->data_arr_ram_cell_tech_type == 4) { + cout << "\n---------- CACTI (version "<< VER_MAJOR_CACTI <<"."<< VER_MINOR_CACTI<<"."VER_COMMENT_CACTI + << " of " << VER_UPDATE_CACTI << "), Uniform" << + "Cache Access Commodity DRAM Model ----------\n"; + } + else { + cout << "\n---------- CACTI (version "<< VER_MAJOR_CACTI <<"."<< VER_MINOR_CACTI<<"."VER_COMMENT_CACTI + << " of " << VER_UPDATE_CACTI << "), Uniform Cache Access " + "SRAM Model ----------\n"; + } + cout << "\nCache Parameters:\n"; + cout << " Total cache size (bytes): " << + (int) (g_ip->cache_sz) << endl; + } + + cout << " Number of banks: " << (int) g_ip->nbanks << endl; + if (g_ip->fully_assoc|| g_ip->pure_cam) + cout << " Associativity: fully associative\n"; + else { + if (g_ip->tag_assoc == 1) + cout << " Associativity: direct mapped\n"; + else + cout << " Associativity: " << + g_ip->tag_assoc << endl; + } + + + cout << " Block size (bytes): " << g_ip->line_sz << endl; + cout << " Read/write Ports: " << + g_ip->num_rw_ports << endl; + cout << " Read ports: " << + g_ip->num_rd_ports << endl; + cout << " Write ports: " << + g_ip->num_wr_ports << endl; + if (g_ip->fully_assoc|| g_ip->pure_cam) + cout << " search ports: " << + g_ip->num_search_ports << endl; + cout << " Technology size (nm): " << + g_ip->F_sz_nm << endl << endl; + + cout << " Access time (ns): " << fr->access_time*1e9 << endl; + cout << " Cycle time (ns): " << fr->cycle_time*1e9 << endl; + if (g_ip->data_arr_ram_cell_tech_type >= 4) { + cout << " Precharge Delay (ns): " << fr->data_array2->precharge_delay*1e9 << endl; + cout << " Activate Energy (nJ): " << fr->data_array2->activate_energy*1e9 << endl; + cout << " Read Energy (nJ): " << fr->data_array2->read_energy*1e9 << endl; + cout << " Write Energy (nJ): " << fr->data_array2->write_energy*1e9 << endl; + cout << " Precharge Energy (nJ): " << fr->data_array2->precharge_energy*1e9 << endl; + cout << " Leakage Power Closed Page (mW): " << fr->data_array2->leak_power_subbank_closed_page*1e3 << endl; + cout << " Leakage Power Open Page (mW): " << fr->data_array2->leak_power_subbank_open_page*1e3 << endl; + cout << " Leakage Power I/O (mW): " << fr->data_array2->leak_power_request_and_reply_networks*1e3 << endl; + cout << " Refresh power (mW): " << + fr->data_array2->refresh_power*1e3 << endl; + } + else { + if ((g_ip->fully_assoc|| g_ip->pure_cam)) + { + cout << " Total dynamic associative search energy per access (nJ): " << + fr->power.searchOp.dynamic*1e9 << endl; +// cout << " Total dynamic read energy per access (nJ): " << +// fr->power.readOp.dynamic*1e9 << endl; +// cout << " Total dynamic write energy per access (nJ): " << +// fr->power.writeOp.dynamic*1e9 << endl; + } +// else +// { + cout << " Total dynamic read energy per access (nJ): " << + fr->power.readOp.dynamic*1e9 << endl; + cout << " Total dynamic write energy per access (nJ): " << + fr->power.writeOp.dynamic*1e9 << endl; +// } + cout << " Total leakage power of a bank" + " (mW): " << fr->power.readOp.leakage*1e3 << endl; + cout << " Total gate leakage power of a bank" + " (mW): " << fr->power.readOp.gate_leakage*1e3 << endl; + } + + if (g_ip->data_arr_ram_cell_tech_type ==3 || g_ip->data_arr_ram_cell_tech_type ==4) + { + } + cout << " Cache height x width (mm): " << + fr->cache_ht*1e-3 << " x " << fr->cache_len*1e-3 << endl << endl; + + + cout << " Best Ndwl : " << fr->data_array2->Ndwl << endl; + cout << " Best Ndbl : " << fr->data_array2->Ndbl << endl; + cout << " Best Nspd : " << fr->data_array2->Nspd << endl; + cout << " Best Ndcm : " << fr->data_array2->deg_bl_muxing << endl; + cout << " Best Ndsam L1 : " << fr->data_array2->Ndsam_lev_1 << endl; + cout << " Best Ndsam L2 : " << fr->data_array2->Ndsam_lev_2 << endl << endl; + + if ((!(g_ip->pure_ram|| g_ip->pure_cam || g_ip->fully_assoc)) && !g_ip->is_main_mem) + { + cout << " Best Ntwl : " << fr->tag_array2->Ndwl << endl; + cout << " Best Ntbl : " << fr->tag_array2->Ndbl << endl; + cout << " Best Ntspd : " << fr->tag_array2->Nspd << endl; + cout << " Best Ntcm : " << fr->tag_array2->deg_bl_muxing << endl; + cout << " Best Ntsam L1 : " << fr->tag_array2->Ndsam_lev_1 << endl; + cout << " Best Ntsam L2 : " << fr->tag_array2->Ndsam_lev_2 << endl; + } + + switch (fr->data_array2->wt) { + case (0): + cout << " Data array, H-tree wire type: Delay optimized global wires\n"; + break; + case (1): + cout << " Data array, H-tree wire type: Global wires with 5\% delay penalty\n"; + break; + case (2): + cout << " Data array, H-tree wire type: Global wires with 10\% delay penalty\n"; + break; + case (3): + cout << " Data array, H-tree wire type: Global wires with 20\% delay penalty\n"; + break; + case (4): + cout << " Data array, H-tree wire type: Global wires with 30\% delay penalty\n"; + break; + case (5): + cout << " Data array, wire type: Low swing wires\n"; + break; + default: + cout << "ERROR - Unknown wire type " << (int) fr->data_array2->wt <pure_ram|| g_ip->pure_cam || g_ip->fully_assoc)) { + switch (fr->tag_array2->wt) { + case (0): + cout << " Tag array, H-tree wire type: Delay optimized global wires\n"; + break; + case (1): + cout << " Tag array, H-tree wire type: Global wires with 5\% delay penalty\n"; + break; + case (2): + cout << " Tag array, H-tree wire type: Global wires with 10\% delay penalty\n"; + break; + case (3): + cout << " Tag array, H-tree wire type: Global wires with 20\% delay penalty\n"; + break; + case (4): + cout << " Tag array, H-tree wire type: Global wires with 30\% delay penalty\n"; + break; + case (5): + cout << " Tag array, wire type: Low swing wires\n"; + break; + default: + cout << "ERROR - Unknown wire type " << (int) fr->tag_array2->wt <is_3d_mem) + if (g_ip->print_detail) + { + //if(g_ip->fully_assoc) return; + + if (g_ip->is_3d_mem) + { + cout << endl << endl << "3D DRAM Detail Components:" << endl << endl; + cout << endl << "Time Components:" << endl << endl; + cout << "\t row activation bus delay (ns): " << fr->data_array2->delay_row_activate_net*1e9 << endl; + cout << "\t row predecoder delay (ns): " << fr->data_array2->delay_row_predecode_driver_and_block*1e9 << endl; + cout << "\t row decoder delay (ns): " << fr->data_array2->delay_row_decoder*1e9 << endl; + cout << "\t local wordline delay (ns): " << fr->data_array2->delay_local_wordline*1e9 << endl; + cout << "\t bitline delay (ns): " << fr->data_array2->delay_bitlines*1e9 << endl; + cout << "\t sense amp delay (ns): " << fr->data_array2->delay_sense_amp*1e9 << endl; + cout << "\t column access bus delay (ns): " << fr->data_array2->delay_column_access_net*1e9 << endl; + cout << "\t column predecoder delay (ns): " << fr->data_array2->delay_column_predecoder*1e9 << endl; + cout << "\t column decoder delay (ns): " << fr->data_array2->delay_column_decoder*1e9 << endl; + //cout << "\t column selectline delay (ns): " << fr->data_array2->delay_column_selectline*1e9 << endl; + cout << "\t datapath bus delay (ns): " << fr->data_array2->delay_datapath_net*1e9 << endl; + cout << "\t global dataline delay (ns): " << fr->data_array2->delay_global_data*1e9 << endl; + cout << "\t local dataline delay (ns): " << fr->data_array2->delay_local_data_and_drv*1e9 << endl; + cout << "\t data buffer delay (ns): " << fr->data_array2->delay_data_buffer*1e9 << endl; + cout << "\t subarray output driver delay (ns): " << fr->data_array2->delay_subarray_output_driver*1e9 << endl; + + cout << endl << "Energy Components:" << endl << endl; + cout << "\t row activation bus energy (nJ): " << fr->data_array2->energy_row_activate_net*1e9 << endl; + cout << "\t row predecoder energy (nJ): " << fr->data_array2->energy_row_predecode_driver_and_block*1e9 << endl; + cout << "\t row decoder energy (nJ): " << fr->data_array2->energy_row_decoder*1e9 << endl; + cout << "\t local wordline energy (nJ): " << fr->data_array2->energy_local_wordline*1e9 << endl; + cout << "\t bitline energy (nJ): " << fr->data_array2->energy_bitlines*1e9 << endl; + cout << "\t sense amp energy (nJ): " << fr->data_array2->energy_sense_amp*1e9 << endl; + cout << "\t column access bus energy (nJ): " << fr->data_array2->energy_column_access_net*1e9 << endl; + cout << "\t column predecoder energy (nJ): " << fr->data_array2->energy_column_predecoder*1e9 << endl; + cout << "\t column decoder energy (nJ): " << fr->data_array2->energy_column_decoder*1e9 << endl; + cout << "\t column selectline energy (nJ): " << fr->data_array2->energy_column_selectline*1e9 << endl; + cout << "\t datapath bus energy (nJ): " << fr->data_array2->energy_datapath_net*1e9 << endl; + cout << "\t global dataline energy (nJ): " << fr->data_array2->energy_global_data*1e9 << endl; + cout << "\t local dataline energy (nJ): " << fr->data_array2->energy_local_data_and_drv*1e9 << endl; + cout << "\t data buffer energy (nJ): " << fr->data_array2->energy_subarray_output_driver*1e9 << endl; + //cout << "\t subarray output driver energy (nJ): " << fr->data_array2->energy_data_buffer*1e9 << endl; + + cout << endl << "Area Components:" << endl << endl; + //cout << "\t subarray area (mm2): " << fr->data_array2->area_subarray/1e6 << endl; + cout << "\t DRAM cell area (mm2): " << fr->data_array2->area_ram_cells/1e6 << endl; + cout << "\t local WL driver area (mm2): " << fr->data_array2->area_lwl_drv/1e6 << endl; + cout << "\t subarray sense amp area (mm2): " << fr->data_array2->area_sense_amp/1e6 << endl; + cout << "\t row predecoder/decoder area (mm2): " << fr->data_array2->area_row_predec_dec/1e6 << endl; + cout << "\t column predecoder/decoder area (mm2): " << fr->data_array2->area_col_predec_dec/1e6 << endl; + cout << "\t center stripe bus area (mm2): " << fr->data_array2->area_bus/1e6 << endl; + cout << "\t address bus area (mm2): " << fr->data_array2->area_address_bus/1e6 << endl; + cout << "\t data bus area (mm2): " << fr->data_array2->area_data_bus/1e6 << endl; + cout << "\t data driver area (mm2): " << fr->data_array2->area_data_drv/1e6 << endl; + cout << "\t IO secondary sense amp area (mm2): " << fr->data_array2->area_IOSA/1e6 << endl; + cout << "\t TSV area (mm2): "<< fr->data_array2->area_TSV_tot /1e6 << endl; + + } + else //if (!g_ip->is_3d_mem) + { + if (g_ip->power_gating) + { + /* Energy/Power stats */ + cout << endl << endl << "Power-gating Components:" << endl << endl; + /* Data array power-gating stats */ + if (!(g_ip->pure_cam || g_ip->fully_assoc)) + cout << " Data array: " << endl; + else if (g_ip->pure_cam) + cout << " CAM array: " << endl; + else + cout << " Fully associative cache array: " << endl; + + cout << "\t Sub-array Sleep Tx size (um) - " << + fr->data_array2->sram_sleep_tx_width << endl; + + // cout << "\t Sub-array Sleep Tx total size (um) - " << + // fr->data_array2->sram_sleep_tx_width << endl; + + cout << "\t Sub-array Sleep Tx total area (mm^2) - " << + fr->data_array2->sram_sleep_tx_area*1e-6 << endl; + + cout << "\t Sub-array wakeup time (ns) - " << + fr->data_array2->sram_sleep_wakeup_latency*1e9 << endl; + + cout << "\t Sub-array Tx energy (nJ) - " << + fr->data_array2->sram_sleep_wakeup_energy*1e9 << endl; + //+++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + cout << endl; + cout << "\t WL Sleep Tx size (um) - " << + fr->data_array2->wl_sleep_tx_width << endl; + + // cout << "\t WL Sleep total Tx size (um) - " << + // fr->data_array2->wl_sleep_tx_width << endl; + + cout << "\t WL Sleep Tx total area (mm^2) - " << + fr->data_array2->wl_sleep_tx_area*1e-6 << endl; + + cout << "\t WL wakeup time (ns) - " << + fr->data_array2->wl_sleep_wakeup_latency*1e9 << endl; + + cout << "\t WL Tx energy (nJ) - " << + fr->data_array2->wl_sleep_wakeup_energy*1e9 << endl; + //+++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + cout << endl; + cout << "\t BL floating wakeup time (ns) - " << + fr->data_array2->bl_floating_wakeup_latency*1e9 << endl; + + cout << "\t BL floating Tx energy (nJ) - " << + fr->data_array2->bl_floating_wakeup_energy*1e9 << endl; + //+++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + + cout << endl; + + cout << "\t Active mats per access - " << fr->data_array2->num_active_mats<data_array2->num_submarray_mats<pure_ram|| g_ip->pure_cam || g_ip->fully_assoc)) && !g_ip->is_main_mem) + { + cout << " Tag array: " << endl; + cout << "\t Sub-array Sleep Tx size (um) - " << + fr->tag_array2->sram_sleep_tx_width << endl; + + // cout << "\t Sub-array Sleep Tx total size (um) - " << + // fr->tag_array2->sram_sleep_tx_width << endl; + + cout << "\t Sub-array Sleep Tx total area (mm^2) - " << + fr->tag_array2->sram_sleep_tx_area*1e-6 << endl; + + cout << "\t Sub-array wakeup time (ns) - " << + fr->tag_array2->sram_sleep_wakeup_latency*1e9 << endl; + + cout << "\t Sub-array Tx energy (nJ) - " << + fr->tag_array2->sram_sleep_wakeup_energy*1e9 << endl; + //+++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + cout << endl; + cout << "\t WL Sleep Tx size (um) - " << + fr->tag_array2->wl_sleep_tx_width << endl; + + // cout << "\t WL Sleep total Tx size (um) - " << + // fr->tag_array2->wl_sleep_tx_width << endl; + + cout << "\t WL Sleep Tx total area (mm^2) - " << + fr->tag_array2->wl_sleep_tx_area*1e-6 << endl; + + cout << "\t WL wakeup time (ns) - " << + fr->tag_array2->wl_sleep_wakeup_latency*1e9 << endl; + + cout << "\t WL Tx energy (nJ) - " << + fr->tag_array2->wl_sleep_wakeup_energy*1e9 << endl; + //+++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + cout << endl; + cout << "\t BL floating wakeup time (ns) - " << + fr->tag_array2->bl_floating_wakeup_latency*1e9 << endl; + + cout << "\t BL floating Tx energy (nJ) - " << + fr->tag_array2->bl_floating_wakeup_energy*1e9 << endl; + //+++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + cout << endl; + + cout << "\t Active mats per access - " << fr->tag_array2->num_active_mats<tag_array2->num_submarray_mats<data_array2->access_time/1e-9 << endl; + + cout << "\tH-tree input delay (ns): " << + fr->data_array2->delay_route_to_bank * 1e9 + + fr->data_array2->delay_input_htree * 1e9 << endl; + + if (!(g_ip->pure_cam || g_ip->fully_assoc)) + { + cout << "\tDecoder + wordline delay (ns): " << + fr->data_array2->delay_row_predecode_driver_and_block * 1e9 + + fr->data_array2->delay_row_decoder * 1e9 << endl; + } + else + { + cout << "\tCAM search delay (ns): " << + fr->data_array2->delay_matchlines * 1e9 << endl; + } + + cout << "\tBitline delay (ns): " << + fr->data_array2->delay_bitlines/1e-9 << endl; + + cout << "\tSense Amplifier delay (ns): " << + fr->data_array2->delay_sense_amp * 1e9 << endl; + + + cout << "\tH-tree output delay (ns): " << + fr->data_array2->delay_subarray_output_driver * 1e9 + + fr->data_array2->delay_dout_htree * 1e9 << endl; + + if ((!(g_ip->pure_ram|| g_ip->pure_cam || g_ip->fully_assoc)) && !g_ip->is_main_mem) + { + /* tag array stats */ + cout << endl << " Tag side (with Output driver) (ns): " << + fr->tag_array2->access_time/1e-9 << endl; + + cout << "\tH-tree input delay (ns): " << + fr->tag_array2->delay_route_to_bank * 1e9 + + fr->tag_array2->delay_input_htree * 1e9 << endl; + + cout << "\tDecoder + wordline delay (ns): " << + fr->tag_array2->delay_row_predecode_driver_and_block * 1e9 + + fr->tag_array2->delay_row_decoder * 1e9 << endl; + + cout << "\tBitline delay (ns): " << + fr->tag_array2->delay_bitlines/1e-9 << endl; + + cout << "\tSense Amplifier delay (ns): " << + fr->tag_array2->delay_sense_amp * 1e9 << endl; + + cout << "\tComparator delay (ns): " << + fr->tag_array2->delay_comparator * 1e9 << endl; + + cout << "\tH-tree output delay (ns): " << + fr->tag_array2->delay_subarray_output_driver * 1e9 + + fr->tag_array2->delay_dout_htree * 1e9 << endl; + } + + + + /* Energy/Power stats */ + cout << endl << endl << "Power Components:" << endl << endl; + + if (!(g_ip->pure_cam || g_ip->fully_assoc)) + { + cout << " Data array: Total dynamic read energy/access (nJ): " << + fr->data_array2->power.readOp.dynamic * 1e9 << endl; + + cout << "\tTotal energy in H-tree (that includes both " + "address and data transfer) (nJ): " << + (fr->data_array2->power_addr_input_htree.readOp.dynamic + + fr->data_array2->power_data_output_htree.readOp.dynamic + + fr->data_array2->power_routing_to_bank.readOp.dynamic) * 1e9 << endl; + + cout << "\tOutput Htree inside bank Energy (nJ): " << + fr->data_array2->power_data_output_htree.readOp.dynamic * 1e9 << endl; + cout << "\tDecoder (nJ): " << + fr->data_array2->power_row_predecoder_drivers.readOp.dynamic * 1e9 + + fr->data_array2->power_row_predecoder_blocks.readOp.dynamic * 1e9 << endl; + cout << "\tWordline (nJ): " << + fr->data_array2->power_row_decoders.readOp.dynamic * 1e9 << endl; + cout << "\tBitline mux & associated drivers (nJ): " << + fr->data_array2->power_bit_mux_predecoder_drivers.readOp.dynamic * 1e9 + + fr->data_array2->power_bit_mux_predecoder_blocks.readOp.dynamic * 1e9 + + fr->data_array2->power_bit_mux_decoders.readOp.dynamic * 1e9 << endl; + cout << "\tSense amp mux & associated drivers (nJ): " << + fr->data_array2->power_senseamp_mux_lev_1_predecoder_drivers.readOp.dynamic * 1e9 + + fr->data_array2->power_senseamp_mux_lev_1_predecoder_blocks.readOp.dynamic * 1e9 + + fr->data_array2->power_senseamp_mux_lev_1_decoders.readOp.dynamic * 1e9 + + fr->data_array2->power_senseamp_mux_lev_2_predecoder_drivers.readOp.dynamic * 1e9 + + fr->data_array2->power_senseamp_mux_lev_2_predecoder_blocks.readOp.dynamic * 1e9 + + fr->data_array2->power_senseamp_mux_lev_2_decoders.readOp.dynamic * 1e9 << endl; + + cout << "\tBitlines precharge and equalization circuit (nJ): " << + fr->data_array2->power_prechg_eq_drivers.readOp.dynamic * 1e9 << endl; + cout << "\tBitlines (nJ): " << + fr->data_array2->power_bitlines.readOp.dynamic * 1e9 << endl; + cout << "\tSense amplifier energy (nJ): " << + fr->data_array2->power_sense_amps.readOp.dynamic * 1e9 << endl; + cout << "\tSub-array output driver (nJ): " << + fr->data_array2->power_output_drivers_at_subarray.readOp.dynamic * 1e9 << endl; + + cout << "\tTotal leakage power of a bank (mW): " << + fr->data_array2->power.readOp.leakage * 1e3 << endl; + cout << "\tTotal leakage power in H-tree (that includes both " + "address and data network) ((mW)): " << + (fr->data_array2->power_addr_input_htree.readOp.leakage + + fr->data_array2->power_data_output_htree.readOp.leakage + + fr->data_array2->power_routing_to_bank.readOp.leakage) * 1e3 << endl; + + cout << "\tTotal leakage power in cells (mW): " << + (fr->data_array2->array_leakage) * 1e3 << endl; + cout << "\tTotal leakage power in row logic(mW): " << + (fr->data_array2->wl_leakage) * 1e3 << endl; + cout << "\tTotal leakage power in column logic(mW): " << + (fr->data_array2->cl_leakage) * 1e3 << endl; + + cout << "\tTotal gate leakage power in H-tree (that includes both " + "address and data network) ((mW)): " << + (fr->data_array2->power_addr_input_htree.readOp.gate_leakage + + fr->data_array2->power_data_output_htree.readOp.gate_leakage + + fr->data_array2->power_routing_to_bank.readOp.gate_leakage) * 1e3 << endl; + } + + else if (g_ip->pure_cam) + { + + cout << " CAM array:"<data_array2->power.searchOp.dynamic * 1e9 << endl; + cout << "\tTotal energy in H-tree (that includes both " + "match key and data transfer) (nJ): " << + (fr->data_array2->power_htree_in_search.searchOp.dynamic + + fr->data_array2->power_htree_out_search.searchOp.dynamic + + fr->data_array2->power_routing_to_bank.searchOp.dynamic) * 1e9 << endl; + cout << "\tKeyword input and result output Htrees inside bank Energy (nJ): " << + (fr->data_array2->power_htree_in_search.searchOp.dynamic + + fr->data_array2->power_htree_out_search.searchOp.dynamic) * 1e9 << endl; + cout << "\tSearchlines (nJ): " << + fr->data_array2->power_searchline.searchOp.dynamic * 1e9 + + fr->data_array2->power_searchline_precharge.searchOp.dynamic * 1e9 << endl; + cout << "\tMatchlines (nJ): " << + fr->data_array2->power_matchlines.searchOp.dynamic * 1e9 + + fr->data_array2->power_matchline_precharge.searchOp.dynamic * 1e9 << endl; + cout << "\tSub-array output driver (nJ): " << + fr->data_array2->power_output_drivers_at_subarray.searchOp.dynamic * 1e9 << endl; + + + cout <data_array2->power.readOp.dynamic * 1e9 << endl; + cout << "\tTotal energy in H-tree (that includes both " + "address and data transfer) (nJ): " << + (fr->data_array2->power_addr_input_htree.readOp.dynamic + + fr->data_array2->power_data_output_htree.readOp.dynamic + + fr->data_array2->power_routing_to_bank.readOp.dynamic) * 1e9 << endl; + cout << "\tOutput Htree inside bank Energy (nJ): " << + fr->data_array2->power_data_output_htree.readOp.dynamic * 1e9 << endl; + cout << "\tDecoder (nJ): " << + fr->data_array2->power_row_predecoder_drivers.readOp.dynamic * 1e9 + + fr->data_array2->power_row_predecoder_blocks.readOp.dynamic * 1e9 << endl; + cout << "\tWordline (nJ): " << + fr->data_array2->power_row_decoders.readOp.dynamic * 1e9 << endl; + cout << "\tBitline mux & associated drivers (nJ): " << + fr->data_array2->power_bit_mux_predecoder_drivers.readOp.dynamic * 1e9 + + fr->data_array2->power_bit_mux_predecoder_blocks.readOp.dynamic * 1e9 + + fr->data_array2->power_bit_mux_decoders.readOp.dynamic * 1e9 << endl; + cout << "\tSense amp mux & associated drivers (nJ): " << + fr->data_array2->power_senseamp_mux_lev_1_predecoder_drivers.readOp.dynamic * 1e9 + + fr->data_array2->power_senseamp_mux_lev_1_predecoder_blocks.readOp.dynamic * 1e9 + + fr->data_array2->power_senseamp_mux_lev_1_decoders.readOp.dynamic * 1e9 + + fr->data_array2->power_senseamp_mux_lev_2_predecoder_drivers.readOp.dynamic * 1e9 + + fr->data_array2->power_senseamp_mux_lev_2_predecoder_blocks.readOp.dynamic * 1e9 + + fr->data_array2->power_senseamp_mux_lev_2_decoders.readOp.dynamic * 1e9 << endl; + cout << "\tBitlines (nJ): " << + fr->data_array2->power_bitlines.readOp.dynamic * 1e9 + + fr->data_array2->power_prechg_eq_drivers.readOp.dynamic * 1e9<< endl; + cout << "\tSense amplifier energy (nJ): " << + fr->data_array2->power_sense_amps.readOp.dynamic * 1e9 << endl; + cout << "\tSub-array output driver (nJ): " << + fr->data_array2->power_output_drivers_at_subarray.readOp.dynamic * 1e9 << endl; + + cout << endl <<" Total leakage power of a bank (mW): " << + fr->data_array2->power.readOp.leakage * 1e3 << endl; + } + else + { + cout << " Fully associative array:"<data_array2->power.searchOp.dynamic * 1e9 << endl; + cout << "\tTotal energy in H-tree (that includes both " + "match key and data transfer) (nJ): " << + (fr->data_array2->power_htree_in_search.searchOp.dynamic + + fr->data_array2->power_htree_out_search.searchOp.dynamic + + fr->data_array2->power_routing_to_bank.searchOp.dynamic) * 1e9 << endl; + cout << "\tKeyword input and result output Htrees inside bank Energy (nJ): " << + (fr->data_array2->power_htree_in_search.searchOp.dynamic + + fr->data_array2->power_htree_out_search.searchOp.dynamic) * 1e9 << endl; + cout << "\tSearchlines (nJ): " << + fr->data_array2->power_searchline.searchOp.dynamic * 1e9 + + fr->data_array2->power_searchline_precharge.searchOp.dynamic * 1e9 << endl; + cout << "\tMatchlines (nJ): " << + fr->data_array2->power_matchlines.searchOp.dynamic * 1e9 + + fr->data_array2->power_matchline_precharge.searchOp.dynamic * 1e9 << endl; + cout << "\tData portion wordline (nJ): " << + fr->data_array2->power_matchline_to_wordline_drv.searchOp.dynamic * 1e9 << endl; + cout << "\tData Bitlines (nJ): " << + fr->data_array2->power_bitlines.searchOp.dynamic * 1e9 + + fr->data_array2->power_prechg_eq_drivers.searchOp.dynamic * 1e9 << endl; + cout << "\tSense amplifier energy (nJ): " << + fr->data_array2->power_sense_amps.searchOp.dynamic * 1e9 << endl; + cout << "\tSub-array output driver (nJ): " << + fr->data_array2->power_output_drivers_at_subarray.searchOp.dynamic * 1e9 << endl; + + + cout <data_array2->power.readOp.dynamic * 1e9 << endl; + cout << "\tTotal energy in H-tree (that includes both " + "address and data transfer) (nJ): " << + (fr->data_array2->power_addr_input_htree.readOp.dynamic + + fr->data_array2->power_data_output_htree.readOp.dynamic + + fr->data_array2->power_routing_to_bank.readOp.dynamic) * 1e9 << endl; + cout << "\tOutput Htree inside bank Energy (nJ): " << + fr->data_array2->power_data_output_htree.readOp.dynamic * 1e9 << endl; + cout << "\tDecoder (nJ): " << + fr->data_array2->power_row_predecoder_drivers.readOp.dynamic * 1e9 + + fr->data_array2->power_row_predecoder_blocks.readOp.dynamic * 1e9 << endl; + cout << "\tWordline (nJ): " << + fr->data_array2->power_row_decoders.readOp.dynamic * 1e9 << endl; + cout << "\tBitline mux & associated drivers (nJ): " << + fr->data_array2->power_bit_mux_predecoder_drivers.readOp.dynamic * 1e9 + + fr->data_array2->power_bit_mux_predecoder_blocks.readOp.dynamic * 1e9 + + fr->data_array2->power_bit_mux_decoders.readOp.dynamic * 1e9 << endl; + cout << "\tSense amp mux & associated drivers (nJ): " << + fr->data_array2->power_senseamp_mux_lev_1_predecoder_drivers.readOp.dynamic * 1e9 + + fr->data_array2->power_senseamp_mux_lev_1_predecoder_blocks.readOp.dynamic * 1e9 + + fr->data_array2->power_senseamp_mux_lev_1_decoders.readOp.dynamic * 1e9 + + fr->data_array2->power_senseamp_mux_lev_2_predecoder_drivers.readOp.dynamic * 1e9 + + fr->data_array2->power_senseamp_mux_lev_2_predecoder_blocks.readOp.dynamic * 1e9 + + fr->data_array2->power_senseamp_mux_lev_2_decoders.readOp.dynamic * 1e9 << endl; + cout << "\tBitlines (nJ): " << + fr->data_array2->power_bitlines.readOp.dynamic * 1e9 + + fr->data_array2->power_prechg_eq_drivers.readOp.dynamic * 1e9<< endl; + cout << "\tSense amplifier energy (nJ): " << + fr->data_array2->power_sense_amps.readOp.dynamic * 1e9 << endl; + cout << "\tSub-array output driver (nJ): " << + fr->data_array2->power_output_drivers_at_subarray.readOp.dynamic * 1e9 << endl; + + cout << endl <<" Total leakage power of a bank (mW): " << + fr->data_array2->power.readOp.leakage * 1e3 << endl; + } + + + if ((!(g_ip->pure_ram|| g_ip->pure_cam || g_ip->fully_assoc)) && !g_ip->is_main_mem) + { + cout << endl << " Tag array: Total dynamic read energy/access (nJ): " << + fr->tag_array2->power.readOp.dynamic * 1e9 << endl; + cout << "\tTotal leakage read/write power of a bank (mW): " << + fr->tag_array2->power.readOp.leakage * 1e3 << endl; + cout << "\tTotal energy in H-tree (that includes both " + "address and data transfer) (nJ): " << + (fr->tag_array2->power_addr_input_htree.readOp.dynamic + + fr->tag_array2->power_data_output_htree.readOp.dynamic + + fr->tag_array2->power_routing_to_bank.readOp.dynamic) * 1e9 << endl; + + cout << "\tOutput Htree inside a bank Energy (nJ): " << + fr->tag_array2->power_data_output_htree.readOp.dynamic * 1e9 << endl; + cout << "\tDecoder (nJ): " << + fr->tag_array2->power_row_predecoder_drivers.readOp.dynamic * 1e9 + + fr->tag_array2->power_row_predecoder_blocks.readOp.dynamic * 1e9 << endl; + cout << "\tWordline (nJ): " << + fr->tag_array2->power_row_decoders.readOp.dynamic * 1e9 << endl; + cout << "\tBitline mux & associated drivers (nJ): " << + fr->tag_array2->power_bit_mux_predecoder_drivers.readOp.dynamic * 1e9 + + fr->tag_array2->power_bit_mux_predecoder_blocks.readOp.dynamic * 1e9 + + fr->tag_array2->power_bit_mux_decoders.readOp.dynamic * 1e9 << endl; + cout << "\tSense amp mux & associated drivers (nJ): " << + fr->tag_array2->power_senseamp_mux_lev_1_predecoder_drivers.readOp.dynamic * 1e9 + + fr->tag_array2->power_senseamp_mux_lev_1_predecoder_blocks.readOp.dynamic * 1e9 + + fr->tag_array2->power_senseamp_mux_lev_1_decoders.readOp.dynamic * 1e9 + + fr->tag_array2->power_senseamp_mux_lev_2_predecoder_drivers.readOp.dynamic * 1e9 + + fr->tag_array2->power_senseamp_mux_lev_2_predecoder_blocks.readOp.dynamic * 1e9 + + fr->tag_array2->power_senseamp_mux_lev_2_decoders.readOp.dynamic * 1e9 << endl; + cout << "\tBitlines precharge and equalization circuit (nJ): " << + fr->tag_array2->power_prechg_eq_drivers.readOp.dynamic * 1e9 << endl; + cout << "\tBitlines (nJ): " << + fr->tag_array2->power_bitlines.readOp.dynamic * 1e9 << endl; + cout << "\tSense amplifier energy (nJ): " << + fr->tag_array2->power_sense_amps.readOp.dynamic * 1e9 << endl; + cout << "\tSub-array output driver (nJ): " << + fr->tag_array2->power_output_drivers_at_subarray.readOp.dynamic * 1e9 << endl; + + cout << "\tTotal leakage power of a bank (mW): " << + fr->tag_array2->power.readOp.leakage * 1e3 << endl; + cout << "\tTotal leakage power in H-tree (that includes both " + "address and data network) ((mW)): " << + (fr->tag_array2->power_addr_input_htree.readOp.leakage + + fr->tag_array2->power_data_output_htree.readOp.leakage + + fr->tag_array2->power_routing_to_bank.readOp.leakage) * 1e3 << endl; + + cout << "\tTotal leakage power in cells (mW): " << + (fr->tag_array2->array_leakage) * 1e3 << endl; + cout << "\tTotal leakage power in row logic(mW): " << + (fr->tag_array2->wl_leakage) * 1e3 << endl; + cout << "\tTotal leakage power in column logic(mW): " << + (fr->tag_array2->cl_leakage) * 1e3 << endl; + cout << "\tTotal gate leakage power in H-tree (that includes both " + "address and data network) ((mW)): " << + (fr->tag_array2->power_addr_input_htree.readOp.gate_leakage + + fr->tag_array2->power_data_output_htree.readOp.gate_leakage + + fr->tag_array2->power_routing_to_bank.readOp.gate_leakage) * 1e3 << endl; + } + + cout << endl << endl << "Area Components:" << endl << endl; + /* Data array area stats */ + if (!(g_ip->pure_cam || g_ip->fully_assoc)) + cout << " Data array: Area (mm2): " << fr->data_array2->area * 1e-6 << endl; + else if (g_ip->pure_cam) + cout << " CAM array: Area (mm2): " << fr->data_array2->area * 1e-6 << endl; + else + cout << " Fully associative cache array: Area (mm2): " << fr->data_array2->area * 1e-6 << endl; + cout << "\tHeight (mm): " << + fr->data_array2->all_banks_height*1e-3 << endl; + cout << "\tWidth (mm): " << + fr->data_array2->all_banks_width*1e-3 << endl; + if (g_ip->print_detail) { + cout << "\tArea efficiency (Memory cell area/Total area) - " << + fr->data_array2->area_efficiency << " %" << endl; + cout << "\t\tMAT Height (mm): " << + fr->data_array2->mat_height*1e-3 << endl; + cout << "\t\tMAT Length (mm): " << + fr->data_array2->mat_length*1e-3 << endl; + cout << "\t\tSubarray Height (mm): " << + fr->data_array2->subarray_height*1e-3 << endl; + cout << "\t\tSubarray Length (mm): " << + fr->data_array2->subarray_length*1e-3 << endl; + } + + /* Tag array area stats */ + if ((!(g_ip->pure_ram|| g_ip->pure_cam || g_ip->fully_assoc)) && !g_ip->is_main_mem) + { + cout << endl << " Tag array: Area (mm2): " << fr->tag_array2->area * 1e-6 << endl; + cout << "\tHeight (mm): " << + fr->tag_array2->all_banks_height*1e-3 << endl; + cout << "\tWidth (mm): " << + fr->tag_array2->all_banks_width*1e-3 << endl; + if (g_ip->print_detail) + { + cout << "\tArea efficiency (Memory cell area/Total area) - " << + fr->tag_array2->area_efficiency << " %" << endl; + cout << "\t\tMAT Height (mm): " << + fr->tag_array2->mat_height*1e-3 << endl; + cout << "\t\tMAT Length (mm): " << + fr->tag_array2->mat_length*1e-3 << endl; + cout << "\t\tSubarray Height (mm): " << + fr->tag_array2->subarray_height*1e-3 << endl; + cout << "\t\tSubarray Length (mm): " << + fr->tag_array2->subarray_length*1e-3 << endl; + } + } + + }//if (!g_ip->is_3d_mem) + + + + Wire wpr; + wpr.print_wire(); + + //cout << "FO4 = " << g_tp.FO4 << endl; + } +} + +//McPAT's plain interface, please keep !!! +uca_org_t cacti_interface(InputParameter * const local_interface) +{ +// g_ip = new InputParameter(); + //g_ip->add_ecc_b_ = true; + + uca_org_t fin_res; + fin_res.valid = false; + + g_ip = local_interface; + +// g_ip->data_arr_ram_cell_tech_type = data_arr_ram_cell_tech_flavor_in; +// g_ip->data_arr_peri_global_tech_type = data_arr_peri_global_tech_flavor_in; +// g_ip->tag_arr_ram_cell_tech_type = tag_arr_ram_cell_tech_flavor_in; +// g_ip->tag_arr_peri_global_tech_type = tag_arr_peri_global_tech_flavor_in; +// +// g_ip->ic_proj_type = interconnect_projection_type_in; +// g_ip->wire_is_mat_type = wire_inside_mat_type_in; +// g_ip->wire_os_mat_type = wire_outside_mat_type_in; +// g_ip->burst_len = BURST_LENGTH_in; +// g_ip->int_prefetch_w = INTERNAL_PREFETCH_WIDTH_in; +// g_ip->page_sz_bits = PAGE_SIZE_BITS_in; +// +// g_ip->cache_sz = cache_size; +// g_ip->line_sz = line_size; +// g_ip->assoc = associativity; +// g_ip->nbanks = banks; +// g_ip->out_w = output_width; +// g_ip->specific_tag = specific_tag; +// if (tag_width == 0) { +// g_ip->tag_w = 42; +// } +// else { +// g_ip->tag_w = tag_width; +// } +// +// g_ip->access_mode = access_mode; +// g_ip->delay_wt = obj_func_delay; +// g_ip->dynamic_power_wt = obj_func_dynamic_power; +// g_ip->leakage_power_wt = obj_func_leakage_power; +// g_ip->area_wt = obj_func_area; +// g_ip->cycle_time_wt = obj_func_cycle_time; +// g_ip->delay_dev = dev_func_delay; +// g_ip->dynamic_power_dev = dev_func_dynamic_power; +// g_ip->leakage_power_dev = dev_func_leakage_power; +// g_ip->area_dev = dev_func_area; +// g_ip->cycle_time_dev = dev_func_cycle_time; +// g_ip->temp = temp; +// +// g_ip->F_sz_nm = tech_node; +// g_ip->F_sz_um = tech_node / 1000; +// g_ip->is_main_mem = (main_mem != 0) ? true : false; +// g_ip->is_cache = (cache ==1) ? true : false; +// g_ip->pure_ram = (cache ==0) ? true : false; +// g_ip->pure_cam = (cache ==2) ? true : false; +// g_ip->rpters_in_htree = (REPEATERS_IN_HTREE_SEGMENTS_in != 0) ? true : false; +// g_ip->ver_htree_wires_over_array = VERTICAL_HTREE_WIRES_OVER_THE_ARRAY_in; +// g_ip->broadcast_addr_din_over_ver_htrees = BROADCAST_ADDR_DATAIN_OVER_VERTICAL_HTREES_in; +// +// g_ip->num_rw_ports = rw_ports; +// g_ip->num_rd_ports = excl_read_ports; +// g_ip->num_wr_ports = excl_write_ports; +// g_ip->num_se_rd_ports = single_ended_read_ports; +// g_ip->num_search_ports = search_ports; +// +// g_ip->print_detail = 1; +// g_ip->nuca = 0; +// g_ip->is_cache=true; +// +// if (force_wiretype == 0) +// { +// g_ip->wt = Global; +// g_ip->force_wiretype = false; +// } +// else +// { g_ip->force_wiretype = true; +// if (wiretype==10) { +// g_ip->wt = Global_10; +// } +// if (wiretype==20) { +// g_ip->wt = Global_20; +// } +// if (wiretype==30) { +// g_ip->wt = Global_30; +// } +// if (wiretype==5) { +// g_ip->wt = Global_5; +// } +// if (wiretype==0) { +// g_ip->wt = Low_swing; +// } +// } +// //g_ip->wt = Global_5; +// if (force_config == 0) +// { +// g_ip->force_cache_config = false; +// } +// else +// { +// g_ip->force_cache_config = true; +// g_ip->ndbl=ndbl; +// g_ip->ndwl=ndwl; +// g_ip->nspd=nspd; +// g_ip->ndcm=ndcm; +// g_ip->ndsam1=ndsam1; +// g_ip->ndsam2=ndsam2; +// +// +// } +// +// if (ecc==0){ +// g_ip->add_ecc_b_=false; +// } +// else +// { +// g_ip->add_ecc_b_=true; +// } + + + g_ip->error_checking(); + + init_tech_params(g_ip->F_sz_um, false); + Wire winit; // Do not delete this line. It initializes wires. + + solve(&fin_res); + +// g_ip->display_ip(); +// output_UCA(&fin_res); +// output_data_csv(fin_res); + + // delete (g_ip); + + return fin_res; +} + +//McPAT's plain interface, please keep !!! +uca_org_t init_interface(InputParameter* const local_interface) +{ + // g_ip = new InputParameter(); + //g_ip->add_ecc_b_ = true; + + uca_org_t fin_res; + fin_res.valid = false; + + g_ip = local_interface; + + +// g_ip->data_arr_ram_cell_tech_type = data_arr_ram_cell_tech_flavor_in; +// g_ip->data_arr_peri_global_tech_type = data_arr_peri_global_tech_flavor_in; +// g_ip->tag_arr_ram_cell_tech_type = tag_arr_ram_cell_tech_flavor_in; +// g_ip->tag_arr_peri_global_tech_type = tag_arr_peri_global_tech_flavor_in; +// +// g_ip->ic_proj_type = interconnect_projection_type_in; +// g_ip->wire_is_mat_type = wire_inside_mat_type_in; +// g_ip->wire_os_mat_type = wire_outside_mat_type_in; +// g_ip->burst_len = BURST_LENGTH_in; +// g_ip->int_prefetch_w = INTERNAL_PREFETCH_WIDTH_in; +// g_ip->page_sz_bits = PAGE_SIZE_BITS_in; +// +// g_ip->cache_sz = cache_size; +// g_ip->line_sz = line_size; +// g_ip->assoc = associativity; +// g_ip->nbanks = banks; +// g_ip->out_w = output_width; +// g_ip->specific_tag = specific_tag; +// if (tag_width == 0) { +// g_ip->tag_w = 42; +// } +// else { +// g_ip->tag_w = tag_width; +// } +// +// g_ip->access_mode = access_mode; +// g_ip->delay_wt = obj_func_delay; +// g_ip->dynamic_power_wt = obj_func_dynamic_power; +// g_ip->leakage_power_wt = obj_func_leakage_power; +// g_ip->area_wt = obj_func_area; +// g_ip->cycle_time_wt = obj_func_cycle_time; +// g_ip->delay_dev = dev_func_delay; +// g_ip->dynamic_power_dev = dev_func_dynamic_power; +// g_ip->leakage_power_dev = dev_func_leakage_power; +// g_ip->area_dev = dev_func_area; +// g_ip->cycle_time_dev = dev_func_cycle_time; +// g_ip->temp = temp; +// +// g_ip->F_sz_nm = tech_node; +// g_ip->F_sz_um = tech_node / 1000; +// g_ip->is_main_mem = (main_mem != 0) ? true : false; +// g_ip->is_cache = (cache ==1) ? true : false; +// g_ip->pure_ram = (cache ==0) ? true : false; +// g_ip->pure_cam = (cache ==2) ? true : false; +// g_ip->rpters_in_htree = (REPEATERS_IN_HTREE_SEGMENTS_in != 0) ? true : false; +// g_ip->ver_htree_wires_over_array = VERTICAL_HTREE_WIRES_OVER_THE_ARRAY_in; +// g_ip->broadcast_addr_din_over_ver_htrees = BROADCAST_ADDR_DATAIN_OVER_VERTICAL_HTREES_in; +// +// g_ip->num_rw_ports = rw_ports; +// g_ip->num_rd_ports = excl_read_ports; +// g_ip->num_wr_ports = excl_write_ports; +// g_ip->num_se_rd_ports = single_ended_read_ports; +// g_ip->num_search_ports = search_ports; +// +// g_ip->print_detail = 1; +// g_ip->nuca = 0; +// +// if (force_wiretype == 0) +// { +// g_ip->wt = Global; +// g_ip->force_wiretype = false; +// } +// else +// { g_ip->force_wiretype = true; +// if (wiretype==10) { +// g_ip->wt = Global_10; +// } +// if (wiretype==20) { +// g_ip->wt = Global_20; +// } +// if (wiretype==30) { +// g_ip->wt = Global_30; +// } +// if (wiretype==5) { +// g_ip->wt = Global_5; +// } +// if (wiretype==0) { +// g_ip->wt = Low_swing; +// } +// } +// //g_ip->wt = Global_5; +// if (force_config == 0) +// { +// g_ip->force_cache_config = false; +// } +// else +// { +// g_ip->force_cache_config = true; +// g_ip->ndbl=ndbl; +// g_ip->ndwl=ndwl; +// g_ip->nspd=nspd; +// g_ip->ndcm=ndcm; +// g_ip->ndsam1=ndsam1; +// g_ip->ndsam2=ndsam2; +// +// +// } +// +// if (ecc==0){ +// g_ip->add_ecc_b_=false; +// } +// else +// { +// g_ip->add_ecc_b_=true; +// } + + + g_ip->error_checking(); + + init_tech_params(g_ip->F_sz_um, false); + Wire winit; // Do not delete this line. It initializes wires. + //solve(&fin_res); + //g_ip->display_ip(); + + //solve(&fin_res); + //output_UCA(&fin_res); + //output_data_csv(fin_res); + // delete (g_ip); + + return fin_res; +} + +void reconfigure(InputParameter *local_interface, uca_org_t *fin_res) +{ + // Copy the InputParameter to global interface (g_ip) and do error checking. + g_ip = local_interface; + g_ip->error_checking(); + + // Initialize technology parameters + init_tech_params(g_ip->F_sz_um,false); + + Wire winit; // Do not delete this line. It initializes wires. + + // This corresponds to solve() in the initialization process. + update(fin_res); +} + diff --git a/Project_FARSI/cacti_for_FARSI/io.h b/Project_FARSI/cacti_for_FARSI/io.h new file mode 100644 index 00000000..7c82feea --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/io.h @@ -0,0 +1,45 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + +#ifndef __IO_H__ +#define __IO_H__ + + +#include "const.h" +#include "cacti_interface.h" + + +void output_data_csv(const uca_org_t & fin_res, string fn="out.csv"); +void output_UCA(uca_org_t * fin_res); +void output_data_csv_3dd(const uca_org_t & fin_res); + +#endif diff --git a/Project_FARSI/cacti_for_FARSI/lpddr.cfg b/Project_FARSI/cacti_for_FARSI/lpddr.cfg new file mode 100644 index 00000000..c1564875 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/lpddr.cfg @@ -0,0 +1,254 @@ +# Cache size +//-size (bytes) 2048 +//-size (bytes) 4096 +//-size (bytes) 32768 +//-size (bytes) 131072 +//-size (bytes) 262144 +//-size (bytes) 1048576 +//-size (bytes) 2097152 +//-size (bytes) 4194304 +-size (bytes) 8388608 +//-size (bytes) 16777216 +//-size (bytes) 33554432 +//-size (bytes) 134217728 +//-size (bytes) 67108864 +//-size (bytes) 1073741824 + +# power gating +-Array Power Gating - "false" +-WL Power Gating - "false" +-CL Power Gating - "false" +-Bitline floating - "false" +-Interconnect Power Gating - "false" +-Power Gating Performance Loss 0.01 + +# Line size +//-block size (bytes) 8 +-block size (bytes) 64 + +# To model Fully Associative cache, set associativity to zero +//-associativity 0 +//-associativity 2 +//-associativity 4 +//-associativity 8 +-associativity 8 + +-read-write port 1 +-exclusive read port 0 +-exclusive write port 0 +-single ended read ports 0 + +# Multiple banks connected using a bus +-UCA bank count 1 +-technology (u) 0.022 +//-technology (u) 0.040 +//-technology (u) 0.032 +//-technology (u) 0.090 + +# following three parameters are meaningful only for main memories + +-page size (bits) 8192 +-burst length 8 +-internal prefetch width 8 + +# following parameter can have one of five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +-Data array cell type - "itrs-hp" +//-Data array cell type - "itrs-lstp" +//-Data array cell type - "itrs-lop" + +# following parameter can have one of three values -- (itrs-hp, itrs-lstp, itrs-lop) +-Data array peripheral type - "itrs-hp" +//-Data array peripheral type - "itrs-lstp" +//-Data array peripheral type - "itrs-lop" + +# following parameter can have one of five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +-Tag array cell type - "itrs-hp" +//-Tag array cell type - "itrs-lstp" +//-Tag array cell type - "itrs-lop" + +# following parameter can have one of three values -- (itrs-hp, itrs-lstp, itrs-lop) +-Tag array peripheral type - "itrs-hp" +//-Tag array peripheral type - "itrs-lstp" +//-Tag array peripheral type - "itrs-lop + +# Bus width include data bits and address bits required by the decoder +//-output/input bus width 16 +-output/input bus width 512 + +// 300-400 in steps of 10 +-operating temperature (K) 360 + +# Type of memory - cache (with a tag array) or ram (scratch ram similar to a register file) +# or main memory (no tag array and every access will happen at a page granularity Ref: CACTI 5.3 report) +-cache type "cache" +//-cache type "ram" +//-cache type "main memory" + +# to model special structure like branch target buffers, directory, etc. +# change the tag size parameter +# if you want cacti to calculate the tagbits, set the tag size to "default" +-tag size (b) "default" +//-tag size (b) 22 + +# fast - data and tag access happen in parallel +# sequential - data array is accessed after accessing the tag array +# normal - data array lookup and tag access happen in parallel +# final data block is broadcasted in data array h-tree +# after getting the signal from the tag array +//-access mode (normal, sequential, fast) - "fast" +-access mode (normal, sequential, fast) - "normal" +//-access mode (normal, sequential, fast) - "sequential" + + +# DESIGN OBJECTIVE for UCA (or banks in NUCA) +-design objective (weight delay, dynamic power, leakage power, cycle time, area) 0:0:0:100:0 + +# Percentage deviation from the minimum value +# Ex: A deviation value of 10:1000:1000:1000:1000 will try to find an organization +# that compromises at most 10% delay. +# NOTE: Try reasonable values for % deviation. Inconsistent deviation +# percentage values will not produce any valid organizations. For example, +# 0:0:100:100:100 will try to identify an organization that has both +# least delay and dynamic power. Since such an organization is not possible, CACTI will +# throw an error. Refer CACTI-6 Technical report for more details +-deviate (delay, dynamic power, leakage power, cycle time, area) 20:100000:100000:100000:100000 + +# Objective for NUCA +-NUCAdesign objective (weight delay, dynamic power, leakage power, cycle time, area) 100:100:0:0:100 +-NUCAdeviate (delay, dynamic power, leakage power, cycle time, area) 10:10000:10000:10000:10000 + +# Set optimize tag to ED or ED^2 to obtain a cache configuration optimized for +# energy-delay or energy-delay sq. product +# Note: Optimize tag will disable weight or deviate values mentioned above +# Set it to NONE to let weight and deviate values determine the +# appropriate cache configuration +//-Optimize ED or ED^2 (ED, ED^2, NONE): "ED" +-Optimize ED or ED^2 (ED, ED^2, NONE): "ED^2" +//-Optimize ED or ED^2 (ED, ED^2, NONE): "NONE" + +-Cache model (NUCA, UCA) - "UCA" +//-Cache model (NUCA, UCA) - "NUCA" + +# In order for CACTI to find the optimal NUCA bank value the following +# variable should be assigned 0. +-NUCA bank count 0 + +# NOTE: for nuca network frequency is set to a default value of +# 5GHz in time.c. CACTI automatically +# calculates the maximum possible frequency and downgrades this value if necessary + +# By default CACTI considers both full-swing and low-swing +# wires to find an optimal configuration. However, it is possible to +# restrict the search space by changing the signaling from "default" to +# "fullswing" or "lowswing" type. +-Wire signaling (fullswing, lowswing, default) - "Global_30" +//-Wire signaling (fullswing, lowswing, default) - "default" +//-Wire signaling (fullswing, lowswing, default) - "lowswing" + +//-Wire inside mat - "global" +-Wire inside mat - "semi-global" +//-Wire outside mat - "global" +-Wire outside mat - "semi-global" + +-Interconnect projection - "conservative" +//-Interconnect projection - "aggressive" + +# Contention in network (which is a function of core count and cache level) is one of +# the critical factor used for deciding the optimal bank count value +# core count can be 4, 8, or 16 +//-Core count 4 +-Core count 8 +//-Core count 16 +-Cache level (L2/L3) - "L3" + +-Add ECC - "true" + +//-Print level (DETAILED, CONCISE) - "CONCISE" +-Print level (DETAILED, CONCISE) - "DETAILED" + +# for debugging +//-Print input parameters - "true" +-Print input parameters - "false" +# force CACTI to model the cache with the +# following Ndbl, Ndwl, Nspd, Ndsam, +# and Ndcm values +//-Force cache config - "true" +-Force cache config - "false" +-Ndwl 1 +-Ndbl 1 +-Nspd 0 +-Ndcm 1 +-Ndsam1 0 +-Ndsam2 0 + + + +#### Default CONFIGURATION values for baseline external IO parameters to DRAM. More details can be found in the CACTI-IO technical report (), especially Chapters 2 and 3. + +# Memory Type (D=DDR3, L=LPDDR2, W=WideIO). Additional memory types can be defined by the user in extio_technology.cc, along with their technology and configuration parameters. + +//-dram_type "D" +-dram_type "L" +//-dram_type "W" +//-dram_type "S" + +# Memory State (R=Read, W=Write, I=Idle or S=Sleep) + +//-iostate "R" +-iostate "W" +//-iostate "I" +//-iostate "S" + +#Address bus timing. To alleviate the timing on the command and address bus due to high loading (shared across all memories on the channel), the interface allows for multi-cycle timing options. + +-addr_timing 0.5 //DDR +//-addr_timing 1.0 //SDR (half of DQ rate) +//-addr_timing 2.0 //2T timing (One fourth of DQ rate) +//-addr_timing 3.0 // 3T timing (One sixth of DQ rate) + +# Memory Density (Gbit per memory/DRAM die) + +-mem_density 8 Gb //Valid values 2^n Gb + +# IO frequency (MHz) (frequency of the external memory interface). + +-bus_freq 533 MHz //As of current memory standards (2013), valid range 0 to 1.5 GHz for DDR3, 0 to 533 MHz for LPDDR2, 0 - 800 MHz for WideIO and 0 - 3 GHz for Low-swing differential. However this can change, and the user is free to define valid ranges based on new memory types or extending beyond existing standards for existing dram types. + +# Duty Cycle (fraction of time in the Memory State defined above) + +-duty_cycle 1.0 //Valid range 0 to 1.0 + +# Activity factor for Data (0->1 transitions) per cycle (for DDR, need to account for the higher activity in this parameter. E.g. max. activity factor for DDR is 1.0, for SDR is 0.5) + +-activity_dq 1.0 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR +#-activity_dq .50 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR + +# Activity factor for Control/Address (0->1 transitions) per cycle (for DDR, need to account for the higher activity in this parameter. E.g. max. activity factor for DDR is 1.0, for SDR is 0.5) + +-activity_ca 1.0 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR, 0 to 0.25 for 2T, and 0 to 0.17 for 3T +#-activity_ca 0.25 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR, 0 to 0.25 for 2T, and 0 to 0.17 for 3T + +# Number of DQ pins + +-num_dq 72 //Number of DQ pins. Includes ECC pins. + +# Number of DQS pins. DQS is a data strobe that is sent along with a small number of data-lanes so the source synchronous timing is local to these DQ bits. Typically, 1 DQS per byte (8 DQ bits) is used. The DQS is also typucally differential, just like the CLK pin. + +-num_dqs 36 //2 x differential pairs. Include ECC pins as well. Valid range 0 to 18. For x4 memories, could have 36 DQS pins. + +# Number of CA pins + +-num_ca 35 //Valid range 0 to 35 pins. +#-num_ca 25 //Valid range 0 to 35 pins. + +# Number of CLK pins. CLK is typically a differential pair. In some cases additional CLK pairs may be used to limit the loading on the CLK pin. + +-num_clk 2 //2 x differential pair. Valid values: 0/2/4. + +# Number of Physical Ranks + +-num_mem_dq 2 //Number of ranks (loads on DQ and DQS) per buffer/register. If multiple LRDIMMs or buffer chips exist, the analysis for capacity and power is reported per buffer/register. + +# Width of the Memory Data Bus + +-mem_data_width 32 //x4 or x8 or x16 or x32 memories. For WideIO upto x128. diff --git a/Project_FARSI/cacti_for_FARSI/main.cc b/Project_FARSI/cacti_for_FARSI/main.cc new file mode 100644 index 00000000..04899f14 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/main.cc @@ -0,0 +1,270 @@ +/*------------------------------------------------------------ + * CACTI 6.5 + * Copyright 2008 Hewlett-Packard Development Corporation + * All Rights Reserved + * + * Permission to use, copy, and modify this software and its documentation is + * hereby granted only under the following terms and conditions. Both the + * above copyright notice and this permission notice must appear in all copies + * of the software, derivative works or modified versions, and any portions + * thereof, and both notices must appear in supporting documentation. + * + * Users of this software agree to the terms and conditions set forth herein, and + * hereby grant back to Hewlett-Packard Company and its affiliated companies ("HP") + * a non-exclusive, unrestricted, royalty-free right and license under any changes, + * enhancements or extensions made to the core functions of the software, including + * but not limited to those affording compatibility with other hardware or software + * environments, but excluding applications which incorporate this software. + * Users further agree to use their best efforts to return to HP any such changes, + * enhancements or extensions that they make and inform HP of noteworthy uses of + * this software. Correspondence should be provided to HP at: + * + * Director of Intellectual Property Licensing + * Office of Strategy and Technology + * Hewlett-Packard Company + * 1501 Page Mill Road + * Palo Alto, California 94304 + * + * This software may be distributed (but not offered for sale or transferred + * for compensation) to third parties, provided such third parties agree to + * abide by the terms and conditions of this notice. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND HP DISCLAIMS ALL + * WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES + * OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL HP + * CORPORATION BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL + * DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR + * PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS + * SOFTWARE. + *------------------------------------------------------------*/ + +#include "io.h" +#include + +#include "Ucache.h" + +using namespace std; + + +int main(int argc,char *argv[]) +{ + + uca_org_t result; + if (argc != 53 && argc != 55 && argc !=64) + { + bool infile_specified = false; + string infile_name(""); + + for (int32_t i = 0; i < argc; i++) + { + if (argv[i] == string("-infile")) + { + infile_specified = true; + i++; + infile_name = argv[i]; + } + } + if (infile_specified == false) + { + cerr << " Invalid arguments -- how to use CACTI:" << endl; + cerr << " 1) cacti -infile " << endl; + cerr << " 2) cacti arg1 ... arg52 -- please refer to the README file" << endl; + cerr << " No. of arguments input - " << argc << endl; + exit(1); + } + else + { + result = cacti_interface(infile_name); + } + } + else if (argc == 53) + { + result = cacti_interface(atoi(argv[ 1]), + atoi(argv[ 2]), + atoi(argv[ 3]), + atoi(argv[ 4]), + atoi(argv[ 5]), + atoi(argv[ 6]), + atoi(argv[ 7]), + atoi(argv[ 8]), + atoi(argv[ 9]), + atof(argv[10]), + atoi(argv[11]), + atoi(argv[12]), + atoi(argv[13]), + atoi(argv[14]), + atoi(argv[15]), + atoi(argv[16]), + atoi(argv[17]), + atoi(argv[18]), + atoi(argv[19]), + atoi(argv[20]), + atoi(argv[21]), + atoi(argv[22]), + atoi(argv[23]), + atoi(argv[24]), + atoi(argv[25]), + atoi(argv[26]), + atoi(argv[27]), + atoi(argv[28]), + atoi(argv[29]), + atoi(argv[30]), + atoi(argv[31]), + atoi(argv[32]), + atoi(argv[33]), + atoi(argv[34]), + atoi(argv[35]), + atoi(argv[36]), + atoi(argv[37]), + atoi(argv[38]), + atoi(argv[39]), + atoi(argv[40]), + atoi(argv[41]), + atoi(argv[42]), + atoi(argv[43]), + atoi(argv[44]), + atoi(argv[45]), + atoi(argv[46]), + atoi(argv[47]), + atoi(argv[48]), + atoi(argv[49]), + atoi(argv[50]), + atoi(argv[51]), + atoi(argv[52])); + } + else if (argc == 55) + { + result = cacti_interface(atoi(argv[ 1]), + atoi(argv[ 2]), + atoi(argv[ 3]), + atoi(argv[ 4]), + atoi(argv[ 5]), + atoi(argv[ 6]), + atoi(argv[ 7]), + atoi(argv[ 8]), + atof(argv[ 9]), + atoi(argv[10]), + atoi(argv[11]), + atoi(argv[12]), + atoi(argv[13]), + atoi(argv[14]), + atoi(argv[15]), + atoi(argv[16]), + atoi(argv[17]), + atoi(argv[18]), + atoi(argv[19]), + atoi(argv[20]), + atoi(argv[21]), + atoi(argv[22]), + atoi(argv[23]), + atoi(argv[24]), + atoi(argv[25]), + atoi(argv[26]), + atoi(argv[27]), + atoi(argv[28]), + atoi(argv[29]), + atoi(argv[30]), + atoi(argv[31]), + atoi(argv[32]), + atoi(argv[33]), + atoi(argv[34]), + atoi(argv[35]), + atoi(argv[36]), + atoi(argv[37]), + atoi(argv[38]), + atoi(argv[39]), + atoi(argv[40]), + atoi(argv[41]), + atoi(argv[42]), + atoi(argv[43]), + atoi(argv[44]), + atoi(argv[45]), + atoi(argv[46]), + atoi(argv[47]), + atoi(argv[48]), + atoi(argv[49]), + atoi(argv[50]), + atoi(argv[51]), + atoi(argv[52]), + atoi(argv[53]), + atoi(argv[54])); + } + else if (argc == 64) + { + result = cacti_interface(atoi(argv[ 1]), + atoi(argv[ 2]), + atoi(argv[ 3]), + atoi(argv[ 4]), + atoi(argv[ 5]), + atoi(argv[ 6]), + atoi(argv[ 7]), + atoi(argv[ 8]), + atof(argv[ 9]), + atoi(argv[10]), + atoi(argv[11]), + atoi(argv[12]), + atoi(argv[13]), + atoi(argv[14]), + atoi(argv[15]), + atoi(argv[16]), + atoi(argv[17]), + atoi(argv[18]), + atoi(argv[19]), + atoi(argv[20]), + atoi(argv[21]), + atoi(argv[22]), + atoi(argv[23]), + atoi(argv[24]), + atoi(argv[25]), + atoi(argv[26]), + atoi(argv[27]), + atoi(argv[28]), + atoi(argv[29]), + atoi(argv[30]), + atoi(argv[31]), + atoi(argv[32]), + atoi(argv[33]), + atoi(argv[34]), + atoi(argv[35]), + atoi(argv[36]), + atoi(argv[37]), + atoi(argv[38]), + atoi(argv[39]), + atoi(argv[40]), + atoi(argv[41]), + atoi(argv[42]), + atoi(argv[43]), + atoi(argv[44]), + atoi(argv[45]), + atoi(argv[46]), + atoi(argv[47]), + atoi(argv[48]), + atoi(argv[49]), + atoi(argv[50]), + atoi(argv[51]), + atoi(argv[52]), + atoi(argv[53]), + atoi(argv[54]), + atoi(argv[55]), + atoi(argv[56]), + atoi(argv[57]), + atoi(argv[58]), + atoi(argv[59]), + atoi(argv[60]), + atoi(argv[61]), + atoi(argv[62]), + atoi(argv[63])); + } + + cout << "=============================================\n\n"; + // print_g_tp(); //function to test technology paramters. +// g_tp.display(); + result.cleanup(); +// delete result.data_array2; +// if (result.tag_array2!=NULL) +// delete result.tag_array2; + + return 0; +} + diff --git a/Project_FARSI/cacti_for_FARSI/makefile b/Project_FARSI/cacti_for_FARSI/makefile new file mode 100644 index 00000000..394019fd --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/makefile @@ -0,0 +1,28 @@ +TAR = cacti + +.PHONY: dbg opt depend clean clean_dbg clean_opt + +all: dbg + +dbg: $(TAR).mk obj_dbg + @$(MAKE) TAG=dbg -C . -f $(TAR).mk + +opt: $(TAR).mk obj_opt + @$(MAKE) TAG=opt -C . -f $(TAR).mk + +obj_dbg: + mkdir $@ + +obj_opt: + mkdir $@ + +clean: clean_dbg clean_opt + +clean_dbg: obj_dbg + @$(MAKE) TAG=dbg -C . -f $(TAR).mk clean + rm -rf $< + +clean_opt: obj_opt + @$(MAKE) TAG=opt -C . -f $(TAR).mk clean + rm -rf $< + diff --git a/Project_FARSI/cacti_for_FARSI/mat.cc b/Project_FARSI/cacti_for_FARSI/mat.cc new file mode 100644 index 00000000..f290dafc --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/mat.cc @@ -0,0 +1,1940 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + + +#include "mat.h" +#include + + +Mat::Mat(const DynamicParameter & dyn_p) + :dp(dyn_p), + power_subarray_out_drv(), + delay_fa_tag(0), delay_cam(0), + delay_before_decoder(0), delay_bitline(0), + delay_wl_reset(0), delay_bl_restore(0), + delay_searchline(0), delay_matchchline(0), + delay_cam_sl_restore(0), delay_cam_ml_reset(0), + delay_fa_ram_wl(0),delay_hit_miss_reset(0), + delay_hit_miss(0), + subarray(dp, dp.fully_assoc), + power_bitline(), per_bitline_read_energy(0), + deg_bl_muxing(dp.deg_bl_muxing), + num_act_mats_hor_dir(dyn_p.num_act_mats_hor_dir), + delay_writeback(0), + cell(subarray.cell), cam_cell(subarray.cam_cell), + is_dram(dyn_p.is_dram), + pure_cam(dyn_p.pure_cam), + num_mats(dp.num_mats), + power_sa(), delay_sa(0), + leak_power_sense_amps_closed_page_state(0), + leak_power_sense_amps_open_page_state(0), + delay_subarray_out_drv(0), + delay_comparator(0), power_comparator(), + num_do_b_mat(dyn_p.num_do_b_mat), num_so_b_mat(dyn_p.num_so_b_mat), + num_subarrays_per_mat(dp.num_subarrays/dp.num_mats), + num_subarrays_per_row(dp.Ndwl/dp.num_mats_h_dir), + array_leakage(0), + wl_leakage(0), + cl_leakage(0) + { + assert(num_subarrays_per_mat <= 4); + assert(num_subarrays_per_row <= 2); + is_fa = (dp.fully_assoc) ? true : false; + camFlag = (is_fa || pure_cam);//although cam_cell.w = cell.w for fa, we still differentiate them. + + if (is_fa || pure_cam) + num_subarrays_per_row = num_subarrays_per_mat>2?num_subarrays_per_mat/2:num_subarrays_per_mat; + + if (dp.use_inp_params == 1) { + RWP = dp.num_rw_ports; + ERP = dp.num_rd_ports; + EWP = dp.num_wr_ports; + SCHP = dp.num_search_ports; + } + else { + RWP = g_ip->num_rw_ports; + ERP = g_ip->num_rd_ports; + EWP = g_ip->num_wr_ports; + SCHP = g_ip->num_search_ports; + + } + + double number_sa_subarray; + + if (!is_fa && !pure_cam) + { + number_sa_subarray = subarray.num_cols / deg_bl_muxing; + } + else if (is_fa && !pure_cam) + { + number_sa_subarray = (subarray.num_cols_fa_cam + subarray.num_cols_fa_ram) / deg_bl_muxing; + } + + else + { + number_sa_subarray = (subarray.num_cols_fa_cam) / deg_bl_muxing; + } + + int num_dec_signals = subarray.num_rows; + double C_ld_bit_mux_dec_out = 0; + double C_ld_sa_mux_lev_1_dec_out = 0; + double C_ld_sa_mux_lev_2_dec_out = 0; + double R_wire_wl_drv_out; + + if (!is_fa && !pure_cam) + { + R_wire_wl_drv_out = subarray.num_cols * cell.w * g_tp.wire_local.R_per_um; + } + else if (is_fa && !pure_cam) + { + R_wire_wl_drv_out = (subarray.num_cols_fa_cam * cam_cell.w + subarray.num_cols_fa_ram * cell.w) * g_tp.wire_local.R_per_um ; + } + else + { + R_wire_wl_drv_out = (subarray.num_cols_fa_cam * cam_cell.w ) * g_tp.wire_local.R_per_um; + } + + double R_wire_bit_mux_dec_out = num_subarrays_per_row * subarray.num_cols * g_tp.wire_inside_mat.R_per_um * cell.w;//TODO:revisit for FA + double R_wire_sa_mux_dec_out = num_subarrays_per_row * subarray.num_cols * g_tp.wire_inside_mat.R_per_um * cell.w; + + if (deg_bl_muxing > 1) + { + C_ld_bit_mux_dec_out = + (2 * num_subarrays_per_mat * subarray.num_cols / deg_bl_muxing)*gate_C(g_tp.w_nmos_b_mux, 0, is_dram) + // 2 transistor per cell + num_subarrays_per_row * subarray.num_cols*g_tp.wire_inside_mat.C_per_um*cell.get_w(); + } + + if (dp.Ndsam_lev_1 > 1) + { + C_ld_sa_mux_lev_1_dec_out = + (num_subarrays_per_mat * number_sa_subarray / dp.Ndsam_lev_1)*gate_C(g_tp.w_nmos_sa_mux, 0, is_dram) + + num_subarrays_per_row * subarray.num_cols*g_tp.wire_inside_mat.C_per_um*cell.get_w(); + } + if (dp.Ndsam_lev_2 > 1) + { + C_ld_sa_mux_lev_2_dec_out = + (num_subarrays_per_mat * number_sa_subarray / (dp.Ndsam_lev_1*dp.Ndsam_lev_2))*gate_C(g_tp.w_nmos_sa_mux, 0, is_dram) + + num_subarrays_per_row * subarray.num_cols*g_tp.wire_inside_mat.C_per_um*cell.get_w(); + } + + if (num_subarrays_per_row >= 2) + { + // wire heads for both right and left side of a mat, so half the resistance + R_wire_bit_mux_dec_out /= 2.0; + R_wire_sa_mux_dec_out /= 2.0; + } + + + row_dec = new Decoder( + num_dec_signals, + false, + subarray.C_wl, + R_wire_wl_drv_out, + false/*is_fa*/, + is_dram, + true, + camFlag? cam_cell:cell); + + row_dec->nodes_DSTN = subarray.num_rows;//TODO: this is not a good way for OOO programming +// if (is_fa && (!dp.is_tag)) +// { +// row_dec->exist = true; +// } + bit_mux_dec = new Decoder( + deg_bl_muxing,// This number is 1 for FA or CAM + false, + C_ld_bit_mux_dec_out, + R_wire_bit_mux_dec_out, + false/*is_fa*/, + is_dram, + false, + camFlag? cam_cell:cell); + sa_mux_lev_1_dec = new Decoder( + dp.deg_senseamp_muxing_non_associativity, // This number is 1 for FA or CAM + dp.number_way_select_signals_mat ? true : false,//only sa_mux_lev_1_dec needs way select signal + C_ld_sa_mux_lev_1_dec_out, + R_wire_sa_mux_dec_out, + false/*is_fa*/, + is_dram, + false, + camFlag? cam_cell:cell); + sa_mux_lev_2_dec = new Decoder( + dp.Ndsam_lev_2, // This number is 1 for FA or CAM + false, + C_ld_sa_mux_lev_2_dec_out, + R_wire_sa_mux_dec_out, + false/*is_fa*/, + is_dram, + false, + camFlag? cam_cell:cell); + + double C_wire_predec_blk_out; + double R_wire_predec_blk_out; + + if (!is_fa && !pure_cam) + { + + C_wire_predec_blk_out = num_subarrays_per_row * subarray.num_rows * g_tp.wire_inside_mat.C_per_um * cell.h; + R_wire_predec_blk_out = num_subarrays_per_row * subarray.num_rows * g_tp.wire_inside_mat.R_per_um * cell.h; + + } + else //for pre-decode block's load is same for both FA and CAM + { + C_wire_predec_blk_out = subarray.num_rows * g_tp.wire_inside_mat.C_per_um * cam_cell.h; + R_wire_predec_blk_out = subarray.num_rows * g_tp.wire_inside_mat.R_per_um * cam_cell.h; + } + + + if (is_fa||pure_cam) + num_dec_signals += _log2(num_subarrays_per_mat); + + PredecBlk * r_predec_blk1 = new PredecBlk( + num_dec_signals, + row_dec, + C_wire_predec_blk_out, + R_wire_predec_blk_out, + num_subarrays_per_mat, + is_dram, + true); + PredecBlk * r_predec_blk2 = new PredecBlk( + num_dec_signals, + row_dec, + C_wire_predec_blk_out, + R_wire_predec_blk_out, + num_subarrays_per_mat, + is_dram, + false); + PredecBlk * b_mux_predec_blk1 = new PredecBlk(deg_bl_muxing, bit_mux_dec, 0, 0, 1, is_dram, true); + PredecBlk * b_mux_predec_blk2 = new PredecBlk(deg_bl_muxing, bit_mux_dec, 0, 0, 1, is_dram, false); + PredecBlk * sa_mux_lev_1_predec_blk1 = new PredecBlk(dyn_p.deg_senseamp_muxing_non_associativity, sa_mux_lev_1_dec, 0, 0, 1, is_dram, true); + PredecBlk * sa_mux_lev_1_predec_blk2 = new PredecBlk(dyn_p.deg_senseamp_muxing_non_associativity, sa_mux_lev_1_dec, 0, 0, 1, is_dram, false); + PredecBlk * sa_mux_lev_2_predec_blk1 = new PredecBlk(dp.Ndsam_lev_2, sa_mux_lev_2_dec, 0, 0, 1, is_dram, true); + PredecBlk * sa_mux_lev_2_predec_blk2 = new PredecBlk(dp.Ndsam_lev_2, sa_mux_lev_2_dec, 0, 0, 1, is_dram, false); + dummy_way_sel_predec_blk1 = new PredecBlk(1, sa_mux_lev_1_dec, 0, 0, 0, is_dram, true); + dummy_way_sel_predec_blk2 = new PredecBlk(1, sa_mux_lev_1_dec, 0, 0, 0, is_dram, false); + + PredecBlkDrv * r_predec_blk_drv1 = new PredecBlkDrv(0, r_predec_blk1, is_dram); + PredecBlkDrv * r_predec_blk_drv2 = new PredecBlkDrv(0, r_predec_blk2, is_dram); + PredecBlkDrv * b_mux_predec_blk_drv1 = new PredecBlkDrv(0, b_mux_predec_blk1, is_dram); + PredecBlkDrv * b_mux_predec_blk_drv2 = new PredecBlkDrv(0, b_mux_predec_blk2, is_dram); + PredecBlkDrv * sa_mux_lev_1_predec_blk_drv1 = new PredecBlkDrv(0, sa_mux_lev_1_predec_blk1, is_dram); + PredecBlkDrv * sa_mux_lev_1_predec_blk_drv2 = new PredecBlkDrv(0, sa_mux_lev_1_predec_blk2, is_dram); + PredecBlkDrv * sa_mux_lev_2_predec_blk_drv1 = new PredecBlkDrv(0, sa_mux_lev_2_predec_blk1, is_dram); + PredecBlkDrv * sa_mux_lev_2_predec_blk_drv2 = new PredecBlkDrv(0, sa_mux_lev_2_predec_blk2, is_dram); + way_sel_drv1 = new PredecBlkDrv(dyn_p.number_way_select_signals_mat, dummy_way_sel_predec_blk1, is_dram); + dummy_way_sel_predec_blk_drv2 = new PredecBlkDrv(1, dummy_way_sel_predec_blk2, is_dram); + + r_predec = new Predec(r_predec_blk_drv1, r_predec_blk_drv2); + b_mux_predec = new Predec(b_mux_predec_blk_drv1, b_mux_predec_blk_drv2); + sa_mux_lev_1_predec = new Predec(sa_mux_lev_1_predec_blk_drv1, sa_mux_lev_1_predec_blk_drv2); + sa_mux_lev_2_predec = new Predec(sa_mux_lev_2_predec_blk_drv1, sa_mux_lev_2_predec_blk_drv2); + + subarray_out_wire = new Wire(dp.wtype, g_ip->cl_vertical?subarray.area.w:subarray.area.h);//Bug should be subarray.area.w Owen and + //subarray_out_wire = new Wire(g_ip->wt, g_ip->cl_vertical?subarray.area.w:subarray.area.h);//Bug should be subarray.area.w Owen and + + double driver_c_gate_load; + double driver_c_wire_load; + double driver_r_wire_load; + + if (is_fa || pure_cam) + + { //Although CAM and RAM use different bl pre-charge driver, assuming the precharge p size is the same + driver_c_gate_load = (subarray.num_cols_fa_cam )* gate_C(2 * g_tp.w_pmos_bl_precharge + g_tp.w_pmos_bl_eq, 0, is_dram, false, false); + driver_c_wire_load = subarray.num_cols_fa_cam * cam_cell.w * g_tp.wire_outside_mat.C_per_um; + driver_r_wire_load = subarray.num_cols_fa_cam * cam_cell.w * g_tp.wire_outside_mat.R_per_um; + cam_bl_precharge_eq_drv = new Driver( + driver_c_gate_load, + driver_c_wire_load, + driver_r_wire_load, + is_dram); + + if (!pure_cam) + { + //This is only used for fully asso not pure CAM + driver_c_gate_load = (subarray.num_cols_fa_ram )* gate_C(2 * g_tp.w_pmos_bl_precharge + g_tp.w_pmos_bl_eq, 0, is_dram, false, false); + driver_c_wire_load = subarray.num_cols_fa_ram * cell.w * g_tp.wire_outside_mat.C_per_um; + driver_r_wire_load = subarray.num_cols_fa_ram * cell.w * g_tp.wire_outside_mat.R_per_um; + bl_precharge_eq_drv = new Driver( + driver_c_gate_load, + driver_c_wire_load, + driver_r_wire_load, + is_dram); + } + } + + else + { + driver_c_gate_load = subarray.num_cols * gate_C(2 * g_tp.w_pmos_bl_precharge + g_tp.w_pmos_bl_eq, 0, is_dram, false, false); + driver_c_wire_load = subarray.num_cols * cell.w * g_tp.wire_outside_mat.C_per_um; + driver_r_wire_load = subarray.num_cols * cell.w * g_tp.wire_outside_mat.R_per_um; + bl_precharge_eq_drv = new Driver( + driver_c_gate_load, + driver_c_wire_load, + driver_r_wire_load, + is_dram); + } + double area_row_decoder = row_dec->area.get_area() * subarray.num_rows * (RWP + ERP + EWP); + double w_row_decoder = area_row_decoder / subarray.area.get_h(); + + double h_bit_mux_sense_amp_precharge_sa_mux_write_driver_write_mux = + compute_bit_mux_sa_precharge_sa_mux_wr_drv_wr_mux_h(); + + /* This means the subarray drivers are along the vertical direction since / subarray.area.get_w() is used; + * so the subarray_out_wire (actually the drivers) under the subarray and along the x direction + * So as mentioned above @ line 271 + * subarray_out_wire = new Wire(g_ip->wt, subarray.area.h);//Bug should be subarray.area.w Owen and + * change the out_wire (driver to along y direction need carefully rethinking + * rather than just simply switch w with h ) + * */ + double h_subarray_out_drv = subarray_out_wire->area.get_area() * + (subarray.num_cols / (deg_bl_muxing * dp.Ndsam_lev_1 * dp.Ndsam_lev_2)) / subarray.area.get_w(); + + + h_subarray_out_drv *= (RWP + ERP + SCHP); + + double h_comparators = 0.0; + double w_row_predecode_output_wires = 0.0; + double h_bit_mux_dec_out_wires = 0.0; + double h_senseamp_mux_dec_out_wires = 0.0; + + if ((!is_fa)&&(dp.is_tag)) + { + //tagbits = (4 * num_cols_subarray / (deg_bl_muxing * dp.Ndsam_lev_1 * dp.Ndsam_lev_2)) / num_do_b_mat; + h_comparators = compute_comparators_height(dp.tagbits, dyn_p.num_do_b_mat, subarray.area.get_w()); + h_comparators *= (RWP + ERP); + } + + //power-gating circuit + bool is_footer = false; + double Isat_subarray = 2* simplified_nmos_Isat(g_tp.sram.cell_nmos_w, is_dram, true);//only one wordline active in a subarray 2 means two inverters in an SRAM cell + double detalV_array;//, deltaV_wl, deltaV_floatingBL; + double c_wakeup_array; + + if (!(is_fa || pure_cam) && g_ip->power_gating) + {//for SRAM only at this moment + c_wakeup_array = drain_C_(g_tp.sram.cell_pmos_w, PCH, 1, 1, cell.h, is_dram, true);//1 inv + c_wakeup_array += 2*drain_C_(g_tp.sram.cell_pmos_w, PCH, 1, 1, cell.h, is_dram, true) + + drain_C_(g_tp.sram.cell_nmos_w, NCH, 1, 1, cell.h, is_dram, true);//1 inv + c_wakeup_array *= subarray.num_rows; + detalV_array = g_tp.sram_cell.Vdd-g_tp.sram_cell.Vcc_min; + + sram_sleep_tx = new Sleep_tx (g_ip->perfloss, + Isat_subarray, + is_footer, + c_wakeup_array, + detalV_array, + 1, + cell); + + subarray.area.set_h(subarray.area.h+ sram_sleep_tx->area.h); + + //TODO: add the sleep tx in the wl driver and + } + + + int branch_effort_predec_blk1_out = (1 << r_predec_blk2->number_input_addr_bits); + int branch_effort_predec_blk2_out = (1 << r_predec_blk1->number_input_addr_bits); + w_row_predecode_output_wires = (branch_effort_predec_blk1_out + branch_effort_predec_blk2_out) * + g_tp.wire_inside_mat.pitch * (RWP + ERP + EWP); + + + double h_non_cell_area = (num_subarrays_per_mat / num_subarrays_per_row) * + (h_bit_mux_sense_amp_precharge_sa_mux_write_driver_write_mux + + h_subarray_out_drv + h_comparators); + + double w_non_cell_area = MAX(w_row_predecode_output_wires, num_subarrays_per_row * w_row_decoder); + + if (deg_bl_muxing > 1) + { + h_bit_mux_dec_out_wires = deg_bl_muxing * g_tp.wire_inside_mat.pitch * (RWP + ERP); + } + if (dp.Ndsam_lev_1 > 1) + { + h_senseamp_mux_dec_out_wires = dp.Ndsam_lev_1 * g_tp.wire_inside_mat.pitch * (RWP + ERP); + } + if (dp.Ndsam_lev_2 > 1) + { + h_senseamp_mux_dec_out_wires += dp.Ndsam_lev_2 * g_tp.wire_inside_mat.pitch * (RWP + ERP); + } + + double h_addr_datain_wires; + if (!g_ip->ver_htree_wires_over_array) + { + h_addr_datain_wires = (dp.number_addr_bits_mat + dp.number_way_select_signals_mat + + (dp.num_di_b_mat + dp.num_do_b_mat)/num_subarrays_per_row) * + g_tp.wire_inside_mat.pitch * (RWP + ERP + EWP); + + if (is_fa || pure_cam) + { + h_addr_datain_wires = (dp.number_addr_bits_mat + dp.number_way_select_signals_mat + //TODO: revisit + (dp.num_di_b_mat+ dp.num_do_b_mat )/num_subarrays_per_row) * + g_tp.wire_inside_mat.pitch * (RWP + ERP + EWP) + + (dp.num_si_b_mat + dp.num_so_b_mat )/num_subarrays_per_row * g_tp.wire_inside_mat.pitch * SCHP; + } + //h_non_cell_area = 2 * h_bit_mux_sense_amp_precharge_sa_mux + + //MAX(h_addr_datain_wires, 2 * h_subarray_out_drv); + h_non_cell_area = (h_bit_mux_sense_amp_precharge_sa_mux_write_driver_write_mux + h_comparators + + h_subarray_out_drv) * (num_subarrays_per_mat / num_subarrays_per_row) + + h_addr_datain_wires + + h_bit_mux_dec_out_wires + + h_senseamp_mux_dec_out_wires; + + } + + // double area_rectangle_center_mat = h_non_cell_area * w_non_cell_area; + double area_mat_center_circuitry = (r_predec_blk_drv1->area.get_area() + + b_mux_predec_blk_drv1->area.get_area() + + sa_mux_lev_1_predec_blk_drv1->area.get_area() + + sa_mux_lev_2_predec_blk_drv1->area.get_area() + + way_sel_drv1->area.get_area() + + r_predec_blk_drv2->area.get_area() + + b_mux_predec_blk_drv2->area.get_area() + + sa_mux_lev_1_predec_blk_drv2->area.get_area() + + sa_mux_lev_2_predec_blk_drv2->area.get_area() + + r_predec_blk1->area.get_area() + + b_mux_predec_blk1->area.get_area() + + sa_mux_lev_1_predec_blk1->area.get_area() + + sa_mux_lev_2_predec_blk1->area.get_area() + + r_predec_blk2->area.get_area() + + b_mux_predec_blk2->area.get_area() + + sa_mux_lev_1_predec_blk2->area.get_area() + + sa_mux_lev_2_predec_blk2->area.get_area() + + bit_mux_dec->area.get_area() + + sa_mux_lev_1_dec->area.get_area() + + sa_mux_lev_2_dec->area.get_area()) * (RWP + ERP + EWP); + + /// double area_efficiency_mat; + + +// if (!is_fa) +// { + assert(num_subarrays_per_mat/num_subarrays_per_row>0); + area.h = (num_subarrays_per_mat/num_subarrays_per_row)* subarray.area.h + h_non_cell_area; + area.w = num_subarrays_per_row * subarray.area.get_w() + w_non_cell_area; + area.w = (area.h*area.w + area_mat_center_circuitry) / area.h; + /// = subarray.area.get_area() * num_subarrays_per_mat * 100.0 / area.get_area(); + +// cout<<"h_bit_mux_sense_amp_precharge_sa_mux_write_driver_write_mux"<is_3d_mem) + { + h_non_cell_area = (h_bit_mux_sense_amp_precharge_sa_mux_write_driver_write_mux + + h_subarray_out_drv); + area.h = subarray.area.h + h_non_cell_area; + area.w = subarray.area.w; + if (g_ip->print_detail_debug) + cout << "actual subarray width: " << cell.w * subarray.num_cols /1e3 << " mm" << endl; + } + + if (g_ip->print_detail_debug) + { + cout<<"h_non_cell_area"<0); + assert(area.w>0); +// } +// else +// { +// area.h = (num_subarrays_per_mat / num_subarrays_per_row) * subarray.area.get_h() + h_non_cell_area; +// area.w = num_subarrays_per_row * subarray.area.get_w() + w_non_cell_area; +// area.w = (area.h*area.w + area_mat_center_circuitry) / area.h; +// area_efficiency_mat = subarray.area.get_area() * num_subarrays_per_row * 100.0 / area.get_area(); +// } + } + + + +Mat::~Mat() +{ + delete row_dec; + delete bit_mux_dec; + delete sa_mux_lev_1_dec; + delete sa_mux_lev_2_dec; + + delete r_predec->blk1; + delete r_predec->blk2; + delete b_mux_predec->blk1; + delete b_mux_predec->blk2; + delete sa_mux_lev_1_predec->blk1; + delete sa_mux_lev_1_predec->blk2; + delete sa_mux_lev_2_predec->blk1; + delete sa_mux_lev_2_predec->blk2; + delete dummy_way_sel_predec_blk1; + delete dummy_way_sel_predec_blk2; + + delete r_predec->drv1; + delete r_predec->drv2; + delete b_mux_predec->drv1; + delete b_mux_predec->drv2; + delete sa_mux_lev_1_predec->drv1; + delete sa_mux_lev_1_predec->drv2; + delete sa_mux_lev_2_predec->drv1; + delete sa_mux_lev_2_predec->drv2; + delete way_sel_drv1; + delete dummy_way_sel_predec_blk_drv2; + + delete r_predec; + delete b_mux_predec; + delete sa_mux_lev_1_predec; + delete sa_mux_lev_2_predec; + + delete subarray_out_wire; + if (!pure_cam) + delete bl_precharge_eq_drv; + + if (is_fa || pure_cam) + { + delete sl_precharge_eq_drv ; + delete sl_data_drv ; + delete cam_bl_precharge_eq_drv; + delete ml_precharge_drv; + delete ml_to_ram_wl_drv; + } + if (!sram_sleep_tx) + { + delete sram_sleep_tx; + } +} + + + +double Mat::compute_delays(double inrisetime) +{ + int k; + double rd, C_intrinsic, C_ld, tf, R_bl_precharge,r_b_metal, R_bl, C_bl; + double outrisetime_search, outrisetime, row_dec_outrisetime; + // delay calculation for tags of fully associative cache + if (is_fa || pure_cam) + { + //Compute search access time + outrisetime_search = compute_cam_delay(inrisetime); + if (is_fa) + { + bl_precharge_eq_drv->compute_delay(0); + k = ml_to_ram_wl_drv->number_gates - 1; + rd = tr_R_on(ml_to_ram_wl_drv->width_n[k], NCH, 1, is_dram, false, true); + C_intrinsic = drain_C_(ml_to_ram_wl_drv->width_n[k], PCH, 1, 1, 4*cell.h, is_dram, false, true) + + drain_C_(ml_to_ram_wl_drv->width_n[k], NCH, 1, 1, 4*cell.h, is_dram, false, true); + C_ld = ml_to_ram_wl_drv->c_gate_load+ ml_to_ram_wl_drv->c_wire_load; + tf = rd * (C_intrinsic + C_ld) + ml_to_ram_wl_drv->r_wire_load * C_ld / 2; + delay_wl_reset = horowitz(0, tf, 0.5, 0.5, RISE); + + R_bl_precharge = tr_R_on(g_tp.w_pmos_bl_precharge, PCH, 1, is_dram, false, false); + r_b_metal = cam_cell.h * g_tp.wire_local.R_per_um;//dummy rows in sram are filled in + R_bl = subarray.num_rows * r_b_metal; + C_bl = subarray.C_bl; + delay_bl_restore = bl_precharge_eq_drv->delay + + log((g_tp.sram.Vbitpre - 0.1 * dp.V_b_sense) / (g_tp.sram.Vbitpre - dp.V_b_sense))* + (R_bl_precharge * C_bl + R_bl * C_bl / 2); + + + outrisetime_search = compute_bitline_delay(outrisetime_search); + outrisetime_search = compute_sa_delay(outrisetime_search); + } + outrisetime_search = compute_subarray_out_drv(outrisetime_search); + subarray_out_wire->set_in_rise_time(outrisetime_search); + outrisetime_search = subarray_out_wire->signal_rise_time(); + delay_subarray_out_drv_htree = delay_subarray_out_drv + subarray_out_wire->delay; + + + //TODO: this is just for compute plain read/write energy for fa and cam, plain read/write access timing need to be revisited. + outrisetime = r_predec->compute_delays(inrisetime); + row_dec_outrisetime = row_dec->compute_delays(outrisetime); + + outrisetime = b_mux_predec->compute_delays(inrisetime); + bit_mux_dec->compute_delays(outrisetime); + + outrisetime = sa_mux_lev_1_predec->compute_delays(inrisetime); + sa_mux_lev_1_dec->compute_delays(outrisetime); + + outrisetime = sa_mux_lev_2_predec->compute_delays(inrisetime); + sa_mux_lev_2_dec->compute_delays(outrisetime); + + if (pure_cam) + { + outrisetime = compute_bitline_delay(row_dec_outrisetime); + outrisetime = compute_sa_delay(outrisetime); + } + return outrisetime_search; + } + else + { + bl_precharge_eq_drv->compute_delay(0); + if (row_dec->exist == true) + { + int k = row_dec->num_gates - 1; + double rd = tr_R_on(row_dec->w_dec_n[k], NCH, 1, is_dram, false, true); + // TODO: this 4*cell.h number must be revisited + double C_intrinsic = drain_C_(row_dec->w_dec_p[k], PCH, 1, 1, 4*cell.h, is_dram, false, true) + + drain_C_(row_dec->w_dec_n[k], NCH, 1, 1, 4*cell.h, is_dram, false, true); + double C_ld = row_dec->C_ld_dec_out; + double tf = rd * (C_intrinsic + C_ld) + row_dec->R_wire_dec_out * C_ld / 2; + delay_wl_reset = horowitz(0, tf, 0.5, 0.5, RISE); + } + double R_bl_precharge = tr_R_on(g_tp.w_pmos_bl_precharge, PCH, 1, is_dram, false, false); + double r_b_metal = cell.h * g_tp.wire_local.R_per_um; + double R_bl = subarray.num_rows * r_b_metal; + double C_bl = subarray.C_bl; + + if (is_dram) + { + delay_bl_restore = bl_precharge_eq_drv->delay + 2.3 * (R_bl_precharge * C_bl + R_bl * C_bl / 2); + } + else + { + delay_bl_restore = bl_precharge_eq_drv->delay + + log((g_tp.sram.Vbitpre - 0.1 * dp.V_b_sense) / (g_tp.sram.Vbitpre - dp.V_b_sense))* + (R_bl_precharge * C_bl + R_bl * C_bl / 2); + } + } + + + + outrisetime = r_predec->compute_delays(inrisetime); + row_dec_outrisetime = row_dec->compute_delays(outrisetime); + + outrisetime = b_mux_predec->compute_delays(inrisetime); + bit_mux_dec->compute_delays(outrisetime); + + outrisetime = sa_mux_lev_1_predec->compute_delays(inrisetime); + sa_mux_lev_1_dec->compute_delays(outrisetime); + + outrisetime = sa_mux_lev_2_predec->compute_delays(inrisetime); + sa_mux_lev_2_dec->compute_delays(outrisetime); + + //CACTI3DD + if(g_ip->is_3d_mem) + { + row_dec_outrisetime = inrisetime; + } + + outrisetime = compute_bitline_delay(row_dec_outrisetime); + outrisetime = compute_sa_delay(outrisetime); + outrisetime = compute_subarray_out_drv(outrisetime); + subarray_out_wire->set_in_rise_time(outrisetime); + outrisetime = subarray_out_wire->signal_rise_time(); + + delay_subarray_out_drv_htree = delay_subarray_out_drv + subarray_out_wire->delay; + + if (dp.is_tag == true && dp.fully_assoc == false) + { + compute_comparator_delay(0); + } + + if (row_dec->exist == false) + { + delay_wl_reset = MAX(r_predec->blk1->delay, r_predec->blk2->delay); + } + return outrisetime; +} + + + +double Mat::compute_bit_mux_sa_precharge_sa_mux_wr_drv_wr_mux_h() +{ + + double height = compute_tr_width_after_folding(g_tp.w_pmos_bl_precharge, camFlag? cam_cell.w:cell.w / (2 *(RWP + ERP + SCHP))) + + compute_tr_width_after_folding(g_tp.w_pmos_bl_eq, camFlag? cam_cell.w:cell.w / (RWP + ERP + SCHP)); // precharge circuitry + + if (deg_bl_muxing > 1) + { + height += compute_tr_width_after_folding(g_tp.w_nmos_b_mux, cell.w / (2 *(RWP + ERP))); // col mux tr height + // height += deg_bl_muxing * g_tp.wire_inside_mat.pitch * (RWP + ERP); // bit mux dec out wires height + } + + height += height_sense_amplifier(/*camFlag? sram_cell.w:*/cell.w * deg_bl_muxing / (RWP + ERP)); // sense_amp_height + + if (dp.Ndsam_lev_1 > 1) + { + height += compute_tr_width_after_folding( + g_tp.w_nmos_sa_mux, cell.w * dp.Ndsam_lev_1 / (RWP + ERP)); // sense_amp_mux_height + //height_senseamp_mux_decode_output_wires = Ndsam * wire_inside_mat_pitch * (RWP + ERP); + } + + if (dp.Ndsam_lev_2 > 1) + { + height += compute_tr_width_after_folding( + g_tp.w_nmos_sa_mux, cell.w * deg_bl_muxing * dp.Ndsam_lev_1 / (RWP + ERP)); // sense_amp_mux_height + //height_senseamp_mux_decode_output_wires = Ndsam * wire_inside_mat_pitch * (RWP + ERP); + + // add height of inverter-buffers between the two levels (pass-transistors) of sense-amp mux + height += 2 * compute_tr_width_after_folding( + pmos_to_nmos_sz_ratio(is_dram) * g_tp.min_w_nmos_, cell.w * dp.Ndsam_lev_2 / (RWP + ERP)); + height += 2 * compute_tr_width_after_folding(g_tp.min_w_nmos_, cell.w * dp.Ndsam_lev_2 / (RWP + ERP)); + } + + // TODO: this should be uncommented... + /*if (deg_bl_muxing * dp.Ndsam_lev_1 * dp.Ndsam_lev_2 > 1) + { + //height_write_mux_decode_output_wires = deg_bl_muxing * Ndsam * g_tp.wire_inside_mat.pitch * (RWP + EWP); + double width_write_driver_write_mux = width_write_driver_or_write_mux(); + double height_write_driver_write_mux = compute_tr_width_after_folding(2 * width_write_driver_write_mux, + cell.w * + // deg_bl_muxing * + dp.Ndsam_lev_1 * dp.Ndsam_lev_2 / (RWP + EWP)); + height += height_write_driver_write_mux; + }*/ + + if (g_ip->is_3d_mem) + { + //height_write_mux_decode_output_wires = deg_bl_muxing * Ndsam * g_tp.wire_inside_mat.pitch * (RWP + EWP); + double width_write_driver_write_mux = width_write_driver_or_write_mux(); + double height_write_driver_write_mux = compute_tr_width_after_folding(2 * width_write_driver_write_mux, cell.w); + height += height_write_driver_write_mux; + } + + return height; +} + + + +double Mat::compute_cam_delay(double inrisetime) +{ + + double out_time_ramp, this_delay; + double Rwire, tf, c_intrinsic, rd, Cwire, c_gate_load; + + + double Wfaprechp, Wdummyn, Wdummyinvn, Wdummyinvp, Waddrnandn, Waddrnandp, + Wfanorn, Wfanorp, W_hit_miss_n, W_hit_miss_p; + + /** + double Wdecdrivep, Wdecdriven, Wfadriven, Wfadrivep, Wfadrive2n, Wfadrive2p, Wfadecdrive1n, Wfadecdrive1p, + Wfadecdrive2n, Wfadecdrive2p, Wfadecdriven, Wfadecdrivep, Wfaprechn, Wfaprechp, + Wdummyn, Wdummyinvn, Wdummyinvp, Wfainvn, Wfainvp, Waddrnandn, Waddrnandp, + Wfanandn, Wfanandp, Wfanorn, Wfanorp, Wdecnandn, Wdecnandp, W_hit_miss_n, W_hit_miss_p; + **/ + + double c_matchline_metal, r_matchline_metal, c_searchline_metal, r_searchline_metal, dynSearchEng; + int Htagbits; + + double driver_c_gate_load; + double driver_c_wire_load; + double driver_r_wire_load; + //double searchline_precharge_time; + + double leak_power_cc_inverters_sram_cell = 0; + double leak_power_acc_tr_RW_or_WR_port_sram_cell = 0; + double leak_power_RD_port_sram_cell = 0; + double leak_power_SCHP_port_sram_cell = 0; + double leak_comparator_cam_cell =0; + + double gate_leak_comparator_cam_cell = 0; + double gate_leak_power_cc_inverters_sram_cell = 0; + double gate_leak_power_RD_port_sram_cell = 0; + double gate_leak_power_SCHP_port_sram_cell = 0; + + c_matchline_metal = cam_cell.get_w() * g_tp.wire_local.C_per_um; + c_searchline_metal = cam_cell.get_h() * g_tp.wire_local.C_per_um; + r_matchline_metal = cam_cell.get_w() * g_tp.wire_local.R_per_um; + r_searchline_metal = cam_cell.get_h() * g_tp.wire_local.R_per_um; + + dynSearchEng = 0.0; + delay_matchchline = 0.0; + double p_to_n_sizing_r = pmos_to_nmos_sz_ratio(is_dram); + bool linear_scaling = false; + + if (linear_scaling) + { + /// Wdecdrivep = 450 * g_ip->F_sz_um;//this was 360 micron for the 0.8 micron process + /// Wdecdriven = 300 * g_ip->F_sz_um;//this was 240 micron for the 0.8 micron process + /// Wfadriven = 62.5 * g_ip->F_sz_um;//this was 50 micron for the 0.8 micron process + /// Wfadrivep = 125 * g_ip->F_sz_um;//this was 100 micron for the 0.8 micron process + /// Wfadrive2n = 250 * g_ip->F_sz_um;//this was 200 micron for the 0.8 micron process + /// Wfadrive2p = 500 * g_ip->F_sz_um;//this was 400 micron for the 0.8 micron process + /// Wfadecdrive1n = 6.25 * g_ip->F_sz_um;//this was 5 micron for the 0.8 micron process + /// Wfadecdrive1p = 12.5 * g_ip->F_sz_um;//this was 10 micron for the 0.8 micron process + /// Wfadecdrive2n = 25 * g_ip->F_sz_um;//this was 20 micron for the 0.8 micron process + /// Wfadecdrive2p = 50 * g_ip->F_sz_um;//this was 40 micron for the 0.8 micron process + /// Wfadecdriven = 62.5 * g_ip->F_sz_um;//this was 50 micron for the 0.8 micron process + /// Wfadecdrivep = 125 * g_ip->F_sz_um;//this was 100 micron for the 0.8 micron process + /// Wfaprechn = 7.5 * g_ip->F_sz_um;//this was 6 micron for the 0.8 micron process + /// Wfainvn = 12.5 * g_ip->F_sz_um;//this was 10 micron for the 0.8 micron process + /// Wfainvp = 25 * g_ip->F_sz_um;//this was 20 micron for the 0.8 micron process + /// Wfanandn = 25 * g_ip->F_sz_um;//this was 20 micron for the 0.8 micron process + /// Wfanandp = 37.5 * g_ip->F_sz_um;//this was 30 micron for the 0.8 micron process + /// Wdecnandn = 12.5 * g_ip->F_sz_um;//this was 10 micron for the 0.8 micron process + /// Wdecnandp = 37.5 * g_ip->F_sz_um;//this was 30 micron for the 0.8 micron process + + Wfaprechp = 12.5 * g_ip->F_sz_um;//this was 10 micron for the 0.8 micron process + Wdummyn = 12.5 * g_ip->F_sz_um;//this was 10 micron for the 0.8 micron process + Wdummyinvn = 75 * g_ip->F_sz_um;//this was 60 micron for the 0.8 micron process + Wdummyinvp = 100 * g_ip->F_sz_um;//this was 80 micron for the 0.8 micron process + Waddrnandn = 62.5 * g_ip->F_sz_um;//this was 50 micron for the 0.8 micron process + Waddrnandp = 62.5 * g_ip->F_sz_um;//this was 50 micron for the 0.8 micron process + Wfanorn = 6.25 * g_ip->F_sz_um;//this was 5 micron for the 0.8 micron process + Wfanorp = 12.5 * g_ip->F_sz_um;//this was 10 micron for the 0.8 micron process + W_hit_miss_n = Wdummyn; + W_hit_miss_p = g_tp.min_w_nmos_*p_to_n_sizing_r; + //TODO: this number should updated using new layout; from the NAND to output NOR should be computed using logical effort + } + else + { + /// Wdecdrivep = 450 * g_ip->F_sz_um;//this was 360 micron for the 0.8 micron process + /// Wdecdriven = 300 * g_ip->F_sz_um;//this was 240 micron for the 0.8 micron process + /// Wfadriven = 62.5 * g_ip->F_sz_um;//this was 50 micron for the 0.8 micron process + /// Wfadrivep = 125 * g_ip->F_sz_um;//this was 100 micron for the 0.8 micron process + /// Wfadrive2n = 250 * g_ip->F_sz_um;//this was 200 micron for the 0.8 micron process + /// Wfadrive2p = 500 * g_ip->F_sz_um;//this was 400 micron for the 0.8 micron process + /// Wfadecdrive1n = 6.25 * g_ip->F_sz_um;//this was 5 micron for the 0.8 micron process + /// Wfadecdrive1p = 12.5 * g_ip->F_sz_um;//this was 10 micron for the 0.8 micron process + /// Wfadecdrive2n = 25 * g_ip->F_sz_um;//this was 20 micron for the 0.8 micron process + /// Wfadecdrive2p = 50 * g_ip->F_sz_um;//this was 40 micron for the 0.8 micron process + /// Wfadecdriven = 62.5 * g_ip->F_sz_um;//this was 50 micron for the 0.8 micron process + /// Wfadecdrivep = 125 * g_ip->F_sz_um;//this was 100 micron for the 0.8 micron process + /// Wfaprechn = 7.5 * g_ip->F_sz_um;//this was 6 micron for the 0.8 micron process + /// Wfainvn = 12.5 * g_ip->F_sz_um;//this was 10 micron for the 0.8 micron process + /// Wfainvp = 25 * g_ip->F_sz_um;//this was 20 micron for the 0.8 micron process + /// Wfanandn = 25 * g_ip->F_sz_um;//this was 20 micron for the 0.8 micron process + /// Wfanandp = 37.5 * g_ip->F_sz_um;//this was 30 micron for the 0.8 micron process + /// Wdecnandn = 12.5 * g_ip->F_sz_um;//this was 10 micron for the 0.8 micron process + /// Wdecnandp = 37.5 * g_ip->F_sz_um;//this was 30 micron for the 0.8 micron process + + Wfaprechp = g_tp.w_pmos_bl_precharge;//this was 10 micron for the 0.8 micron process + Wdummyn = g_tp.cam.cell_nmos_w; + Wdummyinvn = 75 * g_ip->F_sz_um;//this was 60 micron for the 0.8 micron process + Wdummyinvp = 100 * g_ip->F_sz_um;//this was 80 micron for the 0.8 micron process + Waddrnandn = 62.5 * g_ip->F_sz_um;//this was 50 micron for the 0.8 micron process + Waddrnandp = 62.5 * g_ip->F_sz_um;//this was 50 micron for the 0.8 micron process + Wfanorn = 6.25 * g_ip->F_sz_um;//this was 5 micron for the 0.8 micron process + Wfanorp = 12.5 * g_ip->F_sz_um;//this was 10 micron for the 0.8 micron process + W_hit_miss_n = Wdummyn; + W_hit_miss_p = g_tp.min_w_nmos_*p_to_n_sizing_r; + } + + Htagbits = (int)(ceil ((double) (subarray.num_cols_fa_cam) / 2.0)); + + /* First stage, searchline is precharged. searchline data driver drives the searchline to open (if miss) the comparators. + search_line_delay, search_line_power, search_line_restore_delay for cycle time computation. + From the driver(am and an) to the comparators in all the rows including the dummy row, + Assuming that comparators in both the normal matching line and the dummy matching line have the same sizing */ + + //Searchline precharge circuitry is same as that of bitline. However, no sharing between search ports and r/w ports + //Searchline precharge routes horizontally + driver_c_gate_load = subarray.num_cols_fa_cam * gate_C(2 * g_tp.w_pmos_bl_precharge + g_tp.w_pmos_bl_eq, 0, is_dram, false, false); + driver_c_wire_load = subarray.num_cols_fa_cam * cam_cell.w * g_tp.wire_outside_mat.C_per_um; + driver_r_wire_load = subarray.num_cols_fa_cam * cam_cell.w * g_tp.wire_outside_mat.R_per_um; + + sl_precharge_eq_drv = new Driver( + driver_c_gate_load, + driver_c_wire_load, + driver_r_wire_load, + is_dram); + + //searchline data driver ; subarray.num_rows + 1 is because of the dummy row + //data drv should only have gate_C not 2*gate_C since the two searchlines are differential--same as bitlines + driver_c_gate_load = (subarray.num_rows + 1) * gate_C(Wdummyn, 0, is_dram, false, false); + driver_c_wire_load = (subarray.num_rows + 1) * c_searchline_metal; + driver_r_wire_load = (subarray.num_rows + 1) * r_searchline_metal; + sl_data_drv = new Driver( + driver_c_gate_load, + driver_c_wire_load, + driver_r_wire_load, + is_dram); + + sl_precharge_eq_drv->compute_delay(0); + double R_bl_precharge = tr_R_on(g_tp.w_pmos_bl_precharge, PCH, 1, is_dram, false, false);//Assuming CAM and SRAM have same Pre_eq_dr + double r_b_metal = cam_cell.h * g_tp.wire_local.R_per_um; + double R_bl = (subarray.num_rows + 1) * r_b_metal; + double C_bl = subarray.C_bl_cam; + delay_cam_sl_restore = sl_precharge_eq_drv->delay + + log(g_tp.cam.Vbitpre)* (R_bl_precharge * C_bl + R_bl * C_bl / 2); + + out_time_ramp = sl_data_drv->compute_delay(inrisetime);//After entering one mat, start to consider the inrisetime from 0(0 is passed from outside) + + //matchline ops delay + delay_matchchline += sl_data_drv->delay; + + /* second stage, from the trasistors in the comparators(both normal row and dummy row) to the NAND gates that combins both half*/ + //matchline delay, matchline power, matchline_reset for cycle time computation, + + ////matchline precharge circuitry routes vertically + //There are two matchline precharge driver chains per subarray. + driver_c_gate_load = (subarray.num_rows + 1) * gate_C(Wfaprechp, 0, is_dram); + driver_c_wire_load = (subarray.num_rows + 1) * c_searchline_metal; + driver_r_wire_load = (subarray.num_rows + 1) * r_searchline_metal; + + ml_precharge_drv = new Driver( + driver_c_gate_load, + driver_c_wire_load, + driver_r_wire_load, + is_dram); + + ml_precharge_drv->compute_delay(0); + + + rd = tr_R_on(Wdummyn, NCH, 2, is_dram); + c_intrinsic = Htagbits*(2*drain_C_(Wdummyn, NCH, 2, 1, g_tp.cell_h_def, is_dram)//TODO: the cell_h_def should be revisit + + drain_C_(Wfaprechp, PCH, 1, 1, g_tp.cell_h_def, is_dram)/Htagbits);//since each halve only has one precharge tx per matchline + + Cwire = c_matchline_metal * Htagbits; + Rwire = r_matchline_metal * Htagbits; + c_gate_load = gate_C(Waddrnandn + Waddrnandp, 0, is_dram); + + double R_ml_precharge = tr_R_on(Wfaprechp, PCH, 1, is_dram); + //double r_ml_metal = cam_cell.w * g_tp.wire_local.R_per_um; + double R_ml = Rwire; + double C_ml = Cwire + c_intrinsic; + delay_cam_ml_reset = ml_precharge_drv->delay + + log(g_tp.cam.Vbitpre)* (R_ml_precharge * C_ml + R_ml * C_ml / 2);//TODO: latest CAM has sense amps on matchlines too + + //matchline ops delay + tf = rd * (c_intrinsic + Cwire / 2 + c_gate_load) + Rwire * (Cwire / 2 + c_gate_load); + this_delay = horowitz(out_time_ramp, tf, VTHFA2, VTHFA3, FALL); + delay_matchchline += this_delay; + out_time_ramp = this_delay / VTHFA3; + + dynSearchEng += ((c_intrinsic + Cwire + c_gate_load)*(subarray.num_rows +1)) //+ 2*drain_C_(Wdummyn, NCH, 2, 1, g_tp.cell_h_def, is_dram))//TODO: need to be precise + * g_tp.peri_global.Vdd * g_tp.peri_global.Vdd *2;//* Ntbl;//each subarry has two halves + + /* third stage, from the NAND2 gates to the drivers in the dummy row */ + rd = tr_R_on(Waddrnandn, NCH, 2, is_dram); + c_intrinsic = drain_C_(Waddrnandn, NCH, 2, 1, g_tp.cell_h_def, is_dram) + + drain_C_(Waddrnandp, PCH, 1, 1, g_tp.cell_h_def, is_dram)*2; + c_gate_load = gate_C(Wdummyinvn + Wdummyinvp, 0, is_dram); + tf = rd * (c_intrinsic + c_gate_load); + this_delay = horowitz(out_time_ramp, tf, VTHFA3, VTHFA4, RISE); + out_time_ramp = this_delay / (1 - VTHFA4); + delay_matchchline += this_delay; + + //only the dummy row has the extra inverter between NAND and NOR gates + dynSearchEng += (c_intrinsic* (subarray.num_rows+1)+ c_gate_load*2) * g_tp.peri_global.Vdd * g_tp.peri_global.Vdd;// * Ntbl; + + /* fourth stage, from the driver in dummy matchline to the NOR2 gate which drives the wordline of the data portion */ + rd = tr_R_on(Wdummyinvn, NCH, 1, is_dram); + c_intrinsic = drain_C_(Wdummyinvn, NCH, 1, 1, g_tp.cell_h_def, is_dram) + drain_C_(Wdummyinvp, NCH, 1, 1, g_tp.cell_h_def, is_dram); + Cwire = c_matchline_metal * Htagbits + c_searchline_metal * (subarray.num_rows+1)/2; + Rwire = r_matchline_metal * Htagbits + r_searchline_metal * (subarray.num_rows+1)/2; + c_gate_load = gate_C(Wfanorn + Wfanorp, 0, is_dram); + tf = rd * (c_intrinsic + Cwire + c_gate_load) + Rwire * (Cwire / 2 + c_gate_load); + this_delay = horowitz (out_time_ramp, tf, VTHFA4, VTHFA5, FALL); + out_time_ramp = this_delay / VTHFA5; + delay_matchchline += this_delay; + + dynSearchEng += (c_intrinsic + Cwire + subarray.num_rows*c_gate_load) * g_tp.peri_global.Vdd * g_tp.peri_global.Vdd;//* Ntbl; + + /*final statge from the NOR gate to drive the wordline of the data portion */ + + //searchline data driver There are two matchline precharge driver chains per subarray. + driver_c_gate_load = gate_C(W_hit_miss_n, 0, is_dram, false, false);//nmos of the pull down logic + driver_c_wire_load = subarray.C_wl_ram; + driver_r_wire_load = subarray.R_wl_ram; + + ml_to_ram_wl_drv = new Driver( + driver_c_gate_load, + driver_c_wire_load, + driver_r_wire_load, + is_dram); + + + + rd = tr_R_on(Wfanorn, NCH, 1, is_dram); + c_intrinsic = 2* drain_C_(Wfanorn, NCH, 1, 1, g_tp.cell_h_def, is_dram) + drain_C_(Wfanorp, NCH, 1, 1, g_tp.cell_h_def, is_dram); + c_gate_load = gate_C(ml_to_ram_wl_drv->width_n[0] + ml_to_ram_wl_drv->width_p[0], 0, is_dram); + tf = rd * (c_intrinsic + c_gate_load); + this_delay = horowitz (out_time_ramp, tf, 0.5, 0.5, RISE); + out_time_ramp = this_delay / (1-0.5); + delay_matchchline += this_delay; + + out_time_ramp = ml_to_ram_wl_drv->compute_delay(out_time_ramp); + + //c_gate_load energy is computed in ml_to_ram_wl_drv + dynSearchEng += (c_intrinsic) * g_tp.peri_global.Vdd * g_tp.peri_global.Vdd;//* Ntbl; + + + /* peripheral-- hitting logic "CMOS VLSI Design Fig11.51*/ + /*Precharge the hitting logic */ + c_intrinsic = 2*drain_C_(W_hit_miss_p, NCH, 2, 1, g_tp.cell_h_def, is_dram); + Cwire = c_searchline_metal * subarray.num_rows; + Rwire = r_searchline_metal * subarray.num_rows; + c_gate_load = drain_C_(W_hit_miss_n, NCH, 1, 1, g_tp.cell_h_def, is_dram)* subarray.num_rows; + + rd = tr_R_on(W_hit_miss_p, PCH, 1, is_dram, false, false); + //double r_ml_metal = cam_cell.w * g_tp.wire_local.R_per_um; + double R_hit_miss = Rwire; + double C_hit_miss = Cwire + c_intrinsic; + delay_hit_miss_reset = log(g_tp.cam.Vbitpre)* (rd * C_hit_miss + R_hit_miss * C_hit_miss / 2); + dynSearchEng += (c_intrinsic + Cwire + c_gate_load) * g_tp.peri_global.Vdd * g_tp.peri_global.Vdd; + + /*hitting logic evaluation */ + c_intrinsic = 2*drain_C_(W_hit_miss_n, NCH, 2, 1, g_tp.cell_h_def, is_dram); + Cwire = c_searchline_metal * subarray.num_rows; + Rwire = r_searchline_metal * subarray.num_rows; + c_gate_load = drain_C_(W_hit_miss_n, NCH, 1, 1, g_tp.cell_h_def, is_dram)* subarray.num_rows; + + rd = tr_R_on(W_hit_miss_n, PCH, 1, is_dram, false, false); + tf = rd * (c_intrinsic + Cwire / 2 + c_gate_load) + Rwire * (Cwire / 2 + c_gate_load); + + delay_hit_miss = horowitz(0, tf, 0.5, 0.5, FALL); + + if (is_fa) + delay_matchchline += MAX(ml_to_ram_wl_drv->delay, delay_hit_miss); + + dynSearchEng += (c_intrinsic + Cwire + c_gate_load) * g_tp.peri_global.Vdd * g_tp.peri_global.Vdd; + + /* TODO: peripheral-- Priority Encoder, usually this is not necessary in processor components*/ + + power_matchline.searchOp.dynamic = dynSearchEng; + + //leakage in one subarray + double Iport = cmos_Isub_leakage(g_tp.cam.cell_a_w, 0, 1, nmos, false, true);//TODO: how much is the idle time? just by *2? + double Iport_erp = cmos_Isub_leakage(g_tp.cam.cell_a_w, 0, 2, nmos, false, true); + double Icell = cmos_Isub_leakage(g_tp.cam.cell_nmos_w, g_tp.cam.cell_pmos_w, 1, inv, false, true)*2; + double Icell_comparator = cmos_Isub_leakage(Wdummyn, Wdummyn, 1, inv, false, true)*2;//approx XOR with Inv + + leak_power_cc_inverters_sram_cell = Icell * g_tp.cam_cell.Vdd; + leak_comparator_cam_cell = Icell_comparator * g_tp.cam_cell.Vdd; + leak_power_acc_tr_RW_or_WR_port_sram_cell = Iport * g_tp.cam_cell.Vdd; + leak_power_RD_port_sram_cell = Iport_erp * g_tp.cam_cell.Vdd; + leak_power_SCHP_port_sram_cell = 0;//search port and r/w port are sperate, therefore no access txs in search ports + + power_matchline.searchOp.leakage += leak_power_cc_inverters_sram_cell + + leak_comparator_cam_cell + + leak_power_acc_tr_RW_or_WR_port_sram_cell + + leak_power_acc_tr_RW_or_WR_port_sram_cell * (RWP + EWP - 1) + + leak_power_RD_port_sram_cell * ERP + + leak_power_SCHP_port_sram_cell*SCHP; +// power_matchline.searchOp.leakage += leak_comparator_cam_cell; + power_matchline.searchOp.leakage *= (subarray.num_rows+1) * subarray.num_cols_fa_cam;//TODO:dumy line precise + power_matchline.searchOp.leakage += (subarray.num_rows+1) * cmos_Isub_leakage(0, Wfaprechp, 1, pmos) * g_tp.cam_cell.Vdd; + power_matchline.searchOp.leakage += (subarray.num_rows+1) * cmos_Isub_leakage(Waddrnandn, Waddrnandp, 2, nand) * g_tp.cam_cell.Vdd; + power_matchline.searchOp.leakage += (subarray.num_rows+1) * cmos_Isub_leakage(Wfanorn, Wfanorp,2, nor) * g_tp.cam_cell.Vdd; + //In idle states, the hit/miss txs are closed (on) therefore no Isub + power_matchline.searchOp.leakage += 0;// subarray.num_rows * cmos_Isub_leakage(W_hit_miss_n, 0,1, nmos) * g_tp.cam_cell.Vdd+ + // + cmos_Isub_leakage(0, W_hit_miss_p,1, pmos) * g_tp.cam_cell.Vdd; + + //in idle state, Ig_on only possibly exist in access transistors of read only ports + double Ig_port_erp = cmos_Ig_leakage(g_tp.cam.cell_a_w, 0, 1, nmos, false, true); + double Ig_cell = cmos_Ig_leakage(g_tp.cam.cell_nmos_w, g_tp.cam.cell_pmos_w, 1, inv, false, true)*2; + double Ig_cell_comparator = cmos_Ig_leakage(Wdummyn, Wdummyn, 1, inv, false, true)*2;// cmos_Ig_leakage(Wdummyn, 0, 2, nmos)*2; + + gate_leak_comparator_cam_cell = Ig_cell_comparator* g_tp.cam_cell.Vdd; + gate_leak_power_cc_inverters_sram_cell = Ig_cell*g_tp.cam_cell.Vdd; + gate_leak_power_RD_port_sram_cell = Ig_port_erp*g_tp.sram_cell.Vdd; + gate_leak_power_SCHP_port_sram_cell = 0; + + //cout<<"power_matchline.searchOp.leakage"<array_power_gated? g_tp.sram_cell.Vcc_min : g_tp.sram_cell.Vdd); + leak_power_acc_tr_RW_or_WR_port_sram_cell = Iport * (g_ip->bitline_floating? g_tp.sram.Vbitfloating : g_tp.sram_cell.Vdd); + leak_power_RD_port_sram_cell = Iport_erp * (g_ip->bitline_floating? g_tp.sram.Vbitfloating : g_tp.sram_cell.Vdd); +// +// leak_power_cc_inverters_sram_cell_gated = leak_power_cc_inverters_sram_cell/g_tp.sram_cell.Vdd*g_tp.sram_cell.Vcc_min; +// leak_power_acc_tr_RW_or_WR_port_sram_cell_floating = leak_power_acc_tr_RW_or_WR_port_sram_cell/g_tp.sram_cell.Vdd*g_tp.sram.Vbitfloating; +// leak_power_RD_port_sram_cell_floating = leak_power_RD_port_sram_cell_floating/g_tp.sram_cell.Vdd*g_tp.sram.Vbitfloating; +// + + + //in idle state, Ig_on only possibly exist in access transistors of read only ports + double Ig_port_erp = cmos_Ig_leakage(g_tp.sram.cell_a_w, 0, 1, nmos,false, true); + double Ig_cell = cmos_Ig_leakage(g_tp.sram.cell_nmos_w, g_tp.sram.cell_pmos_w, 1, inv,false, true); + + gate_leak_power_cc_inverters_sram_cell = Ig_cell*g_tp.sram_cell.Vdd; + gate_leak_power_RD_port_sram_cell = Ig_port_erp*g_tp.sram_cell.Vdd; + } + + + double C_drain_bit_mux = drain_C_(g_tp.w_nmos_b_mux, NCH, 1, 0, camFlag? cam_cell.w:cell.w / (2 *(RWP + ERP + SCHP)), is_dram); + double R_bit_mux = tr_R_on(g_tp.w_nmos_b_mux, NCH, 1, is_dram); + double C_drain_sense_amp_iso = drain_C_(g_tp.w_iso, PCH, 1, 0, camFlag? cam_cell.w:cell.w * deg_bl_muxing / (RWP + ERP + SCHP), is_dram); + double R_sense_amp_iso = tr_R_on(g_tp.w_iso, PCH, 1, is_dram); + double C_sense_amp_latch = gate_C(g_tp.w_sense_p + g_tp.w_sense_n, 0, is_dram) + + drain_C_(g_tp.w_sense_n, NCH, 1, 0, camFlag? cam_cell.w:cell.w * deg_bl_muxing / (RWP + ERP + SCHP), is_dram) + + drain_C_(g_tp.w_sense_p, PCH, 1, 0, camFlag? cam_cell.w:cell.w * deg_bl_muxing / (RWP + ERP + SCHP), is_dram); + double C_drain_sense_amp_mux = drain_C_(g_tp.w_nmos_sa_mux, NCH, 1, 0, camFlag? cam_cell.w:cell.w * deg_bl_muxing / (RWP + ERP + SCHP), is_dram); + + if (is_dram) + { + double fraction = dp.V_b_sense / ((g_tp.dram_cell_Vdd/2) * g_tp.dram_cell_C /(g_tp.dram_cell_C + C_bl)); + //tstep = 2.3 * fraction * r_dev * + tstep = fraction * r_dev * (g_ip->is_3d_mem==1?1:2.3) * + (g_tp.dram_cell_C * (C_bl + 2*C_drain_sense_amp_iso + C_sense_amp_latch + C_drain_sense_amp_mux)) / + (g_tp.dram_cell_C + (C_bl + 2*C_drain_sense_amp_iso + C_sense_amp_latch + C_drain_sense_amp_mux)); + delay_writeback = tstep; + dynRdEnergy += (C_bl + 2*C_drain_sense_amp_iso + C_sense_amp_latch + C_drain_sense_amp_mux) * + (g_tp.dram_cell_Vdd / 2) * g_tp.dram_cell_Vdd /* subarray.num_cols * num_subarrays_per_mat*/; + dynWriteEnergy += (C_bl + 2*C_drain_sense_amp_iso + C_sense_amp_latch) * + (g_tp.dram_cell_Vdd / 2) * g_tp.dram_cell_Vdd /* subarray.num_cols * num_subarrays_per_mat*/ * num_act_mats_hor_dir*100; + per_bitline_read_energy = (C_bl + 2*C_drain_sense_amp_iso + C_sense_amp_latch + C_drain_sense_amp_mux) * + (g_tp.dram_cell_Vdd / 2) * g_tp.dram_cell_Vdd; + } + else + { + double tau; + + if (deg_bl_muxing > 1) + { + tau = (R_cell_pull_down + R_cell_acc) * + (C_bl + 2*C_drain_bit_mux + 2*C_drain_sense_amp_iso + C_sense_amp_latch + C_drain_sense_amp_mux) + + R_bl * (C_bl/2 + 2*C_drain_bit_mux + 2*C_drain_sense_amp_iso + C_sense_amp_latch + C_drain_sense_amp_mux) + + R_bit_mux * (C_drain_bit_mux + 2*C_drain_sense_amp_iso + C_sense_amp_latch + C_drain_sense_amp_mux) + + R_sense_amp_iso * (C_drain_sense_amp_iso + C_sense_amp_latch + C_drain_sense_amp_mux); + dynRdEnergy += (C_bl + 2 * C_drain_bit_mux) * 2 * dp.V_b_sense * g_tp.sram_cell.Vdd /* + subarray.num_cols * num_subarrays_per_mat*/; + blfloating_c += (C_bl + 2 * C_drain_bit_mux) * 2; + dynRdEnergy += (2 * C_drain_sense_amp_iso + C_sense_amp_latch + C_drain_sense_amp_mux) * + 2 * dp.V_b_sense * g_tp.sram_cell.Vdd * (1.0/*subarray.num_cols * num_subarrays_per_mat*/ / deg_bl_muxing); + blfloating_c += (2 * C_drain_sense_amp_iso + C_sense_amp_latch + C_drain_sense_amp_mux) *2; + dynWriteEnergy += ((1.0/*subarray.num_cols *num_subarrays_per_mat*/ / deg_bl_muxing) / deg_senseamp_muxing) * + num_act_mats_hor_dir * (C_bl + 2*C_drain_bit_mux) * g_tp.sram_cell.Vdd * g_tp.sram_cell.Vdd*2; + //Write Ops are differential for SRAM + + } + else + { + tau = (R_cell_pull_down + R_cell_acc) * + (C_bl + C_drain_sense_amp_iso + C_sense_amp_latch + C_drain_sense_amp_mux) + R_bl * C_bl / 2 + + R_sense_amp_iso * (C_drain_sense_amp_iso + C_sense_amp_latch + C_drain_sense_amp_mux); + dynRdEnergy += (C_bl + 2 * C_drain_sense_amp_iso + C_sense_amp_latch + C_drain_sense_amp_mux) * + 2 * dp.V_b_sense * g_tp.sram_cell.Vdd /* subarray.num_cols * num_subarrays_per_mat*/; + + blfloating_c += (C_bl + 2 * C_drain_sense_amp_iso + C_sense_amp_latch + C_drain_sense_amp_mux) * 2; + dynWriteEnergy += (((1.0/*subarray.num_cols * num_subarrays_per_mat*/ / deg_bl_muxing) / deg_senseamp_muxing) * + num_act_mats_hor_dir * C_bl) * g_tp.sram_cell.Vdd * g_tp.sram_cell.Vdd*2; + + } + tstep = tau * log(V_b_pre / (V_b_pre - dp.V_b_sense)); + + +// if (g_ip->array_power_gated) +// power_bitline.readOp.leakage = +// leak_power_cc_inverters_sram_cell_gated + +// leak_power_acc_tr_RW_or_WR_port_sram_cell_floating + +// leak_power_acc_tr_RW_or_WR_port_sram_cell_floating * (RWP + EWP - 1) + +// leak_power_RD_port_sram_cell_floating * ERP; +// else + power_bitline.readOp.leakage = + leak_power_cc_inverters_sram_cell + + leak_power_acc_tr_RW_or_WR_port_sram_cell + + leak_power_acc_tr_RW_or_WR_port_sram_cell * (RWP + EWP - 1) + + leak_power_RD_port_sram_cell * ERP; + + power_bitline.readOp.gate_leakage = gate_leak_power_cc_inverters_sram_cell + + gate_leak_power_RD_port_sram_cell * ERP; + + } + +// cout<<"leak_power_cc_inverters_sram_cell"<repeater_size * g_tp.min_w_nmos_ * (1 + p_to_n_sz_r), 0.0, is_dram); + gate_C(subarray_out_wire->repeater_size *(subarray_out_wire->wire_length/subarray_out_wire->repeater_spacing) * g_tp.min_w_nmos_ * (1 + p_to_n_sz_r), 0.0, is_dram); + tf = rd * C_ld; + this_delay = horowitz(inrisetime, tf, 0.5, 0.5, RISE); + delay_subarray_out_drv += this_delay; + inrisetime = this_delay/(1.0 - 0.5); + power_subarray_out_drv.readOp.dynamic += C_ld * 0.5 * g_tp.peri_global.Vdd * g_tp.peri_global.Vdd; + power_subarray_out_drv.readOp.leakage += 0; // for now, let leakage of the pass transistor be 0 + power_subarray_out_drv.readOp.gate_leakage += cmos_Ig_leakage(g_tp.w_nmos_sa_mux, 0, 1, nmos)* g_tp.peri_global.Vdd; + + + return inrisetime; +} + + + +double Mat::compute_comparator_delay(double inrisetime) +{ + int A = g_ip->tag_assoc; + + int tagbits_ = dp.tagbits / 4; // Assuming there are 4 quarter comparators. input tagbits is already + // a multiple of 4. + + /* First Inverter */ + double Ceq = gate_C(g_tp.w_comp_inv_n2+g_tp.w_comp_inv_p2, 0, is_dram) + + drain_C_(g_tp.w_comp_inv_p1, PCH, 1, 1, g_tp.cell_h_def, is_dram) + + drain_C_(g_tp.w_comp_inv_n1, NCH, 1, 1, g_tp.cell_h_def, is_dram); + double Req = tr_R_on(g_tp.w_comp_inv_p1, PCH, 1, is_dram); + double tf = Req*Ceq; + double st1del = horowitz(inrisetime,tf,VTHCOMPINV,VTHCOMPINV,FALL); + double nextinputtime = st1del/VTHCOMPINV; + power_comparator.readOp.dynamic += 0.5 * Ceq * g_tp.peri_global.Vdd * g_tp.peri_global.Vdd * 4 * A; + + //For each degree of associativity + //there are 4 such quarter comparators + double lkgCurrent = cmos_Isub_leakage(g_tp.w_comp_inv_n1, g_tp.w_comp_inv_p1, 1, inv, is_dram)* 4 * A; + double gatelkgCurrent = cmos_Ig_leakage(g_tp.w_comp_inv_n1, g_tp.w_comp_inv_p1, 1, inv, is_dram)* 4 * A; + /* Second Inverter */ + Ceq = gate_C(g_tp.w_comp_inv_n3+g_tp.w_comp_inv_p3, 0, is_dram) + + drain_C_(g_tp.w_comp_inv_p2, PCH, 1, 1, g_tp.cell_h_def, is_dram) + + drain_C_(g_tp.w_comp_inv_n2, NCH, 1, 1, g_tp.cell_h_def, is_dram); + Req = tr_R_on(g_tp.w_comp_inv_n2, NCH, 1, is_dram); + tf = Req*Ceq; + double st2del = horowitz(nextinputtime,tf,VTHCOMPINV,VTHCOMPINV,RISE); + nextinputtime = st2del/(1.0-VTHCOMPINV); + power_comparator.readOp.dynamic += 0.5 * Ceq * g_tp.peri_global.Vdd * g_tp.peri_global.Vdd * 4 * A; + lkgCurrent += cmos_Isub_leakage(g_tp.w_comp_inv_n2, g_tp.w_comp_inv_p2, 1, inv, is_dram)* 4 * A; + gatelkgCurrent += cmos_Ig_leakage(g_tp.w_comp_inv_n2, g_tp.w_comp_inv_p2, 1, inv, is_dram)* 4 * A; + + /* Third Inverter */ + Ceq = gate_C(g_tp.w_eval_inv_n+g_tp.w_eval_inv_p, 0, is_dram) + + drain_C_(g_tp.w_comp_inv_p3, PCH, 1, 1, g_tp.cell_h_def, is_dram) + + drain_C_(g_tp.w_comp_inv_n3, NCH, 1, 1, g_tp.cell_h_def, is_dram); + Req = tr_R_on(g_tp.w_comp_inv_p3, PCH, 1, is_dram); + tf = Req*Ceq; + double st3del = horowitz(nextinputtime,tf,VTHCOMPINV,VTHEVALINV,FALL); + nextinputtime = st3del/(VTHEVALINV); + power_comparator.readOp.dynamic += 0.5 * Ceq * g_tp.peri_global.Vdd * g_tp.peri_global.Vdd * 4 * A; + lkgCurrent += cmos_Isub_leakage(g_tp.w_comp_inv_n3, g_tp.w_comp_inv_p3, 1, inv, is_dram)* 4 * A; + gatelkgCurrent += cmos_Ig_leakage(g_tp.w_comp_inv_n3, g_tp.w_comp_inv_p3, 1, inv, is_dram)* 4 * A; + + /* Final Inverter (virtual ground driver) discharging compare part */ + double r1 = tr_R_on(g_tp.w_comp_n,NCH,2, is_dram); + double r2 = tr_R_on(g_tp.w_eval_inv_n,NCH,1, is_dram); /* was switch */ + double c2 = (tagbits_)*(drain_C_(g_tp.w_comp_n,NCH,1, 1, g_tp.cell_h_def, is_dram) + + drain_C_(g_tp.w_comp_n,NCH,2, 1, g_tp.cell_h_def, is_dram)) + + drain_C_(g_tp.w_eval_inv_p,PCH,1, 1, g_tp.cell_h_def, is_dram) + + drain_C_(g_tp.w_eval_inv_n,NCH,1, 1, g_tp.cell_h_def, is_dram); + double c1 = (tagbits_)*(drain_C_(g_tp.w_comp_n,NCH,1, 1, g_tp.cell_h_def, is_dram) + + drain_C_(g_tp.w_comp_n,NCH,2, 1, g_tp.cell_h_def, is_dram)) + + drain_C_(g_tp.w_comp_p,PCH,1, 1, g_tp.cell_h_def, is_dram) + + gate_C(WmuxdrvNANDn+WmuxdrvNANDp,0, is_dram); + power_comparator.readOp.dynamic += 0.5 * c2 * g_tp.peri_global.Vdd * g_tp.peri_global.Vdd * 4 * A; + power_comparator.readOp.dynamic += c1 * g_tp.peri_global.Vdd * g_tp.peri_global.Vdd * (A - 1); + lkgCurrent += cmos_Isub_leakage(g_tp.w_eval_inv_n, g_tp.w_eval_inv_p, 1, inv, is_dram)* 4 * A; + lkgCurrent += cmos_Isub_leakage(g_tp.w_comp_n, g_tp.w_comp_n, 1, inv, is_dram)* 4 * A; // stack factor of 0.2 + + gatelkgCurrent += cmos_Ig_leakage(g_tp.w_eval_inv_n, g_tp.w_eval_inv_p, 1, inv, is_dram)* 4 * A; + gatelkgCurrent += cmos_Ig_leakage(g_tp.w_comp_n, g_tp.w_comp_n, 1, inv, is_dram)* 4 * A;//for gate leakage this equals to a inverter + + /* time to go to threshold of mux driver */ + double tstep = (r2*c2+(r1+r2)*c1)*log(1.0/VTHMUXNAND); + /* take into account non-zero input rise time */ + double m = g_tp.peri_global.Vdd/nextinputtime; + double Tcomparatorni; + + if((tstep) <= (0.5*(g_tp.peri_global.Vdd-g_tp.peri_global.Vth)/m)) + { + double a = m; + double b = 2*((g_tp.peri_global.Vdd*VTHEVALINV)-g_tp.peri_global.Vth); + double c = -2*(tstep)*(g_tp.peri_global.Vdd-g_tp.peri_global.Vth)+1/m*((g_tp.peri_global.Vdd*VTHEVALINV)-g_tp.peri_global.Vth)*((g_tp.peri_global.Vdd*VTHEVALINV)-g_tp.peri_global.Vth); + Tcomparatorni = (-b+sqrt(b*b-4*a*c))/(2*a); + } + else + { + Tcomparatorni = (tstep) + (g_tp.peri_global.Vdd+g_tp.peri_global.Vth)/(2*m) - (g_tp.peri_global.Vdd*VTHEVALINV)/m; + } + delay_comparator = Tcomparatorni+st1del+st2del+st3del; + power_comparator.readOp.leakage = lkgCurrent * g_tp.peri_global.Vdd; + power_comparator.readOp.gate_leakage = gatelkgCurrent * g_tp.peri_global.Vdd; + + return Tcomparatorni / (1.0 - VTHMUXNAND);; +} + + + +void Mat::compute_power_energy() +{ + //for cam and FA, power.readOp is the plain read power, power.searchOp is the associative search related power + //when search all subarrays and all mats are fully active + //when plain read/write only one subarray in a single mat is active. + + // add energy consumed in predecoder drivers. This unit is shared by all subarrays in a mat. + // FIXME + //CACTI3DD + if (g_ip->is_3d_mem) + { + if (g_ip->print_detail_debug) + cout << "mat.cc: subarray.num_cols = " << subarray.num_cols << endl; + power_bl_precharge_eq_drv.readOp.dynamic = bl_precharge_eq_drv->power.readOp.dynamic; + //power_bl_precharge_eq_drv = num_subarrays_per_mat; + + power_sa.readOp.dynamic *= subarray.num_cols; + + power_bitline.readOp.dynamic *= subarray.num_cols; + + power_subarray_out_drv.readOp.dynamic = power_subarray_out_drv.readOp.dynamic * g_ip->io_width * g_ip->burst_depth;//* subarray.num_cols; + + if (g_ip->print_detail_debug) + { + //cout<<"mat.cc: g_ip->burst_len = "<< g_ip->burst_len << endl; + cout<<"mat.cc: power_bl_precharge_eq_drv.readOp.dynamic = "<< power_bl_precharge_eq_drv.readOp.dynamic * 1e9 << " nJ" <power.readOp.dynamic + + b_mux_predec->power.readOp.dynamic + + sa_mux_lev_1_predec->power.readOp.dynamic + + sa_mux_lev_2_predec->power.readOp.dynamic; + + // add energy consumed in decoders + power_row_decoders.readOp.dynamic = row_dec->power.readOp.dynamic; + if (!(is_fa||pure_cam)) + power_row_decoders.readOp.dynamic *= num_subarrays_per_mat; + + // add energy consumed in bitline prechagers, SAs, and bitlines + if (!(is_fa||pure_cam)) + { + // add energy consumed in bitline prechagers + power_bl_precharge_eq_drv.readOp.dynamic = bl_precharge_eq_drv->power.readOp.dynamic; + power_bl_precharge_eq_drv.readOp.dynamic *= num_subarrays_per_mat; + + //Add sense amps energy + num_sa_subarray = subarray.num_cols / deg_bl_muxing; + power_sa.readOp.dynamic *= num_sa_subarray*num_subarrays_per_mat ; + + // add energy consumed in bitlines + //cout<<"bitline power"<power.readOp.dynamic) * num_do_b_mat; + + power.readOp.dynamic += power_bl_precharge_eq_drv.readOp.dynamic + + power_sa.readOp.dynamic + + power_bitline.readOp.dynamic + + power_subarray_out_drv.readOp.dynamic; + + power.readOp.dynamic += power_row_decoders.readOp.dynamic + + bit_mux_dec->power.readOp.dynamic + + sa_mux_lev_1_dec->power.readOp.dynamic + + sa_mux_lev_2_dec->power.readOp.dynamic + + power_comparator.readOp.dynamic; + } + + else if (is_fa) + { + //for plain read/write only one subarray in a mat is active + // add energy consumed in bitline prechagers + power_bl_precharge_eq_drv.readOp.dynamic = bl_precharge_eq_drv->power.readOp.dynamic + + cam_bl_precharge_eq_drv->power.readOp.dynamic; + power_bl_precharge_eq_drv.searchOp.dynamic = bl_precharge_eq_drv->power.readOp.dynamic; + + //Add sense amps energy + num_sa_subarray = (subarray.num_cols_fa_cam + subarray.num_cols_fa_ram)/ deg_bl_muxing; + num_sa_subarray_search = subarray.num_cols_fa_ram/ deg_bl_muxing; + power_sa.searchOp.dynamic = power_sa.readOp.dynamic*num_sa_subarray_search; + power_sa.readOp.dynamic *= num_sa_subarray; + + + // add energy consumed in bitlines + power_bitline.searchOp.dynamic = power_bitline.readOp.dynamic; + power_bitline.readOp.dynamic *= (subarray.num_cols_fa_cam+subarray.num_cols_fa_ram); + power_bitline.writeOp.dynamic *= (subarray.num_cols_fa_cam+subarray.num_cols_fa_ram); + power_bitline.searchOp.dynamic *= subarray.num_cols_fa_ram; + + //Add subarray output energy + power_subarray_out_drv.searchOp.dynamic = + (power_subarray_out_drv.readOp.dynamic + subarray_out_wire->power.readOp.dynamic) * num_so_b_mat; + power_subarray_out_drv.readOp.dynamic = + (power_subarray_out_drv.readOp.dynamic + subarray_out_wire->power.readOp.dynamic) * num_do_b_mat; + + + power.readOp.dynamic += power_bl_precharge_eq_drv.readOp.dynamic + + power_sa.readOp.dynamic + + power_bitline.readOp.dynamic + + power_subarray_out_drv.readOp.dynamic; + + power.readOp.dynamic += power_row_decoders.readOp.dynamic + + bit_mux_dec->power.readOp.dynamic + + sa_mux_lev_1_dec->power.readOp.dynamic + + sa_mux_lev_2_dec->power.readOp.dynamic + + power_comparator.readOp.dynamic; + + //add energy consumed inside cam + power_matchline.searchOp.dynamic *= num_subarrays_per_mat; + power_searchline_precharge = sl_precharge_eq_drv->power; + power_searchline_precharge.searchOp.dynamic = power_searchline_precharge.readOp.dynamic * num_subarrays_per_mat; + power_searchline = sl_data_drv->power; + power_searchline.searchOp.dynamic = power_searchline.readOp.dynamic*subarray.num_cols_fa_cam* num_subarrays_per_mat;; + power_matchline_precharge = ml_precharge_drv->power; + power_matchline_precharge.searchOp.dynamic = power_matchline_precharge.readOp.dynamic* num_subarrays_per_mat; + power_ml_to_ram_wl_drv= ml_to_ram_wl_drv->power; + power_ml_to_ram_wl_drv.searchOp.dynamic= ml_to_ram_wl_drv->power.readOp.dynamic; + + power_cam_all_active.searchOp.dynamic = power_matchline.searchOp.dynamic; + power_cam_all_active.searchOp.dynamic +=power_searchline_precharge.searchOp.dynamic; + power_cam_all_active.searchOp.dynamic +=power_searchline.searchOp.dynamic; + power_cam_all_active.searchOp.dynamic +=power_matchline_precharge.searchOp.dynamic; + + power.searchOp.dynamic += power_cam_all_active.searchOp.dynamic; + //power.searchOp.dynamic += ml_to_ram_wl_drv->power.readOp.dynamic; + + } + else + { + // add energy consumed in bitline prechagers + power_bl_precharge_eq_drv.readOp.dynamic = cam_bl_precharge_eq_drv->power.readOp.dynamic; + //power_bl_precharge_eq_drv.readOp.dynamic *= num_subarrays_per_mat; + //power_bl_precharge_eq_drv.searchOp.dynamic = cam_bl_precharge_eq_drv->power.readOp.dynamic; + //power_bl_precharge_eq_drv.searchOp.dynamic *= num_subarrays_per_mat; + + //Add sense amps energy + num_sa_subarray = subarray.num_cols_fa_cam/ deg_bl_muxing; + power_sa.readOp.dynamic *= num_sa_subarray;//*num_subarrays_per_mat; + power_sa.searchOp.dynamic = 0; + + power_bitline.readOp.dynamic *= subarray.num_cols_fa_cam; + power_bitline.searchOp.dynamic = 0; + power_bitline.writeOp.dynamic *= subarray.num_cols_fa_cam; + + power_subarray_out_drv.searchOp.dynamic = + (power_subarray_out_drv.readOp.dynamic + subarray_out_wire->power.readOp.dynamic) * num_so_b_mat; + power_subarray_out_drv.readOp.dynamic = + (power_subarray_out_drv.readOp.dynamic + subarray_out_wire->power.readOp.dynamic) * num_do_b_mat; + + power.readOp.dynamic += power_bl_precharge_eq_drv.readOp.dynamic + + power_sa.readOp.dynamic + + power_bitline.readOp.dynamic + + power_subarray_out_drv.readOp.dynamic; + + power.readOp.dynamic += power_row_decoders.readOp.dynamic + + bit_mux_dec->power.readOp.dynamic + + sa_mux_lev_1_dec->power.readOp.dynamic + + sa_mux_lev_2_dec->power.readOp.dynamic + + power_comparator.readOp.dynamic; + + + ////add energy consumed inside cam + power_matchline.searchOp.dynamic *= num_subarrays_per_mat; + power_searchline_precharge = sl_precharge_eq_drv->power; + power_searchline_precharge.searchOp.dynamic = power_searchline_precharge.readOp.dynamic * num_subarrays_per_mat; + power_searchline = sl_data_drv->power; + power_searchline.searchOp.dynamic = power_searchline.readOp.dynamic*subarray.num_cols_fa_cam* num_subarrays_per_mat;; + power_matchline_precharge = ml_precharge_drv->power; + power_matchline_precharge.searchOp.dynamic = power_matchline_precharge.readOp.dynamic* num_subarrays_per_mat; + power_ml_to_ram_wl_drv= ml_to_ram_wl_drv->power; + power_ml_to_ram_wl_drv.searchOp.dynamic= ml_to_ram_wl_drv->power.readOp.dynamic; + + power_cam_all_active.searchOp.dynamic = power_matchline.searchOp.dynamic; + power_cam_all_active.searchOp.dynamic +=power_searchline_precharge.searchOp.dynamic; + power_cam_all_active.searchOp.dynamic +=power_searchline.searchOp.dynamic; + power_cam_all_active.searchOp.dynamic +=power_matchline_precharge.searchOp.dynamic; + + power.searchOp.dynamic += power_cam_all_active.searchOp.dynamic; + //power.searchOp.dynamic += ml_to_ram_wl_drv->power.readOp.dynamic; + + } + + }//CACTI3DD + + int number_output_drivers_subarray; + + +// // calculate leakage power + if (!(is_fa || pure_cam)) + { + number_output_drivers_subarray = num_sa_subarray / (dp.Ndsam_lev_1 * dp.Ndsam_lev_2); + + power_bitline.readOp.leakage *= subarray.num_rows * subarray.num_cols * num_subarrays_per_mat; + power_bl_precharge_eq_drv.readOp.leakage = bl_precharge_eq_drv->power.readOp.leakage * num_subarrays_per_mat; + power_sa.readOp.leakage *= num_sa_subarray*num_subarrays_per_mat*(RWP + ERP); + + //num_sa_subarray = subarray.num_cols / deg_bl_muxing; + power_subarray_out_drv.readOp.leakage = + (power_subarray_out_drv.readOp.leakage + subarray_out_wire->power.readOp.leakage) * + number_output_drivers_subarray * num_subarrays_per_mat * (RWP + ERP); + + power.readOp.leakage += power_bitline.readOp.leakage + + power_bl_precharge_eq_drv.readOp.leakage + + power_sa.readOp.leakage + + power_subarray_out_drv.readOp.leakage; + + power_comparator.readOp.leakage *= num_do_b_mat * (RWP + ERP); + power.readOp.leakage += power_comparator.readOp.leakage; + + array_leakage = power_bitline.readOp.leakage; + + cl_leakage = + power_bl_precharge_eq_drv.readOp.leakage + + power_sa.readOp.leakage + + power_subarray_out_drv.readOp.leakage + + power_comparator.readOp.leakage; + + + + //Decoder blocks + power_row_decoders.readOp.leakage = row_dec->power.readOp.leakage * subarray.num_rows * num_subarrays_per_mat; + power_bit_mux_decoders.readOp.leakage = bit_mux_dec->power.readOp.leakage * deg_bl_muxing; + power_sa_mux_lev_1_decoders.readOp.leakage = sa_mux_lev_1_dec->power.readOp.leakage * dp.Ndsam_lev_1; + power_sa_mux_lev_2_decoders.readOp.leakage = sa_mux_lev_2_dec->power.readOp.leakage * dp.Ndsam_lev_2; + + if (!g_ip->wl_power_gated) + power.readOp.leakage += r_predec->power.readOp.leakage + + b_mux_predec->power.readOp.leakage + + sa_mux_lev_1_predec->power.readOp.leakage + + sa_mux_lev_2_predec->power.readOp.leakage + + power_row_decoders.readOp.leakage + + power_bit_mux_decoders.readOp.leakage + + power_sa_mux_lev_1_decoders.readOp.leakage + + power_sa_mux_lev_2_decoders.readOp.leakage; + else + power.readOp.leakage += (r_predec->power.readOp.leakage + + b_mux_predec->power.readOp.leakage + + sa_mux_lev_1_predec->power.readOp.leakage + + sa_mux_lev_2_predec->power.readOp.leakage + + power_row_decoders.readOp.leakage + + power_bit_mux_decoders.readOp.leakage + + power_sa_mux_lev_1_decoders.readOp.leakage + + power_sa_mux_lev_2_decoders.readOp.leakage)/g_tp.peri_global.Vdd*g_tp.peri_global.Vcc_min; + + wl_leakage = r_predec->power.readOp.leakage + + b_mux_predec->power.readOp.leakage + + sa_mux_lev_1_predec->power.readOp.leakage + + sa_mux_lev_2_predec->power.readOp.leakage + + power_row_decoders.readOp.leakage + + power_bit_mux_decoders.readOp.leakage + + power_sa_mux_lev_1_decoders.readOp.leakage + + power_sa_mux_lev_2_decoders.readOp.leakage; + + //++++Below is gate leakage + power_bitline.readOp.gate_leakage *= subarray.num_rows * subarray.num_cols * num_subarrays_per_mat; + power_bl_precharge_eq_drv.readOp.gate_leakage = bl_precharge_eq_drv->power.readOp.gate_leakage * num_subarrays_per_mat; + power_sa.readOp.gate_leakage *= num_sa_subarray*num_subarrays_per_mat*(RWP + ERP); + + //num_sa_subarray = subarray.num_cols / deg_bl_muxing; + power_subarray_out_drv.readOp.gate_leakage = + (power_subarray_out_drv.readOp.gate_leakage + subarray_out_wire->power.readOp.gate_leakage) * + number_output_drivers_subarray * num_subarrays_per_mat * (RWP + ERP); + + power.readOp.gate_leakage += power_bitline.readOp.gate_leakage + + power_bl_precharge_eq_drv.readOp.gate_leakage + + power_sa.readOp.gate_leakage + + power_subarray_out_drv.readOp.gate_leakage; + //cout<<"leakage"<power_gating) + { + + //cout<<"leakage1"<area.get_area()*subarray.num_cols * num_subarrays_per_mat*dp.num_mats; + array_wakeup_e.readOp.dynamic = sram_sleep_tx->wakeup_power.readOp.dynamic * num_subarrays_per_mat*subarray.num_cols*dp.num_act_mats_hor_dir; + array_wakeup_t = sram_sleep_tx->wakeup_delay; + + wl_sleep_tx_area = row_dec->sleeptx->area.get_area()*subarray.num_rows * num_subarrays_per_mat*dp.num_mats; + wl_wakeup_e.readOp.dynamic = row_dec->sleeptx->wakeup_power.readOp.dynamic * num_subarrays_per_mat*subarray.num_rows*dp.num_act_mats_hor_dir; + wl_wakeup_t = row_dec->sleeptx->wakeup_delay; + + } + + // gate_leakage power + power_row_decoders.readOp.gate_leakage = row_dec->power.readOp.gate_leakage * subarray.num_rows * num_subarrays_per_mat; + power_bit_mux_decoders.readOp.gate_leakage = bit_mux_dec->power.readOp.gate_leakage * deg_bl_muxing; + power_sa_mux_lev_1_decoders.readOp.gate_leakage = sa_mux_lev_1_dec->power.readOp.gate_leakage * dp.Ndsam_lev_1; + power_sa_mux_lev_2_decoders.readOp.gate_leakage = sa_mux_lev_2_dec->power.readOp.gate_leakage * dp.Ndsam_lev_2; + + power.readOp.gate_leakage += r_predec->power.readOp.gate_leakage + + b_mux_predec->power.readOp.gate_leakage + + sa_mux_lev_1_predec->power.readOp.gate_leakage + + sa_mux_lev_2_predec->power.readOp.gate_leakage + + power_row_decoders.readOp.gate_leakage + + power_bit_mux_decoders.readOp.gate_leakage + + power_sa_mux_lev_1_decoders.readOp.gate_leakage + + power_sa_mux_lev_2_decoders.readOp.gate_leakage; + } + else if (is_fa) + { + int number_output_drivers_subarray = num_sa_subarray;// / (dp.Ndsam_lev_1 * dp.Ndsam_lev_2); + + power_bitline.readOp.leakage *= subarray.num_rows * subarray.num_cols * num_subarrays_per_mat; + power_bl_precharge_eq_drv.readOp.leakage = bl_precharge_eq_drv->power.readOp.leakage * num_subarrays_per_mat; + power_bl_precharge_eq_drv.searchOp.leakage = cam_bl_precharge_eq_drv->power.readOp.leakage * num_subarrays_per_mat; + power_sa.readOp.leakage *= num_sa_subarray*num_subarrays_per_mat*(RWP + ERP + SCHP); + + //cout<<"leakage3"<power.readOp.leakage) * + number_output_drivers_subarray * num_subarrays_per_mat * (RWP + ERP + SCHP); + + power.readOp.leakage += power_bitline.readOp.leakage + + power_bl_precharge_eq_drv.readOp.leakage + + power_bl_precharge_eq_drv.searchOp.leakage + + power_sa.readOp.leakage + + power_subarray_out_drv.readOp.leakage; + + //cout<<"leakage4"<power.readOp.leakage * subarray.num_rows * num_subarrays_per_mat; + power.readOp.leakage += r_predec->power.readOp.leakage + + power_row_decoders.readOp.leakage; + + //cout<<"leakage5"<power.readOp.leakage; + power_cam_all_active.searchOp.leakage +=sl_data_drv->power.readOp.leakage*subarray.num_cols_fa_cam; + power_cam_all_active.searchOp.leakage +=ml_precharge_drv->power.readOp.dynamic; + power_cam_all_active.searchOp.leakage *= num_subarrays_per_mat; + + power.readOp.leakage += power_cam_all_active.searchOp.leakage; + +// cout<<"leakage6"<power.readOp.gate_leakage * num_subarrays_per_mat; + power_bl_precharge_eq_drv.searchOp.gate_leakage = cam_bl_precharge_eq_drv->power.readOp.gate_leakage * num_subarrays_per_mat; + power_sa.readOp.gate_leakage *= num_sa_subarray*num_subarrays_per_mat*(RWP + ERP + SCHP); + + //cout<<"leakage3"<power.readOp.gate_leakage) * + number_output_drivers_subarray * num_subarrays_per_mat * (RWP + ERP + SCHP); + + power.readOp.gate_leakage += power_bitline.readOp.gate_leakage + + power_bl_precharge_eq_drv.readOp.gate_leakage + + power_bl_precharge_eq_drv.searchOp.gate_leakage + + power_sa.readOp.gate_leakage + + power_subarray_out_drv.readOp.gate_leakage; + + //cout<<"leakage4"<power.readOp.gate_leakage * subarray.num_rows * num_subarrays_per_mat; + power.readOp.gate_leakage += r_predec->power.readOp.gate_leakage + + power_row_decoders.readOp.gate_leakage; + + //cout<<"leakage5"<power.readOp.gate_leakage; + power_cam_all_active.searchOp.gate_leakage +=sl_data_drv->power.readOp.gate_leakage*subarray.num_cols_fa_cam; + power_cam_all_active.searchOp.gate_leakage +=ml_precharge_drv->power.readOp.dynamic; + power_cam_all_active.searchOp.gate_leakage *= num_subarrays_per_mat; + + power.readOp.gate_leakage += power_cam_all_active.searchOp.gate_leakage; + + } + else + { + int number_output_drivers_subarray = num_sa_subarray;// / (dp.Ndsam_lev_1 * dp.Ndsam_lev_2); + + //power_bitline.readOp.leakage *= subarray.num_rows * subarray.num_cols * num_subarrays_per_mat; + //power_bl_precharge_eq_drv.readOp.leakage = bl_precharge_eq_drv->power.readOp.leakage * num_subarrays_per_mat; + power_bl_precharge_eq_drv.searchOp.leakage = cam_bl_precharge_eq_drv->power.readOp.leakage * num_subarrays_per_mat; + power_sa.readOp.leakage *= num_sa_subarray*num_subarrays_per_mat*(RWP + ERP + SCHP); + + + power_subarray_out_drv.readOp.leakage = + (power_subarray_out_drv.readOp.leakage + subarray_out_wire->power.readOp.leakage) * + number_output_drivers_subarray * num_subarrays_per_mat * (RWP + ERP + SCHP); + + power.readOp.leakage += //power_bitline.readOp.leakage + + //power_bl_precharge_eq_drv.readOp.leakage + + power_bl_precharge_eq_drv.searchOp.leakage + + power_sa.readOp.leakage + + power_subarray_out_drv.readOp.leakage; + + // leakage power + power_row_decoders.readOp.leakage = row_dec->power.readOp.leakage * subarray.num_rows * num_subarrays_per_mat*(RWP + ERP + EWP); + power.readOp.leakage += r_predec->power.readOp.leakage + + power_row_decoders.readOp.leakage; + + //inside cam + power_cam_all_active.searchOp.leakage = power_matchline.searchOp.leakage; + power_cam_all_active.searchOp.leakage +=sl_precharge_eq_drv->power.readOp.leakage; + power_cam_all_active.searchOp.leakage +=sl_data_drv->power.readOp.leakage*subarray.num_cols_fa_cam; + power_cam_all_active.searchOp.leakage +=ml_precharge_drv->power.readOp.dynamic; + power_cam_all_active.searchOp.leakage *= num_subarrays_per_mat; + + power.readOp.leakage += power_cam_all_active.searchOp.leakage; + + //+++Below is gate leakage + power_bl_precharge_eq_drv.searchOp.gate_leakage = cam_bl_precharge_eq_drv->power.readOp.gate_leakage * num_subarrays_per_mat; + power_sa.readOp.gate_leakage *= num_sa_subarray*num_subarrays_per_mat*(RWP + ERP + SCHP); + + + power_subarray_out_drv.readOp.gate_leakage = + (power_subarray_out_drv.readOp.gate_leakage + subarray_out_wire->power.readOp.gate_leakage) * + number_output_drivers_subarray * num_subarrays_per_mat * (RWP + ERP + SCHP); + + power.readOp.gate_leakage += //power_bitline.readOp.gate_leakage + + //power_bl_precharge_eq_drv.readOp.gate_leakage + + power_bl_precharge_eq_drv.searchOp.gate_leakage + + power_sa.readOp.gate_leakage + + power_subarray_out_drv.readOp.gate_leakage; + + // gate_leakage power + power_row_decoders.readOp.gate_leakage = row_dec->power.readOp.gate_leakage * subarray.num_rows * num_subarrays_per_mat*(RWP + ERP + EWP); + power.readOp.gate_leakage += r_predec->power.readOp.gate_leakage + + power_row_decoders.readOp.gate_leakage; + + //inside cam + power_cam_all_active.searchOp.gate_leakage = power_matchline.searchOp.gate_leakage; + power_cam_all_active.searchOp.gate_leakage +=sl_precharge_eq_drv->power.readOp.gate_leakage; + power_cam_all_active.searchOp.gate_leakage +=sl_data_drv->power.readOp.gate_leakage*subarray.num_cols_fa_cam; + power_cam_all_active.searchOp.gate_leakage +=ml_precharge_drv->power.readOp.dynamic; + power_cam_all_active.searchOp.gate_leakage *= num_subarrays_per_mat; + + power.readOp.gate_leakage += power_cam_all_active.searchOp.gate_leakage; + } +} + diff --git a/Project_FARSI/cacti_for_FARSI/mat.h b/Project_FARSI/cacti_for_FARSI/mat.h new file mode 100644 index 00000000..c265e509 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/mat.h @@ -0,0 +1,176 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + + +#ifndef __MAT_H__ +#define __MAT_H__ + +#include "component.h" +#include "decoder.h" +#include "wire.h" +#include "subarray.h" +#include "powergating.h" + +class Mat : public Component +{ + public: + Mat(const DynamicParameter & dyn_p); + ~Mat(); + double compute_delays(double inrisetime); // return outrisetime + void compute_power_energy(); + + const DynamicParameter & dp; + + // TODO: clean up pointers and powerDefs below + Decoder * row_dec; + Decoder * bit_mux_dec; + Decoder * sa_mux_lev_1_dec; + Decoder * sa_mux_lev_2_dec; + PredecBlk * dummy_way_sel_predec_blk1; + PredecBlk * dummy_way_sel_predec_blk2; + PredecBlkDrv * way_sel_drv1; + PredecBlkDrv * dummy_way_sel_predec_blk_drv2; + + Predec * r_predec; + Predec * b_mux_predec; + Predec * sa_mux_lev_1_predec; + Predec * sa_mux_lev_2_predec; + + Wire * subarray_out_wire; + Driver * bl_precharge_eq_drv; + Driver * cam_bl_precharge_eq_drv;//bitline pre-charge circuit is separated for CAM and RAM arrays. + Driver * ml_precharge_drv;//matchline prechange driver + Driver * sl_precharge_eq_drv;//searchline prechage driver + Driver * sl_data_drv;//search line data driver + Driver * ml_to_ram_wl_drv;//search line data driver + + + powerDef power_row_decoders; + powerDef power_bit_mux_decoders; + powerDef power_sa_mux_lev_1_decoders; + powerDef power_sa_mux_lev_2_decoders; + powerDef power_fa_cam; // TODO: leakage power is not computed yet + powerDef power_bl_precharge_eq_drv; + powerDef power_subarray_out_drv; + powerDef power_cam_all_active; + powerDef power_searchline_precharge; + powerDef power_matchline_precharge; + powerDef power_ml_to_ram_wl_drv; + + double delay_fa_tag, delay_cam; + double delay_before_decoder; + double delay_bitline; + double delay_wl_reset; + double delay_bl_restore; + + double delay_searchline; + double delay_matchchline; + double delay_cam_sl_restore; + double delay_cam_ml_reset; + double delay_fa_ram_wl; + + double delay_hit_miss_reset; + double delay_hit_miss; + + Subarray subarray; + powerDef power_bitline, power_searchline, power_matchline, power_bitline_gated; + double per_bitline_read_energy; + int deg_bl_muxing; + int num_act_mats_hor_dir; + double delay_writeback; + Area cell,cam_cell; + bool is_dram,is_fa, pure_cam, camFlag; + int num_mats; + powerDef power_sa; + double delay_sa; + double leak_power_sense_amps_closed_page_state; + double leak_power_sense_amps_open_page_state; + double delay_subarray_out_drv; + double delay_subarray_out_drv_htree; + double delay_comparator; + powerDef power_comparator; + int num_do_b_mat; + int num_so_b_mat; + int num_sa_subarray; + int num_sa_subarray_search; + double C_bl; + + uint32_t num_subarrays_per_mat; // the number of subarrays in a mat + uint32_t num_subarrays_per_row; // the number of subarrays in a row of a mat + + double array_leakage; + double wl_leakage; + double cl_leakage; + + Sleep_tx * sram_sleep_tx; + Sleep_tx * wl_sleep_tx; + Sleep_tx * cl_sleep_tx; + + powerDef array_wakeup_e; + double array_wakeup_t; + double array_sleep_tx_area; + + powerDef blfloating_wakeup_e; + double blfloating_wakeup_t; + double blfloating_sleep_tx_area; + + powerDef wl_wakeup_e; + double wl_wakeup_t; + double wl_sleep_tx_area; + + powerDef cl_wakeup_e; + double cl_wakeup_t; + double cl_sleep_tx_area; + + double compute_bitline_delay(double inrisetime); + double compute_sa_delay(double inrisetime); + double compute_subarray_out_drv(double inrisetime); + + private: + double compute_bit_mux_sa_precharge_sa_mux_wr_drv_wr_mux_h(); + double width_write_driver_or_write_mux(); + double compute_comparators_height(int tagbits, int number_ways_in_mat, double subarray_mem_cell_area_w); + double compute_cam_delay(double inrisetime); + //double compute_bitline_delay(double inrisetime); + //double compute_sa_delay(double inrisetime); + //double compute_subarray_out_drv(double inrisetime); + double compute_comparator_delay(double inrisetime); + + int RWP; + int ERP; + int EWP; + int SCHP; +}; + + + +#endif diff --git a/Project_FARSI/cacti_for_FARSI/memcad.cc b/Project_FARSI/cacti_for_FARSI/memcad.cc new file mode 100644 index 00000000..64bf32aa --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/memcad.cc @@ -0,0 +1,599 @@ +#include "memcad.h" +#include +#include +#include +#include +#include +#include + +using namespace std; + + +vector *memcad_all_channels; + +vector *memcad_all_bobs; + +vector *memcad_all_memories; + +vector *memcad_best_results; + +bool compare_channels(channel_conf* first, channel_conf* second) +{ + if(first->capacity != second->capacity) + return (first->capacity < second->capacity); + + MemCad_metrics first_metric = first->memcad_params->first_metric; + MemCad_metrics second_metric = first->memcad_params->second_metric; + MemCad_metrics third_metric = first->memcad_params->third_metric; + + switch(first_metric) + { + case(Cost): + if(first->cost != second->cost) + return (first->cost < second->cost); + break; + case(Bandwidth): + if(first->bandwidth != second->bandwidth) + return (first->bandwidth > second->bandwidth); + break; + case(Energy): + if( fabs(first->energy_per_access - second->energy_per_access)>EPS) + return (first->energy_per_access < second->energy_per_access); + break; + default: + assert(false); + } + + switch(second_metric) + { + case(Cost): + if(first->cost != second->cost) + return (first->cost < second->cost); + break; + case(Bandwidth): + if(first->bandwidth != second->bandwidth) + return (first->bandwidth > second->bandwidth); + break; + case(Energy): + if( fabs(first->energy_per_access - second->energy_per_access)>EPS) + return (first->energy_per_access < second->energy_per_access); + break; + default: + assert(false); + } + + switch(third_metric) + { + case(Cost): + if(first->cost != second->cost) + return (first->cost < second->cost); + break; + case(Bandwidth): + if(first->bandwidth != second->bandwidth) + return (first->bandwidth > second->bandwidth); + break; + case(Energy): + if( fabs(first->energy_per_access - second->energy_per_access)>EPS) + return (first->energy_per_access < second->energy_per_access); + break; + default: + assert(false); + } + return true; +} + + +void prune_channels() +{ + vector * temp = new vector(); + int last_added = -1; + for(unsigned int i=0;i< memcad_all_channels->size();i++) + { + if(last_added != (*memcad_all_channels)[i]->capacity) + { + temp->push_back(clone((*memcad_all_channels)[i])); + last_added = (*memcad_all_channels)[i]->capacity; + } + } + + for(unsigned int i=0;i< memcad_all_channels->size();i++) + { + delete (*memcad_all_channels)[i]; + } + memcad_all_channels->clear(); + delete memcad_all_channels; + memcad_all_channels = temp; +} + +void find_all_channels(MemCadParameters * memcad_params) +{ + + int DIMM_size[]={0,4,8,16,32,64}; + Mem_IO_type current_io_type = memcad_params->io_type; + DIMM_Model current_dimm_model = memcad_params->dimm_model; + + + memcad_all_channels= new vector(); + + // channels can have up to 3 DIMMs per channel + // di is the capacity if i-th dimm in the channel + for(int d1=0; d1<6;d1++) + { + for(int d2=d1;d2<6;d2++) + { + for(int d3=d2;d3<6;d3++) + { + // channel capacity should not exceed the entire memory capacity. + if((DIMM_size[d1]+DIMM_size[d2]+DIMM_size[d3])>memcad_params->capacity) + continue; + + if( ((current_dimm_model== JUST_LRDIMM) || (current_dimm_model== ALL)) + && ((d1==0) || (MemoryParameters::cost[current_io_type][2][d1-1] dimm_cap; + dimm_cap.push_back(DIMM_size[d1]); if(d1>0) num_dimm_per_channel++; + dimm_cap.push_back(DIMM_size[d2]); if(d2>0) num_dimm_per_channel++; + dimm_cap.push_back(DIMM_size[d3]); if(d3>0) num_dimm_per_channel++; + + int max_index = bw_index(current_io_type, MemoryParameters::bandwidth_load[current_io_type][4-num_dimm_per_channel]); + for(int bw_id=0;bw_id<=max_index; ++bw_id) + { + int bandwidth = MemoryParameters::bandwidth_load[current_io_type][bw_id]; + channel_conf * new_channel = new channel_conf(memcad_params, dimm_cap, bandwidth, LRDIMM, false); + if(new_channel->cost push_back(new_channel); + } + + if((DIMM_size[d1]+DIMM_size[d2]+DIMM_size[d3])==0) + continue; + + if(memcad_params->low_power_permitted) + { + new_channel = new channel_conf(memcad_params, dimm_cap, bandwidth, LRDIMM, true); + if(new_channel->cost push_back(new_channel); + } + } + + } + } + + if( (current_dimm_model== JUST_RDIMM) || (current_dimm_model== ALL) + && ((d1==0) || (MemoryParameters::cost[current_io_type][1][d1-1] dimm_cap; + dimm_cap.push_back(DIMM_size[d1]); if(d1>0) num_dimm_per_channel++; + dimm_cap.push_back(DIMM_size[d2]); if(d2>0) num_dimm_per_channel++; + dimm_cap.push_back(DIMM_size[d3]); if(d3>0) num_dimm_per_channel++; + + if((DIMM_size[d1]+DIMM_size[d2]+DIMM_size[d3])==0) + continue; + + int max_index = bw_index(current_io_type, MemoryParameters::bandwidth_load[current_io_type][4-num_dimm_per_channel]); + + for(int bw_id=0;bw_id<=max_index; ++bw_id) + { + int bandwidth = MemoryParameters::bandwidth_load[current_io_type][bw_id]; + channel_conf * new_channel = new channel_conf(memcad_params, dimm_cap, bandwidth, RDIMM, false); + if(new_channel->cost push_back(new_channel); + } + + if(memcad_params->low_power_permitted) + { + new_channel = new channel_conf(memcad_params, dimm_cap, bandwidth, RDIMM, true); + if(new_channel->cost push_back(new_channel); + } + } + } + } + + if( (current_dimm_model== JUST_UDIMM) || (current_dimm_model== ALL) + && ((d1==0) || (MemoryParameters::cost[current_io_type][0][d1-1] dimm_cap; + dimm_cap.push_back(DIMM_size[d1]); if(d1>0) num_dimm_per_channel++; + dimm_cap.push_back(DIMM_size[d2]); if(d2>0) num_dimm_per_channel++; + dimm_cap.push_back(DIMM_size[d3]); if(d3>0) num_dimm_per_channel++; + + if((DIMM_size[d1]+DIMM_size[d2]+DIMM_size[d3])==0) + continue; + int max_index = bw_index(current_io_type, MemoryParameters::bandwidth_load[current_io_type][4-num_dimm_per_channel]); + for(int bw_id=0;bw_id<=max_index; ++bw_id) + { + int bandwidth = MemoryParameters::bandwidth_load[current_io_type][bw_id]; + channel_conf * new_channel = new channel_conf(memcad_params, dimm_cap, bandwidth, UDIMM, false); + if(new_channel->cost push_back(new_channel); + } + + if(memcad_params->low_power_permitted) + { + new_channel = new channel_conf(memcad_params, dimm_cap, bandwidth, UDIMM, true); + if(new_channel->cost push_back(new_channel); + } + } + } + } + + } + } + } + + sort(memcad_all_channels->begin(), memcad_all_channels->end(), compare_channels); + + + prune_channels(); + + if(memcad_params->verbose) + { + for(unsigned int i=0;isize();i++) + { + cout << *(*memcad_all_channels)[i] << endl; + } + } + +} + +bool compare_channels_bw(channel_conf* first, channel_conf* second) +{ + return (first->bandwidth < second->bandwidth); +} + +bool compare_bobs(bob_conf* first, bob_conf* second) +{ + if(first->capacity != second->capacity) + return (first->capacity < second->capacity); + + MemCad_metrics first_metric = first->memcad_params->first_metric; + MemCad_metrics second_metric = first->memcad_params->second_metric; + MemCad_metrics third_metric = first->memcad_params->third_metric; + + switch(first_metric) + { + case(Cost): + if(first->cost != second->cost) + return (first->cost < second->cost); + break; + case(Bandwidth): + if(first->bandwidth != second->bandwidth) + return (first->bandwidth > second->bandwidth); + break; + case(Energy): + if( fabs(first->energy_per_access - second->energy_per_access)>EPS) + return (first->energy_per_access < second->energy_per_access); + break; + default: + assert(false); + } + + switch(second_metric) + { + case(Cost): + if(first->cost != second->cost) + return (first->cost < second->cost); + break; + case(Bandwidth): + if(first->bandwidth != second->bandwidth) + return (first->bandwidth > second->bandwidth); + break; + case(Energy): + if( fabs(first->energy_per_access - second->energy_per_access)>EPS) + return (first->energy_per_access < second->energy_per_access); + break; + default: + assert(false); + } + + switch(third_metric) + { + case(Cost): + if(first->cost != second->cost) + return (first->cost < second->cost); + break; + case(Bandwidth): + if(first->bandwidth != second->bandwidth) + return (first->bandwidth > second->bandwidth); + break; + case(Energy): + if( fabs(first->energy_per_access - second->energy_per_access)>EPS) + return (first->energy_per_access < second->energy_per_access); + break; + default: + assert(false); + } + return true; +} + +void prune_bobs() +{ + vector * temp = new vector(); + int last_added = -1; + for(unsigned int i=0;i< memcad_all_bobs->size();i++) + { + if(last_added != (*memcad_all_bobs)[i]->capacity) + { + temp->push_back(clone((*memcad_all_bobs)[i])); + last_added = (*memcad_all_bobs)[i]->capacity; + } + } + + for(unsigned int i=0;i< memcad_all_bobs->size();i++) + { + delete (*memcad_all_bobs)[i]; + } + memcad_all_bobs->clear(); + delete memcad_all_bobs; + memcad_all_bobs = temp; +} + +void find_bobs_recursive(MemCadParameters * memcad_params,int start,int end,int nb, list *channel_index) +{ + if(nb==1) + { + for(int i=start; i<=end;++i) + { + channel_index->push_back(i); + + vector temp; + for(list::iterator it= channel_index->begin(); it!= channel_index->end(); it++) + { + int idx = *it; + temp.push_back((*memcad_all_channels)[idx]); + } + memcad_all_bobs->push_back(new bob_conf(memcad_params, &temp)); + temp.clear(); + + channel_index->pop_back(); + } + return; + } + for(int i=start;i<=end;++i) + { + channel_index->push_back(i); + find_bobs_recursive(memcad_params,i,end,nb-1,channel_index); + channel_index->pop_back(); + } +} + +void find_all_bobs(MemCadParameters * memcad_params) +{ + memcad_all_bobs = new vector(); + if(memcad_params->mirror_in_bob) + { + for(unsigned int i=0;isize();++i) + { + vector channels; + for(int j=0;jnum_channels_per_bob;j++) + channels.push_back((*memcad_all_channels)[i]); + memcad_all_bobs->push_back(new bob_conf(memcad_params, &channels)); + channels.clear(); + } + } + else if(memcad_params->same_bw_in_bob) + { + sort(memcad_all_channels->begin(), memcad_all_channels->end(), compare_channels_bw); + vector start_index; start_index.push_back(0); + vector end_index; + int last_bw =(*memcad_all_channels)[0]->bandwidth; + for(unsigned int i=0;i< memcad_all_channels->size();i++) + { + if(last_bw!=(*memcad_all_channels)[i]->bandwidth) + { + end_index.push_back(i-1); + start_index.push_back(i); + last_bw = (*memcad_all_channels)[i]->bandwidth; + } + } + end_index.push_back(memcad_all_channels->size()-1); + + list channel_index; + + for(unsigned int i=0;i< start_index.size();++i) + { + find_bobs_recursive(memcad_params,start_index[i],end_index[i],memcad_params->num_channels_per_bob, &channel_index); + } + + } + else + { + cout << "We do not support different frequencies per in a BoB!" << endl; + assert(false); + } + + + sort(memcad_all_bobs->begin(), memcad_all_bobs->end(), compare_bobs); + prune_bobs(); + if(memcad_params->verbose) + { + for(unsigned int i=0;isize();i++) + { + cout << *(*memcad_all_bobs)[i] << endl; + } + } +} + +void find_mems_recursive(MemCadParameters * memcad_params, int remaining_capacity, int start, int nb, list* bobs_index) +{ + + if(nb==1) + { + for(unsigned int i=start; i< memcad_all_bobs->size();++i) + { + if((*memcad_all_bobs)[i]->capacity != remaining_capacity) + continue; + + bobs_index->push_back(i); + vector temp; + for(list::iterator it= bobs_index->begin(); it!= bobs_index->end(); it++) + { + int index = *it; + temp.push_back((*memcad_all_bobs)[index]); + } + memcad_all_memories->push_back(new memory_conf(memcad_params, &temp)); + temp.clear(); + bobs_index->pop_back(); + } + return; + } + + for(unsigned int i=start; isize();i++) + { + if((*memcad_all_bobs)[i]->capacity > remaining_capacity) + continue; + + int new_remaining_capacity = remaining_capacity-(*memcad_all_bobs)[i]->capacity; + bobs_index->push_back(i); + find_mems_recursive(memcad_params, new_remaining_capacity, i, nb-1, bobs_index); + bobs_index->pop_back(); + } +} + +//void find_mems_recursive(MemCadParameters * memcad_params, int start, int + +bool compare_memories(memory_conf* first, memory_conf* second) +{ + if(first->capacity != second->capacity) + return (first->capacity < second->capacity); + + MemCad_metrics first_metric = first->memcad_params->first_metric; + MemCad_metrics second_metric = first->memcad_params->second_metric; + MemCad_metrics third_metric = first->memcad_params->third_metric; + + switch(first_metric) + { + case(Cost): + if(first->cost != second->cost) + return (first->cost < second->cost); + break; + case(Bandwidth): + if(first->bandwidth != second->bandwidth) + return (first->bandwidth > second->bandwidth); + break; + case(Energy): + if( fabs(first->energy_per_access - second->energy_per_access)>EPS) + return (first->energy_per_access < second->energy_per_access); + break; + default: + assert(false); + } + + switch(second_metric) + { + case(Cost): + if(first->cost != second->cost) + return (first->cost < second->cost); + break; + case(Bandwidth): + if(first->bandwidth != second->bandwidth) + return (first->bandwidth > second->bandwidth); + break; + case(Energy): + if( fabs(first->energy_per_access - second->energy_per_access)>EPS) + return (first->energy_per_access < second->energy_per_access); + break; + default: + assert(false); + } + + switch(third_metric) + { + case(Cost): + if(first->cost != second->cost) + return (first->cost < second->cost); + break; + case(Bandwidth): + if(first->bandwidth != second->bandwidth) + return (first->bandwidth > second->bandwidth); + break; + case(Energy): + if( fabs(first->energy_per_access - second->energy_per_access)>EPS) + return (first->energy_per_access < second->energy_per_access); + break; + default: + assert(false); + } + return true; +} + +bool find_all_memories(MemCadParameters * memcad_params) +{ + memcad_all_memories = new vector(); + + list bobs_index; + find_mems_recursive(memcad_params, memcad_params->capacity, 0,memcad_params->num_bobs, &bobs_index); + + sort(memcad_all_memories->begin(), memcad_all_memories->end(), compare_memories); + + if(memcad_params->verbose) + { + cout << "all possible results:" << endl; + for(unsigned int i=0;isize();i++) + { + cout << *(*memcad_all_memories)[i] << endl; + } + } + if(memcad_all_memories->size()==0) + { + cout << "No result found " << endl; + return false; + } + cout << "top 3 best memory configurations are:" << endl; + int min_num_results = (memcad_all_memories->size()>3?3:memcad_all_memories->size()); + for(int i=0;isize();++i) + { + delete (*memcad_all_channels)[i]; + } + delete memcad_all_channels; + + for(unsigned int i=0;isize();++i) + { + delete (*memcad_all_bobs)[i]; + } + delete memcad_all_bobs; + + for(unsigned int i=0;isize();++i) + { + delete (*memcad_all_memories)[i]; + } + delete memcad_all_memories; +} + + +void solve_memcad(MemCadParameters * memcad_params) +{ + + find_all_channels(memcad_params); + find_all_bobs(memcad_params); + find_all_memories(memcad_params); + clean_results(); +} + diff --git a/Project_FARSI/cacti_for_FARSI/memcad.h b/Project_FARSI/cacti_for_FARSI/memcad.h new file mode 100644 index 00000000..fa534e34 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/memcad.h @@ -0,0 +1,30 @@ +#ifndef __MEMCAD_H__ +#define __MEMCAD_H__ + +#include "memcad_parameters.h" +#include + + +extern vector *memcad_all_channels; + +extern vector *memcad_all_bobs; + +extern vector *memcad_all_memories; + +extern vector *memcad_best_results; + + + +void find_all_channels(MemCadParameters * memcad_params); + +void find_all_bobs(MemCadParameters * memcad_params); + +bool find_all_memories(MemCadParameters * memcad_params); + +void clean_results(); + +void solve_memcad(MemCadParameters * memcad_params); + +#endif + + diff --git a/Project_FARSI/cacti_for_FARSI/memcad_parameters.cc b/Project_FARSI/cacti_for_FARSI/memcad_parameters.cc new file mode 100644 index 00000000..295e4315 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/memcad_parameters.cc @@ -0,0 +1,466 @@ +#include "memcad_parameters.h" +#include +#include + +MemCadParameters::MemCadParameters(InputParameter * g_ip) +{ + // default value + io_type=DDR4; // DDR3 vs. DDR4 + capacity=400; // in GB + num_bobs=4; // default=4me + num_channels_per_bob=2; // 1 means no bob + capacity_wise=true; // true means the load on each channel is proportional to its capacity. + first_metric=Cost; + second_metric=Bandwidth; + third_metric=Energy; + dimm_model=ALL; + low_power_permitted=false; + load=0.9; // between 0 to 1 + row_buffer_hit_rate=1; + rd_2_wr_ratio=2; + same_bw_in_bob=true; // true if all the channels in the bob have the same bandwidth + mirror_in_bob=true;// true if all the channels in the bob have the same configs + total_power=false; // false means just considering I/O Power. + verbose=false; + // values for input + io_type=g_ip->io_type; + capacity=g_ip->capacity; + num_bobs=g_ip->num_bobs; + num_channels_per_bob=g_ip->num_channels_per_bob; + first_metric=g_ip->first_metric; + second_metric=g_ip->second_metric; + third_metric=g_ip->third_metric; + dimm_model=g_ip->dimm_model; + ///low_power_permitted=g_ip->low_power_permitted; + ///load=g_ip->load; + ///row_buffer_hit_rate=g_ip->row_buffer_hit_rate; + ///rd_2_wr_ratio=g_ip->rd_2_wr_ratio; + ///same_bw_in_bob=g_ip->same_bw_in_bob; + mirror_in_bob=g_ip->mirror_in_bob; + ///total_power=g_ip->total_power; + verbose=g_ip->verbose; + +} + +void MemCadParameters::print_inputs() +{ + +} + +bool MemCadParameters::sanity_check() +{ + + return true; +} + + +double MemoryParameters::VDD[2][2][4]= //[lp:hp][ddr3:ddr4][frequency index] +{ + { + {1.5,1.5,1.5,1.5}, + {1.2,1.2,1.2,1.2} + }, + { + {1.35,1.35,1.35,1.35}, + {1.0,1.0,1.0,1.0} + } +}; + +double MemoryParameters::IDD0[2][4]= +{ + {55,60,65,75}, + {58,58,60,64} +}; + +double MemoryParameters::IDD2P0[2][4]= +{ + {20,20,20,20}, + {20,20,20,20} +}; + +double MemoryParameters::IDD2P1[2][4]= +{ + {30,30,32,37}, + {30,30,30,32} +}; + +double MemoryParameters::IDD2N[2][4]= +{ + {40,42,45,50}, + {44,44,46,50} +}; + +double MemoryParameters::IDD3P[2][4]= +{ + {45,50,55,60}, + {44,44,44,44} +}; + +double MemoryParameters::IDD3N[2][4]= +{ + {42,47,52,57}, + {44,44,44,44} +}; + +double MemoryParameters::IDD4R[2][4]= +{ + {120,135,155,175}, + {140,140,150,160} +}; + +double MemoryParameters::IDD4W[2][4]= +{ + {100,125,145,165}, + {156,156,176,196} +}; + +double MemoryParameters::IDD5[2][4]= +{ + {150,205,210,220}, + {190,190,190,192} +}; + +double MemoryParameters::io_energy_read[2][3][3][4] =// [ddr3:ddr4][udimm:rdimm:lrdimm][load 1:2:3][frequency 0:1:2:3] +{ + { //ddr3 + {//udimm + {2592.33, 2593.33, 3288.784, 4348.612}, + {2638.23, 2640.23, 3941.584, 5415.492}, + {2978.659, 2981.659, 4816.644, 6964.162} + + }, + {//rdimm + {2592.33, 3087.071, 3865.044, 4844.982}, + {2932.759, 3733.318, 4237.634, 5415.492}, + {3572.509, 4603.109, 5300.004, 6964.162} + }, + {//lrdimm + {4628.966, 6357.625, 7079.348, 9680.454}, + {5368.26, 6418.788, 7428.058, 10057.164}, + {5708.689, 7065.038, 7808.678, 10627.674} + + } + + }, + { //ddr + {//udimm + {2135.906, 2633.317, 2750.919, 2869.406}, + {2458.714, 2695.791, 2822.298, 3211.111}, + {2622.85, 3030.048, 3160.265, 3534.448} + + }, + {//rdimm + {2135.906, 2633.317, 2750.919, 2869.406}, + {2458.714, 2695.791, 3088.886, 3211.111}, + {2622.85, 3030.048, 3312.468, 3758.445} + + }, + {//lrdimm + {4226.903, 5015.342, 5490.61, 5979.864}, + {4280.471, 5319.132, 5668.945, 6060.216}, + {4603.279, 5381.605, 5740.325, 6401.926} + + } + + } +}; + +double MemoryParameters::io_energy_write[2][3][3][4] = +{ + { //ddr3 + {//udimm + {2758.951, 2984.854, 3571.804, 4838.902}, + {2804.851, 3768.524, 4352.214, 5580.362}, + {3213.897, 3829.684, 5425.854, 6933.512} + + }, + {//rdimm + {2758.951, 3346.104, 3931.154, 4838.902}, + {3167.997, 4114.754, 4696.724, 5580.362}, + {3561.831, 3829.684, 6039.994, 8075.542} + + }, + {//lrdimm + {4872.238, 5374.314, 7013.868, 9267.574}, + {5701.502, 6214.348, 7449.758, 10045.004}, + {5747.402, 6998.018, 8230.168, 10786.464} + + } + + }, + { //ddr4 + {//udimm + {2525.129, 2840.853, 2979.037, 3293.608}, + {2933.756, 3080.126, 3226.497, 3979.698}, + {3293.964, 3753.37, 3906.137, 4312.448} + + }, + {//rdimm + {2525.129, 2840.853, 3155.117, 3293.608}, + {2933.756, 3080.126, 3834.757, 3979.698}, + {3293.964, 3753.37, 4413.037, 5358.078} + + }, + {//lrdimm + {4816.453, 5692.314, 5996.134, 6652.936}, + {4870.021, 5754.788, 6067.514, 6908.636}, + {5298.373, 5994.07, 6491.054, 7594.726} + + } + + } +}; + +double MemoryParameters::T_RAS[2] = {35,35}; + +double MemoryParameters::T_RC[2] = {47.5,47.5}; + +double MemoryParameters::T_RP[2] = {13,13}; + +double MemoryParameters::T_RFC[2] = {340,260}; + +double MemoryParameters::T_REFI[2] = {7800,7800}; + +int MemoryParameters::bandwidth_load[2][4]={{400,533,667,800},{800,933,1066,1200}}; + +double MemoryParameters::cost[2][3][5] = +{ + { + {40.38,76.13,INF,INF,INF}, + {42.24,64.17,122.6,304.3,INF}, + {INF,INF,211.3,287.5,1079.5} + }, + { + {25.99,45.99,INF,INF,INF}, + {32.99,60.45,126,296.3,INF}, + {INF,INF,278.99,333,1474} + } +}; + + + +/////////////////////////////////////////////////////////////////////////////////// + +double calculate_power(double load, double row_buffer_hr, double rd_wr_ratio, int chips_per_rank, int frequency_index, int lp) +{ + return 0; +} + +int bw_index(Mem_IO_type type, int bandwidth) +{ + if(type==DDR3) + { + if(bandwidth<=400) + return 0; + else if(bandwidth <= 533) + return 1; + else if(bandwidth <= 667) + return 2; + else + return 3; + } + else + { + if(bandwidth<=800) + return 0; + else if(bandwidth <= 933) + return 1; + else if(bandwidth <= 1066) + return 2; + else + return 3; + } + return 0; +} + +channel_conf::channel_conf(MemCadParameters * memcad_params, const vector& dimm_cap, int bandwidth, Mem_DIMM type, bool low_power) +:memcad_params(memcad_params),type(type),low_power(low_power),bandwidth(bandwidth),latency(0),valid(true) +{ + //assert(memcad_params); + assert(dimm_cap.size() <=DIMM_PER_CHANNEL); + assert(memcad_params->io_type<2); // So far, we just support DDR3 and DDR4. + // upading capacity + num_dimm_per_channel=0; + capacity =0; + for(int i=0;i<5;i++) histogram_capacity[i]=0; + for(unsigned int i=0;i0) + bandwidth =0; + + //bandwidth = MemoryParameters::bandwidth_load[memcad_params->io_type][4-num_dimm_per_channel]; + // updating channel cost + cost =0; + for(int i=0;i<5;++i) + cost += histogram_capacity[i] * MemoryParameters::cost[memcad_params->io_type][type][i]; + + // update energy + calc_power(); + +} + +void channel_conf::calc_power() +{ + + double read_ratio = memcad_params->rd_2_wr_ratio/(1.0+memcad_params->rd_2_wr_ratio); + double write_ratio = 1.0/(1.0+memcad_params->rd_2_wr_ratio); + Mem_IO_type current_io_type = memcad_params->io_type; + double capacity_ratio = (capacity/(double) memcad_params->capacity ); + + double T_BURST = 4; // memory cycles + + energy_per_read = MemoryParameters::io_energy_read[current_io_type][type][num_dimm_per_channel-1][bw_index(current_io_type,bandwidth)]; + energy_per_read /= (bandwidth/T_BURST); + + energy_per_write = MemoryParameters::io_energy_write[current_io_type][type][num_dimm_per_channel-1][bw_index(current_io_type,bandwidth)]; + energy_per_write /= (bandwidth/T_BURST); + if(memcad_params->capacity_wise) + { + energy_per_read *= capacity_ratio; + energy_per_write *= capacity_ratio; + } + + energy_per_access = read_ratio* energy_per_read + write_ratio*energy_per_write; + +} + +channel_conf* clone(channel_conf* origin) +{ + vector temp; + int size =4; + for(int i=0;i<5;++i) + { + for(int j=0;jhistogram_capacity[i];++j) + { + temp.push_back(size); + } + size *=2; + } + channel_conf * new_channel = new channel_conf(origin->memcad_params,temp,origin->bandwidth, origin->type,origin->low_power); + return new_channel; +} + +ostream& operator<<(ostream &os, const channel_conf& ch_cnf) +{ + os << "cap: " << ch_cnf.capacity << " GB "; + os << "bw: " << ch_cnf.bandwidth << " (MHz) "; + os << "cost: $" << ch_cnf.cost << " "; + os << "dpc: " << ch_cnf.num_dimm_per_channel << " "; + os << "energy: " << ch_cnf.energy_per_access << " (nJ) "; + os << " DIMM: " << ((ch_cnf.type==UDIMM)?" UDIMM ":((ch_cnf.type==RDIMM)?" RDIMM ":"LRDIMM ")); + os << " low power: " << ((ch_cnf.low_power)? "T ":"F "); + os << "[ "; + for(int i=0;i<5;i++) + os << ch_cnf.histogram_capacity[i] << "(" << (1<<(i+2)) << "GB) "; + os << "]"; + return os; +} + + +bob_conf::bob_conf(MemCadParameters * memcad_params, vector * in_channels) +:memcad_params(memcad_params),num_channels(0),capacity(0),bandwidth(0) +,energy_per_read(0),energy_per_write(0),energy_per_access(0),cost(0),latency(0),valid(true) +{ + + assert(in_channels->size() <= MAX_NUM_CHANNELS_PER_BOB); + for(int i=0;isize();++i) + { + channels[i] = (*in_channels)[i]; + num_channels++; + capacity += (*in_channels)[i]->capacity; + cost += (*in_channels)[i]->cost; + bandwidth += (*in_channels)[i]->bandwidth; + energy_per_read += (*in_channels)[i]->energy_per_read; + energy_per_write += (*in_channels)[i]->energy_per_write; + energy_per_access += (*in_channels)[i]->energy_per_access; + } +} + +bob_conf* clone(bob_conf* origin) +{ + vector temp; + for(int i=0;ichannels)[i]==0 ) + break; + temp.push_back( (origin->channels)[i] ); + } + + bob_conf * new_bob = new bob_conf(origin->memcad_params,&temp); + return new_bob; +} + +ostream & operator <<(ostream &os, const bob_conf& bob_cnf) +{ + os << " " << "BoB " ; + os << "cap: " << bob_cnf.capacity << " GB "; + os << "num_channels: " << bob_cnf.num_channels << " "; + os << "bw: " << bob_cnf.bandwidth << " (MHz) "; + os << "cost: $" << bob_cnf.cost << " "; + os << "energy: " << bob_cnf.energy_per_access << " (nJ) "; + os << endl; + os << " " << " ==============" << endl; + for(int i=0;i * in_bobs) +:memcad_params(memcad_params),num_bobs(0),capacity(0),bandwidth(0) +,energy_per_read(0),energy_per_write(0),energy_per_access(0),cost(0),latency(0),valid(true) +{ + assert(in_bobs->size() <= MAX_NUM_BOBS); + for(int i=0;isize();++i) + { + bobs[i] = (*in_bobs)[i]; + num_bobs++; + capacity += (*in_bobs)[i]->capacity; + cost += (*in_bobs)[i]->cost; + bandwidth += (*in_bobs)[i]->bandwidth; + energy_per_read += (*in_bobs)[i]->energy_per_read; + energy_per_write += (*in_bobs)[i]->energy_per_write; + energy_per_access += (*in_bobs)[i]->energy_per_access; + } +} + +ostream & operator <<(ostream &os, const memory_conf& mem_cnf) +{ + os << "Memory " ; + os << "cap: " << mem_cnf.capacity << " GB "; + os << "num_bobs: " << mem_cnf.num_bobs << " "; + os << "bw: " << mem_cnf.bandwidth << " (MHz) "; + os << "cost: $" << mem_cnf.cost << " "; + os << "energy: " << mem_cnf.energy_per_access << " (nJ) "; + os << endl; + os << " {" << endl; + for(int i=0;i +#include +#include "cacti_interface.h" +#include "const.h" +#include "parameter.h" + +using namespace std; + +///#define INF 1000000 +#define EPS 0.0000001 + +#define MAX_DIMM_PER_CHANNEL 3 +#define MAX_CAP_PER_DIMM 64 +#define MAX_RANKS_PER_DIMM 4 +#define MIN_BW_PER_CHANNEL 400 +#define MAX_DDR3_CHANNEL_BW 800 +#define MAX_DDR4_CHANNEL_BW 1600 +#define MAX_NUM_CHANNELS_PER_BOB 2 +#define MAX_NUM_BOBS 6 +#define DIMM_PER_CHANNEL 3 + +/* +enum Mem_IO_type +{ + DDR3, + DDR4, + LPDDR2, + WideIO, + Low_Swing_Diff, + Serial +}; + +enum Mem_DIMM +{ + UDIMM, + RDIMM, + LRDIMM +}; +*/ + + + +class MemCadParameters +{ + public: + + Mem_IO_type io_type; // DDR3 vs. DDR4 + + int capacity; // in GB + + int num_bobs; // default=4me + + ///int bw_per_channel; // defaul=1600 MHz; + + ///bool with_bob; + + int num_channels_per_bob; // 1 means no bob + + bool capacity_wise; // true means the load on each channel is proportional to its capacity. + + ///int min_bandwith; + + MemCad_metrics first_metric; + + MemCad_metrics second_metric; + + MemCad_metrics third_metric; + + DIMM_Model dimm_model; + + bool low_power_permitted; // Not yet implemented. It determines acceptable VDDs. + + double load; // between 0 to 1 + + double row_buffer_hit_rate; + + double rd_2_wr_ratio; + + bool same_bw_in_bob; // true if all the channels in the bob have the same bandwidth. + + + bool mirror_in_bob;// true if all the channels in the bob have the same configs + + bool total_power; // false means just considering I/O Power + + bool verbose; + + // Functions + MemCadParameters(InputParameter * g_ip); + void print_inputs(); + bool sanity_check(); + +}; + + +////////////////////////////////////////////////////////////////////////////////// + +class MemoryParameters +{ + public: + // Power Parameteres + static double VDD[2][2][4]; + + static double IDD0[2][4]; + + static double IDD1[2][4]; + + static double IDD2P0[2][4]; + + static double IDD2P1[2][4]; + + static double IDD2N[2][4]; + + static double IDD3P[2][4]; + + static double IDD3N[2][4]; + + static double IDD4R[2][4]; + + static double IDD4W[2][4]; + + static double IDD5[2][4]; + + static double io_energy_read[2][3][3][4]; + + static double io_energy_write[2][3][3][4]; + + // Timing Parameters + static double T_RAS[2]; + + static double T_RC[2]; + + static double T_RP[2]; + + static double T_RFC[2]; + + static double T_REFI[2]; + + // Bandwidth Parameters + static int bandwidth_load[2][4]; + + // Cost Parameters + static double cost[2][3][5]; + + + // Functions + MemoryParameters(); + + int bw_index(Mem_IO_type type, int bandwidth); +}; + +/////////////////////////////////////////////////////////////////////////// + +int bw_index(Mem_IO_type type, int bandwidth); + + +/////////////////////////////////////////////////////////////////////////// + +class channel_conf +{ + public: + MemCadParameters *memcad_params; + + Mem_DIMM type; + int num_dimm_per_channel; + int histogram_capacity[5]; // 0->4GB, 1->8GB, 2->16GB, 3->32GB, 4->64GB + bool low_power; + + int capacity; + int bandwidth; + double energy_per_read; + double energy_per_write; + double energy_per_access; + + double cost; + double latency; + + bool valid; + // Functions + channel_conf(MemCadParameters * memcad_params, const vector& dimm_cap, int bandwidth, Mem_DIMM type, bool low_power); + + void calc_power(); + + friend channel_conf* clone(channel_conf*); + friend ostream & operator<<(ostream &os, const channel_conf& ch_cnf); + +}; + + +/////////////////////////////////////////////////////////////////////////// + +class bob_conf +{ + public: + MemCadParameters *memcad_params; + int num_channels; + channel_conf *channels[MAX_NUM_CHANNELS_PER_BOB]; + + int capacity; + int bandwidth; + double energy_per_read; + double energy_per_write; + double energy_per_access; + + double cost; + double latency; + + bool valid; + + bob_conf(MemCadParameters * memcad_params, vector * channels); + + friend bob_conf* clone(bob_conf*); + friend ostream & operator <<(ostream &os, const bob_conf& bob_cnf); +}; + +/////////////////////////////////////////////////////////////////////////// + + +class memory_conf +{ + public: + MemCadParameters *memcad_params; + int num_bobs; + bob_conf* bobs[MAX_NUM_BOBS]; + + int capacity; + int bandwidth; + double energy_per_read; + double energy_per_write; + double energy_per_access; + + double cost; + double latency; + + bool valid; + + memory_conf(MemCadParameters * memcad_params, vector * bobs); + friend ostream & operator <<(ostream &os, const memory_conf& bob_cnf); +}; + + + + + + +#endif + + diff --git a/Project_FARSI/cacti_for_FARSI/memorybus.cc b/Project_FARSI/cacti_for_FARSI/memorybus.cc new file mode 100644 index 00000000..c626c924 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/memorybus.cc @@ -0,0 +1,741 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + +#include "memorybus.h" +#include "wire.h" +#include +#include +#include + +Memorybus::Memorybus( + enum Wire_type wire_model, double mat_w, double mat_h, double subarray_w_, double subarray_h_, + int _row_add_bits, int _col_add_bits, int _data_bits, int _ndbl, int _ndwl, /*enum Htree_type htree_type,*/ + enum Memorybus_type membus_type_, const DynamicParameter & dp_, + /*TechnologyParameter::*/DeviceType *dt): + dp(dp_), + in_rise_time(0), out_rise_time(0), + is_dram(dp.is_dram), + membus_type(membus_type_), + mat_width(mat_w), mat_height(mat_h), subarray_width(subarray_w_), subarray_height(subarray_h_), + data_bits(_data_bits), ndbl(_ndbl), ndwl(_ndwl), + wt(wire_model), deviceType(dt) +{ + if (g_ip->print_detail_debug) + cout << "memorybus.cc: membus_type = " << membus_type << endl; + power.readOp.dynamic = 0; + power.readOp.leakage = 0; + power.readOp.gate_leakage = 0; + power.searchOp.dynamic =0; + delay = 0; + + cell.h = g_tp.dram.b_h; + cell.w = g_tp.dram.b_w; + + if (!g_ip->is_3d_mem) + assert(ndbl >= 2 && ndwl >= 2); + + if (g_ip->print_detail_debug) + { + cout << "burst length: " << g_ip->burst_depth <io_width <io_width; //g_ip->out_w; //x4, x8, x16 chip + burst_length = g_ip->burst_depth; //g_ip->burst_len; //DDR2 4, DDR3 8 + data_bits = chip_IO_width * burst_length; + + row_add_bits = _row_add_bits; + col_add_bits = _col_add_bits; + + + max_unpipelined_link_delay = 0; //TODO + min_w_nmos = g_tp.min_w_nmos_; + min_w_pmos = deviceType->n_to_p_eff_curr_drv_ratio * min_w_nmos; + + + semi_repeated_global_line = 0; // 1: semi-repeated global line, repeaters in decoder stripes; 0: Non-repeated global line, slower + ndwl = _ndwl/ g_ip->num_tier_row_sprd; + ndbl = _ndbl/ g_ip->num_tier_col_sprd; + num_subarray_global_IO = ndbl>16?16:ndbl; + + switch (membus_type) + { + case Data_path: + data_bits = chip_IO_width * burst_length; + Network(); + break; + case Row_add_path: + add_bits = _row_add_bits; + num_dec_signals = dp.num_r_subarray * ndbl; + Network(); + break; + case Col_add_path: + add_bits = _col_add_bits; + num_dec_signals = dp.num_c_subarray * ndwl / data_bits; + Network(); + break; + default: + assert(0); + break; + } + + assert(power.readOp.dynamic >= 0); + assert(power.readOp.leakage >= 0); +} + +Memorybus::~Memorybus() +{ + delete center_stripe; + delete bank_bus; + switch (membus_type) + { + case Data_path: + delete local_data; + delete global_data; + delete local_data_drv; + if(semi_repeated_global_line) + delete global_data_drv; + delete out_seg; + break; + case Row_add_path: + delete global_WL; + delete add_predec; + delete add_dec; + delete lwl_drv; + break; + case Col_add_path: + delete column_sel; + delete add_predec; + delete add_dec; + break; + default: + assert(0); + break; + } +} + +// ---For 3D DRAM, the bank height and length is reduced to 1/num_tier_row_sprd and 1/num_tier_col_sprd. +// ---As a result, ndwl and ndbl are also reduced to the same ratio, but he number of banks increase to the product of these two parameters +void Memorybus::Network() +{ + //double POLY_RESISTIVITY = 0.148; //ohm-micron + double R_wire_dec_out = 0; + double C_ld_dec_out = 0; + double bank_bus_length = 0; + double area_bank_vertical_peripheral_circuitry = 0, area_bank_horizontal_peripheral_circuitry = 0; + + area_sense_amp = (mat_height - subarray_height) * mat_width * ndbl * ndwl; + area_subarray = subarray_height * subarray_width * ndbl * ndwl; + + // ---Because in 3D DRAM mat only has one subarray, but contains the subarray peripheral circuits such as SA. Detail see mat.cc is_3d_mem part. + subarray_height = mat_height; + subarray_width = mat_width; + + if(g_ip->partition_gran == 0)// Coarse_rank_level: add/data bus around + { + height_bank = subarray_height * ndbl + (col_add_bits + row_add_bits)*g_tp.wire_outside_mat.pitch/2 + data_bits*g_tp.wire_outside_mat.pitch; + length_bank = subarray_width * ndwl + (col_add_bits + row_add_bits)*g_tp.wire_outside_mat.pitch/2 + data_bits*g_tp.wire_outside_mat.pitch; + area_address_bus = (row_add_bits + col_add_bits) *g_tp.wire_outside_mat.pitch * sqrt(length_bank * height_bank); + area_data_bus = data_bits *g_tp.wire_outside_mat.pitch * sqrt(length_bank * height_bank); + } + else if(g_ip->partition_gran == 1)//Fine_rank_level: add bus replaced by TSVs + { + height_bank = subarray_height * ndbl; + length_bank = subarray_width * ndwl; + area_address_bus = 0; + area_data_bus = data_bits *g_tp.wire_outside_mat.pitch * sqrt(length_bank * height_bank); + } + else if(g_ip->partition_gran == 2)//Coarse_bank_level: add/data bus replaced by TSVs + { + height_bank = subarray_height * ndbl; + length_bank = subarray_width * ndwl; + area_address_bus = 0; + area_data_bus = 0; + } + + + + + if (g_ip->print_detail_debug) + { + cout << "memorybus.cc: N subarrays per mat = " << dp.num_subarrays / dp.num_mats << endl; + cout << "memorybus.cc: g_tp.wire_local.pitch = " << g_tp.wire_local.pitch /1e3 << " mm" << endl; + cout << "memorybus.cc: subarray_width = " << subarray_width /1e3 << " mm" << endl; + cout << "memorybus.cc: subarray_height = " << subarray_height /1e3 << " mm" << endl; + cout << "memorybus.cc: mat_height = " << mat_height /1e3 << " mm" << endl; + cout << "memorybus.cc: mat_width = " << mat_width /1e3 << " mm" << endl; + cout << "memorybus.cc: height_bank = " << height_bank /1e3 << " mm" << endl; + cout << "memorybus.cc: length_bank = " << length_bank /1e3 << " mm" << endl; + } + + int num_banks_hor_dir = 1 << (int)ceil((double)_log2( g_ip->nbanks * g_ip->num_tier_row_sprd )/2 ) ; + int num_banks_ver_dir = 1 << (int)ceil((double)_log2( g_ip->nbanks * g_ip->num_tier_col_sprd * g_ip->num_tier_row_sprd /num_banks_hor_dir ) ); + + if (g_ip->print_detail_debug) + { + cout<<"horz bank #: "<nbanks = " << g_ip->nbanks << endl; + cout << "memorybus.cc: num_banks_hor_dir = " << num_banks_hor_dir << endl; + } + + // ************************************* Wire Interconnections ***************************************** + double center_stripe_length = 0.5 * double(num_banks_hor_dir) * height_bank; + if(g_ip->print_detail_debug) + { + cout << "memorybus.cc: center_stripe wire length = " << center_stripe_length << " um"<< endl; + } + center_stripe = new Wire(wt, center_stripe_length); + area_bus = 2.0 * center_stripe_length * (row_add_bits + col_add_bits + data_bits) *g_tp.wire_outside_mat.pitch / g_ip->nbanks; + + //if (g_ip->partition_gran == 0) + //area_bus = (row_add_bits + col_add_bits) *g_tp.wire_outside_mat.pitch * center_stripe_length; + if (membus_type == Row_add_path) + { + int num_lwl_per_gwl = 4; + global_WL = new Wire(wt, length_bank, 1, 1, 1, inside_mat, CU_RESISTIVITY, &(g_tp.peri_global)); + //local_WL = new Wire(wt, length_bank/num_lwl_drv, local_wires, POLY_RESISTIVITY, &(g_tp.dram_wl)); + num_lwl_drv = ndwl; + //C_GWL = num_lwl_drv * gate_C(g_tp.min_w_nmos_+min_w_pmos,0) + c_w_metal * dp.num_c_subarray * ndwl; + if(semi_repeated_global_line) + { + C_GWL = (double)num_lwl_per_gwl * gate_C(g_tp.min_w_nmos_+min_w_pmos,0) + g_tp.wire_inside_mat.C_per_um * (subarray_width + g_tp.wire_local.pitch); + R_GWL = g_tp.wire_inside_mat.R_per_um * (subarray_width + g_tp.wire_local.pitch); + } + else + { + C_GWL = (double)num_lwl_drv * num_lwl_per_gwl * gate_C(g_tp.min_w_nmos_+min_w_pmos,0) + g_tp.wire_inside_mat.C_per_um * length_bank; + R_GWL = length_bank * g_tp.wire_inside_mat.R_per_um; + } + + lwl_driver_c_gate_load = dp.num_c_subarray * gate_C_pass(g_tp.dram.cell_a_w, g_tp.dram.b_w, true, true); + //lwl_driver_c_wire_load = subarray_width * g_tp.wire_local.C_per_um; + //lwl_driver_r_wire_load = subarray_width * g_tp.wire_local.R_per_um; + + if (g_ip->print_detail_debug) + { + cout<<"C_GWL: "<repeater_size = " << column_sel->repeater_size << endl; + + bank_bus_length = double(num_banks_ver_dir) * 0.5 * MAX(length_bank, height_bank); + bank_bus = new Wire(wt, bank_bus_length); + } + else if (membus_type == Data_path) + { + local_data = new Wire(wt, subarray_width, 1, 1, 1, inside_mat, CU_RESISTIVITY, &(g_tp.peri_global)); + global_data = new Wire(wt, sqrt(length_bank * height_bank), 1, 1, 1, outside_mat, CU_RESISTIVITY, &(g_tp.peri_global)); + + if(semi_repeated_global_line) + { + C_global_data = g_tp.wire_inside_mat.C_per_um * (subarray_height + g_tp.wire_local.pitch); + R_global_data = g_tp.wire_inside_mat.R_per_um * (subarray_height + g_tp.wire_local.pitch) ; + + } + else + { + C_global_data = g_tp.wire_inside_mat.C_per_um * height_bank /2; + R_global_data = g_tp.wire_inside_mat.R_per_um * height_bank /2; + } + + global_data_drv = new Driver( + 0, + C_global_data, + R_global_data, + is_dram); + global_data_drv->compute_delay(0); + global_data_drv->compute_area(); + //---Unrepeated local dataline + double local_data_c_gate_load = dp.num_c_subarray * drain_C_(g_tp.w_nmos_sa_mux, NCH, 1, 0, cell.w, is_dram); + //double local_data_c_gate_load = 0; + double local_data_c_wire_load = dp.num_c_subarray * g_tp.dram.b_w * g_tp.wire_inside_mat.C_per_um; + double local_data_r_wire_load = dp.num_c_subarray * g_tp.dram.b_w * g_tp.wire_inside_mat.R_per_um; + //double local_data_r_gate_load = tr_R_on(g_tp.w_nmos_sa_mux, NCH, 1, is_dram); + double local_data_r_gate_load = 0; + + double tf = (local_data_c_gate_load + local_data_c_wire_load) * (local_data_r_wire_load + local_data_r_gate_load); + double this_delay = horowitz(0, tf, 0.5, 0.5, RISE); + //double local_data_outrisetime = this_delay/(1.0-0.5); + + //---Unrepeated and undriven local dataline, not significant growth + //local_data->delay = this_delay; + //local_data->power.readOp.dynamic = (local_data_c_gate_load + local_data_c_wire_load) * g_tp.peri_global.Vdd * g_tp.peri_global.Vdd; + + + double data_drv_c_gate_load = local_data_c_gate_load; + double data_drv_c_wire_load = local_data_c_wire_load; + double data_drv_r_wire_load = local_data_r_gate_load + local_data_r_wire_load; + + //---Assume unrepeated global data path, too high RC + //double data_drv_c_wire_load = height_bank * g_tp.wire_outside_mat.C_per_um; + //double data_drv_r_wire_load = height_bank * g_tp.wire_inside_mat.R_per_um; + + + local_data_drv = new Driver( + data_drv_c_gate_load, + data_drv_c_wire_load, + data_drv_r_wire_load, + is_dram); + local_data_drv->compute_delay(0); + local_data_drv->compute_area(); + + if (g_ip->print_detail_debug) + { + cout<<"C: "<delay * 1e9 <<" ns"<repeater_size * gate_C(g_tp.min_w_nmos_+min_w_pmos,0), + global_data->repeater_spacing * g_tp.wire_inside_mat.C_per_um, + global_data->repeater_spacing * g_tp.wire_inside_mat.R_per_um, + is_dram);*/ + + //bank_bus_length = double(num_banks_ver_dir) * 0.5 * (height_bank + 0.5*double(row_add_bits+col_add_bits+data_bits)*g_tp.wire_outside_mat.pitch) - height_bank + length_bank; + bank_bus_length = double(num_banks_ver_dir) * 0.5 * MAX(length_bank, height_bank); + bank_bus = new Wire(wt, bank_bus_length); + if (g_ip->print_detail_debug) + cout << "memorybus.cc: bank_bus_length = " << bank_bus_length << endl; + + out_seg = new Wire(wt, 0.25 * num_banks_hor_dir * (length_bank + (row_add_bits+col_add_bits+data_bits)*g_tp.wire_outside_mat.pitch) ); + area_IOSA = (875+500)*g_ip->F_sz_um*g_ip->F_sz_um * data_bits;//Reference: + area_data_drv = local_data_drv->area.get_area() * data_bits; + if(ndbl>16) + { + area_IOSA *= (double)ndbl/16.0; + area_data_drv *= (double)ndbl/16.0; + } + area_local_dataline = data_bits*subarray_width *g_tp.wire_local.pitch*ndbl; + + } + + + // Row decoder + if (membus_type == Row_add_path || membus_type == Col_add_path ) + { + + if (g_ip->print_detail_debug) + { + cout << "memorybus.cc: num_dec_signals = " << num_dec_signals << endl; + cout << "memorybus.cc: C_ld_dec_out = " << C_ld_dec_out << endl; + cout << "memorybus.cc: R_wire_dec_out = " << R_wire_dec_out << endl; + cout << "memorybus.cc: is_dram = " << is_dram << endl; + cout << "memorybus.cc: cell.h = " << cell.h << endl; + } + + add_dec = new Decoder( + (num_dec_signals>16)?num_dec_signals:16, + false, + C_ld_dec_out, + R_wire_dec_out, + false, + is_dram, + membus_type == Row_add_path?true:false, + cell); + + + + // Predecoder and decoder for GWL + double C_wire_predec_blk_out; + double R_wire_predec_blk_out; + C_wire_predec_blk_out = 0; // num_subarrays_per_row * dp.num_r_subarray * g_tp.wire_inside_mat.C_per_um * cell.h; + R_wire_predec_blk_out = 0; // num_subarrays_per_row * dp.num_r_subarray * g_tp.wire_inside_mat.R_per_um * cell.h; + + + //int num_subarrays_per_mat = dp.num_subarrays/dp.num_mats; + int num_dec_per_predec = 1; + PredecBlk * add_predec_blk1 = new PredecBlk( + num_dec_signals, + add_dec, + C_wire_predec_blk_out, + R_wire_predec_blk_out, + num_dec_per_predec, + is_dram, + true); + + + + PredecBlk * add_predec_blk2 = new PredecBlk( + num_dec_signals, + add_dec, + C_wire_predec_blk_out, + R_wire_predec_blk_out, + num_dec_per_predec, + is_dram, + false); + + + + PredecBlkDrv * add_predec_blk_drv1 = new PredecBlkDrv(0, add_predec_blk1, is_dram); + PredecBlkDrv * add_predec_blk_drv2 = new PredecBlkDrv(0, add_predec_blk2, is_dram); + + add_predec = new Predec(add_predec_blk_drv1, add_predec_blk_drv2); + + + + if (membus_type == Row_add_path) + { + area_row_predec_dec = add_predec_blk_drv1->area.get_area() + add_predec_blk_drv2->area.get_area() + + add_predec_blk1->area.get_area() + add_predec_blk2->area.get_area() + num_dec_signals * add_dec->area.get_area(); + + + area_lwl_drv = num_lwl_drv/2.0 * dp.num_r_subarray * ndbl * lwl_drv->area.get_area(); //num_lwl_drv is ndwl/the lwl driver count one gwl connects. two adjacent lwls share one driver. + + if (g_ip->print_detail_debug) + { + cout<<"memorybus.cc: area_bank_vertical_peripheral_circuitry = " << area_bank_vertical_peripheral_circuitry /1e6<<" mm2"<area.get_area() + add_predec_blk_drv2->area.get_area() + + add_predec_blk1->area.get_area() + add_predec_blk2->area.get_area() + num_dec_signals * add_dec->area.get_area(); + if(ndbl>16) + { + area_col_predec_dec *= (double)ndbl/16.0; + } + } + + area_bank_vertical_peripheral_circuitry = area_row_predec_dec + area_lwl_drv + area_address_bus + area_data_bus ; + area_bank_horizontal_peripheral_circuitry = area_col_predec_dec + area_data_drv + (area_bus + area_IOSA)/g_ip->nbanks; + + if (g_ip->print_detail_debug) + { + cout<<"memorybus.cc: add_predec_blk_drv1->area = " << add_predec_blk_drv1->area.get_area() /1e6<<" mm2"<area = " << add_predec_blk_drv2->area.get_area() /1e6<<" mm2"<area = " << add_predec_blk1->area.get_area() /1e6<<" mm2"<area = " << add_predec_blk2->area.get_area() /1e6<<" mm2"<area = " << num_dec_signals * add_dec->area.get_area() /1e6<<" mm2"<delay + bank_bus->delay; + delay += delay_bus; + //outrisetime = local_data_drv->compute_delay(inrisetime); + //local_data_drv_outrisetime = local_data_drv->delay; + delay_global_data = (semi_repeated_global_line >0) ? (global_data_drv->delay*num_subarray_global_IO) : (global_data_drv->delay + global_data->delay); + if(g_ip->partition_gran==0 || g_ip->partition_gran==1) + delay += delay_global_data; + //delay += local_data->delay; + delay_local_data = local_data_drv->delay; + delay += delay_local_data; + delay_data_buffer = 2 * 1e-6/(double)g_ip->sys_freq_MHz; + //delay += bank.mat.delay_subarray_out_drv_htree; + delay += delay_data_buffer; + //cout << 1e3/(double)g_ip->sys_freq_MHz<< endl; + //delay += out_seg->delay * burst_length; + if (g_ip->print_detail_debug) + cout << "memorybus.cc: data path delay = " << delay << endl; + out_rise_time = 0; + } + else + { + delay = 0; + delay_bus = center_stripe->delay + bank_bus->delay; + delay += delay_bus; + predec_outrisetime = add_predec->compute_delays(inrisetime); + add_dec_outrisetime = add_dec->compute_delays(predec_outrisetime); + delay_add_predecoder = add_predec->delay; + delay += delay_add_predecoder; + + if (membus_type == Row_add_path) + { + if(semi_repeated_global_line) + { + delay_add_decoder = add_dec->delay * ndwl; + if(g_ip->page_sz_bits > 8192) + delay_add_decoder /= (double)(g_ip->page_sz_bits / 8192); + } + else + { + delay_add_decoder = add_dec->delay; + } + delay += delay_add_decoder; + // There is no function to compute_delay in wire.cc, need to double check if center_stripe->delay and bank_bus->delay is correct. + lwl_drv_outrisetime = lwl_drv->compute_delay(add_dec_outrisetime); + ///tf = (lwl_driver_c_gate_load + lwl_driver_c_wire_load) * lwl_driver_r_wire_load; + // ### no need for global_WL->delay + // delay_WL = global_WL->delay + lwl_drv->delay + horowitz(lwl_drv_outrisetime, tf, 0.5, 0.5, RISE); + delay_lwl_drv = lwl_drv->delay; + if(!g_ip->fine_gran_bank_lvl) + delay += delay_lwl_drv; + if (g_ip->print_detail_debug) + cout << "memorybus.cc: row add path delay = " << delay << endl; + + out_rise_time = lwl_drv_outrisetime; + } + + else if (membus_type == Col_add_path) + { + if(semi_repeated_global_line) + { + delay_add_decoder = add_dec->delay * num_subarray_global_IO; + } + else + { + delay += column_sel->delay; + delay_add_decoder = add_dec->delay; + } + delay += delay_add_decoder; + + out_rise_time = 0; + if (g_ip->print_detail_debug) + { + //cout << "memorybus.cc, compute_delays col: center_stripe->delay = " << center_stripe->delay << endl; + //cout << "memorybus.cc, compute_delays col: bank_bus->delay = " << bank_bus->delay << endl; + //cout << "memorybus.cc, compute_delays col: add_predec->delay = " << add_predec->delay << endl; + //cout << "memorybus.cc, compute_delays col: add_dec->delay = " << add_dec->delay << endl; + + cout << "memorybus.cc: column add path delay = " << delay << endl; + } + + } + else + { + assert(0); + } + } + + + // Double check! + out_rise_time = delay / (1.0-0.5); + // Is delay_wl_reset necessary here? Is the 'false' condition appropriate? See the same code as in mat.cc + /*if (add_dec->exist == false) + { + int delay_wl_reset = MAX(add_predec->blk1->delay, add_predec->blk2->delay); + //delay += delay_wl_reset; + }*/ + + return out_rise_time; +} + + + + +void Memorybus::compute_power_energy() +{ + double coeff1[4] = {(double)add_bits, (double)add_bits, (double)add_bits, (double)add_bits}; + double coeff2[4] = {(double)data_bits, (double)data_bits, (double)data_bits, (double)data_bits}; + double coeff3[4] = {(double)num_lwl_drv, (double)num_lwl_drv, (double)num_lwl_drv, (double)num_lwl_drv}; + double coeff4[4] = {(double)burst_length*chip_IO_width, (double)burst_length*chip_IO_width, + (double)burst_length*chip_IO_width, (double)burst_length*chip_IO_width}; + double coeff5[4] = {(double)ndwl, (double)ndwl, (double)ndwl, (double)ndwl}; + double coeff6[4] = {(double)num_subarray_global_IO, (double)num_subarray_global_IO, (double)num_subarray_global_IO, (double)num_subarray_global_IO}; + + //double coeff4[4] = {(double)num_dec_signals, (double)num_dec_signals, (double)num_dec_signals, (double)num_dec_signals}; + switch (membus_type) + { + case Data_path: + power_bus = (center_stripe->power + bank_bus->power) * coeff2; + power_local_data = local_data_drv->power * coeff2; + power_global_data = semi_repeated_global_line >0 ? (global_data_drv->power*coeff2) : (global_data_drv->power+global_data->power); + + power_global_data.readOp.dynamic = power_global_data.readOp.dynamic + 1.8/1e3*deviceType->Vdd*10.0/1e9/64*data_bits; + power = power_bus + power_local_data; + if(!g_ip->fine_gran_bank_lvl) + power = power + power_global_data; + //power += local_data->power; + + power_burst = out_seg->power * coeff4;//Account for burst read, approxmate the wire length by the center stripe + //power = power + power_burst; + if(g_ip->print_detail_debug) + { + cout << "memorybus.cc: data path center stripe energy = " << center_stripe->power.readOp.dynamic*1e9 << " nJ" << endl; + cout << "memorybus.cc: data path bank bus energy = " << bank_bus->power.readOp.dynamic*1e9 << " nJ" << endl; + cout << "memorybus.cc: data path data driver energy = " << local_data_drv->power.readOp.dynamic*1e9 << " nJ" << endl; + } + break; + case Row_add_path: + power_bus = (center_stripe->power + bank_bus->power) * coeff1; + power_add_predecoder = add_predec->power; + if(semi_repeated_global_line) + { + power_add_decoders = add_dec->power * coeff5; + //power_add_decoders.readOp.dynamic /= (g_ip->page_sz_bits > 8192)?((double)g_ip->page_sz_bits/8192):1; + if(g_ip->page_sz_bits > 8192) + power_add_decoders.readOp.dynamic /= (double)(g_ip->page_sz_bits / 8192); + } + else + power_add_decoders = add_dec->power;// * (1<< add_predec->blk1->number_input_addr_bits); + power_lwl_drv = lwl_drv->power * coeff3; + //power_local_WL.readOp.dynamic = num_lwl_drv * C_LWL * deviceType->Vdd * deviceType->Vdd; + power = power_bus + power_add_predecoder + power_add_decoders + power_lwl_drv; + break; + case Col_add_path: + power_bus = (center_stripe->power + bank_bus->power) * coeff1;// + column_sel->power * double(chip_IO_width * burst_length); + power_add_predecoder = add_predec->power; + if(semi_repeated_global_line) + { + power_add_decoders = add_dec->power * coeff6; + power_add_decoders.readOp.dynamic = power_add_decoders.readOp.dynamic * g_ip->page_sz_bits / data_bits; + power_col_sel.readOp.dynamic = 0; + } + else + { + power_add_decoders = add_dec->power;// * (1<< add_predec->blk1->number_input_addr_bits); + power_col_sel.readOp.dynamic = column_sel->power.readOp.dynamic * g_ip->page_sz_bits / data_bits; + } + power = power_bus + power_add_predecoder + power_add_decoders; + if(!g_ip->fine_gran_bank_lvl) + power = power + power_col_sel; + break; + default: + assert(0); + break; + } + + return; + +} + + + diff --git a/Project_FARSI/cacti_for_FARSI/memorybus.h b/Project_FARSI/cacti_for_FARSI/memorybus.h new file mode 100644 index 00000000..b4eb280a --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/memorybus.h @@ -0,0 +1,150 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + +#ifndef __MEMORYBUS_H__ +#define __MEMORYBUS_H__ + +#include "basic_circuit.h" +#include "component.h" +#include "parameter.h" +//#include "assert.h" +#include "cacti_interface.h" +//#include "wire.h" +class Wire; +//#include "area.h" +#include "decoder.h" + +class Memorybus : public Component +{ + public: + Memorybus(enum Wire_type wire_model, double mat_w, double mat_h, double subarray_w, double subarray_h, + int _row_add_bits, int _col_add_bits, int _data_bits, int _ndbl, int _ndwl, /*enum Htree_type htree_type,*/ + enum Memorybus_type membus_type, const DynamicParameter & dp_, + /*TechnologyParameter::*/DeviceType *dt = &(g_tp.peri_global) + ); + ~Memorybus(); + + //void in_membus(); + //void out_membus(); + void Network(); + + // repeaters only at h-tree nodes + void limited_in_membus(); + void limited_out_membus(); + void input_nand(double s1, double s2, double l); + //void output_buffer(double s1, double s2, double l); + + const DynamicParameter & dp; + + double in_rise_time, out_rise_time; + + void set_in_rise_time(double rt) + { + in_rise_time = rt; + } + + double max_unpipelined_link_delay; + powerDef power_bit; + void memory_bus(); + + double height_bank, length_bank; // The actual height and length of a single bank including all wires between subarrays. + Wire * center_stripe; + Wire * bank_bus; + Wire * global_WL; //3 hierarchical connection wires. + Wire * column_sel; + Wire * local_data; + Wire * global_data; + Wire * out_seg; + // Driver for LWL connecting GWL, same as in mat.cc + double lwl_driver_c_gate_load, lwl_driver_c_wire_load, lwl_driver_r_wire_load; + + powerDef power_bus; + powerDef power_lwl_drv; + powerDef power_add_decoders; + powerDef power_global_WL; + powerDef power_local_WL; + powerDef power_add_predecoder; + powerDef power_burst; + powerDef power_col_sel; + powerDef power_local_data; + powerDef power_global_data; + double delay_bus, delay_add_predecoder, delay_add_decoder, delay_lwl_drv, delay_global_data, delay_local_data, delay_data_buffer; + double area_lwl_drv, area_row_predec_dec, area_col_predec_dec, area_subarray, area_bus, area_address_bus, area_data_bus, area_data_drv, area_IOSA, area_local_dataline, area_sense_amp; + + + Area cell; + bool is_dram; + + Driver * lwl_drv, * local_data_drv, * global_data_drv ; + Predec * add_predec; + Decoder * add_dec; + + double compute_delays(double inrisetime); // return outrisetime + void compute_power_energy(); // + + + + + private: + double wire_bw; + double init_wire_bw; // bus width at root + enum Memorybus_type membus_type; +// double htree_hnodes; +// double htree_vnodes; + double mat_width; + double mat_height; + double subarray_width, subarray_height; + //int add_bits, data_in_bits,search_data_in_bits,data_out_bits, search_data_out_bits; + int row_add_bits, col_add_bits; + int add_bits, data_bits, num_dec_signals; + int semi_repeated_global_line; + + int ndbl, ndwl; +// bool uca_tree; // should have full bandwidth to access all banks in the array simultaneously +// bool search_tree; + + enum Wire_type wt; + double min_w_nmos; + double min_w_pmos; + + int num_lwl_drv; //Ratio between GWL and LWL, how many local WL drives each GWL drives. + int chip_IO_width; + int burst_length; + int num_subarray_global_IO; + + double C_GWL, C_LWL, R_GWL, R_LWL, C_colsel, R_colsel, C_global_data, R_global_data; // Capacitance of global/local WLs. + + /*TechnologyParameter::*/DeviceType *deviceType; +}; + +#endif + diff --git a/Project_FARSI/cacti_for_FARSI/nuca.cc b/Project_FARSI/cacti_for_FARSI/nuca.cc new file mode 100644 index 00000000..05b1bbc5 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/nuca.cc @@ -0,0 +1,611 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + + +#include "nuca.h" +#include "Ucache.h" +#include + +unsigned int MIN_BANKSIZE=65536; +#define FIXED_OVERHEAD 55e-12 /* clock skew and jitter in s. Ref: Hrishikesh et al ISCA 01 */ +#define LATCH_DELAY 28e-12 /* latch delay in s (later should use FO4 TODO) */ +#define CONTR_2_BANK_LAT 0 + +int cont_stats[2 /*l2 or l3*/][5/* cores */][ROUTER_TYPES][7 /*banks*/][8 /* cycle time */]; + + Nuca::Nuca( + /*TechnologyParameter::*/DeviceType *dt + ):deviceType(dt) +{ + init_cont(); +} + +void +Nuca::init_cont() +{ + FILE *cont; + char line[5000]; + char jk[5000]; + cont = fopen("contention.dat", "r"); + if (!cont) { + cout << "contention.dat file is missing!\n"; + exit(0); + } + + for(int i=0; i<2; i++) { + for(int j=2; j<5; j++) { + for(int k=0; k nuca_list; + Router *router_s[ROUTER_TYPES]; + router_s[0] = new Router(64.0, 8, 4, &(g_tp.peri_global)); + router_s[0]->print_router(); + router_s[1] = new Router(128.0, 8, 4, &(g_tp.peri_global)); + router_s[1]->print_router(); + router_s[2] = new Router(256.0, 8, 4, &(g_tp.peri_global)); + router_s[2]->print_router(); + + int core_in; // to store no. of cores + + /* to search diff grid organizations */ + double curr_hop, totno_hops, totno_hhops, totno_vhops, tot_lat, + curr_acclat; + double avg_lat, avg_hop, avg_hhop, avg_vhop, avg_dyn_power, + avg_leakage_power; + + double opt_acclat = INF;//, opt_avg_lat = INF, opt_tot_lat = INF; + int opt_rows = 0; + int opt_columns = 0; +// double opt_totno_hops = 0; + double opt_avg_hop = 0; + double opt_dyn_power = 0, opt_leakage_power = 0; + min_values_t minval; + + int bank_start = 0; + + int flit_width = 0; + + /* vertical and horizontal hop latency values */ + int ver_hop_lat, hor_hop_lat; /* in cycles */ + + + /* no. of different bank sizes to consider */ + int iterations; + + + g_ip->nuca_cache_sz = g_ip->cache_sz; + nuca_list.push_back(new nuca_org_t()); + + if (g_ip->cache_level == 0) l2_c = 1; + else l2_c = 0; + + if (g_ip->cores <= 4) core_in = 2; + else if (g_ip->cores <= 8) core_in = 3; + else if (g_ip->cores <= 16) core_in = 4; + else {cout << "Number of cores should be <= 16!\n"; exit(0);} + + + // set the lower bound to an appropriate value. this depends on cache associativity + if (g_ip->assoc > 2) { + i = 2; + while (i != g_ip->assoc) { + MIN_BANKSIZE *= 2; + i *= 2; + } + } + + iterations = (int)logtwo((int)g_ip->cache_sz/MIN_BANKSIZE); + + if (g_ip->force_wiretype) + { + if (g_ip->wt == Low_swing) { + wt_min = Low_swing; + wt_max = Low_swing; + } + else { + wt_min = Global; + wt_max = Low_swing-1; + } + } + else { + wt_min = Global; + wt_max = Low_swing; + } + if (g_ip->nuca_bank_count != 0) { // simulate just one bank + if (g_ip->nuca_bank_count != 2 && g_ip->nuca_bank_count != 4 && + g_ip->nuca_bank_count != 8 && g_ip->nuca_bank_count != 16 && + g_ip->nuca_bank_count != 32 && g_ip->nuca_bank_count != 64) { + fprintf(stderr,"Incorrect bank count value! Please fix the value in cache.cfg\n"); + } + bank_start = (int)logtwo((double)g_ip->nuca_bank_count); + iterations = bank_start+1; + g_ip->cache_sz = g_ip->cache_sz/g_ip->nuca_bank_count; + } + cout << "Simulating various NUCA configurations\n"; + for (it=bank_start; itnuca_cache_sz/g_ip->cache_sz; + cout << "====" << g_ip->cache_sz << "\n"; + + for (wr=wt_min; wr<=wt_max; wr++) { + + for (ro=0; roflit_size; //initialize router + nuca_list.back()->nuca_pda.cycle_time = router_s[ro]->cycle_time; + + /* calculate router and wire parameters */ + + double vlength = ures.cache_ht; /* length of the wire (u)*/ + double hlength = ures.cache_len; // u + + /* find delay, area, and power for wires */ + wire_vertical[wr] = new Wire((enum Wire_type) wr, vlength); + wire_horizontal[wr] = new Wire((enum Wire_type) wr, hlength); + + + hor_hop_lat = calc_cycles(wire_horizontal[wr]->delay, + 1/(nuca_list.back()->nuca_pda.cycle_time*.001)); + ver_hop_lat = calc_cycles(wire_vertical[wr]->delay, + 1/(nuca_list.back()->nuca_pda.cycle_time*.001)); + + /* + * assume a grid like topology and explore for optimal network + * configuration using different row and column count values. + */ + for (c=1; c<=(unsigned int)bank_count; c++) { + while (bank_count%c != 0) c++; + r = bank_count/c; + + /* + * to find the avg access latency of a NUCA cache, uncontended + * access time to each bank from the + * cache controller is calculated. + * avg latency = + * sum of the access latencies to individual banks)/bank + * count value. + */ + totno_hops = totno_hhops = totno_vhops = tot_lat = 0; +/// k = 1; + for (i=0; idelay*avg_hop) + + calc_cycles(ures.access_time, + 1/(nuca_list.back()->nuca_pda.cycle_time*.001)); + + /* avg access lat of nuca */ + avg_dyn_power = + avg_hop * + (router_s[ro]->power.readOp.dynamic) + avg_hhop * + (wire_horizontal[wr]->power.readOp.dynamic) * + (g_ip->block_sz*8 + 64) + avg_vhop * + (wire_vertical[wr]->power.readOp.dynamic) * + (g_ip->block_sz*8 + 64) + ures.power.readOp.dynamic; + + avg_leakage_power = + bank_count * router_s[ro]->power.readOp.leakage + + avg_hhop * (wire_horizontal[wr]->power.readOp.leakage* + wire_horizontal[wr]->delay) * flit_width + + avg_vhop * (wire_vertical[wr]->power.readOp.leakage * + wire_horizontal[wr]->delay); + + if (curr_acclat < opt_acclat) { + opt_acclat = curr_acclat; +/// opt_tot_lat = tot_lat; +/// opt_avg_lat = avg_lat; +/// opt_totno_hops = totno_hops; + opt_avg_hop = avg_hop; + opt_rows = r; + opt_columns = c; + opt_dyn_power = avg_dyn_power; + opt_leakage_power = avg_leakage_power; + } + totno_hops = 0; + tot_lat = 0; + totno_hhops = 0; + totno_vhops = 0; + } + nuca_list.back()->wire_pda.power.readOp.dynamic = + opt_avg_hop * flit_width * + (wire_horizontal[wr]->power.readOp.dynamic + + wire_vertical[wr]->power.readOp.dynamic); + nuca_list.back()->avg_hops = opt_avg_hop; + /* network delay/power */ + nuca_list.back()->h_wire = wire_horizontal[wr]; + nuca_list.back()->v_wire = wire_vertical[wr]; + nuca_list.back()->router = router_s[ro]; + /* bank delay/power */ + + nuca_list.back()->bank_pda.delay = ures.access_time; + nuca_list.back()->bank_pda.power = ures.power; + nuca_list.back()->bank_pda.area.h = ures.cache_ht; + nuca_list.back()->bank_pda.area.w = ures.cache_len; + nuca_list.back()->bank_pda.cycle_time = ures.cycle_time; + + num_cyc = calc_cycles(nuca_list.back()->bank_pda.delay /*s*/, + 1/(nuca_list.back()->nuca_pda.cycle_time*.001/*GHz*/)); + if(num_cyc%2 != 0) num_cyc++; + if (num_cyc > 16) num_cyc = 16; // we have data only up to 16 cycles + + if (it < 7) { + nuca_list.back()->nuca_pda.delay = opt_acclat + + cont_stats[l2_c][core_in][ro][it][num_cyc/2-1]; + nuca_list.back()->contention = + cont_stats[l2_c][core_in][ro][it][num_cyc/2-1]; + } + else { + nuca_list.back()->nuca_pda.delay = opt_acclat + + cont_stats[l2_c][core_in][ro][7][num_cyc/2-1]; + nuca_list.back()->contention = + cont_stats[l2_c][core_in][ro][7][num_cyc/2-1]; + } + nuca_list.back()->nuca_pda.power.readOp.dynamic = opt_dyn_power; + nuca_list.back()->nuca_pda.power.readOp.leakage = opt_leakage_power; + + /* array organization */ + nuca_list.back()->bank_count = bank_count; + nuca_list.back()->rows = opt_rows; + nuca_list.back()->columns = opt_columns; + calculate_nuca_area (nuca_list.back()); + + minval.update_min_values(nuca_list.back()); + nuca_list.push_back(new nuca_org_t()); + opt_acclat = BIGNUM; + + } + } + g_ip->cache_sz /= 2; + } + + delete(nuca_list.back()); + nuca_list.pop_back(); + opt_n = find_optimal_nuca(&nuca_list, &minval); + print_nuca(opt_n); + g_ip->cache_sz = g_ip->nuca_cache_sz/opt_n->bank_count; + + list::iterator niter; + for (niter = nuca_list.begin(); niter != nuca_list.end(); ++niter) + { + delete *niter; + } + nuca_list.clear(); + + for(int i=0; i < ROUTER_TYPES; i++) + { + delete router_s[i]; + } + g_ip->display_ip(); + // g_ip->force_cache_config = true; + // g_ip->ndwl = 8; + // g_ip->ndbl = 16; + // g_ip->nspd = 4; + // g_ip->ndcm = 1; + // g_ip->ndsam1 = 8; + // g_ip->ndsam2 = 32; + +} + + + void +Nuca::print_nuca (nuca_org_t *fr) +{ + printf("\n---------- CACTI version 6.5, Non-uniform Cache Access " + "----------\n\n"); + printf("Optimal number of banks - %d\n", fr->bank_count); + printf("Grid organization rows x columns - %d x %d\n", + fr->rows, fr->columns); + printf("Network frequency - %g GHz\n", + (1/fr->nuca_pda.cycle_time)*1e3); + printf("Cache dimension (mm x mm) - %g x %g\n", + fr->nuca_pda.area.h*1e-3, + fr->nuca_pda.area.w*1e-3); + + fr->router->print_router(); + + printf("\n\nWire stats:\n"); + if (fr->h_wire->wt == Global) { + printf("\tWire type - Full swing global wires with least " + "possible delay\n"); + } + else if (fr->h_wire->wt == Global_5) { + printf("\tWire type - Full swing global wires with " + "5%% delay penalty\n"); + } + else if (fr->h_wire->wt == Global_10) { + printf("\tWire type - Full swing global wires with " + "10%% delay penalty\n"); + } + else if (fr->h_wire->wt == Global_20) { + printf("\tWire type - Full swing global wires with " + "20%% delay penalty\n"); + } + else if (fr->h_wire->wt == Global_30) { + printf("\tWire type - Full swing global wires with " + "30%% delay penalty\n"); + } + else if(fr->h_wire->wt == Low_swing) { + printf("\tWire type - Low swing wires\n"); + } + + printf("\tHorizontal link delay - %g (ns)\n", + fr->h_wire->delay*1e9); + printf("\tVertical link delay - %g (ns)\n", + fr->v_wire->delay*1e9); + printf("\tDelay/length - %g (ns/mm)\n", + fr->h_wire->delay*1e9/fr->bank_pda.area.w); + printf("\tHorizontal link energy -dynamic/access %g (nJ)\n" + "\t -leakage %g (nW)\n\n", + fr->h_wire->power.readOp.dynamic*1e9, + fr->h_wire->power.readOp.leakage*1e9); + printf("\tVertical link energy -dynamic/access %g (nJ)\n" + "\t -leakage %g (nW)\n\n", + fr->v_wire->power.readOp.dynamic*1e9, + fr->v_wire->power.readOp.leakage*1e9); + printf("\n\n"); + fr->v_wire->print_wire(); + printf("\n\nBank stats:\n"); +} + + + nuca_org_t * +Nuca::find_optimal_nuca (list *n, min_values_t *minval) +{ + double cost = 0; + double min_cost = BIGNUM; + nuca_org_t *res = NULL; + float d, a, dp, lp, c; + int v; + dp = g_ip->dynamic_power_wt_nuca; + lp = g_ip->leakage_power_wt_nuca; + a = g_ip->area_wt_nuca; + d = g_ip->delay_wt_nuca; + c = g_ip->cycle_time_wt_nuca; + + list::iterator niter; + + + for (niter = n->begin(); niter != n->end(); niter++) { + fprintf(stderr, "\n-----------------------------" + "---------------\n"); + + + printf("NUCA___stats %d \tbankcount: lat = %g \tdynP = %g \twt = %d\t " + "bank_dpower = %g \tleak = %g \tcycle = %g\n", + (*niter)->bank_count, + (*niter)->nuca_pda.delay, + (*niter)->nuca_pda.power.readOp.dynamic, + (*niter)->h_wire->wt, + (*niter)->bank_pda.power.readOp.dynamic, + (*niter)->nuca_pda.power.readOp.leakage, + (*niter)->nuca_pda.cycle_time); + + + if (g_ip->ed == 1) { + cost = ((*niter)->nuca_pda.delay/minval->min_delay)* + ((*niter)->nuca_pda.power.readOp.dynamic/minval->min_dyn); + if (min_cost > cost) { + min_cost = cost; + res = ((*niter)); + } + } + else if (g_ip->ed == 2) { + cost = ((*niter)->nuca_pda.delay/minval->min_delay)* + ((*niter)->nuca_pda.delay/minval->min_delay)* + ((*niter)->nuca_pda.power.readOp.dynamic/minval->min_dyn); + if (min_cost > cost) { + min_cost = cost; + res = ((*niter)); + } + } + else { + /* + * check whether the current organization + * meets the input deviation constraints + */ + v = check_nuca_org((*niter), minval); + if (minval->min_leakage == 0) minval->min_leakage = 0.1; //FIXME remove this after leakage modeling + + if (v) { + cost = (d * ((*niter)->nuca_pda.delay/minval->min_delay) + + c * ((*niter)->nuca_pda.cycle_time/minval->min_cyc) + + dp * ((*niter)->nuca_pda.power.readOp.dynamic/minval->min_dyn) + + lp * ((*niter)->nuca_pda.power.readOp.leakage/minval->min_leakage) + + a * ((*niter)->nuca_pda.area.get_area()/minval->min_area)); + fprintf(stderr, "cost = %g\n", cost); + + if (min_cost > cost) { + min_cost = cost; + res = ((*niter)); + } + } + else { + niter = n->erase(niter); + if (niter !=n->begin()) + niter --; + } + } + } + return res; +} + + int +Nuca::check_nuca_org (nuca_org_t *n, min_values_t *minval) +{ + if (((n->nuca_pda.delay - minval->min_delay)*100/minval->min_delay) > g_ip->delay_dev_nuca) { + return 0; + } + if (((n->nuca_pda.power.readOp.dynamic - minval->min_dyn)/minval->min_dyn)*100 > + g_ip->dynamic_power_dev_nuca) { + return 0; + } + if (((n->nuca_pda.power.readOp.leakage - minval->min_leakage)/minval->min_leakage)*100 > + g_ip->leakage_power_dev_nuca) { + return 0; + } + if (((n->nuca_pda.cycle_time - minval->min_cyc)/minval->min_cyc)*100 > + g_ip->cycle_time_dev_nuca) { + return 0; + } + if (((n->nuca_pda.area.get_area() - minval->min_area)/minval->min_area)*100 > + g_ip->area_dev_nuca) { + return 0; + } + return 1; +} + + void +Nuca::calculate_nuca_area (nuca_org_t *nuca) +{ + nuca->nuca_pda.area.h= + nuca->rows * ((nuca->h_wire->wire_width + + nuca->h_wire->wire_spacing) + * nuca->router->flit_size + + nuca->bank_pda.area.h); + + nuca->nuca_pda.area.w = + nuca->columns * ((nuca->v_wire->wire_width + + nuca->v_wire->wire_spacing) + * nuca->router->flit_size + + nuca->bank_pda.area.w); +} + diff --git a/Project_FARSI/cacti_for_FARSI/nuca.h b/Project_FARSI/cacti_for_FARSI/nuca.h new file mode 100644 index 00000000..d8849d29 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/nuca.h @@ -0,0 +1,102 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + +#ifndef __NUCA_H__ +#define __NUCA_H__ + +#include "basic_circuit.h" +#include "component.h" +#include "parameter.h" +#include "assert.h" +#include "cacti_interface.h" +#include "wire.h" +#include "mat.h" +#include "io.h" +#include "router.h" +#include + + + +class nuca_org_t { + public: + ~nuca_org_t(); +// int size; + /* area, power, access time, and cycle time stats */ + Component nuca_pda; + Component bank_pda; + Component wire_pda; + Wire *h_wire; + Wire *v_wire; + Router *router; + /* for particular network configuration + * calculated based on a cycle accurate + * simulation Ref: CACTI 6 - Tech report + */ + double contention; + + /* grid network stats */ + double avg_hops; + int rows; + int columns; + int bank_count; +}; + + + +class Nuca : public Component +{ + public: + Nuca( + /*TechnologyParameter::*/DeviceType *dt= &(g_tp.peri_global) +); + void print_router(); + ~Nuca(); + void sim_nuca(); + void init_cont(); + int calc_cycles(double lat, double oper_freq); + void calculate_nuca_area (nuca_org_t *nuca); + int check_nuca_org (nuca_org_t *n, min_values_t *minval); + nuca_org_t * find_optimal_nuca (list *n, min_values_t *minval); + void print_nuca(nuca_org_t *n); + void print_cont_stats(); + + private: + + /*TechnologyParameter::*/DeviceType *deviceType; + int wt_min, wt_max; + Wire *wire_vertical[WIRE_TYPES], + *wire_horizontal[WIRE_TYPES]; + +}; + + +#endif diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/TSV.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/TSV.o new file mode 100644 index 00000000..20610a0f Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/TSV.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/Ucache.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/Ucache.o new file mode 100644 index 00000000..058d47d2 Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/Ucache.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/arbiter.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/arbiter.o new file mode 100644 index 00000000..cd713554 Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/arbiter.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/area.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/area.o new file mode 100644 index 00000000..eee554fb Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/area.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/bank.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/bank.o new file mode 100644 index 00000000..253ef89b Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/bank.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/basic_circuit.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/basic_circuit.o new file mode 100644 index 00000000..0ecf642b Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/basic_circuit.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/cacti b/Project_FARSI/cacti_for_FARSI/obj_dbg/cacti new file mode 100755 index 00000000..334437c3 Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/cacti differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/cacti_interface.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/cacti_interface.o new file mode 100644 index 00000000..e519a15b Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/cacti_interface.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/component.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/component.o new file mode 100644 index 00000000..74fccf6a Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/component.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/crossbar.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/crossbar.o new file mode 100644 index 00000000..5c9fac73 Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/crossbar.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/decoder.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/decoder.o new file mode 100644 index 00000000..20cbe61c Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/decoder.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/extio.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/extio.o new file mode 100644 index 00000000..4d85efdf Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/extio.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/extio_technology.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/extio_technology.o new file mode 100644 index 00000000..6c705059 Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/extio_technology.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/htree2.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/htree2.o new file mode 100644 index 00000000..bb0e5e95 Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/htree2.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/io.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/io.o new file mode 100644 index 00000000..5c75b08f Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/io.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/main.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/main.o new file mode 100644 index 00000000..06d5584c Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/main.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/mat.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/mat.o new file mode 100644 index 00000000..106b603a Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/mat.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/memcad.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/memcad.o new file mode 100644 index 00000000..26546d85 Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/memcad.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/memcad_parameters.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/memcad_parameters.o new file mode 100644 index 00000000..13590bcf Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/memcad_parameters.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/memorybus.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/memorybus.o new file mode 100644 index 00000000..1a3e7d03 Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/memorybus.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/nuca.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/nuca.o new file mode 100644 index 00000000..9ccac147 Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/nuca.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/parameter.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/parameter.o new file mode 100644 index 00000000..4cf635ca Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/parameter.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/powergating.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/powergating.o new file mode 100644 index 00000000..ff8a8ef6 Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/powergating.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/router.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/router.o new file mode 100644 index 00000000..14ee0844 Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/router.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/subarray.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/subarray.o new file mode 100644 index 00000000..c0e75355 Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/subarray.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/technology.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/technology.o new file mode 100644 index 00000000..8c9aaee3 Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/technology.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/uca.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/uca.o new file mode 100644 index 00000000..d229f9fa Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/uca.o differ diff --git a/Project_FARSI/cacti_for_FARSI/obj_dbg/wire.o b/Project_FARSI/cacti_for_FARSI/obj_dbg/wire.o new file mode 100644 index 00000000..4a025a3e Binary files /dev/null and b/Project_FARSI/cacti_for_FARSI/obj_dbg/wire.o differ diff --git a/Project_FARSI/cacti_for_FARSI/parameter.cc b/Project_FARSI/cacti_for_FARSI/parameter.cc new file mode 100644 index 00000000..3300b958 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/parameter.cc @@ -0,0 +1,2837 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + + +#include +#include +#include + +#include "parameter.h" +#include "area.h" + +#include "basic_circuit.h" +#include + +using namespace std; + + +InputParameter * g_ip; +TechnologyParameter g_tp; + +// ali +bool is_equal(double first, double second) +{ + + if((first == 0) && (second ==0)) + { + return true; + } + + if((second==0) || (second!=second)) + return true; + + if((first!=first) || (second!=second)) // both are NaNs + { + return true; + } + if(first==0) + { + if(fabs(first-second)<(second*0.000001)) + return true; + } + else + { + if(fabs(first-second)<(first*0.000001)) + return true; + } + + return false; +} + +/** +void DeviceType::display(uint32_t indent) const +{ + string indent_str(indent, ' '); + + cout << indent_str << "C_g_ideal = " << setw(12) << C_g_ideal << " F/um" << endl; + cout << indent_str << "C_fringe = " << setw(12) << C_fringe << " F/um" << endl; + cout << indent_str << "C_overlap = " << setw(12) << C_overlap << " F/um" << endl; + cout << indent_str << "C_junc = " << setw(12) << C_junc << " F/um^2" << endl; + cout << indent_str << "C_junc_sw = " << setw(12) << C_junc_sidewall << " F/um^2" << endl; + cout << indent_str << "l_phy = " << setw(12) << l_phy << " um" << endl; + cout << indent_str << "l_elec = " << setw(12) << l_elec << " um" << endl; + cout << indent_str << "R_nch_on = " << setw(12) << R_nch_on << " ohm-um" << endl; + cout << indent_str << "R_pch_on = " << setw(12) << R_pch_on << " ohm-um" << endl; + cout << indent_str << "Vdd = " << setw(12) << Vdd << " V" << endl; + cout << indent_str << "Vth = " << setw(12) << Vth << " V" << endl; + cout << indent_str << "I_on_n = " << setw(12) << I_on_n << " A/um" << endl; + cout << indent_str << "I_on_p = " << setw(12) << I_on_p << " A/um" << endl; + cout << indent_str << "I_off_n = " << setw(12) << I_off_n << " A/um" << endl; + cout << indent_str << "I_off_p = " << setw(12) << I_off_p << " A/um" << endl; + cout << indent_str << "C_ox = " << setw(12) << C_ox << " F/um^2" << endl; + cout << indent_str << "t_ox = " << setw(12) << t_ox << " um" << endl; + cout << indent_str << "n_to_p_eff_curr_drv_ratio = " << n_to_p_eff_curr_drv_ratio << endl; +} +**/ +bool DeviceType::isEqual(const DeviceType & dev) +{ + if( !is_equal(C_g_ideal,dev.C_g_ideal)) {display(0); cout << "\n\n\n"; dev.display(0); assert(false);} + if( !is_equal(C_fringe,dev.C_fringe)) {display(0); cout << "\n\n\n"; dev.display(0);assert(false);} + + if( !is_equal(C_overlap , dev.C_overlap)) {display(0); cout << "\n\n\n"; dev.display(0);assert(false);} + if( !is_equal(C_junc , dev.C_junc)) {display(0); cout << "\n\n\n"; dev.display(0);assert(false);} + if( !is_equal(C_junc_sidewall , dev.C_junc_sidewall)) {display(0); cout << "\n\n\n"; dev.display(0);assert(false);} + if( !is_equal(l_phy , dev.l_phy)) {display(0); cout << "\n\n\n"; dev.display(0);assert(false);} + if( !is_equal(l_elec , dev.l_elec)) {display(0); cout << "\n\n\n"; dev.display(0);assert(false);} + if( !is_equal(R_nch_on , dev.R_nch_on)) {display(0); cout << "\n\n\n"; dev.display(0);assert(false);} + if( !is_equal(R_pch_on , dev.R_pch_on)) {display(0); cout << "\n\n\n"; dev.display(0);assert(false);} + if( !is_equal(Vdd , dev.Vdd)) {display(0); cout << "\n\n\n"; dev.display(0);assert(false);} + if( !is_equal(Vth , dev.Vth)) {display(0); cout << "\n\n\n"; dev.display(0);assert(false);} +//// if( !is_equal(Vcc_min , dev.Vcc_min)) {display(0); cout << "\n\n\n"; dev.display(0);assert(false);} + if( !is_equal(I_on_n , dev.I_on_n)) {display(0); cout << "\n\n\n"; dev.display(0);assert(false);} + if( !is_equal(I_on_p , dev.I_on_p)) {display(0); cout << "\n\n\n"; dev.display(0);assert(false);} + if( !is_equal(I_off_n , dev.I_off_n)) {display(0); cout << "\n\n\n"; dev.display(0);assert(false);} + if( !is_equal(I_off_p , dev.I_off_p)) {display(0); cout << "\n\n\n"; dev.display(0);assert(false);} + if( !is_equal(I_g_on_n , dev.I_g_on_n)) {display(0); cout << "\n\n\n"; dev.display(0);assert(false);} + if( !is_equal(I_g_on_p , dev.I_g_on_p)) {display(0); cout << "\n\n\n"; dev.display(0);assert(false);} + if( !is_equal(C_ox , dev.C_ox)) {display(0); cout << "\n\n\n"; dev.display(0);assert(false);} + if( !is_equal(t_ox , dev.t_ox)) {display(0); cout << "\n\n\n"; dev.display(0);assert(false);} + if( !is_equal(n_to_p_eff_curr_drv_ratio , dev.n_to_p_eff_curr_drv_ratio)) {display(0); cout << "\n\n\n"; dev.display(0);assert(false);} + if( !is_equal(long_channel_leakage_reduction , dev.long_channel_leakage_reduction)) {display(0);cout << "\n\n\n"; dev.display(0); assert(false);} + if( !is_equal(Mobility_n , dev.Mobility_n)) {display(0); cout << "\n\n\n"; dev.display(0);assert(false);} + + // auxilary parameters + ///if( !is_equal(Vdsat , dev.Vdsat)) {display(0); cout << "\n\n\n"; dev.display(0); assert(false);} + ///if( !is_equal(gmp_to_gmn_multiplier , dev.gmp_to_gmn_multiplier)) {display(0); cout << "\n\n\n"; dev.display(0); assert(false);} + + return true; +} + +double scan_single_input_double(char* line, const char* name, const char* unit_name, bool print) +{ + double temp; + char unit[300]; + memset(unit,0,300); + sscanf(&line[strlen(name)], "%*[ \t]%s%*[ \t]%lf",unit,&temp); + if(print) + cout << name << ": " << temp << " " << unit << endl; + return temp; +} + +double scan_five_input_double(char* line, const char* name, const char* unit_name, int flavor, bool print) +{ + double temp[5]; + char unit[300]; + memset(unit,0,300); + sscanf(&line[strlen(name)], "%*[ \t]%s%*[ \t]%lf%*[ \t]%lf%*[ \t]%lf%*[ \t]%lf%*[ \t]%lf" + ,unit,&(temp[0]),&(temp[1]),&(temp[2]),&(temp[3]), &(temp[4]) ); + + if (print) + cout << name << "[" << flavor <<"]: " << temp[flavor] << " " << unit<< endl; + return temp[flavor]; + +} + +void scan_five_input_double_temperature(char* line, const char* name, const char* unit_name, int flavor, unsigned int temperature, bool print, double & result) +{ + double temp[5]; + unsigned int thermal_temp; + char unit[300]; + memset(unit,0,300); + sscanf(&line[strlen(name)], "%*[ \t]%s%*[ \t]%u%*[ \t]%lf%*[ \t]%lf%*[ \t]%lf%*[ \t]%lf%*[ \t]%lf" + ,unit,&thermal_temp,&(temp[0]),&(temp[1]),&(temp[2]),&(temp[3]), &(temp[4]) ); + + + if(thermal_temp==(temperature-300)) + { + if (print) + cout << name << ": " << temp[flavor] << " "<< unit << endl; + + result = temp[flavor]; + } + +} + +void DeviceType::assign(const string & in_file, int tech_flavor, unsigned int temperature) +{ + FILE *fp = fopen(in_file.c_str(), "r"); + char line[5000]; + //char temp_var[5000]; + + //double temp[5]; + //unsigned int thermal_temp; + + double nmos_effective_resistance_multiplier; + + if(!fp) { + cout << in_file << " is missing!\n"; + exit(-1); + } + + while(fscanf(fp, "%[^\n]\n", line) != EOF) + { + if (!strncmp("-C_g_ideal", line, strlen("-C_g_ideal"))) + { + C_g_ideal=scan_five_input_double(line,"-C_g_ideal","F/um",tech_flavor,g_ip->print_detail_debug); + continue; + } + if (!strncmp("-C_fringe", line, strlen("-C_fringe"))) + { + C_fringe=scan_five_input_double(line,"-C_fringe","F/um",tech_flavor,g_ip->print_detail_debug); + continue; + } + if (!strncmp("-C_junc_sw", line, strlen("-C_junc_sw"))) + { + C_junc_sidewall =scan_five_input_double(line,"-C_junc_sw","F/um",tech_flavor,g_ip->print_detail_debug); + continue; + } + if (!strncmp("-C_junc", line, strlen("-C_junc"))) + { + C_junc=scan_five_input_double(line,"-C_junc","F/um",tech_flavor,g_ip->print_detail_debug); + continue; + } + + if (!strncmp("-l_phy", line, strlen("-l_phy"))) + { + l_phy=scan_five_input_double(line,"-l_phy","F/um",tech_flavor,g_ip->print_detail_debug); + continue; + } + if (!strncmp("-l_elec", line, strlen("-l_elec"))) + { + l_elec=scan_five_input_double(line,"-l_elec","F/um",tech_flavor,g_ip->print_detail_debug); + continue; + } + if (!strncmp("-nmos_effective_resistance_multiplier", line, strlen("-nmos_effective_resistance_multiplier"))) + { + nmos_effective_resistance_multiplier=scan_five_input_double(line,"-nmos_effective_resistance_multiplier","F/um",tech_flavor,g_ip->print_detail_debug); + continue; + } + if (!strncmp("-Vdd", line, strlen("-Vdd"))) + { + Vdd=scan_five_input_double(line,"-Vdd","F/um",tech_flavor,g_ip->print_detail_debug); + continue; + } + if (!strncmp("-Vth", line, strlen("-Vth"))) + { + Vth=scan_five_input_double(line,"-Vth","F/um",tech_flavor,g_ip->print_detail_debug); + continue; + } + if (!strncmp("-Vdsat", line, strlen("-Vdsat"))) + { + Vdsat=scan_five_input_double(line,"-Vdsat","V",tech_flavor,g_ip->print_detail_debug); + continue; + } + if (!strncmp("-I_on_n", line, strlen("-I_on_n"))) + { + I_on_n=scan_five_input_double(line,"-I_on_n","F/um",tech_flavor,g_ip->print_detail_debug); + continue; + } + if (!strncmp("-I_on_p", line, strlen("-I_on_p"))) + { + I_on_p = scan_five_input_double(line,"-I_on_p","F/um",tech_flavor,g_ip->print_detail_debug); + continue; + } + if (!strncmp("-I_off_n", line, strlen("-I_off_n"))) + { + scan_five_input_double_temperature(line,"-I_off_n","F/um",tech_flavor,temperature,g_ip->print_detail_debug,I_off_n); + continue; + } + if (!strncmp("-I_g_on_n", line, strlen("-I_g_on_n"))) + { + scan_five_input_double_temperature(line,"-I_g_on_n","F/um",tech_flavor,temperature,g_ip->print_detail_debug,I_g_on_n); + continue; + } + if (!strncmp("-C_ox", line, strlen("-C_ox"))) + { + C_ox=scan_five_input_double(line,"-C_ox","F/um",tech_flavor,g_ip->print_detail_debug); + continue; + } + if (!strncmp("-t_ox", line, strlen("-t_ox"))) + { + t_ox=scan_five_input_double(line,"-t_ox","F/um",tech_flavor,g_ip->print_detail_debug); + continue; + } + if (!strncmp("-n2p_drv_rt", line, strlen("-n2p_drv_rt"))) + { + n_to_p_eff_curr_drv_ratio=scan_five_input_double(line,"-n2p_drv_rt","F/um",tech_flavor,g_ip->print_detail_debug); + continue; + } + if (!strncmp("-lch_lk_rdc", line, strlen("-lch_lk_rdc"))) + { + long_channel_leakage_reduction=scan_five_input_double(line,"-lch_lk_rdc","F/um",tech_flavor,g_ip->print_detail_debug); + continue; + } + if (!strncmp("-Mobility_n", line, strlen("-Mobility_n"))) + { + Mobility_n=scan_five_input_double(line,"-Mobility_n","F/um",tech_flavor,g_ip->print_detail_debug); + continue; + } + if (!strncmp("-gmp_to_gmn_multiplier", line, strlen("-gmp_to_gmn_multiplier"))) + { + gmp_to_gmn_multiplier=scan_five_input_double(line,"-gmp_to_gmn_multiplier","F/um",tech_flavor,g_ip->print_detail_debug); + continue; + } + if (!strncmp("-n_to_p_eff_curr_drv_ratio", line, strlen("-n_to_p_eff_curr_drv_ratio"))) + { + n_to_p_eff_curr_drv_ratio=scan_five_input_double(line,"-n_to_p_eff_curr_drv_ratio","F/um",tech_flavor,g_ip->print_detail_debug); + continue; + } + + } + + C_overlap = 0.2*C_g_ideal; + if(tech_flavor>=3) + R_nch_on = nmos_effective_resistance_multiplier * g_tp.vpp / I_on_n;//ohm-micron + else + R_nch_on = nmos_effective_resistance_multiplier * Vdd / I_on_n;//ohm-micron + R_pch_on = n_to_p_eff_curr_drv_ratio * R_nch_on;//ohm-micron + I_off_p = I_off_n; + I_g_on_p = I_g_on_n; + if(g_ip->print_detail_debug) + { + ///cout << nmos_effective_resistance_multiplier << " -- " << Vdd << " -- " << I_on_n << " -- " << n_to_p_eff_curr_drv_ratio << endl; + cout << "C_overlap: " << C_overlap << " F/um" << endl; + cout << "R_nch_on: " << R_nch_on << " ohm-micron" << endl; + cout << "R_pch_on: " << R_pch_on << " ohm-micron" << endl; + } + + fclose(fp); + +} + + +void DeviceType::interpolate(double alpha, const DeviceType& dev1, const DeviceType& dev2) +{ + C_g_ideal = alpha*dev1.C_g_ideal+(1-alpha)*dev2.C_g_ideal; + C_fringe = alpha*dev1.C_fringe+(1-alpha)*dev2.C_fringe; + C_overlap = alpha*dev1.C_overlap+(1-alpha)*dev2.C_overlap; + C_junc = alpha*dev1.C_junc+(1-alpha)*dev2.C_junc; + l_phy = alpha*dev1.l_phy+(1-alpha)*dev2.l_phy; + l_elec = alpha*dev1.l_elec+(1-alpha)*dev2.l_elec; + R_nch_on = alpha*dev1.R_nch_on+(1-alpha)*dev2.R_nch_on; + R_pch_on = alpha*dev1.R_pch_on+(1-alpha)*dev2.R_pch_on; + Vdd = alpha*dev1.Vdd+(1-alpha)*dev2.Vdd; + Vth = alpha*dev1.Vth+(1-alpha)*dev2.Vth; + Vcc_min = alpha*dev1.Vcc_min+(1-alpha)*dev2.Vcc_min; + I_on_n = alpha*dev1.I_on_n+(1-alpha)*dev2.I_on_n; + I_on_p = alpha*dev1.I_on_p+(1-alpha)*dev2.I_on_p; + I_off_n = alpha*dev1.I_off_n+(1-alpha)*dev2.I_off_n; + I_off_p = alpha*dev1.I_off_p+(1-alpha)*dev2.I_off_p; + I_g_on_n = alpha*dev1.I_g_on_n+(1-alpha)*dev2.I_g_on_n; + I_g_on_p = alpha*dev1.I_g_on_p+(1-alpha)*dev2.I_g_on_p; + C_ox = alpha*dev1.C_ox+(1-alpha)*dev2.C_ox; + t_ox = alpha*dev1.t_ox+(1-alpha)*dev2.t_ox; + n_to_p_eff_curr_drv_ratio = alpha*dev1.n_to_p_eff_curr_drv_ratio+(1-alpha)*dev2.n_to_p_eff_curr_drv_ratio; + long_channel_leakage_reduction = alpha*dev1.long_channel_leakage_reduction+(1-alpha)*dev2.long_channel_leakage_reduction; + Mobility_n = alpha*dev1.Mobility_n+(1-alpha)*dev2.Mobility_n; + Vdsat = alpha*dev1.Vdsat + (1-alpha)*dev2.Vdsat; + gmp_to_gmn_multiplier = alpha*dev1.gmp_to_gmn_multiplier + (1-alpha)*dev2.gmp_to_gmn_multiplier; + n_to_p_eff_curr_drv_ratio = alpha*dev1.n_to_p_eff_curr_drv_ratio + (1-alpha)*dev2.n_to_p_eff_curr_drv_ratio; + + C_junc_sidewall = dev1.C_junc_sidewall; +} + + +double scan_input_double_inter_type(char* line, const char * name, const char * unit_name, int proj_type, int tech_flavor, bool print) +{ + assert(proj_typeprint_detail_debug; + + while(fscanf(fp, "%[^\n]\n", line) != EOF) + { + if (!strncmp("-wire_pitch", line, strlen("-wire_pitch"))) + { + pitch =scan_input_double_inter_type(line,"-wire_pitch","um",g_ip->ic_proj_type,tech_flavor,print); + continue; + } + if (!strncmp("-barrier_thickness", line, strlen("-barrier_thickness"))) + { + barrier_thickness =scan_input_double_inter_type(line,"-barrier_thickness","ohm",g_ip->ic_proj_type,tech_flavor,print); + continue; + } + if (!strncmp("-dishing_thickness", line, strlen("-dishing_thickness"))) + { + dishing_thickness =scan_input_double_inter_type(line,"-dishing_thickness","um",g_ip->ic_proj_type,tech_flavor,print); + continue; + } + if (!strncmp("-alpha_scatter", line, strlen("-alpha_scatter"))) + { + alpha_scatter =scan_input_double_inter_type(line,"-alpha_scatter","um",g_ip->ic_proj_type,tech_flavor,print); + continue; + } + if (!strncmp("-aspect_ratio", line, strlen("-aspect_ratio"))) + { + aspect_ratio =scan_input_double_inter_type(line,"-aspect_ratio","um",g_ip->ic_proj_type,tech_flavor,print); + continue; + } + if (!strncmp("-miller_value", line, strlen("-miller_value"))) + { + miller_value =scan_input_double_inter_type(line,"-miller_value","um",g_ip->ic_proj_type,tech_flavor,print); + continue; + } + if (!strncmp("-horiz_dielectric_constant", line, strlen("-horiz_dielectric_constant"))) + { + horiz_dielectric_constant =scan_input_double_inter_type(line,"-horiz_dielectric_constant","um",g_ip->ic_proj_type,tech_flavor,print); + continue; + } + if (!strncmp("-vert_dielectric_constant", line, strlen("-vert_dielectric_constant"))) + { + vert_dielectric_constant =scan_input_double_inter_type(line,"-vert_dielectric_constant","um",g_ip->ic_proj_type,tech_flavor,print); + continue; + } + if (!strncmp("-ild_thickness", line, strlen("-ild_thickness"))) + { + ild_thickness =scan_input_double_inter_type(line,"-ild_thickness","um",g_ip->ic_proj_type,tech_flavor,print); + continue; + } + if (!strncmp("-fringe_cap", line, strlen("-fringe_cap"))) + { + fringe_cap =scan_input_double_inter_type(line,"-fringe_cap","um",g_ip->ic_proj_type,tech_flavor,print); + continue; + } + if (!strncmp("-wire_r_per_micron", line, strlen("-wire_r_per_micron"))) + { + R_per_um =scan_input_double_inter_type(line,"-wire_r_per_micron","um",g_ip->ic_proj_type,tech_flavor,print); + continue; + } + if (!strncmp("-wire_c_per_micron", line, strlen("-wire_c_per_micron"))) + { + C_per_um =scan_input_double_inter_type(line,"-wire_c_per_micron","um",g_ip->ic_proj_type,tech_flavor,print); + continue; + } + if (!strncmp("-resistivity", line, strlen("-resistivity"))) + { + resistivity =scan_input_double_inter_type(line,"-resistivity","um",g_ip->ic_proj_type,tech_flavor,print); + continue; + } + } + + pitch *= g_ip->F_sz_um; + wire_width = pitch/ 2; //micron + wire_thickness = aspect_ratio * wire_width;//micron + wire_spacing = pitch - wire_width;//micron + if((projection_type!=1) || (tech_flavor!=3)) + { + R_per_um = wire_resistance(resistivity, wire_width, + wire_thickness, barrier_thickness, dishing_thickness, alpha_scatter);//ohm/micron + if(print) + cout << R_per_um << " = wire_resistance(" << resistivity << "," << wire_width << "," << + wire_thickness << "," << barrier_thickness << "," << dishing_thickness << "," << alpha_scatter << ")\n"; + + + C_per_um = wire_capacitance(wire_width, wire_thickness, wire_spacing, + ild_thickness, miller_value, horiz_dielectric_constant, + vert_dielectric_constant, fringe_cap);//F/micron. + if(print) + cout << C_per_um << " = wire_capacitance(" << wire_width << "," << wire_thickness << "," << wire_spacing + << "," << ild_thickness << "," << miller_value << "," << horiz_dielectric_constant + << "," << vert_dielectric_constant << "," << fringe_cap << ")\n"; + + } + fclose(fp); +} + +bool InterconnectType::isEqual(const InterconnectType & inter) +{ + if( !is_equal(pitch , inter.pitch)) {display(0); assert(false);} + if( !is_equal(R_per_um , inter.R_per_um)) {display(0); assert(false);} + if( !is_equal(C_per_um , inter.C_per_um)) {display(0); assert(false);} + if( !is_equal(horiz_dielectric_constant , inter.horiz_dielectric_constant)) {display(0); assert(false);} + if( !is_equal(vert_dielectric_constant , inter.vert_dielectric_constant)) {display(0); assert(false);} + if( !is_equal(aspect_ratio , inter.aspect_ratio)) {display(0); assert(false);} + if( !is_equal(miller_value , inter.miller_value)) {display(0); assert(false);} + if( !is_equal(ild_thickness , inter.ild_thickness)) {display(0); assert(false);} + + //auxilary parameters + ///if( !is_equal(wire_width , inter.wire_width)) {display(0); assert(false);} + ///if( !is_equal(wire_thickness , inter.wire_thickness)) {display(0); assert(false);} + ///if( !is_equal(wire_spacing , inter.wire_spacing)) {display(0); assert(false);} + ///if( !is_equal(barrier_thickness , inter.barrier_thickness)) {display(0); assert(false);} + ///if( !is_equal(dishing_thickness , inter.dishing_thickness)) {display(0); assert(false);} + ///if( !is_equal(alpha_scatter , inter.alpha_scatter)) {display(0); assert(false);} + ///if( !is_equal(fringe_cap , inter.fringe_cap)) {display(0); assert(false);} + + return true; +} + +void InterconnectType::interpolate(double alpha, const InterconnectType & inter1, const InterconnectType & inter2) +{ + pitch = alpha*inter1.pitch + (1-alpha)*inter2.pitch; + R_per_um = alpha*inter1.R_per_um + (1-alpha)*inter2.R_per_um; + C_per_um = alpha*inter1.C_per_um + (1-alpha)*inter2.C_per_um; + horiz_dielectric_constant = alpha*inter1.horiz_dielectric_constant + (1-alpha)*inter2.horiz_dielectric_constant; + vert_dielectric_constant = alpha*inter1.vert_dielectric_constant + (1-alpha)*inter2.vert_dielectric_constant; + aspect_ratio = alpha*inter1.aspect_ratio + (1-alpha)*inter2.aspect_ratio; + miller_value = alpha*inter1.miller_value + (1-alpha)*inter2.miller_value; + ild_thickness = alpha*inter1.ild_thickness + (1-alpha)*inter2.ild_thickness; + +} + +void scan_five_input_double_mem_type(char* line, const char* name, const char* unit_name, int flavor, int cell_type, bool print, double & result) +{ + double temp[5]; + int cell_type_temp; + char unit[300]; + memset(unit,0,300); + + sscanf(&line[strlen(name)], "%*[ \t]%s%*[ \t]%d%*[ \t]%lf%*[ \t]%lf%*[ \t]%lf%*[ \t]%lf%*[ \t]%lf" + ,unit,&cell_type_temp,&(temp[0]),&(temp[1]),&(temp[2]),&(temp[3]), &(temp[4]) ); + + + if(cell_type_temp==cell_type) + { + if (print) + cout << name << ": " << temp[flavor] << " "<< unit << endl; + + result = temp[flavor]; + } +} + +// cell_type --> sram(0),cam(1),dram(2) +void MemoryType::assign(const string & in_file, int tech_flavor, int cell_type) +{ + FILE *fp = fopen(in_file.c_str(), "r"); + char line[5000]; + //char temp_var[5000]; + + //double temp; + //unsigned int thermal_temp; + + double vdd_cell,vdd; + + if(!fp) { + cout << in_file << " is missing!\n"; + exit(-1); + } + while(fscanf(fp, "%[^\n]\n", line) != EOF) + { + if (!strncmp("-Vdd", line, strlen("-Vdd"))) + { + vdd=scan_five_input_double(line,"-Vdd","V",tech_flavor,g_ip->print_detail_debug); + continue; + } + if (!strncmp("-vdd_cell", line, strlen("-vdd_cell"))) + { + scan_five_input_double_mem_type(line,"-vdd_cell","V",tech_flavor,cell_type, g_ip->print_detail_debug,vdd_cell); + continue; + } + if (!strncmp("-Wmemcella", line, strlen("-Wmemcella"))) + { + scan_five_input_double_mem_type(line,"-Wmemcella","V",tech_flavor,cell_type, g_ip->print_detail_debug,cell_a_w); + continue; + } + if (!strncmp("-Wmemcellpmos", line, strlen("-Wmemcellpmos"))) + { + scan_five_input_double_mem_type(line,"-Wmemcellpmos","V",tech_flavor,cell_type, g_ip->print_detail_debug,cell_pmos_w); + continue; + } + if (!strncmp("-Wmemcellnmos", line, strlen("-Wmemcellnmos"))) + { + scan_five_input_double_mem_type(line,"-Wmemcellnmos","V",tech_flavor,cell_type, g_ip->print_detail_debug,cell_nmos_w); + continue; + } + if (!strncmp("-area_cell", line, strlen("-area_cell"))) + { + scan_five_input_double_mem_type(line,"-area_cell","V",tech_flavor,cell_type, g_ip->print_detail_debug,area_cell); + continue; + } + if (!strncmp("-asp_ratio_cell", line, strlen("-asp_ratio_cell"))) + { + scan_five_input_double_mem_type(line,"-asp_ratio_cell","V",tech_flavor,cell_type, g_ip->print_detail_debug,asp_ratio_cell); + continue; + } + } + if(cell_type!=2) + cell_a_w *= g_ip->F_sz_um; + cell_pmos_w *= g_ip->F_sz_um; + cell_nmos_w *= g_ip->F_sz_um; + if(cell_type!=2) + area_cell *= (g_ip->F_sz_um* g_ip->F_sz_um); + ///assert(asp_ratio_cell!=0); + b_w = sqrt(area_cell / (asp_ratio_cell)); + b_h = asp_ratio_cell * b_w; + if(cell_type==2) //dram + Vbitpre = vdd_cell; + else // sram or cam + Vbitpre = vdd; + + + Vbitfloating = Vbitpre*0.7; + + //display(5); + +} + +void MemoryType::interpolate(double alpha, const MemoryType& mem1, const MemoryType& mem2) +{ + cell_a_w = alpha * mem1.cell_a_w + (1-alpha) * mem2.cell_a_w; + cell_pmos_w = alpha * mem1.cell_pmos_w + (1-alpha) * mem2.cell_pmos_w; + cell_nmos_w = alpha * mem1.cell_nmos_w + (1-alpha) * mem2.cell_nmos_w; + + area_cell = alpha * mem1.area_cell + (1-alpha) * mem2.area_cell; + asp_ratio_cell = alpha * mem1.asp_ratio_cell + (1-alpha) * mem2.asp_ratio_cell; + + Vbitpre = mem2.Vbitpre; + Vbitfloating = Vbitpre*0.7; + // updating dependant variables after scaling/interpolating + ///assert(asp_ratio_cell!=0); + b_w = sqrt(area_cell / (asp_ratio_cell)); + b_h = asp_ratio_cell * b_w; + //display(10); +} + +bool MemoryType::isEqual(const MemoryType & mem) +{ + if( !is_equal(b_w , mem.b_w)) {display(0); cout << "\n\n\n"; mem.display(0); assert(false);} + if( !is_equal(b_h , mem.b_h)) {display(0); cout << "\n\n\n"; mem.display(0); assert(false);} + if( !is_equal(cell_a_w , mem.cell_a_w)) {display(0); cout << "\n\n\n"; mem.display(0); assert(false);} + if( !is_equal(cell_pmos_w , mem.cell_pmos_w)) {display(0); cout << "\n\n\n"; mem.display(0); assert(false);} + if( !is_equal(cell_nmos_w , mem.cell_nmos_w)) {display(0); cout << "\n\n\n"; mem.display(0); assert(false);} + if( !is_equal(Vbitpre , mem.Vbitpre)) {display(0); cout << "\n\n\n"; mem.display(0); assert(false);} + ///if( !is_equal(Vbitfloating , mem.Vbitfloating)) {display(0); cout << "\n\n\n"; mem.display(0); assert(false);} + + // needed to calculate b_w b_h + ///if( !is_equal(area_cell , mem.area_cell)) {display(0); assert(false);} + ///if( !is_equal(asp_ratio_cell , mem.asp_ratio_cell)) {display(0); assert(false);} + + return true; +} + +void ScalingFactor::assign(const string & in_file) +{ + FILE *fp = fopen(in_file.c_str(), "r"); + char line[5000]; + //char temp_var[5000]; + if(!fp) + { + cout << in_file << " is missing!\n"; + exit(-1); + } + + while(fscanf(fp, "%[^\n]\n", line) != EOF) + { + if (!strncmp("-logic_scaling_co_eff", line, strlen("-logic_scaling_co_eff"))) + { + logic_scaling_co_eff = scan_single_input_double(line,"-logic_scaling_co_eff","F/um", g_ip->print_detail_debug); + continue; + } + if (!strncmp("-core_tx_density", line, strlen("-core_tx_density"))) + { + core_tx_density = scan_single_input_double(line,"-core_tx_density","F/um", g_ip->print_detail_debug); + continue; + } + + } + + fclose(fp); +} + +void ScalingFactor::interpolate(double alpha, const ScalingFactor& dev1, const ScalingFactor& dev2) +{ + logic_scaling_co_eff = alpha*dev1.logic_scaling_co_eff + (1-alpha)*dev2.logic_scaling_co_eff; + core_tx_density = alpha*dev1.core_tx_density + (1-alpha)*dev2.core_tx_density; +} + +bool ScalingFactor::isEqual(const ScalingFactor & scal) +{ + if( !is_equal(logic_scaling_co_eff,scal.logic_scaling_co_eff)) { display(0); assert(false);} + if( !is_equal(core_tx_density,scal.core_tx_density)) { display(0); assert(false);} + if( !is_equal(long_channel_leakage_reduction , scal.long_channel_leakage_reduction)) { display(0); assert(false);} + return true; +} + +void TechnologyParameter::find_upper_and_lower_tech(double technology, int &tech_lo, string& in_file_lo, int &tech_hi, string& in_file_hi) +{ + if (technology < 181 && technology > 179) + { + tech_lo = 180; + in_file_lo = "tech_params/180nm.dat"; + tech_hi = 180; + in_file_hi = "tech_params/180nm.dat"; + } + else if (technology < 91 && technology > 89) + { + tech_lo = 90; + in_file_lo = "tech_params/90nm.dat"; + tech_hi = 90; + in_file_hi = "tech_params/90nm.dat"; + } + else if (technology < 66 && technology > 64) + { + tech_lo = 65; + in_file_lo = "tech_params/65nm.dat"; + tech_hi = 65; + in_file_hi = "tech_params/65nm.dat"; + } + else if (technology < 46 && technology > 44) + { + tech_lo = 45; + in_file_lo = "tech_params/45nm.dat"; + tech_hi = 45; + in_file_hi = "tech_params/45nm.dat"; + } + else if (technology < 33 && technology > 31) + { + tech_lo = 32; + in_file_lo = "tech_params/32nm.dat"; + tech_hi = 32; + in_file_hi = "tech_params/32nm.dat"; + } + else if (technology < 23 && technology > 21) + { + tech_lo = 22; + in_file_lo = "tech_params/22nm.dat"; + tech_hi = 22; + in_file_hi = "tech_params/22nm.dat"; + } + else if (technology < 180 && technology > 90) + { + tech_lo = 180; + in_file_lo = "tech_params/180nm.dat"; + tech_hi = 90; + in_file_hi = "tech_params/90nm.dat"; + } + else if (technology < 90 && technology > 65) + { + tech_lo = 90; + in_file_lo = "tech_params/90nm.dat"; + tech_hi = 65; + in_file_hi = "tech_params/65nm.dat"; + } + else if (technology < 65 && technology > 45) + { + tech_lo = 65; + in_file_lo = "tech_params/65nm.dat"; + tech_hi = 45; + in_file_hi = "tech_params/45nm.dat"; + } + else if (technology < 45 && technology > 32) + { + tech_lo = 45; + in_file_lo = "tech_params/45nm.dat"; + tech_hi = 32; + in_file_hi = "tech_params/32nm.dat"; + } + else if (technology < 32 && technology > 22) + { + tech_lo = 32; + in_file_lo = "tech_params/32nm.dat"; + tech_hi = 22; + in_file_hi = "tech_params/22nm.dat"; + } + /** + else if (technology < 22 && technology > 16) + { + tech_lo = 22; + in_file_lo = "tech_params/22nm.dat"; + tech_hi = 16; + in_file_hi = "tech_params/16nm.dat"; + } + **/ + else + { + cout<<"Invalid technology nodes"<tsv_is_subarray_type; + } + else + { + tsv_type = g_ip->tsv_os_bank_type; + } + fp = fopen(in_file.c_str(), "r"); + while(fscanf(fp, "%[^\n]\n", line) != EOF) + { + if (!strncmp("-tsv_pitch", line, strlen("-tsv_pitch"))) + { + tsv_pitch = scan_input_double_tsv_type(line,"-tsv_pitch","F/um", g_ip->ic_proj_type, tsv_type, g_ip->print_detail_debug); + continue; + } + if (!strncmp("-tsv_diameter", line, strlen("-tsv_diameter"))) + { + tsv_diameter = scan_input_double_tsv_type(line,"-tsv_diameter","F/um", g_ip->ic_proj_type, tsv_type, g_ip->print_detail_debug); + continue; + } + if (!strncmp("-tsv_length", line, strlen("-tsv_length"))) + { + tsv_length = scan_input_double_tsv_type(line,"-tsv_length","F/um", g_ip->ic_proj_type, tsv_type, g_ip->print_detail_debug); + continue; + } + if (!strncmp("-tsv_dielec_thickness", line, strlen("-tsv_dielec_thickness"))) + { + tsv_dielec_thickness = scan_input_double_tsv_type(line,"-tsv_dielec_thickness","F/um", g_ip->ic_proj_type, tsv_type, g_ip->print_detail_debug); + continue; + } + if (!strncmp("-tsv_contact_resistance", line, strlen("-tsv_contact_resistance"))) + { + tsv_contact_resistance = scan_input_double_tsv_type(line,"-tsv_contact_resistance","F/um", g_ip->ic_proj_type, tsv_type, g_ip->print_detail_debug); + continue; + } + if (!strncmp("-tsv_depletion_width", line, strlen("-tsv_depletion_width"))) + { + tsv_depletion_width = scan_input_double_tsv_type(line,"-tsv_depletion_width","F/um", g_ip->ic_proj_type, tsv_type, g_ip->print_detail_debug); + continue; + } + if (!strncmp("-tsv_liner_dielectric_cons", line, strlen("-tsv_liner_dielectric_cons"))) + { + tsv_liner_dielectric_constant = scan_input_double_tsv_type(line,"-tsv_liner_dielectric_cons","F/um", g_ip->ic_proj_type, tsv_type, g_ip->print_detail_debug); + continue; + } + + tsv_length *= g_ip->num_die_3d; + if(iter==0) + { + tsv_parasitic_resistance_fine = tsv_resistance(BULK_CU_RESISTIVITY, tsv_length, tsv_diameter, tsv_contact_resistance); + tsv_parasitic_capacitance_fine = tsv_capacitance(tsv_length, tsv_diameter, tsv_pitch, tsv_dielec_thickness, tsv_liner_dielectric_constant, tsv_depletion_width); + tsv_minimum_area_fine = tsv_area(tsv_pitch); + } + else + { + tsv_parasitic_resistance_coarse = tsv_resistance(BULK_CU_RESISTIVITY, tsv_length, tsv_diameter, tsv_contact_resistance); + tsv_parasitic_capacitance_coarse = tsv_capacitance(tsv_length, tsv_diameter, tsv_pitch, tsv_dielec_thickness, tsv_liner_dielectric_constant, tsv_depletion_width); + tsv_minimum_area_coarse = tsv_area(tsv_pitch); + } + } + + fclose(fp); + } +} + +void TechnologyParameter::init(double technology, bool is_tag) +{ + FILE *fp ; + reset(); + char line[5000]; + //char temp_var[5000]; + + uint32_t ram_cell_tech_type = (is_tag) ? g_ip->tag_arr_ram_cell_tech_type : g_ip->data_arr_ram_cell_tech_type; + uint32_t peri_global_tech_type = (is_tag) ? g_ip->tag_arr_peri_global_tech_type : g_ip->data_arr_peri_global_tech_type; + + int tech_lo, tech_hi; + string in_file_lo, in_file_hi; + + double alpha; // used for technology interpolation + + + + + technology = technology * 1000.0; // in the unit of nm + + find_upper_and_lower_tech(technology, tech_lo,in_file_lo,tech_hi,in_file_hi); + // excluding some cases. + if((tech_lo==22) && (tech_hi==22)) + { + if (ram_cell_tech_type == 3 ) + { + cout<<"current version does not support eDRAM technologies at 22nm"<print_detail_debug); + continue; + } + if (!strncmp("-dram_cell_Vdd", line, strlen("-dram_cell_Vdd"))) + { + dram_cell_Vdd += alpha* scan_five_input_double(line,"-dram_cell_Vdd","F/um", ram_cell_tech_type, g_ip->print_detail_debug); + continue; + } + if (!strncmp("-dram_cell_C", line, strlen("-dram_cell_C"))) + { + dram_cell_C += alpha* scan_five_input_double(line,"-dram_cell_C","F/um", ram_cell_tech_type, g_ip->print_detail_debug); + continue; + } + if (!strncmp("-dram_cell_I_off_worst_case_len_temp", line, strlen("-dram_cell_I_off_worst_case_len_temp"))) + { + dram_cell_I_off_worst_case_len_temp += alpha* scan_five_input_double(line,"-dram_cell_I_off_worst_case_len_temp","F/um", ram_cell_tech_type, g_ip->print_detail_debug); + continue; + } + if (!strncmp("-vpp", line, strlen("-vpp"))) + { + vpp += alpha* scan_five_input_double(line,"-vpp","F/um", ram_cell_tech_type, g_ip->print_detail_debug); + continue; + } + if (!strncmp("-sckt_co_eff", line, strlen("-sckt_co_eff"))) + { + sckt_co_eff += alpha * scan_single_input_double(line,"-sckt_co_eff","F/um", g_ip->print_detail_debug); + continue; + } + if (!strncmp("-chip_layout_overhead", line, strlen("-chip_layout_overhead"))) + { + chip_layout_overhead += alpha * scan_single_input_double(line,"-chip_layout_overhead","F/um", g_ip->print_detail_debug); + continue; + } + if (!strncmp("-macro_layout_overhead", line, strlen("-macro_layout_overhead"))) + { + macro_layout_overhead += alpha * scan_single_input_double(line,"-macro_layout_overhead","F/um", g_ip->print_detail_debug); + continue; + } + } + fclose(fp); + + + DeviceType peri_global_lo, peri_global_hi; + peri_global_lo.assign(in_file_lo, peri_global_tech_type, g_ip->temp); + peri_global_hi.assign(in_file_hi, peri_global_tech_type, g_ip->temp); + peri_global.interpolate(alpha,peri_global_lo,peri_global_hi); + // in the original code some field of this devide has not been initialized/ + // I make them 0 for compatibility. + ///peri_global.I_on_p = 0.0; + + DeviceType sleep_tx_lo, sleep_tx_hi; + sleep_tx_lo.assign(in_file_lo, 1, g_ip->temp); + sleep_tx_hi.assign(in_file_hi, 1, g_ip->temp); + sleep_tx.interpolate(alpha, sleep_tx_lo, sleep_tx_hi); + + + DeviceType sram_cell_lo, sram_cell_hi; + sram_cell_lo.assign(in_file_lo, ram_cell_tech_type, g_ip->temp); + sram_cell_hi.assign(in_file_hi, ram_cell_tech_type, g_ip->temp); + sram_cell.interpolate(alpha, sram_cell_lo, sram_cell_hi); + // in the original code some field of this devide has not been initialized/ + // I make them 0 for compatibility. + //sram_cell.Vdd=0.0; + ///sram_cell.I_on_p=0.0; + ///sram_cell.C_ox=0.0; + + + DeviceType dram_acc_lo, dram_acc_hi; + dram_acc_lo.assign(in_file_lo, (ram_cell_tech_type==comm_dram? ram_cell_tech_type:dram_cell_tech_flavor), g_ip->temp); + dram_acc_hi.assign(in_file_hi, (ram_cell_tech_type==comm_dram? ram_cell_tech_type:dram_cell_tech_flavor), g_ip->temp); + dram_acc.interpolate(alpha, dram_acc_lo, dram_acc_hi); + // dram_acc exceptions + //dram_acc.R_nch_on = g_tp.dram_cell_Vdd / g_tp.dram_acc.I_on_n; + //dram_acc.R_pch_on = 0; + if(tech_lo<=22) + { + } + else if(tech_lo<=32) + { + if(ram_cell_tech_type == lp_dram) + dram_acc.Vth = 0.44129; + else + dram_acc.Vth = 1.0; + } + else if(tech_lo<=45) + { + if(ram_cell_tech_type == lp_dram) + dram_acc.Vth = 0.44559; + else + dram_acc.Vth = 1.0; + } + else if(tech_lo<=65) + { + if(ram_cell_tech_type == lp_dram) + dram_acc.Vth = 0.43806; + else + dram_acc.Vth = 1.0; + } + else if(tech_lo<=90) + { + if(ram_cell_tech_type == lp_dram) + dram_acc.Vth = 0.4545; + else + dram_acc.Vth = 1.0; + } + // in the original code some field of this devide has not been initialized/ + // I make them 0 for compatibility. + dram_acc.Vdd= 0.0; + dram_acc.I_on_p = 0.0; + dram_acc.I_off_n = 0.0; + dram_acc.I_off_p = 0.0; + dram_acc.C_ox = 0.0; + dram_acc.t_ox = 0.0; + dram_acc.n_to_p_eff_curr_drv_ratio = 0.0; + + DeviceType dram_wl_lo, dram_wl_hi; + dram_wl_lo.assign(in_file_lo, (ram_cell_tech_type==comm_dram? ram_cell_tech_type:dram_cell_tech_flavor), g_ip->temp); + dram_wl_hi.assign(in_file_hi, (ram_cell_tech_type==comm_dram? ram_cell_tech_type:dram_cell_tech_flavor), g_ip->temp); + dram_wl.interpolate(alpha, dram_wl_lo, dram_wl_hi); + // in the original code some field of this devide has not been initialized/ + // I make them 0 for compatibility. + dram_wl.Vdd = 0.0; + dram_wl.Vth = 0.0; + dram_wl.I_on_p = 0.0; + dram_wl.C_ox = 0.0; + dram_wl.t_ox = 0.0; + + // if ram_cell_tech_type is not 3 or 4 ( which means edram and comm-dram) + // then reset dram_wl dram_acc + + if(ram_cell_tech_type <3) + { + dram_acc.reset(); + dram_wl.reset(); + } + + + DeviceType cam_cell_lo, cam_cell_hi; + cam_cell_lo.assign(in_file_lo, ram_cell_tech_type, g_ip->temp); + cam_cell_hi.assign(in_file_hi, ram_cell_tech_type, g_ip->temp); + cam_cell.interpolate(alpha, cam_cell_lo, cam_cell_hi); + + MemoryType dram_lo, dram_hi; + dram_lo.assign(in_file_lo, ram_cell_tech_type, 2); // cell_type = dram(2) + dram_hi.assign(in_file_hi, ram_cell_tech_type, 2); + dram.interpolate(alpha,dram_lo,dram_hi); + + MemoryType sram_lo, sram_hi; + sram_lo.assign(in_file_lo, ram_cell_tech_type, 0); // cell_type = sram(0) + sram_hi.assign(in_file_hi, ram_cell_tech_type, 0); + sram.interpolate(alpha,sram_lo,sram_hi); + // sram cell execptions + /*sram_lo.assign(in_file_lo, 0, g_ip->temp); + sram.cell_a_w =sram_lo.cell_a_w; + sram.b_h = sram_lo.b_h; + sram.b_w = sram_lo.b_w; +*/ + MemoryType cam_lo, cam_hi; + cam_lo.assign(in_file_lo, ram_cell_tech_type, 1); // cell_type = sram(0) + cam_hi.assign(in_file_hi, ram_cell_tech_type, 1); + cam.interpolate(alpha,cam_lo,cam_hi); + + + ScalingFactor scaling_factor_lo, scaling_factor_hi; + scaling_factor_lo.assign(in_file_lo); + scaling_factor_hi.assign(in_file_hi); + scaling_factor.interpolate(alpha, scaling_factor_lo,scaling_factor_hi); + + //vcc_min + peri_global.Vcc_min += (alpha * peri_global_lo.Vdd + (1-alpha)*peri_global_hi.Vdd) * 0.35; + sleep_tx.Vcc_min += (alpha*sleep_tx_lo.Vdd+(1-alpha)*sleep_tx_hi.Vdd); + sram_cell.Vcc_min += (alpha*sram_cell_lo.Vdd +(1-alpha)*sram_cell_hi.Vdd)* 0.65; + + + + fp = fopen(in_file_hi.c_str(), "r"); + + while(fscanf(fp, "%[^\n]\n", line) != EOF) + { + if (!strncmp("-sense_delay", line, strlen("-sense_delay"))) + { + sense_delay = scan_single_input_double(line,"-sense_delay","F/um", g_ip->print_detail_debug); + continue; + } + if (!strncmp("-sense_dy_power", line, strlen("-sense_dy_power"))) + { + sense_dy_power = scan_single_input_double(line,"-sense_dy_power","F/um", g_ip->print_detail_debug); + continue; + } + if (!strncmp("-sckt_co_eff", line, strlen("-sckt_co_eff"))) + { + sckt_co_eff += (1-alpha)* scan_single_input_double(line,"-sckt_co_eff","F/um", g_ip->print_detail_debug); + continue; + } + if (!strncmp("-chip_layout_overhead", line, strlen("-chip_layout_overhead"))) + { + chip_layout_overhead += (1-alpha)* scan_single_input_double(line,"-chip_layout_overhead","F/um", g_ip->print_detail_debug); + continue; + } + if (!strncmp("-macro_layout_overhead", line, strlen("-macro_layout_overhead"))) + { + macro_layout_overhead += (1-alpha)* scan_single_input_double(line,"-macro_layout_overhead","F/um", g_ip->print_detail_debug); + continue; + } + if (!strncmp("-dram_cell_I_on", line, strlen("-dram_cell_I_on"))) + { + dram_cell_I_on += (1-alpha) * scan_five_input_double(line,"-dram_cell_I_on","F/um", ram_cell_tech_type, g_ip->print_detail_debug); + continue; + } + if (!strncmp("-dram_cell_Vdd", line, strlen("-dram_cell_Vdd"))) + { + dram_cell_Vdd += (1-alpha) * scan_five_input_double(line,"-dram_cell_Vdd","F/um", ram_cell_tech_type, g_ip->print_detail_debug); + continue; + } + if (!strncmp("-dram_cell_C", line, strlen("-dram_cell_C"))) + { + dram_cell_C += (1-alpha) * scan_five_input_double(line,"-dram_cell_C","F/um", ram_cell_tech_type, g_ip->print_detail_debug); + continue; + } + if (!strncmp("-dram_cell_I_off_worst_case_len_temp", line, strlen("-dram_cell_I_off_worst_case_len_temp"))) + { + dram_cell_I_off_worst_case_len_temp += (1-alpha) * scan_five_input_double(line,"-dram_cell_I_off_worst_case_len_temp","F/um", ram_cell_tech_type, g_ip->print_detail_debug); + continue; + } + if (!strncmp("-vpp", line, strlen("-vpp"))) + { + vpp += (1-alpha)* scan_five_input_double(line,"-vpp","F/um", ram_cell_tech_type, g_ip->print_detail_debug); + continue; + } + } + fclose(fp); + + //Currently we are not modeling the resistance/capacitance of poly anywhere. + //Continuous function (or date have been processed) does not need linear interpolation + w_comp_inv_p1 = 12.5 * g_ip->F_sz_um;//this was 10 micron for the 0.8 micron process + w_comp_inv_n1 = 7.5 * g_ip->F_sz_um;//this was 6 micron for the 0.8 micron process + w_comp_inv_p2 = 25 * g_ip->F_sz_um;//this was 20 micron for the 0.8 micron process + w_comp_inv_n2 = 15 * g_ip->F_sz_um;//this was 12 micron for the 0.8 micron process + w_comp_inv_p3 = 50 * g_ip->F_sz_um;//this was 40 micron for the 0.8 micron process + w_comp_inv_n3 = 30 * g_ip->F_sz_um;//this was 24 micron for the 0.8 micron process + w_eval_inv_p = 100 * g_ip->F_sz_um;//this was 80 micron for the 0.8 micron process + w_eval_inv_n = 50 * g_ip->F_sz_um;//this was 40 micron for the 0.8 micron process + w_comp_n = 12.5 * g_ip->F_sz_um;//this was 10 micron for the 0.8 micron process + w_comp_p = 37.5 * g_ip->F_sz_um;//this was 30 micron for the 0.8 micron process + + MIN_GAP_BET_P_AND_N_DIFFS = 5 * g_ip->F_sz_um; + MIN_GAP_BET_SAME_TYPE_DIFFS = 1.5 * g_ip->F_sz_um; + HPOWERRAIL = 2 * g_ip->F_sz_um; + cell_h_def = 50 * g_ip->F_sz_um; + w_poly_contact = g_ip->F_sz_um; + spacing_poly_to_contact = g_ip->F_sz_um; + spacing_poly_to_poly = 1.5 * g_ip->F_sz_um; + ram_wl_stitching_overhead_ = 7.5 * g_ip->F_sz_um; + + min_w_nmos_ = 3 * g_ip->F_sz_um / 2; + max_w_nmos_ = 100 * g_ip->F_sz_um; + w_iso = 12.5*g_ip->F_sz_um;//was 10 micron for the 0.8 micron process + w_sense_n = 3.75*g_ip->F_sz_um; // sense amplifier N-trans; was 3 micron for the 0.8 micron process + w_sense_p = 7.5*g_ip->F_sz_um; // sense amplifier P-trans; was 6 micron for the 0.8 micron process + w_sense_en = 5*g_ip->F_sz_um; // Sense enable transistor of the sense amplifier; was 4 micron for the 0.8 micron process + w_nmos_b_mux = 6 * min_w_nmos_; + w_nmos_sa_mux = 6 * min_w_nmos_; + + + w_pmos_bl_precharge = 6 * pmos_to_nmos_sz_ratio() * min_w_nmos_; + w_pmos_bl_eq = pmos_to_nmos_sz_ratio() * min_w_nmos_; + + + if (ram_cell_tech_type == comm_dram) + { + max_w_nmos_dec = 8 * g_ip->F_sz_um; + h_dec = 8; // in the unit of memory cell height + } + else + { + max_w_nmos_dec = g_tp.max_w_nmos_; + h_dec = 4; // in the unit of memory cell height + } + + + + double gmn_sense_amp_latch + = (peri_global.Mobility_n / 2) * peri_global.C_ox + * (w_sense_n / peri_global.l_elec) * peri_global.Vdsat; + double gmp_sense_amp_latch = peri_global.gmp_to_gmn_multiplier * gmn_sense_amp_latch; + gm_sense_amp_latch = gmn_sense_amp_latch + gmp_sense_amp_latch; + + + ///cout << "wire_local " << g_ip->ic_proj_type << " " << ((ram_cell_tech_type == comm_dram)?3:0) << endl; + InterconnectType wire_local_lo, wire_local_hi; + wire_local_lo.assign(in_file_lo,g_ip->ic_proj_type,(ram_cell_tech_type == comm_dram)?3:0); + wire_local_hi.assign(in_file_hi,g_ip->ic_proj_type,(ram_cell_tech_type == comm_dram)?3:0); + wire_local.interpolate(alpha,wire_local_lo,wire_local_hi); + + + ///cout << "wire_inside_mat " << g_ip->ic_proj_type << " " << g_ip->wire_is_mat_type << endl; + InterconnectType wire_inside_mat_lo, wire_inside_mat_hi; + wire_inside_mat_lo.assign(in_file_lo, g_ip->ic_proj_type, g_ip->wire_is_mat_type); + wire_inside_mat_hi.assign(in_file_hi, g_ip->ic_proj_type, g_ip->wire_is_mat_type); + wire_inside_mat.interpolate(alpha, wire_inside_mat_lo, wire_inside_mat_hi); + + ///cout << "wire_outside_mat " << g_ip->ic_proj_type << " " << g_ip->wire_os_mat_type << endl; + InterconnectType wire_outside_mat_lo, wire_outside_mat_hi; + wire_outside_mat_lo.assign(in_file_lo, g_ip->ic_proj_type, g_ip->wire_os_mat_type); + wire_outside_mat_hi.assign(in_file_hi, g_ip->ic_proj_type, g_ip->wire_os_mat_type); + wire_outside_mat.interpolate(alpha, wire_outside_mat_lo, wire_outside_mat_hi); + + unit_len_wire_del = wire_inside_mat.R_per_um * wire_inside_mat.C_per_um / 2; + + // assign value for TSV parameters + + assign_tsv(in_file_hi); + + fringe_cap = wire_local_hi.fringe_cap; // fringe_cap is similar for all wire types. + + double rd = tr_R_on(min_w_nmos_, NCH, 1); + double p_to_n_sizing_r = pmos_to_nmos_sz_ratio(); + double c_load = gate_C(min_w_nmos_ * (1 + p_to_n_sizing_r), 0.0); + double tf = rd * c_load; + kinv = horowitz(0, tf, 0.5, 0.5, RISE); + double KLOAD = 1; + c_load = KLOAD * (drain_C_(min_w_nmos_, NCH, 1, 1, cell_h_def) + + drain_C_(min_w_nmos_ * p_to_n_sizing_r, PCH, 1, 1, cell_h_def) + + gate_C(min_w_nmos_ * 4 * (1 + p_to_n_sizing_r), 0.0)); + tf = rd * c_load; + FO4 = horowitz(0, tf, 0.5, 0.5, RISE); + +} + +#define PRINT(A,X) cout << A << ": " << X << " , " << tech.X << endl + +bool TechnologyParameter::isEqual(const TechnologyParameter& tech) +{ + if(!is_equal(ram_wl_stitching_overhead_,tech.ram_wl_stitching_overhead_)) {assert(false);} //fs + if(!is_equal(min_w_nmos_,tech.min_w_nmos_)) {assert(false);} //fs + if(!is_equal(max_w_nmos_,tech.max_w_nmos_)) {assert(false);} //fs + if(!is_equal(max_w_nmos_dec,tech.max_w_nmos_dec)) {assert(false);} //fs+ ram_cell_tech_type + if(!is_equal(unit_len_wire_del,tech.unit_len_wire_del)) {assert(false);} //wire_inside_mat + if(!is_equal(FO4,tech.FO4)) {assert(false);} //fs + if(!is_equal(kinv,tech.kinv)) {assert(false);} //fs + if(!is_equal(vpp,tech.vpp )) {assert(false);}//input + if(!is_equal(w_sense_en,tech.w_sense_en)) {assert(false);}//fs + if(!is_equal(w_sense_n,tech.w_sense_n)) {assert(false);} //fs + if(!is_equal(w_sense_p,tech.w_sense_p)) {assert(false);} //fs + if(!is_equal(sense_delay,tech.sense_delay)) {PRINT("sense_delay",sense_delay); assert(false);} // input + if(!is_equal(sense_dy_power,tech.sense_dy_power)) {assert(false);} //input + if(!is_equal(w_iso,tech.w_iso)) {assert(false);} //fs + if(!is_equal(w_poly_contact,tech.w_poly_contact)) {assert(false);} //fs + if(!is_equal(spacing_poly_to_poly,tech.spacing_poly_to_poly)) {assert(false);} //fs + if(!is_equal(spacing_poly_to_contact,tech.spacing_poly_to_contact)) {assert(false);}//fs + + //CACTI3D auxilary variables + ///if(!is_equal(tsv_pitch,tech.tsv_pitch)) {assert(false);} + ///if(!is_equal(tsv_diameter,tech.tsv_diameter)) {assert(false);} + ///if(!is_equal(tsv_length,tech.tsv_length)) {assert(false);} + ///if(!is_equal(tsv_dielec_thickness,tech.tsv_dielec_thickness)) {assert(false);} + ///if(!is_equal(tsv_contact_resistance,tech.tsv_contact_resistance)) {assert(false);} + ///if(!is_equal(tsv_depletion_width,tech.tsv_depletion_width)) {assert(false);} + ///if(!is_equal(tsv_liner_dielectric_constant,tech.tsv_liner_dielectric_constant)) {assert(false);} + + //CACTI3DD TSV params + + if(!is_equal(tsv_parasitic_capacitance_fine,tech.tsv_parasitic_capacitance_fine )) {PRINT("tsv_parasitic_capacitance_fine",tsv_parasitic_capacitance_fine); assert(false);} + if(!is_equal(tsv_parasitic_resistance_fine,tech.tsv_parasitic_resistance_fine)) {assert(false);} + if(!is_equal(tsv_minimum_area_fine,tech.tsv_minimum_area_fine)) {assert(false);} + + if(!is_equal(tsv_parasitic_capacitance_coarse,tech.tsv_parasitic_capacitance_coarse)) {assert(false);} + if(!is_equal(tsv_parasitic_resistance_coarse,tech.tsv_parasitic_resistance_coarse)) {assert(false);} + if(!is_equal(tsv_minimum_area_coarse,tech.tsv_minimum_area_coarse)) {assert(false);} + + //fs + if(!is_equal(w_comp_inv_p1,tech.w_comp_inv_p1)) {assert(false);} + if(!is_equal(w_comp_inv_p2,tech.w_comp_inv_p2)) {assert(false);} + if(!is_equal(w_comp_inv_p3,tech.w_comp_inv_p3)) {assert(false);} + if(!is_equal(w_comp_inv_n1,tech.w_comp_inv_n1)) {assert(false);} + if(!is_equal(w_comp_inv_n2,tech.w_comp_inv_n2)) {assert(false);} + if(!is_equal(w_comp_inv_n3,tech.w_comp_inv_n3)) {assert(false);} + if(!is_equal(w_eval_inv_p,tech.w_eval_inv_p)) {assert(false);} + if(!is_equal(w_eval_inv_n,tech.w_eval_inv_n)) {assert(false);} + if(!is_equal(w_comp_n,tech.w_comp_n)) {assert(false);} + if(!is_equal(w_comp_p,tech.w_comp_p)) {assert(false);} + + if(!is_equal(dram_cell_I_on,tech.dram_cell_I_on)) {assert(false);} //ram_cell_tech_type + if(!is_equal(dram_cell_Vdd,tech.dram_cell_Vdd)) {assert(false);} + if(!is_equal(dram_cell_I_off_worst_case_len_temp,tech.dram_cell_I_off_worst_case_len_temp)) {assert(false);} + if(!is_equal(dram_cell_C,tech.dram_cell_C)) {assert(false);} + if(!is_equal(gm_sense_amp_latch,tech.gm_sense_amp_latch)) {assert(false);} // depends on many things + + if(!is_equal(w_nmos_b_mux,tech.w_nmos_b_mux)) {assert(false);} //fs + if(!is_equal(w_nmos_sa_mux,tech.w_nmos_sa_mux)) {assert(false);}//fs + if(!is_equal(w_pmos_bl_precharge,tech.w_pmos_bl_precharge)) {PRINT("w_pmos_bl_precharge",w_pmos_bl_precharge);assert(false);}//fs + if(!is_equal(w_pmos_bl_eq,tech.w_pmos_bl_eq)) {assert(false);}//fs + if(!is_equal(MIN_GAP_BET_P_AND_N_DIFFS,tech.MIN_GAP_BET_P_AND_N_DIFFS)) {assert(false);}//fs + if(!is_equal(MIN_GAP_BET_SAME_TYPE_DIFFS,tech.MIN_GAP_BET_SAME_TYPE_DIFFS)) {assert(false);}//fs + if(!is_equal(HPOWERRAIL,tech.HPOWERRAIL)) {assert(false);}//fs + if(!is_equal(cell_h_def,tech.cell_h_def)) {assert(false);}//fs + + if(!is_equal(chip_layout_overhead,tech.chip_layout_overhead )) {assert(false);}//input + if(!is_equal(macro_layout_overhead,tech.macro_layout_overhead)) {cout <cache_sz / NUMBER_STACKED_DIE_LAYERS; // capacity per stacked die layer + + if (Ndwl != 1 || //Ndwl is fixed to 1 for CAM + Ndcm != 1 || //Ndcm is fixed to 1 for CAM + Nspd < 1 || Nspd > 1 || //Nspd is fixed to 1 for CAM + Ndsam_lev_1 != 1 || //Ndsam_lev_1 is fixed to one + Ndsam_lev_2 != 1 || //Ndsam_lev_2 is fixed to one + Ndbl < 2) //FIXME: why should Ndbl be >1 for very small CAMs? + { + return; + } + + + + if (g_ip->specific_tag) + { + tagbits = int(ceil(g_ip->tag_w/8.0)*8); + } + else + { + tagbits = int(ceil((ADDRESS_BITS + EXTRA_TAG_BITS)/8.0)*8); + } + + //computation of no. of rows and cols of a subarray + tag_num_r_subarray = (int)ceil(capacity_per_die / (g_ip->nbanks*tagbits/8.0 * Ndbl));//TODO: error check input of tagbits and blocksize //TODO: for pure CAM, g_ip->block should be number of entries. + tag_num_c_subarray = tagbits; + + if (tag_num_r_subarray == 0) return; + if (tag_num_r_subarray > MAXSUBARRAYROWS) return; + if (tag_num_c_subarray < MINSUBARRAYCOLS) return; + if (tag_num_c_subarray > MAXSUBARRAYCOLS) return; + num_r_subarray = tag_num_r_subarray; //FIXME: what about num_c_subarray? + + num_subarrays = Ndwl * Ndbl; + + // calculate cell dimensions + cam_cell.h = g_tp.cam.b_h + 2 * wire_local.pitch * (g_ip->num_rw_ports-1 + g_ip->num_rd_ports + g_ip->num_wr_ports) + + 2 * wire_local.pitch*(g_ip->num_search_ports-1) + wire_local.pitch * g_ip->num_se_rd_ports; + cam_cell.w = g_tp.cam.b_w + 2 * wire_local.pitch * (g_ip->num_rw_ports-1 + g_ip->num_rd_ports + g_ip->num_wr_ports) + + 2 * wire_local.pitch*(g_ip->num_search_ports-1) + wire_local.pitch * g_ip->num_se_rd_ports; + + //FIXME: curious where this is getting used in a CAM + cell.h = g_tp.sram.b_h + 2 * wire_local.pitch * (g_ip->num_wr_ports +g_ip->num_rw_ports-1 + g_ip->num_rd_ports) + + 2 * wire_local.pitch*(g_ip->num_search_ports-1); + cell.w = g_tp.sram.b_w + 2 * wire_local.pitch * (g_ip->num_rw_ports -1 + (g_ip->num_rd_ports - g_ip->num_se_rd_ports) + + g_ip->num_wr_ports) + g_tp.wire_local.pitch * g_ip->num_se_rd_ports + 2 * wire_local.pitch*(g_ip->num_search_ports-1); + + // calculate wire parameters + + double c_b_metal = cell.h * wire_local.C_per_um; +// double C_bl; + + c_b_metal = cam_cell.h * wire_local.C_per_um;//IBM and SUN design, SRAM array uses dummy cells to fill the blank space due to mismatch on CAM-RAM + V_b_sense = (0.05 * g_tp.sram_cell.Vdd > VBITSENSEMIN) ? 0.05 * g_tp.sram_cell.Vdd : VBITSENSEMIN; + deg_bl_muxing = 1;//FA fix as 1 + // "/ 2.0" below is due to the fact that two adjacent access transistors share drain + // contacts in a physical layout + double Cbitrow_drain_cap = drain_C_(g_tp.cam.cell_a_w, NCH, 1, 0, cam_cell.w, false, true) / 2.0;//TODO: comment out these two lines +// C_bl = num_r_subarray * (Cbitrow_drain_cap + c_b_metal); + dram_refresh_period = 0; + + + // do/di: data in/out, for fully associative they are the data width for normal read and write + // so/si: search data in/out, for fully associative they are the data width for the search ops + // for CAM, si=di, but so = matching address. do = data out = di (for normal read/write) + // so/si needs broadcase while do/di do not + + switch (Ndbl) { + case (0): + cout << " Invalid Ndbl \n"< num_mats_h_dir) + { + return; + } + + + num_di_b_mat = tagbits; + num_si_b_mat = tagbits;//*num_subarrays/num_mats; + + num_di_b_subbank = num_di_b_mat * num_act_mats_hor_dir;//normal cache or normal r/w for FA + num_si_b_subbank = num_si_b_mat; //* num_act_mats_hor_dir_sl; inside the data is broadcast + + int num_addr_b_row_dec = _log2(num_r_subarray); + num_addr_b_row_dec +=_log2(num_subarrays/num_mats); + int number_subbanks = num_mats / num_act_mats_hor_dir; + number_subbanks_decode = _log2(number_subbanks);//TODO: add log2(num_subarray_per_bank) to FA/CAM + + num_rw_ports = g_ip->num_rw_ports; + num_rd_ports = g_ip->num_rd_ports; + num_wr_ports = g_ip->num_wr_ports; + num_se_rd_ports = g_ip->num_se_rd_ports; + num_search_ports = g_ip->num_search_ports; + + number_addr_bits_mat = num_addr_b_row_dec + _log2(deg_bl_muxing) + + _log2(deg_sa_mux_l1_non_assoc) + _log2(Ndsam_lev_2); + + num_di_b_bank_per_port = tagbits; + num_si_b_bank_per_port = tagbits; + num_do_b_bank_per_port = tagbits; + num_so_b_bank_per_port = int(ceil(log2(num_r_subarray)) + ceil(log2(num_subarrays))); + + if ((!is_tag) && (g_ip->data_assoc > 1) && (!g_ip->fast_access)) + { + number_way_select_signals_mat = g_ip->data_assoc; + } + + // add ECC adjustment to all data signals that traverse on H-trees. + if (g_ip->add_ecc_b_ == true) + { + ECC_adjustment(); + } + + is_valid = true; +} + +void +DynamicParameter::init_FA() +{ + const InterconnectType &wire_local = g_tp.wire_local; + //Disabling 3D model since a 3D stacked FA is never tested + assert(NUMBER_STACKED_DIE_LAYERS == 1); + unsigned int capacity_per_die = g_ip->cache_sz; + + if (Ndwl != 1 || //Ndwl is fixed to 1 for FA + Ndcm != 1 || //Ndcm is fixed to 1 for FA + Nspd < 1 || Nspd > 1 || //Nspd is fixed to 1 for FA + Ndsam_lev_1 != 1 || //Ndsam_lev_1 is fixed to one + Ndsam_lev_2 != 1 || //Ndsam_lev_2 is fixed to one + Ndbl < 2) + { + return; + } + + + //***********compute row, col of an subarray + + //either fully-asso or cam + if (g_ip->specific_tag) + { + tagbits = g_ip->tag_w; + } + else + { + tagbits = ADDRESS_BITS + EXTRA_TAG_BITS - _log2(g_ip->block_sz); + } + tagbits = (((tagbits + 3) >> 2) << 2); + + tag_num_r_subarray = (int)(capacity_per_die / (g_ip->nbanks*g_ip->block_sz * Ndbl)); + tag_num_c_subarray = (int)ceil((tagbits * Nspd / Ndwl));// + EPSILON); + if (tag_num_r_subarray == 0) return; + if (tag_num_r_subarray > MAXSUBARRAYROWS) return; + if (tag_num_c_subarray < MINSUBARRAYCOLS) return; + if (tag_num_c_subarray > MAXSUBARRAYCOLS) return; + + data_num_r_subarray = tag_num_r_subarray; + data_num_c_subarray = 8 * g_ip->block_sz; + if (data_num_r_subarray == 0) return; + if (data_num_r_subarray > MAXSUBARRAYROWS) return; + if (data_num_c_subarray < MINSUBARRAYCOLS) return; + if (data_num_c_subarray > MAXSUBARRAYCOLS) return; + num_r_subarray = tag_num_r_subarray; + + num_subarrays = Ndwl * Ndbl; + //****************end of computation of row, col of an subarray + + // calculate wire parameters + cam_cell.h = g_tp.cam.b_h + 2 * wire_local.pitch * (g_ip->num_rw_ports-1 + g_ip->num_rd_ports + g_ip->num_wr_ports) + + 2 * wire_local.pitch*(g_ip->num_search_ports-1) + wire_local.pitch * g_ip->num_se_rd_ports; + cam_cell.w = g_tp.cam.b_w + 2 * wire_local.pitch * (g_ip->num_rw_ports-1 + g_ip->num_rd_ports + g_ip->num_wr_ports) + + 2 * wire_local.pitch*(g_ip->num_search_ports-1) + wire_local.pitch * g_ip->num_se_rd_ports; + + cell.h = g_tp.sram.b_h + 2 * wire_local.pitch * (g_ip->num_wr_ports +g_ip->num_rw_ports-1 + g_ip->num_rd_ports) + + 2 * wire_local.pitch*(g_ip->num_search_ports-1); + cell.w = g_tp.sram.b_w + 2 * wire_local.pitch * (g_ip->num_rw_ports -1 + (g_ip->num_rd_ports - g_ip->num_se_rd_ports) + + g_ip->num_wr_ports) + g_tp.wire_local.pitch * g_ip->num_se_rd_ports + 2 * wire_local.pitch*(g_ip->num_search_ports-1); + + double c_b_metal = cell.h * wire_local.C_per_um; + // double C_bl; + + c_b_metal = cam_cell.h * wire_local.C_per_um;//IBM and SUN design, SRAM array uses dummy cells to fill the blank space due to mismatch on CAM-RAM + V_b_sense = (0.05 * g_tp.sram_cell.Vdd > VBITSENSEMIN) ? 0.05 * g_tp.sram_cell.Vdd : VBITSENSEMIN; + deg_bl_muxing = 1;//FA fix as 1 + // "/ 2.0" below is due to the fact that two adjacent access transistors share drain + // contacts in a physical layout + double Cbitrow_drain_cap = drain_C_(g_tp.cam.cell_a_w, NCH, 1, 0, cam_cell.w, false, true) / 2.0;//TODO: comment out these two lines + // C_bl = num_r_subarray * (Cbitrow_drain_cap + c_b_metal); + dram_refresh_period = 0; + + + // do/di: data in/out, for fully associative they are the data width for normal read and write + // so/si: search data in/out, for fully associative they are the data width for the search ops + // for CAM, si=di, but so = matching address. do = data out = di (for normal read/write) + // so/si needs broadcase while do/di do not + + switch (Ndbl) { + case (0): + cout << " Invalid Ndbl \n"<block_sz;//TODO:internal perfetch should be considered also for fa + num_do_b_subbank = num_so_b_subbank + tag_num_c_subarray; + + deg_sa_mux_l1_non_assoc = 1; + + deg_senseamp_muxing_non_associativity = deg_sa_mux_l1_non_assoc; + + num_act_mats_hor_dir = 1; + num_act_mats_hor_dir_sl = num_mats_h_dir;//TODO: this is unnecessary, since search op, num_mats is used + + //compute num_do_mat for tag + if (num_act_mats_hor_dir > num_mats_h_dir) + { + return; + } + + + //compute di for mat subbank and bank + if (fully_assoc) + { + num_di_b_mat = num_do_b_mat; + //*num_subarrays/num_mats; bits per mat of CAM/FA is as same as cache, + //but inside the mat wire tracks need to be reserved for search data bus + num_si_b_mat = tagbits; + } + num_di_b_subbank = num_di_b_mat * num_act_mats_hor_dir;//normal cache or normal r/w for FA + num_si_b_subbank = num_si_b_mat; //* num_act_mats_hor_dir_sl; inside the data is broadcast + + int num_addr_b_row_dec = _log2(num_r_subarray); + num_addr_b_row_dec +=_log2(num_subarrays/num_mats); + int number_subbanks = num_mats / num_act_mats_hor_dir; + number_subbanks_decode = _log2(number_subbanks);//TODO: add log2(num_subarray_per_bank) to FA/CAM + + num_rw_ports = g_ip->num_rw_ports; + num_rd_ports = g_ip->num_rd_ports; + num_wr_ports = g_ip->num_wr_ports; + num_se_rd_ports = g_ip->num_se_rd_ports; + num_search_ports = g_ip->num_search_ports; + + number_addr_bits_mat = num_addr_b_row_dec + _log2(deg_bl_muxing) + + _log2(deg_sa_mux_l1_non_assoc) + _log2(Ndsam_lev_2); + + num_di_b_bank_per_port = g_ip->out_w + tagbits;//TODO: out_w or block_sz? + num_si_b_bank_per_port = tagbits; + num_do_b_bank_per_port = g_ip->out_w + tagbits; + num_so_b_bank_per_port = g_ip->out_w; + + if ((!is_tag) && (g_ip->data_assoc > 1) && (!g_ip->fast_access)) + { + number_way_select_signals_mat = g_ip->data_assoc; + } + + // add ECC adjustment to all data signals that traverse on H-trees. + if (g_ip->add_ecc_b_ == true) + { + ECC_adjustment(); + } + + is_valid = true; +} + +//DynamicParameter::init_Mem() +//{ +//} +// +//DynamicParameter::init_3DMem() +//{ +//} + +//*** Calculate number of rows and columns in a subarray +bool +DynamicParameter::calc_subarr_rc(unsigned int capacity_per_die) { + // If it's not an FA tag/data array, Ndwl should be at least two and Ndbl should be + // at least two because an array is assumed to have at least one mat. A mat + // consists of two rows and two columns of subarrays. + if (Ndwl < 2 || Ndbl < 2) + { + return false; + } + + if ((is_dram) && (!is_tag) && (Ndcm > 1)) + { + return false; // For a DRAM array, each bitline has its own sense-amp + } + + // if data array, let tagbits = 0 + if (is_tag) + { + if (g_ip->specific_tag) + { + tagbits = g_ip->tag_w; + } + else + { + tagbits = ADDRESS_BITS + EXTRA_TAG_BITS - _log2(capacity_per_die) + + _log2(g_ip->tag_assoc*2 - 1); + + } +// tagbits = (((tagbits + 3) >> 2) << 2); //FIXME: NAV: Why are we doing this? + + num_r_subarray = (int)ceil(capacity_per_die / (g_ip->nbanks * + g_ip->block_sz * g_ip->tag_assoc * Ndbl * Nspd)); + num_c_subarray = (int)ceil((tagbits * g_ip->tag_assoc * Nspd / Ndwl)); + } + else + { + num_r_subarray = (int)ceil(capacity_per_die / (g_ip->nbanks * + g_ip->block_sz * g_ip->data_assoc * Ndbl * Nspd)); + num_c_subarray = (int)ceil((8 * g_ip->block_sz * g_ip->data_assoc * Nspd / Ndwl)); + if(g_ip->is_3d_mem) + { + double capacity_per_die_double = (double)g_ip->cache_sz / g_ip->num_die_3d; + //num_c_subarray = 1 << (int)ceil((double)_log2( 8*capacity_per_die / (g_ip->nbanks * Ndbl * Ndwl) )/2 ) ; + //num_r_subarray = 1 << (int)ceil((double)_log2( 8*capacity_per_die / (g_ip->nbanks * Ndbl * Ndwl * num_c_subarray) ) ); + num_c_subarray = g_ip->page_sz_bits/Ndwl; + num_r_subarray = 1 << (int)floor(_log2((double) g_ip->cache_sz / g_ip->num_die_3d + / num_c_subarray / g_ip->nbanks / Ndbl / Ndwl * 1024 * 1024 * 1024) +0.5); + if (g_ip->print_detail_debug) + { + cout << "parameter.cc: capacity_per_die_double = " << capacity_per_die_double << " Gbit"<< endl; + cout << "parameter.cc: g_ip->nbanks * Ndbl * Ndwl = " << (g_ip->nbanks * Ndbl * Ndwl) << endl; + //cout << "parameter.cc: subarray capacity = " << 8*capacity_per_die / (g_ip->nbanks * Ndbl * Ndwl) << endl; + //cout << "parameter.cc: total bit add per subarray = " << _log2( 8*capacity_per_die / (g_ip->nbanks * Ndbl * Ndwl) ) << endl; + cout << "parameter.cc: num_r_subarray = " << num_r_subarray << endl; + cout << "parameter.cc: num_c_subarray = " << num_c_subarray << endl; + } + + } + } + + if (num_r_subarray < MINSUBARRAYROWS) return false; + if (num_r_subarray == 0) return false; + if (num_r_subarray > MAXSUBARRAYROWS) return false; + if (num_c_subarray < MINSUBARRAYCOLS) return false; + if (num_c_subarray > MAXSUBARRAYCOLS) return false; + + + + num_subarrays = Ndwl * Ndbl; + return true; +} + + + + + +DynamicParameter::DynamicParameter( + bool is_tag_, + int pure_ram_, + int pure_cam_, + double Nspd_, + unsigned int Ndwl_, + unsigned int Ndbl_, + unsigned int Ndcm_, + unsigned int Ndsam_lev_1_, + unsigned int Ndsam_lev_2_, + Wire_type wt, + bool is_main_mem_): + is_tag(is_tag_), pure_ram(pure_ram_), pure_cam(pure_cam_), tagbits(0), Nspd(Nspd_), Ndwl(Ndwl_), Ndbl(Ndbl_),Ndcm(Ndcm_), + Ndsam_lev_1(Ndsam_lev_1_), Ndsam_lev_2(Ndsam_lev_2_),wtype(wt), + number_way_select_signals_mat(0), V_b_sense(0), use_inp_params(0), + is_main_mem(is_main_mem_), cell(), is_valid(false) +{ + ram_cell_tech_type = (is_tag) ? g_ip->tag_arr_ram_cell_tech_type : g_ip->data_arr_ram_cell_tech_type; + is_dram = ((ram_cell_tech_type == lp_dram) || (ram_cell_tech_type == comm_dram)); + + unsigned int capacity_per_die = g_ip->cache_sz / NUMBER_STACKED_DIE_LAYERS; // capacity per stacked die layer + const InterconnectType & wire_local = g_tp.wire_local; + fully_assoc = (g_ip->fully_assoc) ? true : false; + + if (pure_cam) + { + init_CAM(); + return; + } + + if (fully_assoc) { + init_FA(); + return; + } + + //*** Calculate number of rows and columns in a subarray + // Return if their dimensions do not meet the minimum specs + if (!calc_subarr_rc(capacity_per_die)) return; + + //** Calculate cell dimensions + if(is_tag) + { + cell.h = g_tp.sram.b_h + 2 * wire_local.pitch * (g_ip->num_rw_ports - 1 + g_ip->num_rd_ports + + g_ip->num_wr_ports); + cell.w = g_tp.sram.b_w + 2 * wire_local.pitch * (g_ip->num_rw_ports - 1 + g_ip->num_wr_ports + + (g_ip->num_rd_ports - g_ip->num_se_rd_ports)) + + wire_local.pitch * g_ip->num_se_rd_ports; + } + else + { + if (is_dram) + { + cell.h = g_tp.dram.b_h; + cell.w = g_tp.dram.b_w; + } + else + { + cell.h = g_tp.sram.b_h + 2 * wire_local.pitch * (g_ip->num_wr_ports + + g_ip->num_rw_ports - 1 + g_ip->num_rd_ports); + cell.w = g_tp.sram.b_w + 2 * wire_local.pitch * (g_ip->num_rw_ports - 1 + + (g_ip->num_rd_ports - g_ip->num_se_rd_ports) + + g_ip->num_wr_ports) + g_tp.wire_local.pitch * g_ip->num_se_rd_ports; + } + } + + double c_b_metal = cell.h * wire_local.C_per_um; + double C_bl; + + if (is_dram) + { + deg_bl_muxing = 1; + if (ram_cell_tech_type == comm_dram) + { + double Cbitrow_drain_cap = drain_C_(g_tp.dram.cell_a_w, NCH, 1, 0, cell.w, true, true) / 2.0; + C_bl = num_r_subarray * (Cbitrow_drain_cap + c_b_metal); + //C_bl = num_r_subarray * c_b_metal; + V_b_sense = (g_tp.dram_cell_Vdd/2) * g_tp.dram_cell_C / (g_tp.dram_cell_C + C_bl); + if (V_b_sense < VBITSENSEMIN && !(g_ip->is_3d_mem && g_ip->force_cache_config) ) + { + return; + } + + dram_refresh_period = 64e-3; + + } + else + { + double Cbitrow_drain_cap = drain_C_(g_tp.dram.cell_a_w, NCH, 1, 0, cell.w, true, true) / 2.0; + C_bl = num_r_subarray * (Cbitrow_drain_cap + c_b_metal); + V_b_sense = (g_tp.dram_cell_Vdd/2) * g_tp.dram_cell_C /(g_tp.dram_cell_C + C_bl); + + if (V_b_sense < VBITSENSEMIN) + { + return; //Sense amp input signal is smaller that minimum allowable sense amp input signal + } + V_b_sense = VBITSENSEMIN; // in any case, we fix sense amp input signal to a constant value + //v_storage_worst = g_tp.dram_cell_Vdd / 2 - VBITSENSEMIN * (g_tp.dram_cell_C + C_bl) / g_tp.dram_cell_C; + //dram_refresh_period = 1.1 * g_tp.dram_cell_C * v_storage_worst / g_tp.dram_cell_I_off_worst_case_len_temp; + dram_refresh_period = 0.9 * g_tp.dram_cell_C * VDD_STORAGE_LOSS_FRACTION_WORST * g_tp.dram_cell_Vdd / g_tp.dram_cell_I_off_worst_case_len_temp; + } + } + else + { //SRAM + V_b_sense = (0.05 * g_tp.sram_cell.Vdd > VBITSENSEMIN) ? 0.05 * g_tp.sram_cell.Vdd : VBITSENSEMIN; + deg_bl_muxing = Ndcm; + // "/ 2.0" below is due to the fact that two adjacent access transistors share drain + // contacts in a physical layout + double Cbitrow_drain_cap = drain_C_(g_tp.sram.cell_a_w, NCH, 1, 0, cell.w, false, true) / 2.0; + C_bl = num_r_subarray * (Cbitrow_drain_cap + c_b_metal); + dram_refresh_period = 0; + } + + + // do/di: data in/out, for fully associative they are the data width for normal read and write + // so/si: search data in/out, for fully associative they are the data width for the search ops + // for CAM, si=di, but so = matching address. do = data out = di (for normal read/write) + // so/si needs broadcase while do/di do not + + num_mats_h_dir = MAX(Ndwl / 2, 1); + num_mats_v_dir = MAX(Ndbl / 2, 1); + num_mats = num_mats_h_dir * num_mats_v_dir; + num_do_b_mat = MAX((num_subarrays/num_mats) * num_c_subarray / (deg_bl_muxing * Ndsam_lev_1 * Ndsam_lev_2), 1); + + if (!(fully_assoc|| pure_cam) && (num_do_b_mat < (num_subarrays/num_mats))) + { + return; + } + + + int deg_sa_mux_l1_non_assoc; + //TODO:the i/o for subbank is not necessary and should be removed. + if (!is_tag) + { + if (is_main_mem == true) + { + num_do_b_subbank = g_ip->int_prefetch_w * g_ip->out_w; + //CACTI3DD DRAM page size + if(g_ip->is_3d_mem) + num_do_b_subbank = g_ip->page_sz_bits; + deg_sa_mux_l1_non_assoc = Ndsam_lev_1; + } + else + { + if (g_ip->fast_access == true) + { + num_do_b_subbank = g_ip->out_w * g_ip->data_assoc; + deg_sa_mux_l1_non_assoc = Ndsam_lev_1; + } + else + { + + num_do_b_subbank = g_ip->out_w; + deg_sa_mux_l1_non_assoc = Ndsam_lev_1 / g_ip->data_assoc; + if (deg_sa_mux_l1_non_assoc < 1) + { + return; + } + + } + } + } + else + { + num_do_b_subbank = tagbits * g_ip->tag_assoc; + if (num_do_b_mat < tagbits) + { + return; + } + deg_sa_mux_l1_non_assoc = Ndsam_lev_1; + //num_do_b_mat = g_ip->tag_assoc / num_mats_h_dir; + } + + deg_senseamp_muxing_non_associativity = deg_sa_mux_l1_non_assoc; + + num_act_mats_hor_dir = num_do_b_subbank / num_do_b_mat; + if (g_ip->is_3d_mem && num_act_mats_hor_dir == 0) + num_act_mats_hor_dir = 1; + if (num_act_mats_hor_dir == 0) + { + return; + } + + //compute num_do_mat for tag + if (is_tag) + { + if (!(fully_assoc || pure_cam)) + { + num_do_b_mat = g_ip->tag_assoc / num_act_mats_hor_dir; + num_do_b_subbank = num_act_mats_hor_dir * num_do_b_mat; + } + } + + if ((g_ip->is_cache == false && is_main_mem == true) || (PAGE_MODE == 1 && is_dram)) + { + if (num_act_mats_hor_dir * num_do_b_mat * Ndsam_lev_1 * Ndsam_lev_2 != (int)g_ip->page_sz_bits) + { + return; + } + } + +// if (is_tag == false && g_ip->is_cache == true && !fully_assoc && !pure_cam && //TODO: TODO burst transfer should also apply to RAM arrays + if (is_tag == false && g_ip->is_main_mem == true && + num_act_mats_hor_dir*num_do_b_mat*Ndsam_lev_1*Ndsam_lev_2 < ((int) g_ip->out_w * (int) g_ip->burst_len * (int) g_ip->data_assoc)) + { + return; + } + + if (num_act_mats_hor_dir > num_mats_h_dir) + { + return; + } + + + //compute di for mat subbank and bank + if(!is_tag) + { + if(g_ip->fast_access == true) + { + num_di_b_mat = num_do_b_mat / g_ip->data_assoc; + } + else + { + num_di_b_mat = num_do_b_mat; + } + } + else + { + num_di_b_mat = tagbits; + } + + num_di_b_subbank = num_di_b_mat * num_act_mats_hor_dir;//normal cache or normal r/w for FA + num_si_b_subbank = num_si_b_mat; //* num_act_mats_hor_dir_sl; inside the data is broadcast + + int num_addr_b_row_dec = _log2(num_r_subarray); + if ((fully_assoc ||pure_cam)) + num_addr_b_row_dec +=_log2(num_subarrays/num_mats); + int number_subbanks = num_mats / num_act_mats_hor_dir; + number_subbanks_decode = _log2(number_subbanks);//TODO: add log2(num_subarray_per_bank) to FA/CAM + + num_rw_ports = g_ip->num_rw_ports; + num_rd_ports = g_ip->num_rd_ports; + num_wr_ports = g_ip->num_wr_ports; + num_se_rd_ports = g_ip->num_se_rd_ports; + num_search_ports = g_ip->num_search_ports; + + if (is_dram && is_main_mem) + { + number_addr_bits_mat = MAX((unsigned int) num_addr_b_row_dec, + _log2(deg_bl_muxing) + _log2(deg_sa_mux_l1_non_assoc) + _log2(Ndsam_lev_2)); + if (g_ip->print_detail_debug) + { + cout << "parameter.cc: number_addr_bits_mat = " << num_addr_b_row_dec << endl; + cout << "parameter.cc: num_addr_b_row_dec = " << num_addr_b_row_dec << endl; + cout << "parameter.cc: num_addr_b_mux_sel = " << _log2(deg_bl_muxing) + _log2(deg_sa_mux_l1_non_assoc) + _log2(Ndsam_lev_2) << endl; + } + } + else + { + number_addr_bits_mat = num_addr_b_row_dec + _log2(deg_bl_muxing) + + _log2(deg_sa_mux_l1_non_assoc) + _log2(Ndsam_lev_2); + } + + if (is_tag) + { + num_di_b_bank_per_port = tagbits; + num_do_b_bank_per_port = g_ip->data_assoc; + } + else + { + num_di_b_bank_per_port = g_ip->out_w + g_ip->data_assoc; + num_do_b_bank_per_port = g_ip->out_w; + } + + if ((!is_tag) && (g_ip->data_assoc > 1) && (!g_ip->fast_access)) + { + number_way_select_signals_mat = g_ip->data_assoc; + } + + // add ECC adjustment to all data signals that traverse on H-trees. + if (g_ip->add_ecc_b_ == true) ECC_adjustment(); + + is_valid = true; +} + +void +DynamicParameter::ECC_adjustment() { + num_do_b_mat += (int) (ceil(num_do_b_mat / num_bits_per_ecc_b_)); + num_di_b_mat += (int) (ceil(num_di_b_mat / num_bits_per_ecc_b_)); + num_di_b_subbank += (int) (ceil(num_di_b_subbank / num_bits_per_ecc_b_)); + num_do_b_subbank += (int) (ceil(num_do_b_subbank / num_bits_per_ecc_b_)); + num_di_b_bank_per_port += (int) (ceil(num_di_b_bank_per_port / num_bits_per_ecc_b_)); + num_do_b_bank_per_port += (int) (ceil(num_do_b_bank_per_port / num_bits_per_ecc_b_)); + + num_so_b_mat += (int) (ceil(num_so_b_mat / num_bits_per_ecc_b_)); + num_si_b_mat += (int) (ceil(num_si_b_mat / num_bits_per_ecc_b_)); + num_si_b_subbank += (int) (ceil(num_si_b_subbank / num_bits_per_ecc_b_)); + num_so_b_subbank += (int) (ceil(num_so_b_subbank / num_bits_per_ecc_b_)); + num_si_b_bank_per_port += (int) (ceil(num_si_b_bank_per_port / num_bits_per_ecc_b_)); + num_so_b_bank_per_port += (int) (ceil(num_so_b_bank_per_port / num_bits_per_ecc_b_)); +} + +//DynamicParameter::DynamicParameter( +// bool is_tag_, +// int pure_ram_, +// int pure_cam_, +// double Nspd_, +// unsigned int Ndwl_, +// unsigned int Ndbl_, +// unsigned int Ndcm_, +// unsigned int Ndsam_lev_1_, +// unsigned int Ndsam_lev_2_, +// Wire_type wt, +// bool is_main_mem_): +// is_tag(is_tag_), pure_ram(pure_ram_), pure_cam(pure_cam_), tagbits(0), Nspd(Nspd_), Ndwl(Ndwl_), Ndbl(Ndbl_),Ndcm(Ndcm_), +// Ndsam_lev_1(Ndsam_lev_1_), Ndsam_lev_2(Ndsam_lev_2_),wtype(wt), +// number_way_select_signals_mat(0), V_b_sense(0), use_inp_params(0), +// is_main_mem(is_main_mem_), cell(), is_valid(false) +// ram_cell_tech_type = (is_tag) ? g_ip->tag_arr_ram_cell_tech_type : g_ip->data_arr_ram_cell_tech_type; +// is_dram = ((ram_cell_tech_type == lp_dram) || (ram_cell_tech_type == comm_dram)); +// +// unsigned int capacity_per_die = g_ip->cache_sz / NUMBER_STACKED_DIE_LAYERS; // capacity per stacked die layer +// const /*TechnologyParameter::*/InterconnectType & wire_local = g_tp.wire_local; +// fully_assoc = (g_ip->fully_assoc) ? true : false; +// +// if (fully_assoc || pure_cam) +// { // fully-assocative cache -- ref: CACTi 2.0 report +// if (Ndwl != 1 || //Ndwl is fixed to 1 for FA +// Ndcm != 1 || //Ndcm is fixed to 1 for FA +// Nspd < 1 || Nspd > 1 || //Nspd is fixed to 1 for FA +// Ndsam_lev_1 != 1 || //Ndsam_lev_1 is fixed to one +// Ndsam_lev_2 != 1 || //Ndsam_lev_2 is fixed to one +// Ndbl < 2) +// { +// return; +// } +// } +// +// if ((is_dram) && (!is_tag) && (Ndcm > 1)) +// { +// return; // For a DRAM array, each bitline has its own sense-amp +// } +// +// // If it's not an FA tag/data array, Ndwl should be at least two and Ndbl should be +// // at least two because an array is assumed to have at least one mat. And a mat +// // is formed out of two horizontal subarrays and two vertical subarrays +// if (fully_assoc == false && (Ndwl < 1 || Ndbl < 1)) +// { +// return; +// } +// +// //***********compute row, col of an subarray +// if (!(fully_assoc || pure_cam))//Not fully_asso nor cam +// { +// // if data array, let tagbits = 0 +// if (is_tag) +// { +// if (g_ip->specific_tag) +// { +// tagbits = g_ip->tag_w; +// } +// else +// { +// tagbits = ADDRESS_BITS + EXTRA_TAG_BITS - _log2(capacity_per_die) + +// _log2(g_ip->tag_assoc*2 - 1) - _log2(g_ip->nbanks); +// +// } +// tagbits = (((tagbits + 3) >> 2) << 2); +// +// num_r_subarray = (int)ceil(capacity_per_die / (g_ip->nbanks * +// g_ip->block_sz * g_ip->tag_assoc * Ndbl * Nspd));// + EPSILON); +// num_c_subarray = (int)ceil((tagbits * g_ip->tag_assoc * Nspd / Ndwl));// + EPSILON); +// //burst_length = 1; +// } +// else +// { +// num_r_subarray = (int)ceil(capacity_per_die / (g_ip->nbanks * +// g_ip->block_sz * g_ip->data_assoc * Ndbl * Nspd));// + EPSILON); +// num_c_subarray = (int)ceil((8 * g_ip->block_sz * g_ip->data_assoc * Nspd / Ndwl));// + EPSILON); + EPSILON); +// // burst_length = g_ip->block_sz * 8 / g_ip->out_w; +// if(g_ip->is_3d_mem) +// { +// double capacity_per_die_double = (double)g_ip->cache_sz / g_ip->num_die_3d; +// //num_c_subarray = 1 << (int)ceil((double)_log2( 8*capacity_per_die / (g_ip->nbanks * Ndbl * Ndwl) )/2 ) ; +// //num_r_subarray = 1 << (int)ceil((double)_log2( 8*capacity_per_die / (g_ip->nbanks * Ndbl * Ndwl * num_c_subarray) ) ); +// num_c_subarray = g_ip->page_sz_bits/Ndwl; +// num_r_subarray = 1 << (int)floor(_log2((double) g_ip->cache_sz / g_ip->num_die_3d +// / num_c_subarray / g_ip->nbanks / Ndbl / Ndwl * 1024 * 1024 * 1024) +0.5); +// if (g_ip->print_detail_debug) +// { +// cout << "parameter.cc: capacity_per_die_double = " << capacity_per_die_double << " Gbit"<< endl; +// cout << "parameter.cc: g_ip->nbanks * Ndbl * Ndwl = " << (g_ip->nbanks * Ndbl * Ndwl) << endl; +// //cout << "parameter.cc: subarray capacity = " << 8*capacity_per_die / (g_ip->nbanks * Ndbl * Ndwl) << endl; +// //cout << "parameter.cc: total bit add per subarray = " << _log2( 8*capacity_per_die / (g_ip->nbanks * Ndbl * Ndwl) ) << endl; +// cout << "parameter.cc: num_r_subarray = " << num_r_subarray << endl; +// cout << "parameter.cc: num_c_subarray = " << num_c_subarray << endl; +// } +// +// } +// } +// +// if (num_r_subarray < MINSUBARRAYROWS) return; +// if (num_r_subarray == 0) return; +// if (num_r_subarray > MAXSUBARRAYROWS) return; +// if (num_c_subarray < MINSUBARRAYCOLS) return; +// if (num_c_subarray > MAXSUBARRAYCOLS) return; +// +// } +// +// else +// {//either fully-asso or cam +// if (pure_cam) +// { +// if (g_ip->specific_tag) +// { +// tagbits = int(ceil(g_ip->tag_w/8.0)*8); +// } +// else +// { +// tagbits = int(ceil((ADDRESS_BITS + EXTRA_TAG_BITS)/8.0)*8); +//// cout<<"Pure CAM needs tag width to be specified"<> 2) << 2); +// +// tag_num_r_subarray = (int)ceil(capacity_per_die / (g_ip->nbanks*tagbits/8.0 * Ndbl));//TODO: error check input of tagbits and blocksize //TODO: for pure CAM, g_ip->block should be number of entries. +// //tag_num_c_subarray = (int)(tagbits + EPSILON); +// tag_num_c_subarray = tagbits; +// if (tag_num_r_subarray == 0) return; +// if (tag_num_r_subarray > MAXSUBARRAYROWS) return; +// if (tag_num_c_subarray < MINSUBARRAYCOLS) return; +// if (tag_num_c_subarray > MAXSUBARRAYCOLS) return; +// num_r_subarray = tag_num_r_subarray; +// } +// else //fully associative +// { +// if (g_ip->specific_tag) +// { +// tagbits = g_ip->tag_w; +// } +// else +// { +// tagbits = ADDRESS_BITS + EXTRA_TAG_BITS - _log2(g_ip->block_sz);//TODO: should be the page_offset=log2(page size), but this info is not avail with CACTI, for McPAT this is no problem. +// } +// tagbits = (((tagbits + 3) >> 2) << 2); +// +// tag_num_r_subarray = (int)(capacity_per_die / (g_ip->nbanks*g_ip->block_sz * Ndbl)); +// tag_num_c_subarray = (int)ceil((tagbits * Nspd / Ndwl));// + EPSILON); +// if (tag_num_r_subarray == 0) return; +// if (tag_num_r_subarray > MAXSUBARRAYROWS) return; +// if (tag_num_c_subarray < MINSUBARRAYCOLS) return; +// if (tag_num_c_subarray > MAXSUBARRAYCOLS) return; +// +// data_num_r_subarray = tag_num_r_subarray; +// data_num_c_subarray = 8 * g_ip->block_sz; +// if (data_num_r_subarray == 0) return; +// if (data_num_r_subarray > MAXSUBARRAYROWS) return; +// if (data_num_c_subarray < MINSUBARRAYCOLS) return; +// if (data_num_c_subarray > MAXSUBARRAYCOLS) return; +// num_r_subarray = tag_num_r_subarray; +// } +// } +// +// num_subarrays = Ndwl * Ndbl; +// //****************end of computation of row, col of an subarray +// +// // calculate wire parameters +// if (fully_assoc || pure_cam) +// { +// cam_cell.h = g_tp.cam.b_h + 2 * wire_local.pitch * (g_ip->num_rw_ports-1 + g_ip->num_rd_ports + g_ip->num_wr_ports) +// + 2 * wire_local.pitch*(g_ip->num_search_ports-1) + wire_local.pitch * g_ip->num_se_rd_ports; +// cam_cell.w = g_tp.cam.b_w + 2 * wire_local.pitch * (g_ip->num_rw_ports-1 + g_ip->num_rd_ports + g_ip->num_wr_ports) +// + 2 * wire_local.pitch*(g_ip->num_search_ports-1) + wire_local.pitch * g_ip->num_se_rd_ports; +// +// cell.h = g_tp.sram.b_h + 2 * wire_local.pitch * (g_ip->num_wr_ports +g_ip->num_rw_ports-1 + g_ip->num_rd_ports) +// + 2 * wire_local.pitch*(g_ip->num_search_ports-1); +// cell.w = g_tp.sram.b_w + 2 * wire_local.pitch * (g_ip->num_rw_ports -1 + (g_ip->num_rd_ports - g_ip->num_se_rd_ports) +// + g_ip->num_wr_ports) + g_tp.wire_local.pitch * g_ip->num_se_rd_ports + 2 * wire_local.pitch*(g_ip->num_search_ports-1); +// } +// else +// { +// if(is_tag) +// { +// cell.h = g_tp.sram.b_h + 2 * wire_local.pitch * (g_ip->num_rw_ports - 1 + g_ip->num_rd_ports + +// g_ip->num_wr_ports); +// cell.w = g_tp.sram.b_w + 2 * wire_local.pitch * (g_ip->num_rw_ports - 1 + g_ip->num_wr_ports + +// (g_ip->num_rd_ports - g_ip->num_se_rd_ports)) + +// wire_local.pitch * g_ip->num_se_rd_ports; +// } +// else +// { +// if (is_dram) +// { +// cell.h = g_tp.dram.b_h; +// cell.w = g_tp.dram.b_w; +// } +// else +// { +// cell.h = g_tp.sram.b_h + 2 * wire_local.pitch * (g_ip->num_wr_ports + +// g_ip->num_rw_ports - 1 + g_ip->num_rd_ports); +// cell.w = g_tp.sram.b_w + 2 * wire_local.pitch * (g_ip->num_rw_ports - 1 + +// (g_ip->num_rd_ports - g_ip->num_se_rd_ports) + +// g_ip->num_wr_ports) + g_tp.wire_local.pitch * g_ip->num_se_rd_ports; +// } +// } +// } +// +// double c_b_metal = cell.h * wire_local.C_per_um; +// double C_bl; +// +// if (!(fully_assoc || pure_cam)) +// { +// if (is_dram) +// { +// deg_bl_muxing = 1; +// if (ram_cell_tech_type == comm_dram) +// { +// double Cbitrow_drain_cap = drain_C_(g_tp.dram.cell_a_w, NCH, 1, 0, cell.w, true, true) / 2.0; +// C_bl = num_r_subarray * (Cbitrow_drain_cap + c_b_metal); +// //C_bl = num_r_subarray * c_b_metal; +// V_b_sense = (g_tp.dram_cell_Vdd/2) * g_tp.dram_cell_C / (g_tp.dram_cell_C + C_bl); +// if (V_b_sense < VBITSENSEMIN && !(g_ip->is_3d_mem && g_ip->force_cache_config) ) +// { +// return; +// } +// +// /* +// C_bl = num_r_subarray * c_b_metal; +// V_b_sense = (g_tp.dram_cell_Vdd/2) * g_tp.dram_cell_C / (g_tp.dram_cell_C + C_bl); +// if (V_b_sense < VBITSENSEMIN) +// { +// return; +// } +// V_b_sense = VBITSENSEMIN; // in any case, we fix sense amp input signal to a constant value +// */ +// dram_refresh_period = 64e-3; +// +// } +// else +// { +// double Cbitrow_drain_cap = drain_C_(g_tp.dram.cell_a_w, NCH, 1, 0, cell.w, true, true) / 2.0; +// C_bl = num_r_subarray * (Cbitrow_drain_cap + c_b_metal); +// V_b_sense = (g_tp.dram_cell_Vdd/2) * g_tp.dram_cell_C /(g_tp.dram_cell_C + C_bl); +// +// if (V_b_sense < VBITSENSEMIN) +// { +// return; //Sense amp input signal is smaller that minimum allowable sense amp input signal +// } +// V_b_sense = VBITSENSEMIN; // in any case, we fix sense amp input signal to a constant value +// //v_storage_worst = g_tp.dram_cell_Vdd / 2 - VBITSENSEMIN * (g_tp.dram_cell_C + C_bl) / g_tp.dram_cell_C; +// //dram_refresh_period = 1.1 * g_tp.dram_cell_C * v_storage_worst / g_tp.dram_cell_I_off_worst_case_len_temp; +// dram_refresh_period = 0.9 * g_tp.dram_cell_C * VDD_STORAGE_LOSS_FRACTION_WORST * g_tp.dram_cell_Vdd / g_tp.dram_cell_I_off_worst_case_len_temp; +// } +// } +// else +// { //SRAM +// V_b_sense = (0.05 * g_tp.sram_cell.Vdd > VBITSENSEMIN) ? 0.05 * g_tp.sram_cell.Vdd : VBITSENSEMIN; +// deg_bl_muxing = Ndcm; +// // "/ 2.0" below is due to the fact that two adjacent access transistors share drain +// // contacts in a physical layout +// double Cbitrow_drain_cap = drain_C_(g_tp.sram.cell_a_w, NCH, 1, 0, cell.w, false, true) / 2.0; +// C_bl = num_r_subarray * (Cbitrow_drain_cap + c_b_metal); +// dram_refresh_period = 0; +// } +// } +// else +// { +// c_b_metal = cam_cell.h * wire_local.C_per_um;//IBM and SUN design, SRAM array uses dummy cells to fill the blank space due to mismatch on CAM-RAM +// V_b_sense = (0.05 * g_tp.sram_cell.Vdd > VBITSENSEMIN) ? 0.05 * g_tp.sram_cell.Vdd : VBITSENSEMIN; +// deg_bl_muxing = 1;//FA fix as 1 +// // "/ 2.0" below is due to the fact that two adjacent access transistors share drain +// // contacts in a physical layout +// double Cbitrow_drain_cap = drain_C_(g_tp.cam.cell_a_w, NCH, 1, 0, cam_cell.w, false, true) / 2.0;//TODO: comment out these two lines +// C_bl = num_r_subarray * (Cbitrow_drain_cap + c_b_metal); +// dram_refresh_period = 0; +// } +// +// +// // do/di: data in/out, for fully associative they are the data width for normal read and write +// // so/si: search data in/out, for fully associative they are the data width for the search ops +// // for CAM, si=di, but so = matching address. do = data out = di (for normal read/write) +// // so/si needs broadcase while do/di do not +// +// if (fully_assoc || pure_cam) +// { +// switch (Ndbl) { +// case (0): +// cout << " Invalid Ndbl \n"<int_prefetch_w * g_ip->out_w; +// //CACTI3DD DRAM page size +// if(g_ip->is_3d_mem) +// num_do_b_subbank = g_ip->page_sz_bits; +// deg_sa_mux_l1_non_assoc = Ndsam_lev_1; +// } +// else +// { +// if (g_ip->fast_access == true) +// { +// num_do_b_subbank = g_ip->out_w * g_ip->data_assoc; +// deg_sa_mux_l1_non_assoc = Ndsam_lev_1; +// } +// else +// { +// +// num_do_b_subbank = g_ip->out_w; +// deg_sa_mux_l1_non_assoc = Ndsam_lev_1 / g_ip->data_assoc; +// if (deg_sa_mux_l1_non_assoc < 1) +// { +// return; +// } +// +// } +// } +// } +// else +// { +// num_do_b_subbank = tagbits * g_ip->tag_assoc; +// if (num_do_b_mat < tagbits) +// { +// return; +// } +// deg_sa_mux_l1_non_assoc = Ndsam_lev_1; +// //num_do_b_mat = g_ip->tag_assoc / num_mats_h_dir; +// } +// } +// else +// { +// if (fully_assoc) +// { +// num_so_b_subbank = 8 * g_ip->block_sz;//TODO:internal perfetch should be considered also for fa +// num_do_b_subbank = num_so_b_subbank + tag_num_c_subarray; +// } +// else +// { +// num_so_b_subbank = int(ceil(log2(num_r_subarray)) + ceil(log2(num_subarrays)));//the address contains the matched data +// num_do_b_subbank = tag_num_c_subarray; +// } +// +// deg_sa_mux_l1_non_assoc = 1; +// } +// +// deg_senseamp_muxing_non_associativity = deg_sa_mux_l1_non_assoc; +// +// if (fully_assoc || pure_cam) +// { +// num_act_mats_hor_dir = 1; +// num_act_mats_hor_dir_sl = num_mats_h_dir;//TODO: this is unnecessary, since search op, num_mats is used +// } +// else +// { +// num_act_mats_hor_dir = num_do_b_subbank / num_do_b_mat; +// if (g_ip->is_3d_mem && num_act_mats_hor_dir == 0) +// num_act_mats_hor_dir = 1; +// if (num_act_mats_hor_dir == 0) +// { +// return; +// } +// } +// +// //compute num_do_mat for tag +// if (is_tag) +// { +// if (!(fully_assoc || pure_cam)) +// { +// num_do_b_mat = g_ip->tag_assoc / num_act_mats_hor_dir; +// num_do_b_subbank = num_act_mats_hor_dir * num_do_b_mat; +// } +// } +// +// if ((g_ip->is_cache == false && is_main_mem == true) || (PAGE_MODE == 1 && is_dram)) +// { +// if (num_act_mats_hor_dir * num_do_b_mat * Ndsam_lev_1 * Ndsam_lev_2 != (int)g_ip->page_sz_bits) +// { +// return; +// } +// } +// +//// if (is_tag == false && g_ip->is_cache == true && !fully_assoc && !pure_cam && //TODO: TODO burst transfer should also apply to RAM arrays +// if (is_tag == false && g_ip->is_main_mem == true && +// num_act_mats_hor_dir*num_do_b_mat*Ndsam_lev_1*Ndsam_lev_2 < ((int) g_ip->out_w * (int) g_ip->burst_len * (int) g_ip->data_assoc)) +// { +// return; +// } +// +// if (num_act_mats_hor_dir > num_mats_h_dir) +// { +// return; +// } +// +// +// //compute di for mat subbank and bank +// if (!(fully_assoc ||pure_cam)) +// { +// if(!is_tag) +// { +// if(g_ip->fast_access == true) +// { +// num_di_b_mat = num_do_b_mat / g_ip->data_assoc; +// } +// else +// { +// num_di_b_mat = num_do_b_mat; +// } +// } +// else +// { +// num_di_b_mat = tagbits; +// } +// } +// else +// { +// if (fully_assoc) +// { +// num_di_b_mat = num_do_b_mat; +// //*num_subarrays/num_mats; bits per mat of CAM/FA is as same as cache, +// //but inside the mat wire tracks need to be reserved for search data bus +// num_si_b_mat = tagbits; +// } +// else +// { +// num_di_b_mat = tagbits; +// num_si_b_mat = tagbits;//*num_subarrays/num_mats; +// } +// +// } +// +// num_di_b_subbank = num_di_b_mat * num_act_mats_hor_dir;//normal cache or normal r/w for FA +// num_si_b_subbank = num_si_b_mat; //* num_act_mats_hor_dir_sl; inside the data is broadcast +// +// int num_addr_b_row_dec = _log2(num_r_subarray); +// if ((fully_assoc ||pure_cam)) +// num_addr_b_row_dec +=_log2(num_subarrays/num_mats); +// int number_subbanks = num_mats / num_act_mats_hor_dir; +// number_subbanks_decode = _log2(number_subbanks);//TODO: add log2(num_subarray_per_bank) to FA/CAM +// +// num_rw_ports = g_ip->num_rw_ports; +// num_rd_ports = g_ip->num_rd_ports; +// num_wr_ports = g_ip->num_wr_ports; +// num_se_rd_ports = g_ip->num_se_rd_ports; +// num_search_ports = g_ip->num_search_ports; +// +// if (is_dram && is_main_mem) +// { +// number_addr_bits_mat = MAX((unsigned int) num_addr_b_row_dec, +// _log2(deg_bl_muxing) + _log2(deg_sa_mux_l1_non_assoc) + _log2(Ndsam_lev_2)); +// if (g_ip->print_detail_debug) +// { +// cout << "parameter.cc: number_addr_bits_mat = " << num_addr_b_row_dec << endl; +// cout << "parameter.cc: num_addr_b_row_dec = " << num_addr_b_row_dec << endl; +// cout << "parameter.cc: num_addr_b_mux_sel = " << _log2(deg_bl_muxing) + _log2(deg_sa_mux_l1_non_assoc) + _log2(Ndsam_lev_2) << endl; +// } +// } +// else +// { +// number_addr_bits_mat = num_addr_b_row_dec + _log2(deg_bl_muxing) + +// _log2(deg_sa_mux_l1_non_assoc) + _log2(Ndsam_lev_2); +// } +// +// if (!(fully_assoc ||pure_cam)) +// { +// if (is_tag) +// { +// num_di_b_bank_per_port = tagbits; +// num_do_b_bank_per_port = g_ip->data_assoc; +// } +// else +// { +// num_di_b_bank_per_port = g_ip->out_w + g_ip->data_assoc; +// num_do_b_bank_per_port = g_ip->out_w; +// } +// } +// else +// { +// if (fully_assoc) +// { +// num_di_b_bank_per_port = g_ip->out_w + tagbits;//TODO: out_w or block_sz? +// num_si_b_bank_per_port = tagbits; +// num_do_b_bank_per_port = g_ip->out_w + tagbits; +// num_so_b_bank_per_port = g_ip->out_w; +// } +// else +// { +// num_di_b_bank_per_port = tagbits; +// num_si_b_bank_per_port = tagbits; +// num_do_b_bank_per_port = tagbits; +// num_so_b_bank_per_port = int(ceil(log2(num_r_subarray)) + ceil(log2(num_subarrays))); +// } +// } +// +// if ((!is_tag) && (g_ip->data_assoc > 1) && (!g_ip->fast_access)) +// { +// number_way_select_signals_mat = g_ip->data_assoc; +// } +// +// // add ECC adjustment to all data signals that traverse on H-trees. +// if (g_ip->add_ecc_b_ == true) +// { +// num_do_b_mat += (int) (ceil(num_do_b_mat / num_bits_per_ecc_b_)); +// num_di_b_mat += (int) (ceil(num_di_b_mat / num_bits_per_ecc_b_)); +// num_di_b_subbank += (int) (ceil(num_di_b_subbank / num_bits_per_ecc_b_)); +// num_do_b_subbank += (int) (ceil(num_do_b_subbank / num_bits_per_ecc_b_)); +// num_di_b_bank_per_port += (int) (ceil(num_di_b_bank_per_port / num_bits_per_ecc_b_)); +// num_do_b_bank_per_port += (int) (ceil(num_do_b_bank_per_port / num_bits_per_ecc_b_)); +// +// num_so_b_mat += (int) (ceil(num_so_b_mat / num_bits_per_ecc_b_)); +// num_si_b_mat += (int) (ceil(num_si_b_mat / num_bits_per_ecc_b_)); +// num_si_b_subbank += (int) (ceil(num_si_b_subbank / num_bits_per_ecc_b_)); +// num_so_b_subbank += (int) (ceil(num_so_b_subbank / num_bits_per_ecc_b_)); +// num_si_b_bank_per_port += (int) (ceil(num_si_b_bank_per_port / num_bits_per_ecc_b_)); +// num_so_b_bank_per_port += (int) (ceil(num_so_b_bank_per_port / num_bits_per_ecc_b_)); +// } +// +// is_valid = true; +//} diff --git a/Project_FARSI/cacti_for_FARSI/parameter.h b/Project_FARSI/cacti_for_FARSI/parameter.h new file mode 100644 index 00000000..2cbd49b0 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/parameter.h @@ -0,0 +1,779 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + + +#ifndef __PARAMETER_H__ +#define __PARAMETER_H__ + +#include "area.h" +#include "const.h" +#include "cacti_interface.h" +#include "io.h" + +// parameters which are functions of certain device technology +/** +class TechnologyParameter +{ + public: + class DeviceType + { + public: + double C_g_ideal; + double C_fringe; + double C_overlap; + double C_junc; // C_junc_area + double C_junc_sidewall; + double l_phy; + double l_elec; + double R_nch_on; + double R_pch_on; + double Vdd; + double Vth; + double Vcc_min;//allowed min vcc; for memory cell it is the lowest vcc for data retention. for logic it is the vcc to balance the leakage reduction and wakeup latency + double I_on_n; + double I_on_p; + double I_off_n; + double I_off_p; + double I_g_on_n; + double I_g_on_p; + double C_ox; + double t_ox; + double n_to_p_eff_curr_drv_ratio; + double long_channel_leakage_reduction; + double Mobility_n; + + DeviceType(): C_g_ideal(0), C_fringe(0), C_overlap(0), C_junc(0), + C_junc_sidewall(0), l_phy(0), l_elec(0), R_nch_on(0), R_pch_on(0), + Vdd(0), Vth(0), Vcc_min(0), + I_on_n(0), I_on_p(0), I_off_n(0), I_off_p(0),I_g_on_n(0),I_g_on_p(0), + C_ox(0), t_ox(0), n_to_p_eff_curr_drv_ratio(0), long_channel_leakage_reduction(0), + Mobility_n(0) { }; + void reset() + { + C_g_ideal = 0; + C_fringe = 0; + C_overlap = 0; + C_junc = 0; + l_phy = 0; + l_elec = 0; + R_nch_on = 0; + R_pch_on = 0; + Vdd = 0; + Vth = 0; + Vcc_min = 0; + I_on_n = 0; + I_on_p = 0; + I_off_n = 0; + I_off_p = 0; + I_g_on_n = 0; + I_g_on_p = 0; + C_ox = 0; + t_ox = 0; + n_to_p_eff_curr_drv_ratio = 0; + long_channel_leakage_reduction = 0; + Mobility_n = 0; + } + + void display(uint32_t indent = 0); + }; + class InterconnectType + { + public: + double pitch; + double R_per_um; + double C_per_um; + double horiz_dielectric_constant; + double vert_dielectric_constant; + double aspect_ratio; + double miller_value; + double ild_thickness; + + InterconnectType(): pitch(0), R_per_um(0), C_per_um(0) { }; + + void reset() + { + pitch = 0; + R_per_um = 0; + C_per_um = 0; + horiz_dielectric_constant = 0; + vert_dielectric_constant = 0; + aspect_ratio = 0; + miller_value = 0; + ild_thickness = 0; + } + + void display(uint32_t indent = 0); + }; + class MemoryType + { + public: + double b_w; + double b_h; + double cell_a_w; + double cell_pmos_w; + double cell_nmos_w; + double Vbitpre; + double Vbitfloating;//voltage when floating bitline is supported + + void reset() + { + b_w = 0; //fs and tech + b_h = 0; //fs and tech + cell_a_w = 0; // ram_cell_tech_type + cell_pmos_w = 0; //fs + cell_nmos_w = 0; + Vbitpre = 0; + Vbitfloating = 0; + } + + void display(uint32_t indent = 0); + }; + + class ScalingFactor + { + public: + double logic_scaling_co_eff; + double core_tx_density; + double long_channel_leakage_reduction; + + ScalingFactor(): logic_scaling_co_eff(0), core_tx_density(0), + long_channel_leakage_reduction(0) { }; + + void reset() + { + logic_scaling_co_eff= 0; + core_tx_density = 0; + long_channel_leakage_reduction= 0; + } + + void display(uint32_t indent = 0); + }; + + double ram_wl_stitching_overhead_; //fs + double min_w_nmos_; //fs + double max_w_nmos_; //fs + double max_w_nmos_dec; //fs+ ram_cell_tech_type + double unit_len_wire_del; //wire_inside_mat + double FO4; //fs + double kinv; //fs + double vpp; //input + double w_sense_en;//fs + double w_sense_n; //fs + double w_sense_p; //fs + double sense_delay; // input + double sense_dy_power; //input + double w_iso; //fs + double w_poly_contact; //fs + double spacing_poly_to_poly; //fs + double spacing_poly_to_contact;//fs + + //CACTI3DD TSV params + double tsv_parasitic_capacitance_fine; + double tsv_parasitic_resistance_fine; + double tsv_minimum_area_fine; + + double tsv_parasitic_capacitance_coarse; + double tsv_parasitic_resistance_coarse; + double tsv_minimum_area_coarse; + + //fs + double w_comp_inv_p1; + double w_comp_inv_p2; + double w_comp_inv_p3; + double w_comp_inv_n1; + double w_comp_inv_n2; + double w_comp_inv_n3; + double w_eval_inv_p; + double w_eval_inv_n; + double w_comp_n; + double w_comp_p; + + double dram_cell_I_on; //ram_cell_tech_type + double dram_cell_Vdd; + double dram_cell_I_off_worst_case_len_temp; + double dram_cell_C; + double gm_sense_amp_latch; // depends on many things + + double w_nmos_b_mux;//fs + double w_nmos_sa_mux;//fs + double w_pmos_bl_precharge;//fs + double w_pmos_bl_eq;//fs + double MIN_GAP_BET_P_AND_N_DIFFS;//fs + double MIN_GAP_BET_SAME_TYPE_DIFFS;//fs + double HPOWERRAIL;//fs + double cell_h_def;//fs + + double chip_layout_overhead; //input + double macro_layout_overhead; + double sckt_co_eff; + + double fringe_cap;//input + + uint64_t h_dec; //ram_cell_tech_type + + DeviceType sram_cell; // SRAM cell transistor + DeviceType dram_acc; // DRAM access transistor + DeviceType dram_wl; // DRAM wordline transistor + DeviceType peri_global; // peripheral global + DeviceType cam_cell; // SRAM cell transistor + + DeviceType sleep_tx; // Sleep transistor cell transistor + + InterconnectType wire_local; + InterconnectType wire_inside_mat; + InterconnectType wire_outside_mat; + + ScalingFactor scaling_factor; + + MemoryType sram; + MemoryType dram; + MemoryType cam; + + void display(uint32_t indent = 0); + + void reset() + { + dram_cell_Vdd = 0; + dram_cell_I_on = 0; + dram_cell_C = 0; + vpp = 0; + + sense_delay = 0; + sense_dy_power = 0; + fringe_cap = 0; +// horiz_dielectric_constant = 0; +// vert_dielectric_constant = 0; +// aspect_ratio = 0; +// miller_value = 0; +// ild_thickness = 0; + + dram_cell_I_off_worst_case_len_temp = 0; + + sram_cell.reset(); + dram_acc.reset(); + dram_wl.reset(); + peri_global.reset(); + cam_cell.reset(); + sleep_tx.reset(); + + scaling_factor.reset(); + + wire_local.reset(); + wire_inside_mat.reset(); + wire_outside_mat.reset(); + + sram.reset(); + dram.reset(); + cam.reset(); + + chip_layout_overhead = 0; + macro_layout_overhead = 0; + sckt_co_eff = 0; + } +}; + +**/ +//ali +class DeviceType +{ + public: + double C_g_ideal; + double C_fringe; + double C_overlap; + double C_junc; // C_junc_area + double C_junc_sidewall; + double l_phy; + double l_elec; + double R_nch_on; + double R_pch_on; + double Vdd; + double Vth; + double Vcc_min;//allowed min vcc; for memory cell it is the lowest vcc for data retention. for logic it is the vcc to balance the leakage reduction and wakeup latency + double I_on_n; + double I_on_p; + double I_off_n; + double I_off_p; + double I_g_on_n; + double I_g_on_p; + double C_ox; + double t_ox; + double n_to_p_eff_curr_drv_ratio; + double long_channel_leakage_reduction; + double Mobility_n; + + // auxilary parameters + double Vdsat; + double gmp_to_gmn_multiplier; + + + DeviceType(): C_g_ideal(0), C_fringe(0), C_overlap(0), C_junc(0), + C_junc_sidewall(0), l_phy(0), l_elec(0), R_nch_on(0), R_pch_on(0), + Vdd(0), Vth(0), Vcc_min(0), + I_on_n(0), I_on_p(0), I_off_n(0), I_off_p(0),I_g_on_n(0),I_g_on_p(0), + C_ox(0), t_ox(0), n_to_p_eff_curr_drv_ratio(0), long_channel_leakage_reduction(0), + Mobility_n(0) { reset();}; + + void assign(const string & in_file, int tech_flavor, unsigned int temp); + void interpolate(double alpha, const DeviceType& dev1, const DeviceType& dev2); + void reset() + { + C_g_ideal=0; + C_fringe=0; + C_overlap=0; + C_junc=0; // C_junc_area + C_junc_sidewall=0; + l_phy=0; + l_elec=0; + R_nch_on=0; + R_pch_on=0; + Vdd=0; + Vth=0; + Vcc_min=0;//allowed min vcc, for memory cell it is the lowest vcc for data retention. for logic it is the vcc to balance the leakage reduction and wakeup latency + I_on_n=0; + I_on_p=0; + I_off_n=0; + I_off_p=0; + I_g_on_n=0; + I_g_on_p=0; + C_ox=0; + t_ox=0; + n_to_p_eff_curr_drv_ratio=0; + long_channel_leakage_reduction=0; + Mobility_n=0; + + // auxilary parameters + Vdsat=0; + gmp_to_gmn_multiplier=0; + } + + void display(uint32_t indent = 0) const; + bool isEqual(const DeviceType & dev); +}; + +class InterconnectType +{ + public: + double pitch; + double R_per_um; + double C_per_um; + double horiz_dielectric_constant; + double vert_dielectric_constant; + double aspect_ratio; + double miller_value; + double ild_thickness; + + //auxilary parameters + double wire_width; + double wire_thickness; + double wire_spacing; + double barrier_thickness; + double dishing_thickness; + double alpha_scatter; + double fringe_cap; + + + InterconnectType(): pitch(0), R_per_um(0), C_per_um(0) { reset(); }; + + void reset() + { + pitch=0; + R_per_um=0; + C_per_um=0; + horiz_dielectric_constant=0; + vert_dielectric_constant=0; + aspect_ratio=0; + miller_value=0; + ild_thickness=0; + + //auxilary parameters + wire_width=0; + wire_thickness=0; + wire_spacing=0; + barrier_thickness=0; + dishing_thickness=0; + alpha_scatter=0; + fringe_cap=0; + + } + void assign(const string & in_file, int projection_type, int tech_flavor); + void interpolate(double alpha, const InterconnectType & inter1, const InterconnectType & inter2); + void display(uint32_t indent = 0); + bool isEqual(const InterconnectType & inter); +}; + +class MemoryType +{ + public: + double b_w; + double b_h; + double cell_a_w; + double cell_pmos_w; + double cell_nmos_w; + double Vbitpre; + double Vbitfloating;//voltage when floating bitline is supported + + // needed to calculate b_w b_h + double area_cell; + double asp_ratio_cell; + + MemoryType(){reset();} + void reset() + { + b_w=0; + b_h=0; + cell_a_w=0; + cell_pmos_w=0; + cell_nmos_w=0; + Vbitpre=0; + Vbitfloating=0; + } + void assign(const string & in_file, int tech_flavor, int cell_type); // sram(0),cam(1),dram(2) + void interpolate(double alpha, const MemoryType& dev1, const MemoryType& dev2); + void display(uint32_t indent = 0) const; + bool isEqual(const MemoryType & mem); +}; + +class ScalingFactor +{ + public: + double logic_scaling_co_eff; + double core_tx_density; + double long_channel_leakage_reduction; + + ScalingFactor(): logic_scaling_co_eff(0), core_tx_density(0), + long_channel_leakage_reduction(0) { reset(); }; + + void reset() + { + logic_scaling_co_eff=0; + core_tx_density=0; + long_channel_leakage_reduction=0; + } + void assign(const string & in_file); + void interpolate(double alpha, const ScalingFactor& dev1, const ScalingFactor& dev2); + void display(uint32_t indent = 0); + bool isEqual(const ScalingFactor & scal); +}; + +// parameters which are functions of certain device technology +class TechnologyParameter +{ + public: + double ram_wl_stitching_overhead_; //fs + double min_w_nmos_; //fs + double max_w_nmos_; //fs + double max_w_nmos_dec; //fs+ ram_cell_tech_type + double unit_len_wire_del; //wire_inside_mat + double FO4; //fs + double kinv; //fs + double vpp; //input + double w_sense_en;//fs + double w_sense_n; //fs + double w_sense_p; //fs + double sense_delay; // input + double sense_dy_power; //input + double w_iso; //fs + double w_poly_contact; //fs + double spacing_poly_to_poly; //fs + double spacing_poly_to_contact;//fs + + //CACTI3D auxilary variables + double tsv_pitch; + double tsv_diameter; + double tsv_length; + double tsv_dielec_thickness; + double tsv_contact_resistance; + double tsv_depletion_width; + double tsv_liner_dielectric_constant; + + //CACTI3DD TSV params + + double tsv_parasitic_capacitance_fine; + double tsv_parasitic_resistance_fine; + double tsv_minimum_area_fine; + + double tsv_parasitic_capacitance_coarse; + double tsv_parasitic_resistance_coarse; + double tsv_minimum_area_coarse; + + //fs + double w_comp_inv_p1; + double w_comp_inv_p2; + double w_comp_inv_p3; + double w_comp_inv_n1; + double w_comp_inv_n2; + double w_comp_inv_n3; + double w_eval_inv_p; + double w_eval_inv_n; + double w_comp_n; + double w_comp_p; + + double dram_cell_I_on; //ram_cell_tech_type + double dram_cell_Vdd; + double dram_cell_I_off_worst_case_len_temp; + double dram_cell_C; + double gm_sense_amp_latch; // depends on many things + + double w_nmos_b_mux;//fs + double w_nmos_sa_mux;//fs + double w_pmos_bl_precharge;//fs + double w_pmos_bl_eq;//fs + double MIN_GAP_BET_P_AND_N_DIFFS;//fs + double MIN_GAP_BET_SAME_TYPE_DIFFS;//fs + double HPOWERRAIL;//fs + double cell_h_def;//fs + + double chip_layout_overhead; //input + double macro_layout_overhead; + double sckt_co_eff; + + double fringe_cap;//input + + uint64_t h_dec; //ram_cell_tech_type + + DeviceType sram_cell; // SRAM cell transistor + DeviceType dram_acc; // DRAM access transistor + DeviceType dram_wl; // DRAM wordline transistor + DeviceType peri_global; // peripheral global + DeviceType cam_cell; // SRAM cell transistor + + DeviceType sleep_tx; // Sleep transistor cell transistor + + InterconnectType wire_local; + InterconnectType wire_inside_mat; + InterconnectType wire_outside_mat; + + ScalingFactor scaling_factor; + + MemoryType sram; + MemoryType dram; + MemoryType cam; + + void display(uint32_t indent = 0); + bool isEqual(const TechnologyParameter & tech); + + + void find_upper_and_lower_tech(double technology, int &tech_lo, string& in_file_lo, int &tech_hi, string& in_file_hi); + void assign_tsv(const string & in_file); + void init(double technology, bool is_tag); + TechnologyParameter() + { + reset(); + } + void reset() + { + ram_wl_stitching_overhead_ =0; //fs + min_w_nmos_ =0; //fs + max_w_nmos_ =0; //fs + max_w_nmos_dec =0; //fs+ ram_cell_tech_type + unit_len_wire_del =0; //wire_inside_mat + FO4 =0; //fs + kinv =0; //fs + vpp =0; //input + w_sense_en =0;//fs + w_sense_n =0; //fs + w_sense_p =0; //fs + sense_delay =0; // input + sense_dy_power =0; //input + w_iso =0; //fs + w_poly_contact =0; //fs + spacing_poly_to_poly =0; //fs + spacing_poly_to_contact =0;//fs + + //CACTI3D auxilary variables + tsv_pitch =0; + tsv_diameter =0; + tsv_length =0; + tsv_dielec_thickness =0; + tsv_contact_resistance =0; + tsv_depletion_width =0; + tsv_liner_dielectric_constant =0; + + //CACTI3DD TSV params + + tsv_parasitic_capacitance_fine =0; + tsv_parasitic_resistance_fine =0; + tsv_minimum_area_fine =0; + + tsv_parasitic_capacitance_coarse =0; + tsv_parasitic_resistance_coarse =0; + tsv_minimum_area_coarse =0; + + //fs + w_comp_inv_p1 =0; + w_comp_inv_p2 =0; + w_comp_inv_p3 =0; + w_comp_inv_n1 =0; + w_comp_inv_n2 =0; + w_comp_inv_n3 =0; + w_eval_inv_p =0; + w_eval_inv_n =0; + w_comp_n =0; + w_comp_p =0; + + dram_cell_I_on =0; //ram_cell_tech_type + dram_cell_Vdd =0; + dram_cell_I_off_worst_case_len_temp =0; + dram_cell_C =0; + gm_sense_amp_latch =0; // depends on many things + + w_nmos_b_mux =0;//fs + w_nmos_sa_mux =0;//fs + w_pmos_bl_precharge =0;//fs + w_pmos_bl_eq =0;//fs + MIN_GAP_BET_P_AND_N_DIFFS =0;//fs + MIN_GAP_BET_SAME_TYPE_DIFFS =0;//fs + HPOWERRAIL =0;//fs + cell_h_def =0;//fs + + chip_layout_overhead = 0; + macro_layout_overhead = 0; + sckt_co_eff = 0; + + fringe_cap=0;//input + + h_dec=0; //ram_cell_tech_type + + sram_cell.reset(); + dram_acc.reset(); + dram_wl.reset(); + peri_global.reset(); + cam_cell.reset(); + sleep_tx.reset(); + + scaling_factor.reset(); + + wire_local.reset(); + wire_inside_mat.reset(); + wire_outside_mat.reset(); + + sram.reset(); + dram.reset(); + cam.reset(); + + + } +}; + +//end ali + +class DynamicParameter +{ + public: + bool is_tag; + bool pure_ram; + bool pure_cam; + bool fully_assoc; + int tagbits; + int num_subarrays; // only for leakage computation -- the number of subarrays per bank + int num_mats; // only for leakage computation -- the number of mats per bank + double Nspd; + int Ndwl; + int Ndbl; + int Ndcm; + int deg_bl_muxing; + int deg_senseamp_muxing_non_associativity; + int Ndsam_lev_1; + int Ndsam_lev_2; + Wire_type wtype; // merge from cacti-7 code to cacti3d code. + + int number_addr_bits_mat; // per port + int number_subbanks_decode; // per_port + int num_di_b_bank_per_port; + int num_do_b_bank_per_port; + int num_di_b_mat; + int num_do_b_mat; + int num_di_b_subbank; + int num_do_b_subbank; + + int num_si_b_mat; + int num_so_b_mat; + int num_si_b_subbank; + int num_so_b_subbank; + int num_si_b_bank_per_port; + int num_so_b_bank_per_port; + + int number_way_select_signals_mat; + int num_act_mats_hor_dir; + + int num_act_mats_hor_dir_sl; + bool is_dram; + double V_b_sense; + unsigned int num_r_subarray; + unsigned int num_c_subarray; + int tag_num_r_subarray;//: fully associative cache tag and data must be computed together, data and tag must be separate + int tag_num_c_subarray; + int data_num_r_subarray; + int data_num_c_subarray; + int num_mats_h_dir; + int num_mats_v_dir; + uint32_t ram_cell_tech_type; + double dram_refresh_period; + + DynamicParameter(); + DynamicParameter( + bool is_tag_, + int pure_ram_, + int pure_cam_, + double Nspd_, + unsigned int Ndwl_, + unsigned int Ndbl_, + unsigned int Ndcm_, + unsigned int Ndsam_lev_1_, + unsigned int Ndsam_lev_2_, + Wire_type wt, // merged from cacti-7 to cacti3d + bool is_main_mem_); + + int use_inp_params; + unsigned int num_rw_ports; + unsigned int num_rd_ports; + unsigned int num_wr_ports; + unsigned int num_se_rd_ports; // number of single ended read ports + unsigned int num_search_ports; + unsigned int out_w;// == nr_bits_out + bool is_main_mem; + Area cell, cam_cell;//cell is the sram_cell in both nomal cache/ram and FA. + bool is_valid; + private: + void ECC_adjustment(); + void init_CAM(); + void init_FA(); + bool calc_subarr_rc(unsigned int cap); //to calculate and check subarray rows and columns +}; + + + +extern InputParameter * g_ip; +extern TechnologyParameter g_tp; + +#endif + diff --git a/Project_FARSI/cacti_for_FARSI/powergating.cc b/Project_FARSI/cacti_for_FARSI/powergating.cc new file mode 100644 index 00000000..e0fbd907 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/powergating.cc @@ -0,0 +1,129 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + +#include "area.h" +#include "powergating.h" +#include "parameter.h" +#include +#include +#include + +using namespace std; + +//TODO: although DTSN is used,since for memory array, the number of sleep txs +//is related to the number of rows and cols. so All calculations are still base on +//single sleep tx cases + +Sleep_tx::Sleep_tx( + double _perf_with_sleep_tx, + double _active_Isat,//of circuit block, not sleep tx + bool _is_footer, + double _c_circuit_wakeup, + double _V_delta, + int _num_sleep_tx, +// double _vt_circuit, +// double _vt_sleep_tx, +// double _mobility,//of sleep tx +// double _c_ox,//of sleep tx + const Area & cell_) +:perf_with_sleep_tx(_perf_with_sleep_tx), + active_Isat(_active_Isat), + is_footer(_is_footer), + c_circuit_wakeup(_c_circuit_wakeup), + V_delta(_V_delta), + num_sleep_tx(_num_sleep_tx), +// vt_circuit(_vt_circuit), +// vt_sleep_tx(_vt_sleep_tx), +// mobility(_mobility), +// c_ox(_c_ox) + cell(cell_), + is_sleep_tx(true) +{ + + //a single sleep tx in a network + double raw_area, raw_width, raw_hight; + double p_to_n_sz_ratio = pmos_to_nmos_sz_ratio(false, false, true); + vdd = g_tp.peri_global.Vdd; + vt_circuit = g_tp.peri_global.Vth; + vt_sleep_tx = g_tp.sleep_tx.Vth; + mobility = g_tp.sleep_tx.Mobility_n; + c_ox = g_tp.sleep_tx.C_ox; + + width = active_Isat/(perf_with_sleep_tx*mobility*c_ox*(vdd-vt_circuit)*(vdd-vt_sleep_tx))*g_ip->F_sz_um;//W/L uses physical numbers + width /= num_sleep_tx; + + raw_area = compute_gate_area(INV, 1, width, p_to_n_sz_ratio*width, cell.w*2)/2; //Only single device, assuming device is laide on the side + raw_width = cell.w; + raw_hight = raw_area/cell.w; + area.set_h(raw_hight); + area.set_w(raw_width); + + compute_penalty(); + +} + +double Sleep_tx::compute_penalty() +{ + //V_delta = VDD - VCCmin nothing to do with threshold of sleep tx. Although it might be OK to use sleep tx to control the V_delta +// double c_load; + double p_to_n_sz_ratio = pmos_to_nmos_sz_ratio(false, false, true); + + if (is_footer) + { + c_intrinsic_sleep = drain_C_(width, NCH, 1, 1, area.h, false, false, false,is_sleep_tx); +// V_delta = _V_delta; + wakeup_delay = (c_circuit_wakeup + c_intrinsic_sleep)*V_delta/(simplified_nmos_Isat(width, false, false, false,is_sleep_tx)/Ilinear_to_Isat_ratio); + wakeup_power.readOp.dynamic = (c_circuit_wakeup + c_intrinsic_sleep)*g_tp.sram_cell.Vdd*V_delta; + //no 0.5 because the half of the energy spend in entering sleep and half of the energy will be spent in waking up. And they are pairs + } + else + { + c_intrinsic_sleep = drain_C_(width*p_to_n_sz_ratio, PCH, 1, 1, area.h, false, false, false,is_sleep_tx); +// V_delta = _V_delta; + wakeup_delay = (c_circuit_wakeup + c_intrinsic_sleep)*V_delta/(simplified_pmos_Isat(width, false, false, false,is_sleep_tx)/Ilinear_to_Isat_ratio); + wakeup_power.readOp.dynamic = (c_circuit_wakeup + c_intrinsic_sleep)*g_tp.sram_cell.Vdd*V_delta; + } + + return wakeup_delay; + +/* + The number of cycles in the wake-up latency set the constraint on the + minimum number of idle clock cycles needed before a processor + can enter in the corresponding sleep mode without any wakeup + overhead. + + If the circuit is half way to sleep then waken up, it is still OK + just the wakeup latency will be shorter than the wakeup time from full asleep. + So, the sleep time and energy does not matter +*/ + +} + diff --git a/Project_FARSI/cacti_for_FARSI/powergating.h b/Project_FARSI/cacti_for_FARSI/powergating.h new file mode 100644 index 00000000..c4533998 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/powergating.h @@ -0,0 +1,86 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + +#ifndef POWERGATING_H_ +#define POWERGATING_H_ + +#include "component.h" + +class Sleep_tx : public Component +{ +public: + Sleep_tx( + double _perf_with_sleep_tx, + double _active_Isat,//of circuit block, not sleep tx + bool _is_footer, + double _c_circuit_wakeup, + double _V_delta, + int _num_sleep_tx, + // double _vt_circuit, + // double _vt_sleep_tx, + // double _mobility,//of sleep tx + // double _c_ox,//of sleep tx + const Area & cell_); + + double perf_with_sleep_tx; + double active_Isat; + bool is_footer; + + double vt_circuit; + double vt_sleep_tx; + double vdd;// of circuit block not sleep tx + double mobility;//of sleep tx + double c_ox; + double width; + double c_circuit_wakeup; + double c_intrinsic_sleep; + double delay, wakeup_delay; + powerDef power, wakeup_power; +// double c_circuit_sleep; +// double sleep_delay; +// powerDef sleep_power; + double V_delta; + + int num_sleep_tx; + + const Area & cell; + bool is_sleep_tx; + + + +// void compute_area(); + double compute_penalty(); // return outrisetime + + void leakage_feedback(double temperature){}; + ~Sleep_tx(){}; +}; + +#endif /* POWERGATING_H_ */ diff --git a/Project_FARSI/cacti_for_FARSI/regression.test b/Project_FARSI/cacti_for_FARSI/regression.test new file mode 100755 index 00000000..8cd4722b --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/regression.test @@ -0,0 +1,45 @@ +cache 4 types +./cacti -infile test_configs/cache1.cfg #L1 2-way 32K +./cacti -infile test_configs/cache2.cfg #L2 4-way 256K +./cacti -infile test_configs/cache3.cfg #L3 8-way 16M +./cacti -infile test_configs/cache4.cfg #L1 full-asso 4K with single search port +RAM 4 types +./cacti -infile test_configs/ram1.cfg # 16M +./cacti -infile test_configs/ram2.cfg # itrs-hp itrs-lstp +./cacti -infile test_configs/ram3.cfg # two banks no-ecc 128M +./cacti -infile test_configs/ram4.cfg # 32K 2-way +CAM 4 types +./cacti -infile test_configs/cam1.cfg # same as ram1 but ram->cam and full-asso +./cacti -infile test_configs/cam2.cfg # same as cam1 with line size = 128 +./cacti -infile test_configs/cam3.cfg # cam1 for 40nm technology +./cacti -infile test_configs/cam4.cfg # ca1 with exclusive read and write port +NUCA 4 types +./cacti -infile test_configs/nuca1.cfg # +./cacti -infile test_configs/nuca2.cfg +./cacti -infile test_configs/nuca3.cfg +./cacti -infile test_configs/nuca3.cfg +eDRAM 4 types +./cacti -infile test_configs/edram1.cfg # +./cacti -infile test_configs/edram2.cfg +./cacti -infile test_configs/edram3.cfg +./cacti -infile test_configs/edram4.cfg +DRAM 4 types +./cacti -infile test_configs/dram1.cfg # +./cacti -infile test_configs/dram2.cfg +./cacti -infile test_configs/dram3.cfg +./cacti -infile test_configs/dram4.cfg +IO 4 different parameters +./cacti -infile test_configs/io1.cfg # +./cacti -infile test_configs/io2.cfg +./cacti -infile test_configs/io3.cfg +./cacti -infile test_configs/io4.cfg +Power gating 4 types +./cacti -infile test_configs/power_gate1.cfg +./cacti -infile test_configs/power_gate2.cfg +./cacti -infile test_configs/power_gate3.cfg +./cacti -infile test_configs/power_gate4.cfg +3D 4 types +./cacti -infile test_configs/3D1.cfg +./cacti -infile test_configs/3D2.cfg +./cacti -infile test_configs/3D3.cfg +./cacti -infile test_configs/3D4.cfg \ No newline at end of file diff --git a/Project_FARSI/cacti_for_FARSI/router.cc b/Project_FARSI/cacti_for_FARSI/router.cc new file mode 100644 index 00000000..929c773d --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/router.cc @@ -0,0 +1,311 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + + +#include "router.h" + +Router::Router( + double flit_size_, + double vc_buf, /* vc size = vc_buffer_size * flit_size */ + double vc_c, + /*TechnologyParameter::*/DeviceType *dt, + double I_, + double O_, + double M_ + ):flit_size(flit_size_), + deviceType(dt), + I(I_), + O(O_), + M(M_) +{ + vc_buffer_size = vc_buf; + vc_count = vc_c; + min_w_pmos = deviceType->n_to_p_eff_curr_drv_ratio*g_tp.min_w_nmos_; + double technology = g_ip->F_sz_um; + + Vdd = dt->Vdd; + + /*Crossbar parameters. Transmisson gate is employed for connector*/ + NTtr = 10*technology*1e-6/2; /*Transmission gate's nmos tr. length*/ + PTtr = 20*technology*1e-6/2; /* pmos tr. length*/ + wt = 15*technology*1e-6/2; /*track width*/ + ht = 15*technology*1e-6/2; /*track height*/ +// I = 5; /*Number of crossbar input ports*/ +// O = 5; /*Number of crossbar output ports*/ + NTi = 12.5*technology*1e-6/2; + PTi = 25*technology*1e-6/2; + + NTid = 60*technology*1e-6/2; //m + PTid = 120*technology*1e-6/2; // m + NTod = 60*technology*1e-6/2; // m + PTod = 120*technology*1e-6/2; // m + + calc_router_parameters(); +} + +Router::~Router(){} + + +double //wire cap with triple spacing +Router::Cw3(double length) { + Wire wc(g_ip->wt, length, 1, 3, 3); + return (wc.wire_cap(length)); +} + +/*Function to calculate the gate capacitance*/ +double +Router::gate_cap(double w) { + return (double) gate_C (w*1e6 /*u*/, 0); +} + +/*Function to calculate the diffusion capacitance*/ +double +Router::diff_cap(double w, int type /*0 for n-mos and 1 for p-mos*/, + double s /*number of stacking transistors*/) { + return (double) drain_C_(w*1e6 /*u*/, type, (int) s, 1, g_tp.cell_h_def); +} + + +/*crossbar related functions */ + +// Model for simple transmission gate +double +Router::transmission_buf_inpcap() { + return diff_cap(NTtr, 0, 1)+diff_cap(PTtr, 1, 1); +} + +double +Router::transmission_buf_outcap() { + return diff_cap(NTtr, 0, 1)+diff_cap(PTtr, 1, 1); +} + +double +Router::transmission_buf_ctrcap() { + return gate_cap(NTtr)+gate_cap(PTtr); +} + +double +Router::crossbar_inpline() { + return (Cw3(O*flit_size*wt) + O*transmission_buf_inpcap() + gate_cap(NTid) + + gate_cap(PTid) + diff_cap(NTid, 0, 1) + diff_cap(PTid, 1, 1)); +} + +double +Router::crossbar_outline() { + return (Cw3(I*flit_size*ht) + I*transmission_buf_outcap() + gate_cap(NTod) + + gate_cap(PTod) + diff_cap(NTod, 0, 1) + diff_cap(PTod, 1, 1)); +} + +double +Router::crossbar_ctrline() { + return (Cw3(0.5*O*flit_size*wt) + flit_size*transmission_buf_ctrcap() + + diff_cap(NTi, 0, 1) + diff_cap(PTi, 1, 1) + + gate_cap(NTi) + gate_cap(PTi)); +} + +double +Router::tr_crossbar_power() { + return (crossbar_inpline()*Vdd*Vdd*flit_size/2 + + crossbar_outline()*Vdd*Vdd*flit_size/2)*2; +} + +void Router::buffer_stats() +{ + DynamicParameter dyn_p; + dyn_p.is_tag = false; + dyn_p.pure_cam = false; + dyn_p.fully_assoc = false; + dyn_p.pure_ram = true; + dyn_p.is_dram = false; + dyn_p.is_main_mem = false; + dyn_p.num_subarrays = 1; + dyn_p.num_mats = 1; + dyn_p.Ndbl = 1; + dyn_p.Ndwl = 1; + dyn_p.Nspd = 1; + dyn_p.deg_bl_muxing = 1; + dyn_p.deg_senseamp_muxing_non_associativity = 1; + dyn_p.Ndsam_lev_1 = 1; + dyn_p.Ndsam_lev_2 = 1; + dyn_p.Ndcm = 1; + dyn_p.number_addr_bits_mat = 8; + dyn_p.number_way_select_signals_mat = 1; + dyn_p.number_subbanks_decode = 0; + dyn_p.num_act_mats_hor_dir = 1; + dyn_p.V_b_sense = Vdd; // FIXME check power calc. + dyn_p.ram_cell_tech_type = 0; + dyn_p.num_r_subarray = (int) vc_buffer_size; + dyn_p.num_c_subarray = (int) flit_size * (int) vc_count; + dyn_p.num_mats_h_dir = 1; + dyn_p.num_mats_v_dir = 1; + dyn_p.num_do_b_subbank = (int)flit_size; + dyn_p.num_di_b_subbank = (int)flit_size; + dyn_p.num_do_b_mat = (int) flit_size; + dyn_p.num_di_b_mat = (int) flit_size; + dyn_p.num_do_b_mat = (int) flit_size; + dyn_p.num_di_b_mat = (int) flit_size; + dyn_p.num_do_b_bank_per_port = (int) flit_size; + dyn_p.num_di_b_bank_per_port = (int) flit_size; + dyn_p.out_w = (int) flit_size; + + dyn_p.use_inp_params = 1; + dyn_p.num_wr_ports = (unsigned int) vc_count; + dyn_p.num_rd_ports = 1;//(unsigned int) vc_count;//based on Bill Dally's book + dyn_p.num_rw_ports = 0; + dyn_p.num_se_rd_ports =0; + dyn_p.num_search_ports =0; + + + + dyn_p.cell.h = g_tp.sram.b_h + 2 * g_tp.wire_outside_mat.pitch * (dyn_p.num_wr_ports + + dyn_p.num_rw_ports - 1 + dyn_p.num_rd_ports); + dyn_p.cell.w = g_tp.sram.b_w + 2 * g_tp.wire_outside_mat.pitch * (dyn_p.num_rw_ports - 1 + + (dyn_p.num_rd_ports - dyn_p.num_se_rd_ports) + + dyn_p.num_wr_ports) + g_tp.wire_outside_mat.pitch * dyn_p.num_se_rd_ports; + + Mat buff(dyn_p); + buff.compute_delays(0); + buff.compute_power_energy(); + buffer.power.readOp = buff.power.readOp; + buffer.power.writeOp = buffer.power.readOp; //FIXME + buffer.area = buff.area; +} + + + + void +Router::cb_stats () +{ + if (1) { + Crossbar c_b(I, O, flit_size); + c_b.compute_power(); + crossbar.delay = c_b.delay; + crossbar.power.readOp.dynamic = c_b.power.readOp.dynamic; + crossbar.power.readOp.leakage = c_b.power.readOp.leakage; + crossbar.power.readOp.gate_leakage = c_b.power.readOp.gate_leakage; + crossbar.area = c_b.area; +// c_b.print_crossbar(); + } + else { + crossbar.power.readOp.dynamic = tr_crossbar_power(); + crossbar.power.readOp.leakage = flit_size * I * O * + cmos_Isub_leakage(NTtr*g_tp.min_w_nmos_, PTtr*min_w_pmos, 1, tg); + crossbar.power.readOp.gate_leakage = flit_size * I * O * + cmos_Ig_leakage(NTtr*g_tp.min_w_nmos_, PTtr*min_w_pmos, 1, tg); + } +} + +void +Router::get_router_power() +{ + /* calculate buffer stats */ + buffer_stats(); + + /* calculate cross-bar stats */ + cb_stats(); + + /* calculate arbiter stats */ + Arbiter vcarb(vc_count, flit_size, buffer.area.w); + Arbiter cbarb(I, flit_size, crossbar.area.w); + vcarb.compute_power(); + cbarb.compute_power(); + arbiter.power.readOp.dynamic = vcarb.power.readOp.dynamic * I + + cbarb.power.readOp.dynamic * O; + arbiter.power.readOp.leakage = vcarb.power.readOp.leakage * I + + cbarb.power.readOp.leakage * O; + arbiter.power.readOp.gate_leakage = vcarb.power.readOp.gate_leakage * I + + cbarb.power.readOp.gate_leakage * O; + +// arb_stats(); + power.readOp.dynamic = ((buffer.power.readOp.dynamic+buffer.power.writeOp.dynamic) + + crossbar.power.readOp.dynamic + + arbiter.power.readOp.dynamic)*MIN(I, O)*M; + double pppm_t[4] = {1,I,I,1}; + power = power + (buffer.power*pppm_t + crossbar.power + arbiter.power)*pppm_lkg; + +} + + void +Router::get_router_delay () +{ + FREQUENCY=5; // move this to config file --TODO + cycle_time = (1/(double)FREQUENCY)*1e3; //ps + delay = 4; + max_cyc = 17 * g_tp.FO4; //s + max_cyc *= 1e12; //ps + if (cycle_time < max_cyc) { + FREQUENCY = (1/max_cyc)*1e3; //GHz + } +} + + void +Router::get_router_area() +{ + area.h = I*buffer.area.h; + area.w = buffer.area.w+crossbar.area.w; +} + + void +Router::calc_router_parameters() +{ + /* calculate router frequency and pipeline cycles */ + get_router_delay(); + + /* router power stats */ + get_router_power(); + + /* area stats */ + get_router_area(); +} + + void +Router::print_router() +{ + cout << "\n\nRouter stats:\n"; + cout << "\tRouter Area - "<< area.get_area()*1e-6<<"(mm^2)\n"; + cout << "\tMaximum possible network frequency - " << (1/max_cyc)*1e3 << "GHz\n"; + cout << "\tNetwork frequency - " << FREQUENCY <<" GHz\n"; + cout << "\tNo. of Virtual channels - " << vc_count << "\n"; + cout << "\tNo. of pipeline stages - " << delay << endl; + cout << "\tLink bandwidth - " << flit_size << " (bits)\n"; + cout << "\tNo. of buffer entries per virtual channel - "<< vc_buffer_size << "\n"; + cout << "\tSimple buffer Area - "<< buffer.area.get_area()*1e-6<<"(mm^2)\n"; + cout << "\tSimple buffer access (Read) - " << buffer.power.readOp.dynamic * 1e9 <<" (nJ)\n"; + cout << "\tSimple buffer leakage - " << buffer.power.readOp.leakage * 1e3 <<" (mW)\n"; + cout << "\tCrossbar Area - "<< crossbar.area.get_area()*1e-6<<"(mm^2)\n"; + cout << "\tCross bar access energy - " << crossbar.power.readOp.dynamic * 1e9<<" (nJ)\n"; + cout << "\tCross bar leakage power - " << crossbar.power.readOp.leakage * 1e3<<" (mW)\n"; + cout << "\tArbiter access energy (VC arb + Crossbar arb) - "< +#include +#include "basic_circuit.h" +#include "cacti_interface.h" +#include "component.h" +#include "mat.h" +#include "parameter.h" +#include "wire.h" +#include "crossbar.h" +#include "arbiter.h" + + + +class Router : public Component +{ + public: + Router( + double flit_size_, + double vc_buf, /* vc size = vc_buffer_size * flit_size */ + double vc_count, + /*TechnologyParameter::*/DeviceType *dt = &(g_tp.peri_global), + double I_ = 5, + double O_ = 5, + double M_ = 0.6); + ~Router(); + + + void print_router(); + + Component arbiter, crossbar, buffer; + + double cycle_time, max_cyc; + double flit_size; + double vc_count; + double vc_buffer_size; /* vc size = vc_buffer_size * flit_size */ + + private: + /*TechnologyParameter::*/DeviceType *deviceType; + double FREQUENCY; // move this to config file --TODO + double Cw3(double len); + double gate_cap(double w); + double diff_cap(double w, int type /*0 for n-mos and 1 for p-mos*/, double stack); + enum Wire_type wtype; + enum Wire_placement wire_placement; + //corssbar + double NTtr, PTtr, wt, ht, I, O, NTi, PTi, NTid, PTid, NTod, PTod, TriS1, TriS2; + double M; //network load + double transmission_buf_inpcap(); + double transmission_buf_outcap(); + double transmission_buf_ctrcap(); + double crossbar_inpline(); + double crossbar_outline(); + double crossbar_ctrline(); + double tr_crossbar_power(); + void cb_stats (); + double arb_power(); + void arb_stats (); + double buffer_params(); + void buffer_stats(); + + + //arbiter + + //buffer + + //router params + double Vdd; + + void calc_router_parameters(); + void get_router_area(); + void get_router_power(); + void get_router_delay(); + + double min_w_pmos; + + +}; + +#endif diff --git a/Project_FARSI/cacti_for_FARSI/sample_config_files/ddr3_cache.cfg b/Project_FARSI/cacti_for_FARSI/sample_config_files/ddr3_cache.cfg new file mode 100644 index 00000000..086bf5aa --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/sample_config_files/ddr3_cache.cfg @@ -0,0 +1,259 @@ +# Cache size +//-size (bytes) 2048 +//-size (bytes) 4096 +//-size (bytes) 32768 +//-size (bytes) 131072 +//-size (bytes) 262144 +//-size (bytes) 1048576 +//-size (bytes) 2097152 +//-size (bytes) 4194304 +-size (bytes) 8388608 +//-size (bytes) 16777216 +//-size (bytes) 33554432 +//-size (bytes) 134217728 +//-size (bytes) 67108864 +//-size (bytes) 1073741824 + +# power gating +-Array Power Gating - "false" +-WL Power Gating - "false" +-CL Power Gating - "false" +-Bitline floating - "false" +-Interconnect Power Gating - "false" +-Power Gating Performance Loss 0.01 + +# Line size +//-block size (bytes) 8 +-block size (bytes) 64 + +# To model Fully Associative cache, set associativity to zero +//-associativity 0 +//-associativity 2 +//-associativity 4 +//-associativity 8 +-associativity 8 + +-read-write port 1 +-exclusive read port 0 +-exclusive write port 0 +-single ended read ports 0 + +# Multiple banks connected using a bus +-UCA bank count 1 +-technology (u) 0.022 +//-technology (u) 0.040 +//-technology (u) 0.032 +//-technology (u) 0.090 + +# following three parameters are meaningful only for main memories + +-page size (bits) 8192 +-burst length 8 +-internal prefetch width 8 + +# following parameter can have one of five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +-Data array cell type - "itrs-hp" +//-Data array cell type - "itrs-lstp" +//-Data array cell type - "itrs-lop" + +# following parameter can have one of three values -- (itrs-hp, itrs-lstp, itrs-lop) +-Data array peripheral type - "itrs-hp" +//-Data array peripheral type - "itrs-lstp" +//-Data array peripheral type - "itrs-lop" + +# following parameter can have one of five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +-Tag array cell type - "itrs-hp" +//-Tag array cell type - "itrs-lstp" +//-Tag array cell type - "itrs-lop" + +# following parameter can have one of three values -- (itrs-hp, itrs-lstp, itrs-lop) +-Tag array peripheral type - "itrs-hp" +//-Tag array peripheral type - "itrs-lstp" +//-Tag array peripheral type - "itrs-lop + +# Bus width include data bits and address bits required by the decoder +//-output/input bus width 16 +-output/input bus width 512 + +// 300-400 in steps of 10 +-operating temperature (K) 360 + +# Type of memory - cache (with a tag array) or ram (scratch ram similar to a register file) +# or main memory (no tag array and every access will happen at a page granularity Ref: CACTI 5.3 report) +-cache type "cache" +//-cache type "ram" +//-cache type "main memory" + +# to model special structure like branch target buffers, directory, etc. +# change the tag size parameter +# if you want cacti to calculate the tagbits, set the tag size to "default" +-tag size (b) "default" +//-tag size (b) 22 + +# fast - data and tag access happen in parallel +# sequential - data array is accessed after accessing the tag array +# normal - data array lookup and tag access happen in parallel +# final data block is broadcasted in data array h-tree +# after getting the signal from the tag array +//-access mode (normal, sequential, fast) - "fast" +-access mode (normal, sequential, fast) - "normal" +//-access mode (normal, sequential, fast) - "sequential" + + +# DESIGN OBJECTIVE for UCA (or banks in NUCA) +-design objective (weight delay, dynamic power, leakage power, cycle time, area) 0:0:0:100:0 + +# Percentage deviation from the minimum value +# Ex: A deviation value of 10:1000:1000:1000:1000 will try to find an organization +# that compromises at most 10% delay. +# NOTE: Try reasonable values for % deviation. Inconsistent deviation +# percentage values will not produce any valid organizations. For example, +# 0:0:100:100:100 will try to identify an organization that has both +# least delay and dynamic power. Since such an organization is not possible, CACTI will +# throw an error. Refer CACTI-6 Technical report for more details +-deviate (delay, dynamic power, leakage power, cycle time, area) 20:100000:100000:100000:100000 + +# Objective for NUCA +-NUCAdesign objective (weight delay, dynamic power, leakage power, cycle time, area) 100:100:0:0:100 +-NUCAdeviate (delay, dynamic power, leakage power, cycle time, area) 10:10000:10000:10000:10000 + +# Set optimize tag to ED or ED^2 to obtain a cache configuration optimized for +# energy-delay or energy-delay sq. product +# Note: Optimize tag will disable weight or deviate values mentioned above +# Set it to NONE to let weight and deviate values determine the +# appropriate cache configuration +//-Optimize ED or ED^2 (ED, ED^2, NONE): "ED" +-Optimize ED or ED^2 (ED, ED^2, NONE): "ED^2" +//-Optimize ED or ED^2 (ED, ED^2, NONE): "NONE" + +-Cache model (NUCA, UCA) - "UCA" +//-Cache model (NUCA, UCA) - "NUCA" + +# In order for CACTI to find the optimal NUCA bank value the following +# variable should be assigned 0. +-NUCA bank count 0 + +# NOTE: for nuca network frequency is set to a default value of +# 5GHz in time.c. CACTI automatically +# calculates the maximum possible frequency and downgrades this value if necessary + +# By default CACTI considers both full-swing and low-swing +# wires to find an optimal configuration. However, it is possible to +# restrict the search space by changing the signaling from "default" to +# "fullswing" or "lowswing" type. +-Wire signaling (fullswing, lowswing, default) - "Global_30" +//-Wire signaling (fullswing, lowswing, default) - "default" +//-Wire signaling (fullswing, lowswing, default) - "lowswing" + +//-Wire inside mat - "global" +-Wire inside mat - "semi-global" +//-Wire outside mat - "global" +-Wire outside mat - "semi-global" + +-Interconnect projection - "conservative" +//-Interconnect projection - "aggressive" + +# Contention in network (which is a function of core count and cache level) is one of +# the critical factor used for deciding the optimal bank count value +# core count can be 4, 8, or 16 +//-Core count 4 +-Core count 8 +//-Core count 16 +-Cache level (L2/L3) - "L3" + +-Add ECC - "true" + +//-Print level (DETAILED, CONCISE) - "CONCISE" +-Print level (DETAILED, CONCISE) - "DETAILED" + +# for debugging +//-Print input parameters - "true" +-Print input parameters - "false" +# force CACTI to model the cache with the +# following Ndbl, Ndwl, Nspd, Ndsam, +# and Ndcm values +//-Force cache config - "true" +-Force cache config - "false" +-Ndwl 1 +-Ndbl 1 +-Nspd 0 +-Ndcm 1 +-Ndsam1 0 +-Ndsam2 0 + + + +#### Default CONFIGURATION values for baseline external IO parameters to DRAM. + +# Memory Type (D=DDR3, L=LPDDR2, W=WideIO, S=Low-swing differential) + +-dram_type "D" +//-dram_type "L" +//-dram_type "W" +//-dram_type "S" + +# Memory State (R=Read, W=Write, I=Idle or S=Sleep) + +//-iostate "R" +-iostate "W" +//-iostate "I" +//-iostate "S" + +# Is ECC Enabled (Y=Yes, N=No) + +-dram_ecc "Y" + +#Address bus timing + +//-addr_timing 0.5 //DDR, for LPDDR2 and LPDDR3 +-addr_timing 1.0 //SDR for DDR3, Wide-IO +//-addr_timing 2.0 //2T timing +//addr_timing 3.0 // 3T timing + +# Bandwidth (Gbytes per second, this is the effective bandwidth) + +-bus_bw 12.8 GBps //Valid range 0 to 2*bus_freq*num_dq + +# Memory Density (Gbit per memory/DRAM die) + +-mem_density 4 Gb //Valid values 2^n Gb + +# IO frequency (MHz) (frequency of the external memory interface). + +-bus_freq 800 MHz //Valid range 0 to 1.5 GHz for DDR3, 0 to 1.2 GHz for LPDDR3, 0 - 800 MHz for WideIO and 0 - 3 GHz for Low-swing differential + +# Duty Cycle (fraction of time in the Memory State defined above) + +-duty_cycle 1.0 //Valid range 0 to 1.0 + +# Activity factor for Data (0->1 transitions) per cycle (for DDR, need to account for the higher activity in this parameter. E.g. max. activity factor for DDR is 1.0, for SDR is 0.5) + +-activity_dq 1.0 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR + +# Activity factor for Control/Address (0->1 transitions) per cycle (for DDR, need to account for the higher activity in this parameter. E.g. max. activity factor for DDR is 1.0, for SDR is 0.5) + +-activity_ca 0.5 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR + +# Number of DQ pins + +-num_dq 72 //Include ECC pins as well (if present). If ECC pins are included, the bus bandwidth is 2*(num_dq-#of ECC pins)*bus_freq. Valid range 0 to 72. + +# Number of DQS pins + +-num_dqs 18 //2 x differential pairs. Include ECC pins as well. Valid range 0 to 18. For x4 memories, could have 36 DQS pins. + +# Number of CA pins + +-num_ca 25 //Valid range 0 to 35 pins. + +# Number of CLK pins + +-num_clk 2 //2 x differential pair. Valid values: 0/2/4. + +# Number of Physical Ranks + +-num_mem_dq 2 //Number of ranks (loads on DQ and DQS) per DIMM or buffer chip + +# Width of the Memory Data Bus + +-mem_data_width 8 //x4 or x8 or x16 or x32 or x128 memories \ No newline at end of file diff --git a/Project_FARSI/cacti_for_FARSI/sample_config_files/diff_ddr3_cache.cfg b/Project_FARSI/cacti_for_FARSI/sample_config_files/diff_ddr3_cache.cfg new file mode 100644 index 00000000..cd623a89 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/sample_config_files/diff_ddr3_cache.cfg @@ -0,0 +1,259 @@ +# Cache size +//-size (bytes) 2048 +//-size (bytes) 4096 +//-size (bytes) 32768 +//-size (bytes) 131072 +//-size (bytes) 262144 +//-size (bytes) 1048576 +//-size (bytes) 2097152 +//-size (bytes) 4194304 +-size (bytes) 8388608 +//-size (bytes) 16777216 +//-size (bytes) 33554432 +//-size (bytes) 134217728 +//-size (bytes) 67108864 +//-size (bytes) 1073741824 + +# power gating +-Array Power Gating - "false" +-WL Power Gating - "false" +-CL Power Gating - "false" +-Bitline floating - "false" +-Interconnect Power Gating - "false" +-Power Gating Performance Loss 0.01 + +# Line size +//-block size (bytes) 8 +-block size (bytes) 64 + +# To model Fully Associative cache, set associativity to zero +//-associativity 0 +//-associativity 2 +//-associativity 4 +//-associativity 8 +-associativity 8 + +-read-write port 1 +-exclusive read port 0 +-exclusive write port 0 +-single ended read ports 0 + +# Multiple banks connected using a bus +-UCA bank count 1 +-technology (u) 0.022 +//-technology (u) 0.040 +//-technology (u) 0.032 +//-technology (u) 0.090 + +# following three parameters are meaningful only for main memories + +-page size (bits) 8192 +-burst length 8 +-internal prefetch width 8 + +# following parameter can have one of five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +-Data array cell type - "itrs-hp" +//-Data array cell type - "itrs-lstp" +//-Data array cell type - "itrs-lop" + +# following parameter can have one of three values -- (itrs-hp, itrs-lstp, itrs-lop) +-Data array peripheral type - "itrs-hp" +//-Data array peripheral type - "itrs-lstp" +//-Data array peripheral type - "itrs-lop" + +# following parameter can have one of five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +-Tag array cell type - "itrs-hp" +//-Tag array cell type - "itrs-lstp" +//-Tag array cell type - "itrs-lop" + +# following parameter can have one of three values -- (itrs-hp, itrs-lstp, itrs-lop) +-Tag array peripheral type - "itrs-hp" +//-Tag array peripheral type - "itrs-lstp" +//-Tag array peripheral type - "itrs-lop + +# Bus width include data bits and address bits required by the decoder +//-output/input bus width 16 +-output/input bus width 512 + +// 300-400 in steps of 10 +-operating temperature (K) 360 + +# Type of memory - cache (with a tag array) or ram (scratch ram similar to a register file) +# or main memory (no tag array and every access will happen at a page granularity Ref: CACTI 5.3 report) +-cache type "cache" +//-cache type "ram" +//-cache type "main memory" + +# to model special structure like branch target buffers, directory, etc. +# change the tag size parameter +# if you want cacti to calculate the tagbits, set the tag size to "default" +-tag size (b) "default" +//-tag size (b) 22 + +# fast - data and tag access happen in parallel +# sequential - data array is accessed after accessing the tag array +# normal - data array lookup and tag access happen in parallel +# final data block is broadcasted in data array h-tree +# after getting the signal from the tag array +//-access mode (normal, sequential, fast) - "fast" +-access mode (normal, sequential, fast) - "normal" +//-access mode (normal, sequential, fast) - "sequential" + + +# DESIGN OBJECTIVE for UCA (or banks in NUCA) +-design objective (weight delay, dynamic power, leakage power, cycle time, area) 0:0:0:100:0 + +# Percentage deviation from the minimum value +# Ex: A deviation value of 10:1000:1000:1000:1000 will try to find an organization +# that compromises at most 10% delay. +# NOTE: Try reasonable values for % deviation. Inconsistent deviation +# percentage values will not produce any valid organizations. For example, +# 0:0:100:100:100 will try to identify an organization that has both +# least delay and dynamic power. Since such an organization is not possible, CACTI will +# throw an error. Refer CACTI-6 Technical report for more details +-deviate (delay, dynamic power, leakage power, cycle time, area) 20:100000:100000:100000:100000 + +# Objective for NUCA +-NUCAdesign objective (weight delay, dynamic power, leakage power, cycle time, area) 100:100:0:0:100 +-NUCAdeviate (delay, dynamic power, leakage power, cycle time, area) 10:10000:10000:10000:10000 + +# Set optimize tag to ED or ED^2 to obtain a cache configuration optimized for +# energy-delay or energy-delay sq. product +# Note: Optimize tag will disable weight or deviate values mentioned above +# Set it to NONE to let weight and deviate values determine the +# appropriate cache configuration +//-Optimize ED or ED^2 (ED, ED^2, NONE): "ED" +-Optimize ED or ED^2 (ED, ED^2, NONE): "ED^2" +//-Optimize ED or ED^2 (ED, ED^2, NONE): "NONE" + +-Cache model (NUCA, UCA) - "UCA" +//-Cache model (NUCA, UCA) - "NUCA" + +# In order for CACTI to find the optimal NUCA bank value the following +# variable should be assigned 0. +-NUCA bank count 0 + +# NOTE: for nuca network frequency is set to a default value of +# 5GHz in time.c. CACTI automatically +# calculates the maximum possible frequency and downgrades this value if necessary + +# By default CACTI considers both full-swing and low-swing +# wires to find an optimal configuration. However, it is possible to +# restrict the search space by changing the signaling from "default" to +# "fullswing" or "lowswing" type. +-Wire signaling (fullswing, lowswing, default) - "Global_30" +//-Wire signaling (fullswing, lowswing, default) - "default" +//-Wire signaling (fullswing, lowswing, default) - "lowswing" + +//-Wire inside mat - "global" +-Wire inside mat - "semi-global" +//-Wire outside mat - "global" +-Wire outside mat - "semi-global" + +-Interconnect projection - "conservative" +//-Interconnect projection - "aggressive" + +# Contention in network (which is a function of core count and cache level) is one of +# the critical factor used for deciding the optimal bank count value +# core count can be 4, 8, or 16 +//-Core count 4 +-Core count 8 +//-Core count 16 +-Cache level (L2/L3) - "L3" + +-Add ECC - "true" + +//-Print level (DETAILED, CONCISE) - "CONCISE" +-Print level (DETAILED, CONCISE) - "DETAILED" + +# for debugging +//-Print input parameters - "true" +-Print input parameters - "false" +# force CACTI to model the cache with the +# following Ndbl, Ndwl, Nspd, Ndsam, +# and Ndcm values +//-Force cache config - "true" +-Force cache config - "false" +-Ndwl 1 +-Ndbl 1 +-Nspd 0 +-Ndcm 1 +-Ndsam1 0 +-Ndsam2 0 + + + +#### Default CONFIGURATION values for baseline external IO parameters to DRAM. + +# Memory Type (D=DDR3, L=LPDDR2, W=WideIO, S=Low-swing differential) + +//-dram_type "D" +//-dram_type "L" +//-dram_type "W" +-dram_type "S" + +# Memory State (R=Read, W=Write, I=Idle or S=Sleep) + +//-iostate "R" +-iostate "W" +//-iostate "I" +//-iostate "S" + +# Is ECC Enabled (Y=Yes, N=No) + +-dram_ecc "N" + +#Address bus timing + +//-addr_timing 0.5 //DDR, for LPDDR2 and LPDDR3 +-addr_timing 1.0 //SDR for DDR3, Wide-IO +//-addr_timing 2.0 //2T timing +//addr_timing 3.0 // 3T timing + +# Bandwidth (Gbytes per second, this is the effective bandwidth) + +-bus_bw 6 GBps //Valid range 0 to 2*bus_freq*num_dq + +# Memory Density (Gbit per memory/DRAM die) + +-mem_density 4 Gb //Valid values 2^n Gb + +# IO frequency (MHz) (frequency of the external memory interface). + +-bus_freq 3000 MHz //Valid range 0 to 1.5 GHz for DDR3, 0 to 1.2 GHz for LPDDR3, 0 - 800 MHz for WideIO and 0 - 3 GHz for Low-swing differential + +# Duty Cycle (fraction of time in the Memory State defined above) + +-duty_cycle 1.0 //Valid range 0 to 1.0 + +# Activity factor for Data (0->1 transitions) per cycle (for DDR, need to account for the higher activity in this parameter. E.g. max. activity factor for DDR is 1.0, for SDR is 0.5) + +-activity_dq 1.0 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR + +# Activity factor for Control/Address (0->1 transitions) per cycle (for DDR, need to account for the higher activity in this parameter. E.g. max. activity factor for DDR is 1.0, for SDR is 0.5) + +-activity_ca 0.5 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR + +# Number of DQ pins + +-num_dq 8 //Include ECC pins as well (if present). If ECC pins are included, the bus bandwidth is 2*(num_dq-#of ECC pins)*bus_freq. Valid range 0 to 72. + +# Number of DQS pins + +-num_dqs 2 //2 x differential pairs. Include ECC pins as well. Valid range 0 to 18. For x4 memories, could have 36 DQS pins. + +# Number of CA pins + +-num_ca 0 //Valid range 0 to 35 pins. + +# Number of CLK pins + +-num_clk 0 //2 x differential pair. Valid values: 0/2/4. + +# Number of Physical Ranks + +-num_mem_dq 2 //Number of ranks (loads on DQ and DQS) per DIMM or buffer chip + +# Width of the Memory Data Bus + +-mem_data_width 8 //x4 or x8 or x16 or x32 memories \ No newline at end of file diff --git a/Project_FARSI/cacti_for_FARSI/sample_config_files/lpddr3_cache.cfg b/Project_FARSI/cacti_for_FARSI/sample_config_files/lpddr3_cache.cfg new file mode 100644 index 00000000..caa24920 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/sample_config_files/lpddr3_cache.cfg @@ -0,0 +1,259 @@ +# Cache size +//-size (bytes) 2048 +//-size (bytes) 4096 +//-size (bytes) 32768 +//-size (bytes) 131072 +//-size (bytes) 262144 +//-size (bytes) 1048576 +//-size (bytes) 2097152 +//-size (bytes) 4194304 +-size (bytes) 8388608 +//-size (bytes) 16777216 +//-size (bytes) 33554432 +//-size (bytes) 134217728 +//-size (bytes) 67108864 +//-size (bytes) 1073741824 + +# power gating +-Array Power Gating - "false" +-WL Power Gating - "false" +-CL Power Gating - "false" +-Bitline floating - "false" +-Interconnect Power Gating - "false" +-Power Gating Performance Loss 0.01 + +# Line size +//-block size (bytes) 8 +-block size (bytes) 64 + +# To model Fully Associative cache, set associativity to zero +//-associativity 0 +//-associativity 2 +//-associativity 4 +//-associativity 8 +-associativity 8 + +-read-write port 1 +-exclusive read port 0 +-exclusive write port 0 +-single ended read ports 0 + +# Multiple banks connected using a bus +-UCA bank count 1 +-technology (u) 0.022 +//-technology (u) 0.040 +//-technology (u) 0.032 +//-technology (u) 0.090 + +# following three parameters are meaningful only for main memories + +-page size (bits) 8192 +-burst length 8 +-internal prefetch width 8 + +# following parameter can have one of five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +-Data array cell type - "itrs-hp" +//-Data array cell type - "itrs-lstp" +//-Data array cell type - "itrs-lop" + +# following parameter can have one of three values -- (itrs-hp, itrs-lstp, itrs-lop) +-Data array peripheral type - "itrs-hp" +//-Data array peripheral type - "itrs-lstp" +//-Data array peripheral type - "itrs-lop" + +# following parameter can have one of five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +-Tag array cell type - "itrs-hp" +//-Tag array cell type - "itrs-lstp" +//-Tag array cell type - "itrs-lop" + +# following parameter can have one of three values -- (itrs-hp, itrs-lstp, itrs-lop) +-Tag array peripheral type - "itrs-hp" +//-Tag array peripheral type - "itrs-lstp" +//-Tag array peripheral type - "itrs-lop + +# Bus width include data bits and address bits required by the decoder +//-output/input bus width 16 +-output/input bus width 512 + +// 300-400 in steps of 10 +-operating temperature (K) 360 + +# Type of memory - cache (with a tag array) or ram (scratch ram similar to a register file) +# or main memory (no tag array and every access will happen at a page granularity Ref: CACTI 5.3 report) +-cache type "cache" +//-cache type "ram" +//-cache type "main memory" + +# to model special structure like branch target buffers, directory, etc. +# change the tag size parameter +# if you want cacti to calculate the tagbits, set the tag size to "default" +-tag size (b) "default" +//-tag size (b) 22 + +# fast - data and tag access happen in parallel +# sequential - data array is accessed after accessing the tag array +# normal - data array lookup and tag access happen in parallel +# final data block is broadcasted in data array h-tree +# after getting the signal from the tag array +//-access mode (normal, sequential, fast) - "fast" +-access mode (normal, sequential, fast) - "normal" +//-access mode (normal, sequential, fast) - "sequential" + + +# DESIGN OBJECTIVE for UCA (or banks in NUCA) +-design objective (weight delay, dynamic power, leakage power, cycle time, area) 0:0:0:100:0 + +# Percentage deviation from the minimum value +# Ex: A deviation value of 10:1000:1000:1000:1000 will try to find an organization +# that compromises at most 10% delay. +# NOTE: Try reasonable values for % deviation. Inconsistent deviation +# percentage values will not produce any valid organizations. For example, +# 0:0:100:100:100 will try to identify an organization that has both +# least delay and dynamic power. Since such an organization is not possible, CACTI will +# throw an error. Refer CACTI-6 Technical report for more details +-deviate (delay, dynamic power, leakage power, cycle time, area) 20:100000:100000:100000:100000 + +# Objective for NUCA +-NUCAdesign objective (weight delay, dynamic power, leakage power, cycle time, area) 100:100:0:0:100 +-NUCAdeviate (delay, dynamic power, leakage power, cycle time, area) 10:10000:10000:10000:10000 + +# Set optimize tag to ED or ED^2 to obtain a cache configuration optimized for +# energy-delay or energy-delay sq. product +# Note: Optimize tag will disable weight or deviate values mentioned above +# Set it to NONE to let weight and deviate values determine the +# appropriate cache configuration +//-Optimize ED or ED^2 (ED, ED^2, NONE): "ED" +-Optimize ED or ED^2 (ED, ED^2, NONE): "ED^2" +//-Optimize ED or ED^2 (ED, ED^2, NONE): "NONE" + +-Cache model (NUCA, UCA) - "UCA" +//-Cache model (NUCA, UCA) - "NUCA" + +# In order for CACTI to find the optimal NUCA bank value the following +# variable should be assigned 0. +-NUCA bank count 0 + +# NOTE: for nuca network frequency is set to a default value of +# 5GHz in time.c. CACTI automatically +# calculates the maximum possible frequency and downgrades this value if necessary + +# By default CACTI considers both full-swing and low-swing +# wires to find an optimal configuration. However, it is possible to +# restrict the search space by changing the signaling from "default" to +# "fullswing" or "lowswing" type. +-Wire signaling (fullswing, lowswing, default) - "Global_30" +//-Wire signaling (fullswing, lowswing, default) - "default" +//-Wire signaling (fullswing, lowswing, default) - "lowswing" + +//-Wire inside mat - "global" +-Wire inside mat - "semi-global" +//-Wire outside mat - "global" +-Wire outside mat - "semi-global" + +-Interconnect projection - "conservative" +//-Interconnect projection - "aggressive" + +# Contention in network (which is a function of core count and cache level) is one of +# the critical factor used for deciding the optimal bank count value +# core count can be 4, 8, or 16 +//-Core count 4 +-Core count 8 +//-Core count 16 +-Cache level (L2/L3) - "L3" + +-Add ECC - "true" + +//-Print level (DETAILED, CONCISE) - "CONCISE" +-Print level (DETAILED, CONCISE) - "DETAILED" + +# for debugging +//-Print input parameters - "true" +-Print input parameters - "false" +# force CACTI to model the cache with the +# following Ndbl, Ndwl, Nspd, Ndsam, +# and Ndcm values +//-Force cache config - "true" +-Force cache config - "false" +-Ndwl 1 +-Ndbl 1 +-Nspd 0 +-Ndcm 1 +-Ndsam1 0 +-Ndsam2 0 + + + +#### Default CONFIGURATION values for baseline external IO parameters to DRAM. + +# Memory Type (D=DDR3, L=LPDDR2, W=WideIO, S=Low-swing differential) + +//-dram_type "D" +-dram_type "L" +//-dram_type "W" +//-dram_type "S" + +# Memory State (R=Read, W=Write, I=Idle or S=Sleep) + +//-iostate "R" +-iostate "W" +//-iostate "I" +//-iostate "S" + +# Is ECC Enabled (Y=Yes, N=No) + +-dram_ecc "N" + +#Address bus timing + +-addr_timing 0.5 //DDR, for LPDDR2 and LPDDR3 +//-addr_timing 1.0 //SDR for DDR3, Wide-IO +//-addr_timing 2.0 //2T timing +//addr_timing 3.0 // 3T timing + +# Bandwidth (Gbytes per second, this is the effective bandwidth) + +-bus_bw 6.4 GBps //Valid range 0 to 2*bus_freq*num_dq + +# Memory Density (Gbit per memory/DRAM die) + +-mem_density 4 Gb //Valid values 2^n Gb + +# IO frequency (MHz) (frequency of the external memory interface). + +-bus_freq 800 MHz //Valid range 0 to 1.5 GHz for DDR3, 0 to 1.2 GHz for LPDDR3, 0 - 800 MHz for WideIO and 0 - 3 GHz for Low-swing differential + +# Duty Cycle (fraction of time in the Memory State defined above) + +-duty_cycle 1.0 //Valid range 0 to 1.0 + +# Activity factor for Data (0->1 transitions) per cycle (for DDR, need to account for the higher activity in this parameter. E.g. max. activity factor for DDR is 1.0, for SDR is 0.5) + +-activity_dq 1.0 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR + +# Activity factor for Control/Address (0->1 transitions) per cycle (for DDR, need to account for the higher activity in this parameter. E.g. max. activity factor for DDR is 1.0, for SDR is 0.5) + +-activity_ca 0.5 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR + +# Number of DQ pins + +-num_dq 32 //Include ECC pins as well (if present). If ECC pins are included, the bus bandwidth is 2*(num_dq-#of ECC pins)*bus_freq. Valid range 0 to 72. + +# Number of DQS pins + +-num_dqs 8 //2 x differential pairs. Include ECC pins as well. Valid range 0 to 18. For x4 memories, could have 36 DQS pins. + +# Number of CA pins + +-num_ca 14 //Valid range 0 to 35 pins. + +# Number of CLK pins + +-num_clk 2 //2 x differential pair. Valid values: 0/2/4. + +# Number of Physical Ranks + +-num_mem_dq 2 //Number of ranks (loads on DQ and DQS) per DIMM or buffer chip + +# Width of the Memory Data Bus + +-mem_data_width 32 //x4 or x8 or x16 or x32 or x128 memories \ No newline at end of file diff --git a/Project_FARSI/cacti_for_FARSI/sample_config_files/wideio_cache.cfg b/Project_FARSI/cacti_for_FARSI/sample_config_files/wideio_cache.cfg new file mode 100644 index 00000000..a4156702 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/sample_config_files/wideio_cache.cfg @@ -0,0 +1,259 @@ +# Cache size +//-size (bytes) 2048 +//-size (bytes) 4096 +//-size (bytes) 32768 +//-size (bytes) 131072 +//-size (bytes) 262144 +//-size (bytes) 1048576 +//-size (bytes) 2097152 +//-size (bytes) 4194304 +-size (bytes) 8388608 +//-size (bytes) 16777216 +//-size (bytes) 33554432 +//-size (bytes) 134217728 +//-size (bytes) 67108864 +//-size (bytes) 1073741824 + +# power gating +-Array Power Gating - "false" +-WL Power Gating - "false" +-CL Power Gating - "false" +-Bitline floating - "false" +-Interconnect Power Gating - "false" +-Power Gating Performance Loss 0.01 + +# Line size +//-block size (bytes) 8 +-block size (bytes) 64 + +# To model Fully Associative cache, set associativity to zero +//-associativity 0 +//-associativity 2 +//-associativity 4 +//-associativity 8 +-associativity 8 + +-read-write port 1 +-exclusive read port 0 +-exclusive write port 0 +-single ended read ports 0 + +# Multiple banks connected using a bus +-UCA bank count 1 +-technology (u) 0.022 +//-technology (u) 0.040 +//-technology (u) 0.032 +//-technology (u) 0.090 + +# following three parameters are meaningful only for main memories + +-page size (bits) 8192 +-burst length 8 +-internal prefetch width 8 + +# following parameter can have one of five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +-Data array cell type - "itrs-hp" +//-Data array cell type - "itrs-lstp" +//-Data array cell type - "itrs-lop" + +# following parameter can have one of three values -- (itrs-hp, itrs-lstp, itrs-lop) +-Data array peripheral type - "itrs-hp" +//-Data array peripheral type - "itrs-lstp" +//-Data array peripheral type - "itrs-lop" + +# following parameter can have one of five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +-Tag array cell type - "itrs-hp" +//-Tag array cell type - "itrs-lstp" +//-Tag array cell type - "itrs-lop" + +# following parameter can have one of three values -- (itrs-hp, itrs-lstp, itrs-lop) +-Tag array peripheral type - "itrs-hp" +//-Tag array peripheral type - "itrs-lstp" +//-Tag array peripheral type - "itrs-lop + +# Bus width include data bits and address bits required by the decoder +//-output/input bus width 16 +-output/input bus width 512 + +// 300-400 in steps of 10 +-operating temperature (K) 360 + +# Type of memory - cache (with a tag array) or ram (scratch ram similar to a register file) +# or main memory (no tag array and every access will happen at a page granularity Ref: CACTI 5.3 report) +-cache type "cache" +//-cache type "ram" +//-cache type "main memory" + +# to model special structure like branch target buffers, directory, etc. +# change the tag size parameter +# if you want cacti to calculate the tagbits, set the tag size to "default" +-tag size (b) "default" +//-tag size (b) 22 + +# fast - data and tag access happen in parallel +# sequential - data array is accessed after accessing the tag array +# normal - data array lookup and tag access happen in parallel +# final data block is broadcasted in data array h-tree +# after getting the signal from the tag array +//-access mode (normal, sequential, fast) - "fast" +-access mode (normal, sequential, fast) - "normal" +//-access mode (normal, sequential, fast) - "sequential" + + +# DESIGN OBJECTIVE for UCA (or banks in NUCA) +-design objective (weight delay, dynamic power, leakage power, cycle time, area) 0:0:0:100:0 + +# Percentage deviation from the minimum value +# Ex: A deviation value of 10:1000:1000:1000:1000 will try to find an organization +# that compromises at most 10% delay. +# NOTE: Try reasonable values for % deviation. Inconsistent deviation +# percentage values will not produce any valid organizations. For example, +# 0:0:100:100:100 will try to identify an organization that has both +# least delay and dynamic power. Since such an organization is not possible, CACTI will +# throw an error. Refer CACTI-6 Technical report for more details +-deviate (delay, dynamic power, leakage power, cycle time, area) 20:100000:100000:100000:100000 + +# Objective for NUCA +-NUCAdesign objective (weight delay, dynamic power, leakage power, cycle time, area) 100:100:0:0:100 +-NUCAdeviate (delay, dynamic power, leakage power, cycle time, area) 10:10000:10000:10000:10000 + +# Set optimize tag to ED or ED^2 to obtain a cache configuration optimized for +# energy-delay or energy-delay sq. product +# Note: Optimize tag will disable weight or deviate values mentioned above +# Set it to NONE to let weight and deviate values determine the +# appropriate cache configuration +//-Optimize ED or ED^2 (ED, ED^2, NONE): "ED" +-Optimize ED or ED^2 (ED, ED^2, NONE): "ED^2" +//-Optimize ED or ED^2 (ED, ED^2, NONE): "NONE" + +-Cache model (NUCA, UCA) - "UCA" +//-Cache model (NUCA, UCA) - "NUCA" + +# In order for CACTI to find the optimal NUCA bank value the following +# variable should be assigned 0. +-NUCA bank count 0 + +# NOTE: for nuca network frequency is set to a default value of +# 5GHz in time.c. CACTI automatically +# calculates the maximum possible frequency and downgrades this value if necessary + +# By default CACTI considers both full-swing and low-swing +# wires to find an optimal configuration. However, it is possible to +# restrict the search space by changing the signaling from "default" to +# "fullswing" or "lowswing" type. +-Wire signaling (fullswing, lowswing, default) - "Global_30" +//-Wire signaling (fullswing, lowswing, default) - "default" +//-Wire signaling (fullswing, lowswing, default) - "lowswing" + +//-Wire inside mat - "global" +-Wire inside mat - "semi-global" +//-Wire outside mat - "global" +-Wire outside mat - "semi-global" + +-Interconnect projection - "conservative" +//-Interconnect projection - "aggressive" + +# Contention in network (which is a function of core count and cache level) is one of +# the critical factor used for deciding the optimal bank count value +# core count can be 4, 8, or 16 +//-Core count 4 +-Core count 8 +//-Core count 16 +-Cache level (L2/L3) - "L3" + +-Add ECC - "true" + +//-Print level (DETAILED, CONCISE) - "CONCISE" +-Print level (DETAILED, CONCISE) - "DETAILED" + +# for debugging +//-Print input parameters - "true" +-Print input parameters - "false" +# force CACTI to model the cache with the +# following Ndbl, Ndwl, Nspd, Ndsam, +# and Ndcm values +//-Force cache config - "true" +-Force cache config - "false" +-Ndwl 1 +-Ndbl 1 +-Nspd 0 +-Ndcm 1 +-Ndsam1 0 +-Ndsam2 0 + + + +#### Default CONFIGURATION values for baseline external IO parameters to DRAM. + +# Memory Type (D=DDR3, L=LPDDR2, W=WideIO, S=Low-swing differential) + +//-dram_type "D" +//-dram_type "L" +-dram_type "W" +//-dram_type "S" + +# Memory State (R=Read, W=Write, I=Idle or S=Sleep) + +//-iostate "R" +-iostate "W" +//-iostate "I" +//-iostate "S" + +# Is ECC Enabled (Y=Yes, N=No) + +-dram_ecc "N" + +#Address bus timing + +//-addr_timing 0.5 //DDR, for LPDDR2 and LPDDR3 +-addr_timing 1.0 //SDR for DDR3, Wide-IO +//-addr_timing 2.0 //2T timing +//addr_timing 3.0 // 3T timing + +# Bandwidth (Gbytes per second, this is the effective bandwidth) + +-bus_bw 12.8 GBps //Valid range 0 to 2*bus_freq*num_dq + +# Memory Density (Gbit per memory/DRAM die) + +-mem_density 4 Gb //Valid values 2^n Gb + +# IO frequency (MHz) (frequency of the external memory interface). + +-bus_freq 400 MHz //Valid range 0 to 1.5 GHz for DDR3, 0 to 1.2 GHz for LPDDR3, 0 - 800 MHz for WideIO and 0 - 3 GHz for Low-swing differential + +# Duty Cycle (fraction of time in the Memory State defined above) + +-duty_cycle 1.0 //Valid range 0 to 1.0 + +# Activity factor for Data (0->1 transitions) per cycle (for DDR, need to account for the higher activity in this parameter. E.g. max. activity factor for DDR is 1.0, for SDR is 0.5) + +-activity_dq 1.0 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR + +# Activity factor for Control/Address (0->1 transitions) per cycle (for DDR, need to account for the higher activity in this parameter. E.g. max. activity factor for DDR is 1.0, for SDR is 0.5) + +-activity_ca 0.5 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR + +# Number of DQ pins + +-num_dq 128 //Include ECC pins as well (if present). If ECC pins are included, the bus bandwidth is 2*(num_dq-#of ECC pins)*bus_freq. Valid range 0 to 72. + +# Number of DQS pins + +-num_dqs 16 //2 x differential pairs. Include ECC pins as well. Valid range 0 to 18. For x4 memories, could have 36 DQS pins. + +# Number of CA pins + +-num_ca 30 //Valid range 0 to 35 pins. + +# Number of CLK pins + +-num_clk 2 //2 x differential pair. Valid values: 0/2/4. + +# Number of Physical Ranks + +-num_mem_dq 2 //Number of ranks (loads on DQ and DQS) per DIMM or buffer chip + +# Width of the Memory Data Bus + +-mem_data_width 128 //x4 or x8 or x16 or x32 or x128 memories \ No newline at end of file diff --git a/Project_FARSI/cacti_for_FARSI/subarray.cc b/Project_FARSI/cacti_for_FARSI/subarray.cc new file mode 100644 index 00000000..9dfeefc8 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/subarray.cc @@ -0,0 +1,205 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + + + +#include +#include +#include + +#include "subarray.h" + + +Subarray::Subarray(const DynamicParameter & dp_, bool is_fa_): + dp(dp_), num_rows(dp.num_r_subarray), num_cols(dp.num_c_subarray), + num_cols_fa_cam(dp.tag_num_c_subarray), num_cols_fa_ram(dp.data_num_c_subarray), + cell(dp.cell), cam_cell(dp.cam_cell), is_fa(is_fa_) +{ + //num_cols=7; + //cout<<"num_cols ="<< num_cols <add_ecc_b_ ? (int)ceil(num_cols / num_bits_per_ecc_b_) : 0); // ECC overhead + uint32_t ram_num_cells_wl_stitching = + (dp.ram_cell_tech_type == lp_dram) ? dram_num_cells_wl_stitching_ : + (dp.ram_cell_tech_type == comm_dram) ? comm_dram_num_cells_wl_stitching_ : sram_num_cells_wl_stitching_; + + area.h = cell.h * num_rows; + + area.w = cell.w * num_cols + + ceil(num_cols / ram_num_cells_wl_stitching) * g_tp.ram_wl_stitching_overhead_; // stitching overhead + + if (g_ip->print_detail_debug) + { + cout << "subarray.cc: ram_num_cells_wl_stitching = " << ram_num_cells_wl_stitching<add_ecc_b_ ? (int)ceil(num_cols_fa_cam / num_bits_per_ecc_b_) : 0; + num_cols_fa_ram += (g_ip->add_ecc_b_ ? (int)ceil(num_cols_fa_ram / num_bits_per_ecc_b_) : 0); + num_cols = num_cols_fa_cam + num_cols_fa_ram; + } + else + { + num_cols_fa_cam += g_ip->add_ecc_b_ ? (int)ceil(num_cols_fa_cam / num_bits_per_ecc_b_) : 0; + num_cols_fa_ram = 0; + num_cols = num_cols_fa_cam; + } + + area.h = cam_cell.h * (num_rows + 1);//height of subarray is decided by CAM array. blank space in sram array are filled with dummy cells + area.w = cam_cell.w * num_cols_fa_cam + cell.w * num_cols_fa_ram + + ceil((num_cols_fa_cam + num_cols_fa_ram) / sram_num_cells_wl_stitching_)*g_tp.ram_wl_stitching_overhead_ + + 16*g_tp.wire_local.pitch //the overhead for the NAND gate to connect the two halves + + 128*g_tp.wire_local.pitch;//the overhead for the drivers from matchline to wordline of RAM + + + } + + assert(area.h>0); + assert(area.w>0); + compute_C(); +} + + + +Subarray::~Subarray() +{ +} + + + +double Subarray::get_total_cell_area() +{ +// return (is_fa==false? cell.get_area() * num_rows * num_cols +// //: cam_cell.h*(num_rows+1)*(num_cols_fa_cam + sram_cell.get_area()*num_cols_fa_ram)); +// : cam_cell.get_area()*(num_rows+1)*(num_cols_fa_cam + num_cols_fa_ram)); +// //: cam_cell.get_area()*(num_rows+1)*num_cols_fa_cam + sram_cell.get_area()*(num_rows+1)*num_cols_fa_ram);//for FA, this area does not include the dummy cells in SRAM arrays. + + if (!(is_fa || dp.pure_cam)) + return (cell.get_area() * num_rows * num_cols); + else if (is_fa) + { //for FA, this area includes the dummy cells in SRAM arrays. + //return (cam_cell.get_area()*(num_rows+1)*(num_cols_fa_cam + num_cols_fa_ram)); + //cout<<"diff" < +void init_tech_params(double technology, bool is_tag) +{ + g_tp.init(technology,is_tag); +} + +void printing(const char * name, double value) +{ + cout << "tech " << name << " " << value << endl; +} + +void printing_int(const char * name, uint64_t value) +{ + cout << "tech " << name << " " << value << endl; +} +void print_g_tp() +{ + printing("g_tp.peri_global.Vdd",g_tp.peri_global.Vdd); + printing("g_tp.peri_global.Vcc_min",g_tp.peri_global.Vcc_min); + printing("g_tp.peri_global.t_ox",g_tp.peri_global.t_ox); + printing("g_tp.peri_global.Vth",g_tp.peri_global.Vth); + printing("g_tp.peri_global.C_ox",g_tp.peri_global.C_ox); + printing("g_tp.peri_global.C_g_ideal",g_tp.peri_global.C_g_ideal); + printing("g_tp.peri_global.C_fringe",g_tp.peri_global.C_fringe); + printing("g_tp.peri_global.C_junc",g_tp.peri_global.C_junc); + printing("g_tp.peri_global.C_junc_sidewall",g_tp.peri_global.C_junc_sidewall); + printing("g_tp.peri_global.l_phy",g_tp.peri_global.l_phy); + printing("g_tp.peri_global.l_elec",g_tp.peri_global.l_elec); + printing("g_tp.peri_global.I_on_n",g_tp.peri_global.I_on_n); + printing("g_tp.peri_global.R_nch_on",g_tp.peri_global.R_nch_on); + printing("g_tp.peri_global.R_pch_on",g_tp.peri_global.R_pch_on); + printing("g_tp.peri_global.n_to_p_eff_curr_drv_ratio",g_tp.peri_global.n_to_p_eff_curr_drv_ratio); + printing("g_tp.peri_global.long_channel_leakage_reduction",g_tp.peri_global.long_channel_leakage_reduction); + printing("g_tp.peri_global.I_off_n",g_tp.peri_global.I_off_n); + printing("g_tp.peri_global.I_off_p",g_tp.peri_global.I_off_p); + printing("g_tp.peri_global.I_g_on_n",g_tp.peri_global.I_g_on_n); + printing("g_tp.peri_global.I_g_on_p",g_tp.peri_global.I_g_on_p); + + printing("g_tp.peri_global.Mobility_n",g_tp.peri_global.Mobility_n); + + printing("g_tp.sleep_tx.Vdd",g_tp.sleep_tx.Vdd); + printing("g_tp.sleep_tx.Vcc_min",g_tp.sleep_tx.Vcc_min); + printing("g_tp.sleep_tx.t_ox",g_tp.sleep_tx.t_ox); + printing("g_tp.sleep_tx.Vth",g_tp.sleep_tx.Vth); + printing("g_tp.sleep_tx.C_ox",g_tp.sleep_tx.C_ox); + printing("g_tp.sleep_tx.C_g_ideal",g_tp.sleep_tx.C_g_ideal); + printing("g_tp.sleep_tx.C_fringe",g_tp.sleep_tx.C_fringe); + printing("g_tp.sleep_tx.C_junc",g_tp.sleep_tx.C_junc); + printing("g_tp.sleep_tx.C_junc_sidewall",g_tp.sleep_tx.C_junc_sidewall); + printing("g_tp.sleep_tx.l_phy",g_tp.sleep_tx.l_phy); + printing("g_tp.sleep_tx.l_elec",g_tp.sleep_tx.l_elec); + printing("g_tp.sleep_tx.I_on_n",g_tp.sleep_tx.I_on_n); + printing("g_tp.sleep_tx.R_nch_on",g_tp.sleep_tx.R_nch_on); + printing("g_tp.sleep_tx.R_pch_on",g_tp.sleep_tx.R_pch_on); + printing("g_tp.sleep_tx.n_to_p_eff_curr_drv_ratio",g_tp.sleep_tx.n_to_p_eff_curr_drv_ratio); + printing("g_tp.sleep_tx.long_channel_leakage_reduction",g_tp.sleep_tx.long_channel_leakage_reduction); + printing("g_tp.sleep_tx.I_off_n",g_tp.sleep_tx.I_off_n); + printing("g_tp.sleep_tx.I_off_p",g_tp.sleep_tx.I_off_p); + printing("g_tp.sleep_tx.I_g_on_n",g_tp.sleep_tx.I_g_on_n); + printing("g_tp.sleep_tx.I_g_on_p",g_tp.sleep_tx.I_g_on_p); + printing("g_tp.sleep_tx.Mobility_n",g_tp.sleep_tx.Mobility_n); + + printing("g_tp.sram_cell.Vdd",g_tp.sram_cell.Vdd); + printing("g_tp.sram_cell.Vcc_min",g_tp.sram_cell.Vcc_min); + printing("g_tp.sram_cell.l_phy",g_tp.sram_cell.l_phy); + printing("g_tp.sram_cell.l_elec",g_tp.sram_cell.l_elec); + printing("g_tp.sram_cell.t_ox",g_tp.sram_cell.t_ox); + printing("g_tp.sram_cell.Vth",g_tp.sram_cell.Vth); + printing("g_tp.sram_cell.C_g_ideal",g_tp.sram_cell.C_g_ideal); + printing("g_tp.sram_cell.C_fringe",g_tp.sram_cell.C_fringe); + printing("g_tp.sram_cell.C_junc",g_tp.sram_cell.C_junc); + printing("g_tp.sram_cell.C_junc_sidewall",g_tp.sram_cell.C_junc_sidewall); + printing("g_tp.sram_cell.I_on_n",g_tp.sram_cell.I_on_n); + printing("g_tp.sram_cell.R_nch_on",g_tp.sram_cell.R_nch_on); + printing("g_tp.sram_cell.R_pch_on",g_tp.sram_cell.R_pch_on); + printing("g_tp.sram_cell.n_to_p_eff_curr_drv_ratio",g_tp.sram_cell.n_to_p_eff_curr_drv_ratio); + printing("g_tp.sram_cell.long_channel_leakage_reduction",g_tp.sram_cell.long_channel_leakage_reduction); + printing("g_tp.sram_cell.I_off_n",g_tp.sram_cell.I_off_n); + printing("g_tp.sram_cell.I_off_p",g_tp.sram_cell.I_off_p); + printing("g_tp.sram_cell.I_g_on_n",g_tp.sram_cell.I_g_on_n); + printing("g_tp.sram_cell.I_g_on_p",g_tp.sram_cell.I_g_on_p); + + printing("g_tp.dram_cell_Vdd",g_tp.dram_cell_Vdd); + printing("g_tp.dram_acc.Vth",g_tp.dram_acc.Vth); + printing("g_tp.dram_acc.l_phy",g_tp.dram_acc.l_phy); + printing("g_tp.dram_acc.l_elec",g_tp.dram_acc.l_elec); + printing("g_tp.dram_acc.C_g_ideal",g_tp.dram_acc.C_g_ideal); + printing("g_tp.dram_acc.C_fringe",g_tp.dram_acc.C_fringe); + printing("g_tp.dram_acc.C_junc",g_tp.dram_acc.C_junc); + printing("g_tp.dram_acc.C_junc_sidewall",g_tp.dram_acc.C_junc_sidewall); + printing("g_tp.dram_cell_I_on",g_tp.dram_cell_I_on); + printing("g_tp.dram_cell_I_off_worst_case_len_temp",g_tp.dram_cell_I_off_worst_case_len_temp); + printing("g_tp.dram_acc.I_on_n",g_tp.dram_acc.I_on_n); + printing("g_tp.dram_cell_C",g_tp.dram_cell_C); + printing("g_tp.vpp",g_tp.vpp); + printing("g_tp.dram_wl.l_phy",g_tp.dram_wl.l_phy); + printing("g_tp.dram_wl.l_elec",g_tp.dram_wl.l_elec); + printing("g_tp.dram_wl.C_g_ideal",g_tp.dram_wl.C_g_ideal); + printing("g_tp.dram_wl.C_fringe",g_tp.dram_wl.C_fringe); + printing("g_tp.dram_wl.C_junc",g_tp.dram_wl.C_junc); + printing("g_tp.dram_wl.C_junc_sidewall",g_tp.dram_wl.C_junc_sidewall); + printing("g_tp.dram_wl.I_on_n",g_tp.dram_wl.I_on_n); + printing("g_tp.dram_wl.R_nch_on",g_tp.dram_wl.R_nch_on); + printing("g_tp.dram_wl.R_pch_on",g_tp.dram_wl.R_pch_on); + printing("g_tp.dram_wl.n_to_p_eff_curr_drv_ratio",g_tp.dram_wl.n_to_p_eff_curr_drv_ratio); + printing("g_tp.dram_wl.long_channel_leakage_reduction",g_tp.dram_wl.long_channel_leakage_reduction); + printing("g_tp.dram_wl.I_off_n",g_tp.dram_wl.I_off_n); + printing("g_tp.dram_wl.I_off_p",g_tp.dram_wl.I_off_p); + + printing("g_tp.cam_cell.Vdd",g_tp.cam_cell.Vdd); + printing("g_tp.cam_cell.l_phy",g_tp.cam_cell.l_phy); + printing("g_tp.cam_cell.l_elec",g_tp.cam_cell.l_elec); + printing("g_tp.cam_cell.t_ox",g_tp.cam_cell.t_ox); + printing("g_tp.cam_cell.Vth",g_tp.cam_cell.Vth); + printing("g_tp.cam_cell.C_g_ideal",g_tp.cam_cell.C_g_ideal); + printing("g_tp.cam_cell.C_fringe",g_tp.cam_cell.C_fringe); + printing("g_tp.cam_cell.C_junc",g_tp.cam_cell.C_junc); + printing("g_tp.cam_cell.C_junc_sidewall",g_tp.cam_cell.C_junc_sidewall); + printing("g_tp.cam_cell.I_on_n",g_tp.cam_cell.I_on_n); + printing("g_tp.cam_cell.R_nch_on",g_tp.cam_cell.R_nch_on); + printing("g_tp.cam_cell.R_pch_on",g_tp.cam_cell.R_pch_on); + printing("g_tp.cam_cell.n_to_p_eff_curr_drv_ratio",g_tp.cam_cell.n_to_p_eff_curr_drv_ratio); + printing("g_tp.cam_cell.long_channel_leakage_reduction",g_tp.cam_cell.long_channel_leakage_reduction); + printing("g_tp.cam_cell.I_off_n",g_tp.cam_cell.I_off_n); + printing("g_tp.cam_cell.I_off_p",g_tp.cam_cell.I_off_p); + printing("g_tp.cam_cell.I_g_on_n",g_tp.cam_cell.I_g_on_n); + printing("g_tp.cam_cell.I_g_on_p",g_tp.cam_cell.I_g_on_p); + + printing("g_tp.dram.cell_a_w",g_tp.dram.cell_a_w); + printing("g_tp.dram.cell_pmos_w",g_tp.dram.cell_pmos_w); + printing("g_tp.dram.cell_nmos_w",g_tp.dram.cell_nmos_w); + + + printing("g_tp.sram.cell_a_w",g_tp.sram.cell_a_w); + printing("g_tp.sram.cell_pmos_w",g_tp.sram.cell_pmos_w); + printing("g_tp.sram.cell_nmos_w",g_tp.sram.cell_nmos_w); + + + printing("g_tp.cam.cell_a_w",g_tp.cam.cell_a_w); + printing("g_tp.cam.cell_pmos_w",g_tp.cam.cell_pmos_w); + printing("g_tp.cam.cell_nmos_w",g_tp.cam.cell_nmos_w); + + printing("g_tp.scaling_factor.logic_scaling_co_eff",g_tp.scaling_factor.logic_scaling_co_eff); + printing("g_tp.scaling_factor.core_tx_density",g_tp.scaling_factor.core_tx_density); + printing("g_tp.chip_layout_overhead",g_tp.chip_layout_overhead); + printing("g_tp.macro_layout_overhead",g_tp.macro_layout_overhead); + printing("g_tp.sckt_co_eff",g_tp.sckt_co_eff); + + printing("g_tp.w_comp_inv_p1",g_tp.w_comp_inv_p1); + printing("g_tp.w_comp_inv_n1",g_tp.w_comp_inv_n1); + printing("g_tp.w_comp_inv_p2",g_tp.w_comp_inv_p2); + printing("g_tp.w_comp_inv_n2",g_tp.w_comp_inv_n2); + printing("g_tp.w_comp_inv_p3",g_tp.w_comp_inv_p3); + printing("g_tp.w_comp_inv_n3",g_tp.w_comp_inv_n3); + printing("g_tp.w_eval_inv_p",g_tp.w_eval_inv_p); + printing("g_tp.w_eval_inv_n",g_tp.w_eval_inv_n); + printing("g_tp.w_comp_n",g_tp.w_comp_n); + printing("g_tp.w_comp_p",g_tp.w_comp_p); + + printing("g_tp.MIN_GAP_BET_P_AND_N_DIFFS",g_tp.MIN_GAP_BET_P_AND_N_DIFFS); + printing("g_tp.MIN_GAP_BET_SAME_TYPE_DIFFS",g_tp.MIN_GAP_BET_SAME_TYPE_DIFFS); + printing("g_tp.HPOWERRAIL",g_tp.HPOWERRAIL); + printing("g_tp.cell_h_def",g_tp.cell_h_def); + printing("g_tp.w_poly_contact",g_tp.w_poly_contact); + printing("g_tp.spacing_poly_to_contact",g_tp.spacing_poly_to_contact); + printing("g_tp.spacing_poly_to_poly",g_tp.spacing_poly_to_poly); + printing("g_tp.ram_wl_stitching_overhead_",g_tp.ram_wl_stitching_overhead_); + + printing("g_tp.min_w_nmos_",g_tp.min_w_nmos_); + printing("g_tp.max_w_nmos_",g_tp.max_w_nmos_); + printing("g_tp.w_iso",g_tp.w_iso); + printing("g_tp.w_sense_n",g_tp.w_sense_n); + printing("g_tp.w_sense_p",g_tp.w_sense_p); + printing("g_tp.w_sense_en",g_tp.w_sense_en); + printing("g_tp.w_nmos_b_mux",g_tp.w_nmos_b_mux); + printing("g_tp.w_nmos_sa_mux",g_tp.w_nmos_sa_mux); + + printing("g_tp.max_w_nmos_dec",g_tp.max_w_nmos_dec); + printing_int("g_tp.h_dec",g_tp.h_dec); + + printing("g_tp.peri_global.C_overlap",g_tp.peri_global.C_overlap); + printing("g_tp.sram_cell.C_overlap",g_tp.sram_cell.C_overlap); + printing("g_tp.cam_cell.C_overlap",g_tp.cam_cell.C_overlap); + + printing("g_tp.dram_acc.C_overlap",g_tp.dram_acc.C_overlap); + printing("g_tp.dram_acc.R_nch_on",g_tp.dram_acc.R_nch_on); + + printing("g_tp.dram_wl.C_overlap",g_tp.dram_wl.C_overlap); + + printing("g_tp.gm_sense_amp_latch",g_tp.gm_sense_amp_latch); + + printing("g_tp.dram.b_w",g_tp.dram.b_w); + printing("g_tp.dram.b_h",g_tp.dram.b_h); + printing("g_tp.sram.b_w",g_tp.sram.b_w); + printing("g_tp.sram.b_h",g_tp.sram.b_h); + printing("g_tp.cam.b_w",g_tp.cam.b_w); + printing("g_tp.cam.b_h",g_tp.cam.b_h); + + printing("g_tp.dram.Vbitpre",g_tp.dram.Vbitpre); + printing("g_tp.sram.Vbitpre",g_tp.sram.Vbitpre); + printing("g_tp.sram.Vbitfloating",g_tp.sram.Vbitfloating); + printing("g_tp.cam.Vbitpre",g_tp.cam.Vbitpre); + + printing("g_tp.w_pmos_bl_precharge",g_tp.w_pmos_bl_precharge); + printing("g_tp.w_pmos_bl_eq",g_tp.w_pmos_bl_eq); + + printing("g_tp.wire_local.pitch",g_tp.wire_local.pitch); + printing("g_tp.wire_local.R_per_um",g_tp.wire_local.R_per_um); + printing("g_tp.wire_local.C_per_um",g_tp.wire_local.C_per_um); + printing("g_tp.wire_local.aspect_ratio",g_tp.wire_local.aspect_ratio); + printing("g_tp.wire_local.ild_thickness",g_tp.wire_local.ild_thickness); + printing("g_tp.wire_local.miller_value",g_tp.wire_local.miller_value); + printing("g_tp.wire_local.horiz_dielectric_constant",g_tp.wire_local.horiz_dielectric_constant); + printing("g_tp.wire_local.vert_dielectric_constant",g_tp.wire_local.vert_dielectric_constant); + + printing("g_tp.wire_inside_mat.pitch",g_tp.wire_inside_mat.pitch); + printing("g_tp.wire_inside_mat.R_per_um",g_tp.wire_inside_mat.R_per_um); + printing("g_tp.wire_inside_mat.C_per_um",g_tp.wire_inside_mat.C_per_um); + printing("g_tp.wire_inside_mat.aspect_ratio",g_tp.wire_inside_mat.aspect_ratio); + printing("g_tp.wire_inside_mat.ild_thickness",g_tp.wire_inside_mat.ild_thickness); + printing("g_tp.wire_inside_mat.miller_value",g_tp.wire_inside_mat.miller_value); + printing("g_tp.wire_inside_mat.horiz_dielectric_constant",g_tp.wire_inside_mat.horiz_dielectric_constant); + printing("g_tp.wire_inside_mat.vert_dielectric_constant",g_tp.wire_inside_mat.vert_dielectric_constant); + + printing("g_tp.wire_outside_mat.pitch",g_tp.wire_outside_mat.pitch); + printing("g_tp.wire_outside_mat.R_per_um",g_tp.wire_outside_mat.R_per_um); + printing("g_tp.wire_outside_mat.C_per_um",g_tp.wire_outside_mat.C_per_um); + printing("g_tp.wire_outside_mat.aspect_ratio",g_tp.wire_outside_mat.aspect_ratio); + printing("g_tp.wire_outside_mat.ild_thickness",g_tp.wire_outside_mat.ild_thickness); + printing("g_tp.wire_outside_mat.miller_value",g_tp.wire_outside_mat.miller_value); + printing("g_tp.wire_outside_mat.horiz_dielectric_constant",g_tp.wire_outside_mat.horiz_dielectric_constant); + printing("g_tp.wire_outside_mat.vert_dielectric_constant",g_tp.wire_outside_mat.vert_dielectric_constant); + + printing("g_tp.unit_len_wire_del",g_tp.unit_len_wire_del); + + printing("g_tp.sense_delay",g_tp.sense_delay); + printing("g_tp.sense_dy_power",g_tp.sense_dy_power); + + printing("g_tp.tsv_parasitic_resistance_fine",g_tp.tsv_parasitic_resistance_fine); + printing("g_tp.tsv_parasitic_capacitance_fine",g_tp.tsv_parasitic_capacitance_fine); + printing("g_tp.tsv_minimum_area_fine",g_tp.tsv_minimum_area_fine); + + printing("g_tp.tsv_parasitic_resistance_coarse",g_tp.tsv_parasitic_resistance_coarse); + printing("g_tp.tsv_parasitic_capacitance_coarse",g_tp.tsv_parasitic_capacitance_coarse); + printing("g_tp.tsv_minimum_area_coarse",g_tp.tsv_minimum_area_coarse); + + printing("g_tp.tsv_minimum_area_coarse",g_tp.tsv_minimum_area_coarse); + printing("g_tp.fringe_cap",g_tp.fringe_cap); + printing("g_tp.kinv",g_tp.kinv); + printing("g_tp.FO4",g_tp.FO4); + +} diff --git a/Project_FARSI/cacti_for_FARSI/uca.cc b/Project_FARSI/cacti_for_FARSI/uca.cc new file mode 100644 index 00000000..bb6124f1 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/uca.cc @@ -0,0 +1,818 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + + +#include +#include + +#include "uca.h" +#include "TSV.h" +#include "memorybus.h" + + +UCA::UCA(const DynamicParameter & dyn_p) + :dp(dyn_p), bank(dp), nbanks(g_ip->nbanks), refresh_power(0) +{ + int num_banks_ver_dir = 1 << ((bank.area.h > bank.area.w) ? _log2(nbanks)/2 : (_log2(nbanks) - _log2(nbanks)/2)); + int num_banks_hor_dir = nbanks/num_banks_ver_dir; + + if (dp.use_inp_params) + { + RWP = dp.num_rw_ports; + ERP = dp.num_rd_ports; + EWP = dp.num_wr_ports; + SCHP = dp.num_search_ports; + } + else + { + RWP = g_ip->num_rw_ports; + ERP = g_ip->num_rd_ports; + EWP = g_ip->num_wr_ports; + SCHP = g_ip->num_search_ports; + } + + num_addr_b_bank = (dp.number_addr_bits_mat + dp.number_subbanks_decode)*(RWP+ERP+EWP); + num_di_b_bank = dp.num_di_b_bank_per_port * (RWP + EWP); + num_do_b_bank = dp.num_do_b_bank_per_port * (RWP + ERP); + num_si_b_bank = dp.num_si_b_bank_per_port * SCHP; + num_so_b_bank = dp.num_so_b_bank_per_port * SCHP; + + if (!dp.fully_assoc && !dp.pure_cam) + { + + if (g_ip->fast_access && dp.is_tag == false) + { + num_do_b_bank *= g_ip->data_assoc; + } + + htree_in_add = new Htree2(g_ip->wt, bank.area.w, bank.area.h, + num_addr_b_bank, num_di_b_bank,0, num_do_b_bank,0,num_banks_ver_dir*2, num_banks_hor_dir*2, Add_htree, true); + htree_in_data = new Htree2(g_ip->wt, bank.area.w, bank.area.h, + num_addr_b_bank, num_di_b_bank, 0, num_do_b_bank, 0, num_banks_ver_dir*2, num_banks_hor_dir*2, Data_in_htree, true); + htree_out_data = new Htree2(g_ip->wt, bank.area.w, bank.area.h, + num_addr_b_bank, num_di_b_bank, 0, num_do_b_bank, 0, num_banks_ver_dir*2, num_banks_hor_dir*2, Data_out_htree, true); + } + + else + { + + htree_in_add = new Htree2(g_ip->wt, bank.area.w, bank.area.h, + num_addr_b_bank, num_di_b_bank, num_si_b_bank, num_do_b_bank, num_so_b_bank, num_banks_ver_dir*2, num_banks_hor_dir*2, Add_htree, true); + htree_in_data = new Htree2(g_ip->wt, bank.area.w, bank.area.h, + num_addr_b_bank, num_di_b_bank,num_si_b_bank, num_do_b_bank, num_so_b_bank, num_banks_ver_dir*2, num_banks_hor_dir*2, Data_in_htree, true); + htree_out_data = new Htree2(g_ip->wt, bank.area.w, bank.area.h, + num_addr_b_bank, num_di_b_bank,num_si_b_bank, num_do_b_bank, num_so_b_bank, num_banks_ver_dir*2, num_banks_hor_dir*2, Data_out_htree, true); + htree_in_search = new Htree2(g_ip->wt, bank.area.w, bank.area.h, + num_addr_b_bank, num_di_b_bank,num_si_b_bank, num_do_b_bank, num_so_b_bank, num_banks_ver_dir*2, num_banks_hor_dir*2, Data_in_htree, true); + htree_out_search = new Htree2(g_ip->wt, bank.area.w, bank.area.h, + num_addr_b_bank, num_di_b_bank,num_si_b_bank, num_do_b_bank, num_so_b_bank, num_banks_ver_dir*2, num_banks_hor_dir*2, Data_out_htree, true); + } + + area.w = htree_in_data->area.w; + area.h = htree_in_data->area.h; + + area_all_dataramcells = bank.mat.subarray.get_total_cell_area() * dp.num_subarrays * g_ip->nbanks; +// cout<<"area cell"<print_detail_debug) + cout << "uca.cc: g_ip->is_3d_mem = " << g_ip->is_3d_mem << endl; + if(g_ip->is_3d_mem) + { + membus_RAS = new Memorybus(g_ip->wt, bank.mat.area.w, bank.mat.area.h, bank.mat.subarray.area.w, bank.mat.subarray.area.h, + _log2(dp.num_r_subarray * dp.Ndbl), _log2(dp.num_c_subarray * dp.Ndwl), g_ip->burst_depth*g_ip->io_width, dp.Ndbl, dp.Ndwl, Row_add_path, dp); + membus_CAS = new Memorybus(g_ip->wt, bank.mat.area.w, bank.mat.area.h, bank.mat.subarray.area.w, bank.mat.subarray.area.h, + _log2(dp.num_r_subarray * dp.Ndbl), _log2(dp.num_c_subarray * dp.Ndwl), g_ip->burst_depth*g_ip->io_width, dp.Ndbl, dp.Ndwl, Col_add_path, dp); + membus_data = new Memorybus(g_ip->wt, bank.mat.area.w, bank.mat.area.h, bank.mat.subarray.area.w, bank.mat.subarray.area.h, + _log2(dp.num_r_subarray * dp.Ndbl), _log2(dp.num_c_subarray * dp.Ndwl), g_ip->burst_depth*g_ip->io_width, dp.Ndbl, dp.Ndwl, Data_path, dp); + area.h = membus_RAS->area.h; + area.w = membus_RAS->area.w; + + if (g_ip->print_detail_debug) + { + cout<<"uca.cc: area.h = "<is_3d_mem) + { + // Add TSV delay to the terms + // --- Although there are coarse and fine, because is_array and os_bank TSV are the same, so they are the same + TSV tsv_os_bank(Coarse); + TSV tsv_is_subarray(Fine); + if(g_ip->print_detail_debug) + { + tsv_os_bank.print_TSV(); + tsv_is_subarray.print_TSV(); + } + + comm_bits = 6; + row_add_bits = _log2(dp.num_r_subarray * dp.Ndbl); + col_add_bits = _log2(dp.num_c_subarray * dp.Ndwl); + data_bits = g_ip->burst_depth * g_ip->io_width; + + //enum Part_grain part_gran = Fine_rank_level; + + double redundancy_perc_TSV = 0.5; + switch(g_ip->partition_gran) + { + case 0:// Coarse_rank_level: + delay_TSV_tot = (g_ip->num_die_3d-1) * tsv_os_bank.delay; + num_TSV_tot = (comm_bits + row_add_bits + col_add_bits + data_bits*2) * (1 + redundancy_perc_TSV); //* (g_ip->nbanks/4) + area_TSV_tot = num_TSV_tot * tsv_os_bank.area.get_area(); + dyn_pow_TSV_tot = num_TSV_tot * (g_ip->num_die_3d-1) * tsv_os_bank.power.readOp.dynamic; + dyn_pow_TSV_per_access = (comm_bits + row_add_bits + col_add_bits + data_bits) * (g_ip->num_die_3d-1) * tsv_os_bank.power.readOp.dynamic; + area_address_bus = membus_RAS->area_address_bus * (1.0 + (double)comm_bits/(double)(row_add_bits + col_add_bits)); + area_data_bus = membus_RAS->area_data_bus; + break; + case 1://Fine_rank_level: + delay_TSV_tot = (g_ip->num_die_3d) * tsv_os_bank.delay; + num_TSV_tot = (comm_bits + row_add_bits + col_add_bits + data_bits/2) * g_ip->nbanks * (1 + redundancy_perc_TSV); + area_TSV_tot = num_TSV_tot * tsv_os_bank.area.get_area(); + dyn_pow_TSV_tot = num_TSV_tot * (g_ip->num_die_3d) * tsv_os_bank.power.readOp.dynamic; + dyn_pow_TSV_per_access = (comm_bits + row_add_bits + col_add_bits + data_bits) * (g_ip->num_die_3d) * tsv_os_bank.power.readOp.dynamic; + //area_address_bus = (comm_bits + row_add_bits + col_add_bits) * 25.0; + //area_data_bus = membus_RAS->area_data_bus + (double)data_bits/2 * 25.0; + break; + case 2://Coarse_bank_level: + delay_TSV_tot = (g_ip->num_die_3d) * tsv_os_bank.delay; + num_TSV_tot = (comm_bits + row_add_bits + col_add_bits + data_bits/2) * g_ip->nbanks + * g_ip->num_tier_row_sprd * g_ip->num_tier_col_sprd * (1 + redundancy_perc_TSV); + area_TSV_tot = num_TSV_tot * tsv_os_bank.area.get_area(); + dyn_pow_TSV_tot = num_TSV_tot * (g_ip->num_die_3d) * tsv_os_bank.power.readOp.dynamic; + dyn_pow_TSV_per_access = (comm_bits + row_add_bits + col_add_bits + data_bits) * (g_ip->num_die_3d) * tsv_os_bank.power.readOp.dynamic; + //area_address_bus = (comm_bits + row_add_bits + col_add_bits) * 25.0; + //area_data_bus = (double)data_bits/2 * 25.0; + + //activate_energy *= g_ip->num_tier_row_sprd * g_ip->num_tier_col_sprd; + //read_energy *= g_ip->num_tier_row_sprd * g_ip->num_tier_col_sprd; + //write_energy *= g_ip->num_tier_row_sprd * g_ip->num_tier_col_sprd; + //precharge_energy *= g_ip->num_tier_row_sprd * g_ip->num_tier_col_sprd; + break; + case 3://Fine_bank_level: + delay_TSV_tot = (g_ip->num_die_3d) * tsv_os_bank.delay; + num_TSV_tot = (comm_bits + row_add_bits + col_add_bits + data_bits) * g_ip->nbanks *g_ip->ndwl *g_ip->ndbl + /g_ip->num_tier_col_sprd /g_ip->num_tier_row_sprd * (1 + redundancy_perc_TSV); + area_TSV_tot = num_TSV_tot * tsv_os_bank.area.get_area(); + dyn_pow_TSV_tot = num_TSV_tot * (g_ip->num_die_3d) * tsv_os_bank.power.readOp.dynamic; + dyn_pow_TSV_per_access = (comm_bits + row_add_bits + col_add_bits + data_bits) * (g_ip->num_die_3d) * tsv_os_bank.power.readOp.dynamic; + //area_address_bus = pow(2, (comm_bits + row_add_bits + col_add_bits)) * 25.0; + //area_data_bus = pow(2, data_bits/2) * 25.0; + //activate_energy *= g_ip->num_tier_row_sprd * g_ip->num_tier_col_sprd; + //read_energy *= g_ip->num_tier_row_sprd * g_ip->num_tier_col_sprd; + //write_energy *= g_ip->num_tier_row_sprd * g_ip->num_tier_col_sprd; + //precharge_energy *= g_ip->num_tier_row_sprd * g_ip->num_tier_col_sprd; + break; + default: + assert(0); + break; + } + + if(g_ip->print_detail_debug) + { + cout << "uca.cc: num_TSV_tot = " << num_TSV_tot << endl; + } + + area_lwl_drv = membus_RAS->area_lwl_drv * g_ip->nbanks; + area_row_predec_dec = membus_RAS->area_row_predec_dec * g_ip->nbanks; + area_col_predec_dec = membus_CAS->area_col_predec_dec * g_ip->nbanks; + + area_subarray = membus_RAS->area_subarray * g_ip->nbanks; + area_bus = membus_RAS->area_bus * g_ip->nbanks; + + + area_data_drv = membus_data->area_data_drv * g_ip->nbanks; + area_IOSA = membus_data->area_IOSA * g_ip->nbanks; + area_sense_amp = membus_data->area_sense_amp * g_ip->nbanks; + + area_address_bus = membus_RAS->area_address_bus * (1.0 + (double)comm_bits/(double)(row_add_bits + col_add_bits)) * g_ip->nbanks;; + area_data_bus = membus_RAS->area_data_bus + membus_data->area_local_dataline * g_ip->nbanks; + + area_per_bank = (area_lwl_drv + area_row_predec_dec + area_col_predec_dec + + area_subarray + area_bus + area_data_drv + area_IOSA + + area_address_bus + area_data_bus)/g_ip->nbanks + area_sense_amp; + + + t_RCD += delay_TSV_tot; + t_RAS += delay_TSV_tot; + t_RC += delay_TSV_tot; + t_RP += delay_TSV_tot; + t_CAS += 2 * delay_TSV_tot; + t_RRD += delay_TSV_tot; + + activate_energy += dyn_pow_TSV_per_access; + read_energy += dyn_pow_TSV_per_access; + write_energy += dyn_pow_TSV_per_access; + precharge_energy += dyn_pow_TSV_per_access; + + //double area_per_die = area.get_area(); + //double area_stack_tot = g_ip->num_die_3d * (area.get_area() + area_TSV_tot); + //int num_die = g_ip->num_die_3d; + //area.set_area(area_stack_tot); + + if(g_ip->num_die_3d > 1 || g_ip->partition_gran > 0) + total_area_per_die = area_all_dataramcells + area_TSV_tot; + else + total_area_per_die = area_all_dataramcells; + + + + if(g_ip->is_3d_mem && g_ip->print_detail_debug) + { + + cout<<"------- CACTI 3D DRAM Main Memory -------"<cache_sz) << endl; + cout << " Number of banks: " << (int) g_ip->nbanks << endl; + cout << " Technology size (nm): " << + g_ip->F_sz_nm << endl; + cout << " Page size (bits): " << g_ip->page_sz_bits << endl; + cout << " Burst depth: " << g_ip->burst_depth << endl; + cout << " Chip IO width: " << g_ip->io_width << endl; + cout << " Ndwl: " << dp.Ndwl << endl; + cout << " Ndbl: " << dp.Ndbl << endl; + cout << " # rows in subarray: " << dp.num_r_subarray << endl; + cout << " # columns in subarray: " << dp.num_c_subarray << endl; + + cout << "\nResults:\n"; + cout<<" ******************Timing terms******************"<burst_depth)/(g_ip->sys_freq_MHz*1e6)/2) * 1e3 << " mW" <print_detail_debug) + { + cout<<" ********************Other terms******************"<center_stripe->power.readOp.dynamic + membus_RAS->bank_bus->power.readOp.dynamic + + membus_RAS->add_predec->power.readOp.dynamic + membus_RAS->add_dec->power.readOp.dynamic; + cout<<" Act Bus Energy: "<< act_bus_energy * 1e9 <<" nJ"<center_stripe->delay + membus_RAS->bank_bus->delay + + membus_RAS->add_predec->delay + membus_RAS->add_dec->delay; + cout<<" Act Bus Latency: "<< act_bus_latency * 1e9 <<" ns"<num_die_3d>1) + { + cout<<" ********************TSV terms******************"<is_3d_mem) + { + delete membus_RAS; + delete membus_CAS; + delete membus_data; + } +} + + + +double UCA::compute_delays(double inrisetime) +{ + double outrisetime = bank.compute_delays(inrisetime); + //CACTI3DD + if (g_ip->is_3d_mem) + { + outrisetime = bank.compute_delays(membus_RAS->out_rise_time); + + //ram_delay_inside_mat = bank.mat.delay_bitline;// + bank.mat.delay_matchchline; + //access_time = membus_RAS->delay + bank.mat.delay_bitline + bank.mat.delay_sa + membus_CAS->delay + membus_data->delay; + + //double t_rcd = membus_RAS->delay + bank.mat.delay_bitline + bank.mat.delay_sa; + //t_RCD= membus_RAS->add_dec->delay + membus_RAS->lwl_drv->delay + bank.mat.delay_bitline + bank.mat.delay_sa; + t_RCD = membus_RAS->add_dec->delay + membus_RAS->lwl_drv->delay + bank.mat.delay_bitline + bank.mat.delay_sa; + t_RAS = membus_RAS->delay + bank.mat.delay_bitline + bank.mat.delay_sa + bank.mat.delay_bl_restore; + precharge_delay = bank.mat.delay_writeback + + bank.mat.delay_wl_reset + bank.mat.delay_bl_restore; + t_RP = precharge_delay; + t_RC = t_RAS + t_RP; + t_CAS = membus_CAS->delay + bank.mat.delay_subarray_out_drv + membus_data->delay; + t_RRD = membus_RAS->center_stripe->delay + membus_RAS->bank_bus->delay; + //t_RRD = membus_RAS->delay; + access_time = t_RCD + t_CAS; + multisubbank_interleave_cycle_time = membus_RAS->center_stripe->delay + membus_RAS->bank_bus->delay; + //cout<<"uca.cc: multisubbank_interleave_cycle_time = "<delay = "<delay * 1e9 << " ns" <delay = "<delay * 1e9 << " ns" <delay = "<delay * 1e9 << " ns" <center_stripe->delay = "<center_stripe->delay * 1e9 << " ns" <bank_bus->delay = "<bank_bus->delay * 1e9 << " ns" <add_predec->delay = "<add_predec->delay * 1e9 << " ns" <add_dec->delay = "<add_dec->delay * 1e9 << " ns" <global_WL->delay = "<global_WL->delay * 1e9 << " ns" <lwl_drv->delay = "<lwl_drv->delay * 1e9 << " ns" <center_stripe->delay = "<center_stripe->delay * 1e9 << " ns" <bank_bus->delay = "<bank_bus->delay * 1e9 << " ns" <add_predec->delay = "<add_predec->delay * 1e9 << " ns" <add_dec->delay = "<add_dec->delay * 1e9 << " ns" <column_sel->delay = "<column_sel->delay * 1e9 << " ns" <center_stripe->delay = "<center_stripe->delay * 1e9 << " ns" <bank_bus->delay = "<bank_bus->delay * 1e9 << " ns" <global_data->delay = "<global_data->delay * 1e9 << " ns" <data_drv->delay = "<data_drv->delay * 1e9 << " ns" <local_data->delay = "<local_data->delay * 1e9 << " ns" <delay + bank.htree_in_add->delay; + double max_delay_before_row_decoder = delay_array_to_mat + bank.mat.r_predec->delay; + delay_array_to_sa_mux_lev_1_decoder = delay_array_to_mat + + bank.mat.sa_mux_lev_1_predec->delay + + bank.mat.sa_mux_lev_1_dec->delay; + delay_array_to_sa_mux_lev_2_decoder = delay_array_to_mat + + bank.mat.sa_mux_lev_2_predec->delay + + bank.mat.sa_mux_lev_2_dec->delay; + double delay_inside_mat = bank.mat.row_dec->delay + bank.mat.delay_bitline + bank.mat.delay_sa; + + delay_before_subarray_output_driver = + MAX(MAX(max_delay_before_row_decoder + delay_inside_mat, // row_path + delay_array_to_mat + bank.mat.b_mux_predec->delay + bank.mat.bit_mux_dec->delay + bank.mat.delay_sa), // col_path + MAX(delay_array_to_sa_mux_lev_1_decoder, // sa_mux_lev_1_path + delay_array_to_sa_mux_lev_2_decoder)); // sa_mux_lev_2_path + delay_from_subarray_out_drv_to_out = bank.mat.delay_subarray_out_drv_htree + + bank.htree_out_data->delay + htree_out_data->delay; + access_time = bank.mat.delay_comparator; + + double ram_delay_inside_mat; + if (dp.fully_assoc) + { + //delay of FA contains both CAM tag and RAM data + { //delay of CAM + ram_delay_inside_mat = bank.mat.delay_bitline + bank.mat.delay_matchchline; + access_time = htree_in_add->delay + bank.htree_in_add->delay; + //delay of fully-associative data array + access_time += ram_delay_inside_mat + delay_from_subarray_out_drv_to_out; + } + } + else + { + access_time = delay_before_subarray_output_driver + delay_from_subarray_out_drv_to_out; //data_acc_path + } + + if (dp.is_main_mem) + { + double t_rcd = max_delay_before_row_decoder + delay_inside_mat; + double cas_latency = MAX(delay_array_to_sa_mux_lev_1_decoder, delay_array_to_sa_mux_lev_2_decoder) + + delay_from_subarray_out_drv_to_out; + access_time = t_rcd + cas_latency; + } + + double temp; + + if (!dp.fully_assoc) + { + temp = delay_inside_mat + bank.mat.delay_wl_reset + bank.mat.delay_bl_restore;//TODO: : revisit + if (dp.is_dram) + { + temp += bank.mat.delay_writeback; // temp stores random cycle time + } + + + temp = MAX(temp, bank.mat.r_predec->delay); + temp = MAX(temp, bank.mat.b_mux_predec->delay); + temp = MAX(temp, bank.mat.sa_mux_lev_1_predec->delay); + temp = MAX(temp, bank.mat.sa_mux_lev_2_predec->delay); + } + else + { + ram_delay_inside_mat = bank.mat.delay_bitline + bank.mat.delay_matchchline; + temp = ram_delay_inside_mat + bank.mat.delay_cam_sl_restore + bank.mat.delay_cam_ml_reset + bank.mat.delay_bl_restore + + bank.mat.delay_hit_miss_reset + bank.mat.delay_wl_reset; + + temp = MAX(temp, bank.mat.b_mux_predec->delay);//TODO: revisit whether distinguish cam and ram bitline etc. + temp = MAX(temp, bank.mat.sa_mux_lev_1_predec->delay); + temp = MAX(temp, bank.mat.sa_mux_lev_2_predec->delay); + } + + // The following is true only if the input parameter "repeaters_in_htree" is set to false --Nav + if (g_ip->rpters_in_htree == false) + { + temp = MAX(temp, bank.htree_in_add->max_unpipelined_link_delay); + } + cycle_time = temp; + + double delay_req_network = max_delay_before_row_decoder; + double delay_rep_network = delay_from_subarray_out_drv_to_out; + multisubbank_interleave_cycle_time = MAX(delay_req_network, delay_rep_network); + + if (dp.is_main_mem) + { + multisubbank_interleave_cycle_time = htree_in_add->delay; + precharge_delay = htree_in_add->delay + + bank.htree_in_add->delay + bank.mat.delay_writeback + + bank.mat.delay_wl_reset + bank.mat.delay_bl_restore; + cycle_time = access_time + precharge_delay; + } + else + { + precharge_delay = 0; + } +/** + double dram_array_availability = 0; + if (dp.is_dram) + { + dram_array_availability = (1 - dp.num_r_subarray * cycle_time / dp.dram_refresh_period) * 100; + } +**/ + }//CACTI3DD, else + return outrisetime; +} + + + +// note: currently, power numbers are for a bank of an array +void UCA::compute_power_energy() +{ + bank.compute_power_energy(); + power = bank.power; + //CACTI3DD + if (g_ip->is_3d_mem) + { + double datapath_energy = 0.505e-9 *g_ip->F_sz_nm / 55; + //double chip_IO_width = 4; + //g_ip->burst_len = 4; + activate_energy = membus_RAS->power.readOp.dynamic + (bank.mat.power_bitline.readOp.dynamic + + bank.mat.power_sa.readOp.dynamic) * dp.Ndwl; // /4 + read_energy = (membus_CAS->power.readOp.dynamic + bank.mat.power_subarray_out_drv.readOp.dynamic + + membus_data->power.readOp.dynamic ) + datapath_energy; //* g_ip->burst_len; + write_energy = (membus_CAS->power.readOp.dynamic + bank.mat.power_subarray_out_drv.readOp.dynamic + + membus_data->power.readOp.dynamic + bank.mat.power_sa.readOp.dynamic * g_ip->burst_depth*g_ip->io_width/g_ip->page_sz_bits) + datapath_energy; //* g_ip->burst_len; + precharge_energy = (bank.mat.power_bitline.readOp.dynamic + + bank.mat.power_bl_precharge_eq_drv.readOp.dynamic)* dp.Ndwl; // /4 + + activate_power = activate_energy / t_RC; + double col_cycle_act_row; + //col_cycle_act_row = MAX(MAX(MAX(membus_CAS->center_stripe->delay + membus_CAS->bank_bus->delay, bank.mat.delay_subarray_out_drv), + //membus_data->delay), membus_data->out_seg->delay *g_ip->burst_depth); + //col_cycle_act_row = membus_data->out_seg->delay * g_ip->burst_depth; + col_cycle_act_row = (1e-6/(double)g_ip->sys_freq_MHz)/2 * g_ip->burst_depth; + //--- Activity factor assumption comes from Micron data spreadsheet. + read_power = 0.25 * read_energy / col_cycle_act_row; + write_power = 0.15 * write_energy / col_cycle_act_row; + + if (g_ip->print_detail_debug) + { + cout<<"Row Address Delay components: "<power.readOp.dynamic = "<< membus_RAS->power.readOp.dynamic * 1e9 << " nJ" <power.readOp.dynamic = "<< membus_CAS->power.readOp.dynamic * 1e9 << " nJ" <power.readOp.dynamic = "<< membus_data->power.readOp.dynamic * 1e9 << " nJ" <power_bus.readOp.dynamic = "<power_bus.readOp.dynamic * 1e9 << " nJ" <power_add_predecoder.readOp.dynamic = "<< membus_RAS->power_add_predecoder.readOp.dynamic * 1e9 << " nJ" <power_add_decoders.readOp.dynamic = "<< membus_RAS->power_add_decoders.readOp.dynamic * 1e9 << " nJ" <power_lwl_drv.readOp.dynamic = "<< membus_RAS->power_lwl_drv.readOp.dynamic * 1e9 << " nJ" <power_bus.readOp.dynamic = "<< membus_CAS->power_bus.readOp.dynamic * 1e9 << " nJ" <power_add_predecoder.readOp.dynamic = "<< membus_CAS->power_add_predecoder.readOp.dynamic * 1e9 << " nJ" <power_add_decoders.readOp.dynamic = "<< membus_CAS->power_add_decoders.readOp.dynamic * 1e9 << " nJ" <power.readOp.dynamic = "<< membus_CAS->power.readOp.dynamic * 1e9 << " nJ" <power.readOp.dynamic = "<< membus_data->power.readOp.dynamic * 1e9 << " nJ" <power.readOp.dynamic + htree_out_data->power.readOp.dynamic; + power_routing_to_bank.writeOp.dynamic = htree_in_add->power.readOp.dynamic + htree_in_data->power.readOp.dynamic; + if (dp.fully_assoc || dp.pure_cam) + power_routing_to_bank.searchOp.dynamic= htree_in_search->power.searchOp.dynamic + htree_out_search->power.searchOp.dynamic; + + power_routing_to_bank.readOp.leakage += htree_in_add->power.readOp.leakage + + htree_in_data->power.readOp.leakage + + htree_out_data->power.readOp.leakage; + + power_routing_to_bank.readOp.gate_leakage += htree_in_add->power.readOp.gate_leakage + + htree_in_data->power.readOp.gate_leakage + + htree_out_data->power.readOp.gate_leakage; + if (dp.fully_assoc || dp.pure_cam) + { + power_routing_to_bank.readOp.leakage += htree_in_search->power.readOp.leakage + htree_out_search->power.readOp.leakage; + power_routing_to_bank.readOp.gate_leakage += htree_in_search->power.readOp.gate_leakage + htree_out_search->power.readOp.gate_leakage; + } + + power.searchOp.dynamic += power_routing_to_bank.searchOp.dynamic; + power.readOp.dynamic += power_routing_to_bank.readOp.dynamic; + power.readOp.leakage += power_routing_to_bank.readOp.leakage; + power.readOp.gate_leakage += power_routing_to_bank.readOp.gate_leakage; + + // calculate total write energy per access + power.writeOp.dynamic = power.readOp.dynamic + - bank.mat.power_bitline.readOp.dynamic * dp.num_act_mats_hor_dir + + bank.mat.power_bitline.writeOp.dynamic * dp.num_act_mats_hor_dir + - power_routing_to_bank.readOp.dynamic + + power_routing_to_bank.writeOp.dynamic + + bank.htree_in_data->power.readOp.dynamic + - bank.htree_out_data->power.readOp.dynamic; + + if (dp.is_dram == false) + { + power.writeOp.dynamic -= bank.mat.power_sa.readOp.dynamic * dp.num_act_mats_hor_dir; + } + + dyn_read_energy_from_closed_page = power.readOp.dynamic; + dyn_read_energy_from_open_page = power.readOp.dynamic - + (bank.mat.r_predec->power.readOp.dynamic + + bank.mat.power_row_decoders.readOp.dynamic + + bank.mat.power_bl_precharge_eq_drv.readOp.dynamic + + bank.mat.power_sa.readOp.dynamic + + bank.mat.power_bitline.readOp.dynamic) * dp.num_act_mats_hor_dir; + + dyn_read_energy_remaining_words_in_burst = + (MAX((g_ip->burst_len / g_ip->int_prefetch_w), 1) - 1) * + ((bank.mat.sa_mux_lev_1_predec->power.readOp.dynamic + + bank.mat.sa_mux_lev_2_predec->power.readOp.dynamic + + bank.mat.power_sa_mux_lev_1_decoders.readOp.dynamic + + bank.mat.power_sa_mux_lev_2_decoders.readOp.dynamic + + bank.mat.power_subarray_out_drv.readOp.dynamic) * dp.num_act_mats_hor_dir + + bank.htree_out_data->power.readOp.dynamic + + power_routing_to_bank.readOp.dynamic); + dyn_read_energy_from_closed_page += dyn_read_energy_remaining_words_in_burst; + dyn_read_energy_from_open_page += dyn_read_energy_remaining_words_in_burst; + + activate_energy = htree_in_add->power.readOp.dynamic + + bank.htree_in_add->power_bit.readOp.dynamic * bank.num_addr_b_routed_to_mat_for_act + + (bank.mat.r_predec->power.readOp.dynamic + + bank.mat.power_row_decoders.readOp.dynamic + + bank.mat.power_sa.readOp.dynamic) * dp.num_act_mats_hor_dir; + read_energy = (htree_in_add->power.readOp.dynamic + + bank.htree_in_add->power_bit.readOp.dynamic * bank.num_addr_b_routed_to_mat_for_rd_or_wr + + (bank.mat.sa_mux_lev_1_predec->power.readOp.dynamic + + bank.mat.sa_mux_lev_2_predec->power.readOp.dynamic + + bank.mat.power_sa_mux_lev_1_decoders.readOp.dynamic + + bank.mat.power_sa_mux_lev_2_decoders.readOp.dynamic + + bank.mat.power_subarray_out_drv.readOp.dynamic) * dp.num_act_mats_hor_dir + + bank.htree_out_data->power.readOp.dynamic + + htree_in_data->power.readOp.dynamic) * g_ip->burst_len; + write_energy = (htree_in_add->power.readOp.dynamic + + bank.htree_in_add->power_bit.readOp.dynamic * bank.num_addr_b_routed_to_mat_for_rd_or_wr + + htree_in_data->power.readOp.dynamic + + bank.htree_in_data->power.readOp.dynamic + + (bank.mat.sa_mux_lev_1_predec->power.readOp.dynamic + + bank.mat.sa_mux_lev_2_predec->power.readOp.dynamic + + bank.mat.power_sa_mux_lev_1_decoders.readOp.dynamic + + bank.mat.power_sa_mux_lev_2_decoders.readOp.dynamic) * dp.num_act_mats_hor_dir) * g_ip->burst_len; + precharge_energy = (bank.mat.power_bitline.readOp.dynamic + + bank.mat.power_bl_precharge_eq_drv.readOp.dynamic) * dp.num_act_mats_hor_dir; + } //CACTI3DD + leak_power_subbank_closed_page = + (bank.mat.r_predec->power.readOp.leakage + + bank.mat.b_mux_predec->power.readOp.leakage + + bank.mat.sa_mux_lev_1_predec->power.readOp.leakage + + bank.mat.sa_mux_lev_2_predec->power.readOp.leakage + + bank.mat.power_row_decoders.readOp.leakage + + bank.mat.power_bit_mux_decoders.readOp.leakage + + bank.mat.power_sa_mux_lev_1_decoders.readOp.leakage + + bank.mat.power_sa_mux_lev_2_decoders.readOp.leakage + + bank.mat.leak_power_sense_amps_closed_page_state) * dp.num_act_mats_hor_dir; + + leak_power_subbank_closed_page += + (bank.mat.r_predec->power.readOp.gate_leakage + + bank.mat.b_mux_predec->power.readOp.gate_leakage + + bank.mat.sa_mux_lev_1_predec->power.readOp.gate_leakage + + bank.mat.sa_mux_lev_2_predec->power.readOp.gate_leakage + + bank.mat.power_row_decoders.readOp.gate_leakage + + bank.mat.power_bit_mux_decoders.readOp.gate_leakage + + bank.mat.power_sa_mux_lev_1_decoders.readOp.gate_leakage + + bank.mat.power_sa_mux_lev_2_decoders.readOp.gate_leakage) * dp.num_act_mats_hor_dir; //+ + //bank.mat.leak_power_sense_amps_closed_page_state) * dp.num_act_mats_hor_dir; + + leak_power_subbank_open_page = + (bank.mat.r_predec->power.readOp.leakage + + bank.mat.b_mux_predec->power.readOp.leakage + + bank.mat.sa_mux_lev_1_predec->power.readOp.leakage + + bank.mat.sa_mux_lev_2_predec->power.readOp.leakage + + bank.mat.power_row_decoders.readOp.leakage + + bank.mat.power_bit_mux_decoders.readOp.leakage + + bank.mat.power_sa_mux_lev_1_decoders.readOp.leakage + + bank.mat.power_sa_mux_lev_2_decoders.readOp.leakage + + bank.mat.leak_power_sense_amps_open_page_state) * dp.num_act_mats_hor_dir; + + leak_power_subbank_open_page += + (bank.mat.r_predec->power.readOp.gate_leakage + + bank.mat.b_mux_predec->power.readOp.gate_leakage + + bank.mat.sa_mux_lev_1_predec->power.readOp.gate_leakage + + bank.mat.sa_mux_lev_2_predec->power.readOp.gate_leakage + + bank.mat.power_row_decoders.readOp.gate_leakage + + bank.mat.power_bit_mux_decoders.readOp.gate_leakage + + bank.mat.power_sa_mux_lev_1_decoders.readOp.gate_leakage + + bank.mat.power_sa_mux_lev_2_decoders.readOp.gate_leakage ) * dp.num_act_mats_hor_dir; + //bank.mat.leak_power_sense_amps_open_page_state) * dp.num_act_mats_hor_dir; + + leak_power_request_and_reply_networks = + power_routing_to_bank.readOp.leakage + + bank.htree_in_add->power.readOp.leakage + + bank.htree_in_data->power.readOp.leakage + + bank.htree_out_data->power.readOp.leakage; + + leak_power_request_and_reply_networks += + power_routing_to_bank.readOp.gate_leakage + + bank.htree_in_add->power.readOp.gate_leakage + + bank.htree_in_data->power.readOp.gate_leakage + + bank.htree_out_data->power.readOp.gate_leakage; + + if (dp.fully_assoc || dp.pure_cam) + { + leak_power_request_and_reply_networks += htree_in_search->power.readOp.leakage + htree_out_search->power.readOp.leakage; + leak_power_request_and_reply_networks += htree_in_search->power.readOp.gate_leakage + htree_out_search->power.readOp.gate_leakage; + } + + + if (dp.is_dram) + { // if DRAM, add contribution of power spent in row predecoder drivers, blocks and decoders to refresh power + refresh_power = (bank.mat.r_predec->power.readOp.dynamic * dp.num_act_mats_hor_dir + + bank.mat.row_dec->power.readOp.dynamic) * dp.num_r_subarray * dp.num_subarrays; + refresh_power += bank.mat.per_bitline_read_energy * dp.num_c_subarray * dp.num_r_subarray * dp.num_subarrays; + refresh_power += bank.mat.power_bl_precharge_eq_drv.readOp.dynamic * dp.num_act_mats_hor_dir; + refresh_power += bank.mat.power_sa.readOp.dynamic * dp.num_act_mats_hor_dir; + refresh_power /= dp.dram_refresh_period; + } + + + if (dp.is_tag == false) + { + power.readOp.dynamic = dyn_read_energy_from_closed_page; + power.writeOp.dynamic = dyn_read_energy_from_closed_page + - dyn_read_energy_remaining_words_in_burst + - bank.mat.power_bitline.readOp.dynamic * dp.num_act_mats_hor_dir + + bank.mat.power_bitline.writeOp.dynamic * dp.num_act_mats_hor_dir + + (power_routing_to_bank.writeOp.dynamic - + power_routing_to_bank.readOp.dynamic - + bank.htree_out_data->power.readOp.dynamic + + bank.htree_in_data->power.readOp.dynamic) * + (MAX((g_ip->burst_len / g_ip->int_prefetch_w), 1) - 1); //FIXME + + if (dp.is_dram == false) + { + power.writeOp.dynamic -= bank.mat.power_sa.readOp.dynamic * dp.num_act_mats_hor_dir; + } + } + + // if DRAM, add refresh power to total leakage + if (dp.is_dram) + { + power.readOp.leakage += refresh_power; + } + + // TODO: below should be avoided. + /*if (dp.is_main_mem) + { + power.readOp.leakage += MAIN_MEM_PER_CHIP_STANDBY_CURRENT_mA * 1e-3 * g_tp.peri_global.Vdd / g_ip->nbanks; + }*/ + + if (g_ip->is_3d_mem) + {// ---This is only to make sure the following assert() functions don't generate errors. The values are not used in 3D DRAM models + // power = power + membus_RAS->power + membus_CAS->power + membus_data->power; //for leakage power add up, not used yet for optimization + power.readOp.dynamic = read_energy; + power.writeOp.dynamic = write_energy; + // ---Before the brackets, power = power.bank, and all the specific leakage terms have and only have accounted for bank to mat levels. + // power.readOp.leakage = power.readOp.leakage + membus_RAS->power.readOp.leakage + membus_CAS->power.readOp.leakage + membus_data->power.readOp.leakage; + power.readOp.leakage =membus_RAS->power.readOp.leakage + membus_CAS->power.readOp.leakage + membus_data->power.readOp.leakage; + //cout << "test: " << power.readOp.dynamic << endl; + //cout << "test: " << membus_RAS->power.readOp.leakage << endl; + //cout << "test: " << membus_CAS->power.readOp.leakage << endl; + //cout << "test: " << membus_data->power.readOp.leakage << endl; + //cout << "test: power.readOp.leakage" << power.readOp.leakage << endl; + } + + assert(power.readOp.dynamic > 0); + assert(power.writeOp.dynamic > 0); + assert(power.readOp.leakage > 0); +} + diff --git a/Project_FARSI/cacti_for_FARSI/uca.h b/Project_FARSI/cacti_for_FARSI/uca.h new file mode 100644 index 00000000..7b6aa38e --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/uca.h @@ -0,0 +1,116 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + + + +#ifndef __UCA_H__ +#define __UCA_H__ + +#include "area.h" +#include "bank.h" +#include "component.h" +#include "parameter.h" +#include "htree2.h" +#include "memorybus.h" +#include "basic_circuit.h" +#include "cacti_interface.h" + + + +class UCA : public Component +{ + public: + UCA(const DynamicParameter & dyn_p); + ~UCA(); + double compute_delays(double inrisetime); // returns outrisetime + void compute_power_energy(); + + DynamicParameter dp; + Bank bank; + + Htree2 * htree_in_add; + Htree2 * htree_in_data; + Htree2 * htree_out_data; + Htree2 * htree_in_search; + Htree2 * htree_out_search; + + Memorybus * membus_RAS; + Memorybus * membus_CAS; + Memorybus * membus_data; + + powerDef power_routing_to_bank; + + uint32_t nbanks; + + int num_addr_b_bank; + int num_di_b_bank; + int num_do_b_bank; + int num_si_b_bank; + int num_so_b_bank; + int RWP, ERP, EWP,SCHP; + double area_all_dataramcells; + double total_area_per_die; + + double dyn_read_energy_from_closed_page; + double dyn_read_energy_from_open_page; + double dyn_read_energy_remaining_words_in_burst; + + double refresh_power; // only for DRAM + double activate_energy; + double read_energy; + double write_energy; + double precharge_energy; + double leak_power_subbank_closed_page; + double leak_power_subbank_open_page; + double leak_power_request_and_reply_networks; + + double delay_array_to_sa_mux_lev_1_decoder; + double delay_array_to_sa_mux_lev_2_decoder; + double delay_before_subarray_output_driver; + double delay_from_subarray_out_drv_to_out; + double access_time; + double precharge_delay; + double multisubbank_interleave_cycle_time; + + double t_RAS, t_CAS, t_RCD, t_RC, t_RP, t_RRD; + double activate_power, read_power, write_power; + + double delay_TSV_tot, area_TSV_tot, dyn_pow_TSV_tot, dyn_pow_TSV_per_access; + unsigned int num_TSV_tot; + unsigned int comm_bits, row_add_bits, col_add_bits, data_bits; + double area_lwl_drv, area_row_predec_dec, area_col_predec_dec, + area_subarray, area_bus, area_address_bus, area_data_bus, area_data_drv, area_IOSA, area_sense_amp, + area_per_bank; + +}; + +#endif + diff --git a/Project_FARSI/cacti_for_FARSI/version_cacti.h b/Project_FARSI/cacti_for_FARSI/version_cacti.h new file mode 100644 index 00000000..e1528bb1 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/version_cacti.h @@ -0,0 +1,40 @@ +/***************************************************************************** + * McPAT + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + +#ifndef VERSION_H_ +#define VERSION_H_ + +#define VER_MAJOR_CACTI 7 /* 3dd */ +#define VER_MINOR_CACTI 0 +#define VER_COMMENT_CACTI "3DD Prerelease" +#define VER_UPDATE_CACTI "Aug, 2012" + +#endif /* VERSION_H_ */ diff --git a/Project_FARSI/cacti_for_FARSI/wire.cc b/Project_FARSI/cacti_for_FARSI/wire.cc new file mode 100644 index 00000000..55a08ae1 --- /dev/null +++ b/Project_FARSI/cacti_for_FARSI/wire.cc @@ -0,0 +1,830 @@ +/***************************************************************************** + * CACTI 7.0 + * SOFTWARE LICENSE AGREEMENT + * Copyright 2015 Hewlett-Packard Development Company, L.P. + * All Rights Reserved + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.” + * + ***************************************************************************/ + +#include "wire.h" +#include "cmath" +// use this constructor to calculate wire stats +Wire::Wire( + enum Wire_type wire_model, + double wl, + int n, + double w_s, + double s_s, + enum Wire_placement wp, + double resistivity, + /*TechnologyParameter::*/DeviceType *dt + ):wt(wire_model), wire_length(wl*1e-6), nsense(n), w_scale(w_s), s_scale(s_s), + resistivity(resistivity), deviceType(dt) +{ + wire_placement = wp; + min_w_pmos = deviceType->n_to_p_eff_curr_drv_ratio*g_tp.min_w_nmos_; + in_rise_time = 0; + out_rise_time = 0; + if (initialized != 1) { + cout << "Wire not initialized. Initializing it with default values\n"; + Wire winit; + } + calculate_wire_stats(); + // change everything back to seconds, microns, and Joules + repeater_spacing *= 1e6; + wire_length *= 1e6; + wire_width *= 1e6; + wire_spacing *= 1e6; + assert(wire_length > 0); + assert(power.readOp.dynamic > 0); + assert(power.readOp.leakage > 0); + assert(power.readOp.gate_leakage > 0); +} + + // the following values are for peripheral global technology + // specified in the input config file + Component Wire::global; + Component Wire::global_5; + Component Wire::global_10; + Component Wire::global_20; + Component Wire::global_30; + Component Wire::low_swing; + + int Wire::initialized; + double Wire::wire_width_init; + double Wire::wire_spacing_init; + + +Wire::Wire(double w_s, double s_s, enum Wire_placement wp, double resis, /*TechnologyParameter::*/DeviceType *dt) +{ + w_scale = w_s; + s_scale = s_s; + deviceType = dt; + wire_placement = wp; + resistivity = resis; + min_w_pmos = deviceType->n_to_p_eff_curr_drv_ratio * g_tp.min_w_nmos_; + in_rise_time = 0; + out_rise_time = 0; + + switch (wire_placement) + { + case outside_mat: wire_width = g_tp.wire_outside_mat.pitch/2; break; + case inside_mat : wire_width = g_tp.wire_inside_mat.pitch/2; break; + default: wire_width = g_tp.wire_local.pitch/2; break; + } + + wire_spacing = wire_width; + + wire_width *= (w_scale * 1e-6/2) /* (m) */; + wire_spacing *= (s_scale * 1e-6/2) /* (m) */; + + initialized = 1; + init_wire(); + wire_width_init = wire_width; + wire_spacing_init = wire_spacing; + + assert(power.readOp.dynamic > 0); + assert(power.readOp.leakage > 0); + assert(power.readOp.gate_leakage > 0); +} + + + +Wire::~Wire() +{ +} + + + +void +Wire::calculate_wire_stats() +{ + + if (wire_placement == outside_mat) { + wire_width = g_tp.wire_outside_mat.pitch/2; + } + else if (wire_placement == inside_mat) { + wire_width = g_tp.wire_inside_mat.pitch/2; + } + else { + wire_width = g_tp.wire_local.pitch/2; + } + + wire_spacing = wire_width; + + wire_width *= (w_scale * 1e-6/2) /* (m) */; + wire_spacing *= (s_scale * 1e-6/2) /* (m) */; + + + if (wt != Low_swing) { + + // delay_optimal_wire(); + + if (wt == Global) { + delay = global.delay * wire_length; + power.readOp.dynamic = global.power.readOp.dynamic * wire_length; + power.readOp.leakage = global.power.readOp.leakage * wire_length; + power.readOp.gate_leakage = global.power.readOp.gate_leakage * wire_length; + repeater_spacing = global.area.w; + repeater_size = global.area.h; + area.set_area((wire_length/repeater_spacing) * + compute_gate_area(INV, 1, min_w_pmos * repeater_size, + g_tp.min_w_nmos_ * repeater_size, g_tp.cell_h_def)); + } + else if (wt == Global_5) { + delay = global_5.delay * wire_length; + power.readOp.dynamic = global_5.power.readOp.dynamic * wire_length; + power.readOp.leakage = global_5.power.readOp.leakage * wire_length; + power.readOp.gate_leakage = global_5.power.readOp.gate_leakage * wire_length; + repeater_spacing = global_5.area.w; + repeater_size = global_5.area.h; + area.set_area((wire_length/repeater_spacing) * + compute_gate_area(INV, 1, min_w_pmos * repeater_size, + g_tp.min_w_nmos_ * repeater_size, g_tp.cell_h_def)); + } + else if (wt == Global_10) { + delay = global_10.delay * wire_length; + power.readOp.dynamic = global_10.power.readOp.dynamic * wire_length; + power.readOp.leakage = global_10.power.readOp.leakage * wire_length; + power.readOp.gate_leakage = global_10.power.readOp.gate_leakage * wire_length; + repeater_spacing = global_10.area.w; + repeater_size = global_10.area.h; + area.set_area((wire_length/repeater_spacing) * + compute_gate_area(INV, 1, min_w_pmos * repeater_size, + g_tp.min_w_nmos_ * repeater_size, g_tp.cell_h_def)); + } + else if (wt == Global_20) { + delay = global_20.delay * wire_length; + power.readOp.dynamic = global_20.power.readOp.dynamic * wire_length; + power.readOp.leakage = global_20.power.readOp.leakage * wire_length; + power.readOp.gate_leakage = global_20.power.readOp.gate_leakage * wire_length; + repeater_spacing = global_20.area.w; + repeater_size = global_20.area.h; + area.set_area((wire_length/repeater_spacing) * + compute_gate_area(INV, 1, min_w_pmos * repeater_size, + g_tp.min_w_nmos_ * repeater_size, g_tp.cell_h_def)); + } + else if (wt == Global_30) { + delay = global_30.delay * wire_length; + power.readOp.dynamic = global_30.power.readOp.dynamic * wire_length; + power.readOp.leakage = global_30.power.readOp.leakage * wire_length; + power.readOp.gate_leakage = global_30.power.readOp.gate_leakage * wire_length; + repeater_spacing = global_30.area.w; + repeater_size = global_30.area.h; + area.set_area((wire_length/repeater_spacing) * + compute_gate_area(INV, 1, min_w_pmos * repeater_size, + g_tp.min_w_nmos_ * repeater_size, g_tp.cell_h_def)); + } + out_rise_time = delay*repeater_spacing/deviceType->Vth; + } + else if (wt == Low_swing) { + low_swing_model (); + repeater_spacing = wire_length; + repeater_size = 1; + } + else { + assert(0); + } +} + + + +/* + * The fall time of an input signal to the first stage of a circuit is + * assumed to be same as the fall time of the output signal of two + * inverters connected in series (refer: CACTI 1 Technical report, + * section 6.1.3) + */ + double +Wire::signal_fall_time () +{ + + /* rise time of inverter 1's output */ + double rt; + /* fall time of inverter 2's output */ + double ft; + double timeconst; + + timeconst = (drain_C_(g_tp.min_w_nmos_, NCH, 1, 1, g_tp.cell_h_def) + + drain_C_(min_w_pmos, PCH, 1, 1, g_tp.cell_h_def) + + gate_C(min_w_pmos + g_tp.min_w_nmos_, 0)) * + tr_R_on(min_w_pmos, PCH, 1); + rt = horowitz (0, timeconst, deviceType->Vth/deviceType->Vdd, deviceType->Vth/deviceType->Vdd, FALL) / (deviceType->Vdd - deviceType->Vth); + timeconst = (drain_C_(g_tp.min_w_nmos_, NCH, 1, 1, g_tp.cell_h_def) + + drain_C_(min_w_pmos, PCH, 1, 1, g_tp.cell_h_def) + + gate_C(min_w_pmos + g_tp.min_w_nmos_, 0)) * + tr_R_on(g_tp.min_w_nmos_, NCH, 1); + ft = horowitz (rt, timeconst, deviceType->Vth/deviceType->Vdd, deviceType->Vth/deviceType->Vdd, RISE) / deviceType->Vth; + return ft; +} + + + +double Wire::signal_rise_time () +{ + + /* rise time of inverter 1's output */ + double ft; + /* fall time of inverter 2's output */ + double rt; + double timeconst; + + timeconst = (drain_C_(g_tp.min_w_nmos_, NCH, 1, 1, g_tp.cell_h_def) + + drain_C_(min_w_pmos, PCH, 1, 1, g_tp.cell_h_def) + + gate_C(min_w_pmos + g_tp.min_w_nmos_, 0)) * + tr_R_on(g_tp.min_w_nmos_, NCH, 1); + rt = horowitz (0, timeconst, deviceType->Vth/deviceType->Vdd, deviceType->Vth/deviceType->Vdd, RISE) / deviceType->Vth; + timeconst = (drain_C_(g_tp.min_w_nmos_, NCH, 1, 1, g_tp.cell_h_def) + + drain_C_(min_w_pmos, PCH, 1, 1, g_tp.cell_h_def) + + gate_C(min_w_pmos + g_tp.min_w_nmos_, 0)) * + tr_R_on(min_w_pmos, PCH, 1); + ft = horowitz (rt, timeconst, deviceType->Vth/deviceType->Vdd, deviceType->Vth/deviceType->Vdd, FALL) / (deviceType->Vdd - deviceType->Vth); + return ft; //sec +} + + + +/* Wire resistance and capacitance calculations + * wire width + * + * /__/ + * | | + * | | height = ASPECT_RATIO*wire width (ASPECT_RATIO = 2.2, ref: ITRS) + * |__|/ + * + * spacing between wires in same level = wire width + * + */ + +double Wire::wire_cap (double len /* in m */, bool call_from_outside) +{ + //TODO: this should be consistent with the wire_res in technology file + double sidewall, adj, tot_cap; + double wire_height; + double epsilon0 = 8.8542e-12; + double aspect_ratio, horiz_dielectric_constant, vert_dielectric_constant, miller_value,ild_thickness; + + switch (wire_placement) + { + case outside_mat: + { + aspect_ratio = g_tp.wire_outside_mat.aspect_ratio; + horiz_dielectric_constant = g_tp.wire_outside_mat.horiz_dielectric_constant; + vert_dielectric_constant = g_tp.wire_outside_mat.vert_dielectric_constant; + miller_value = g_tp.wire_outside_mat.miller_value; + ild_thickness = g_tp.wire_outside_mat.ild_thickness; + break; + } + case inside_mat : + { + aspect_ratio = g_tp.wire_inside_mat.aspect_ratio; + horiz_dielectric_constant = g_tp.wire_inside_mat.horiz_dielectric_constant; + vert_dielectric_constant = g_tp.wire_inside_mat.vert_dielectric_constant; + miller_value = g_tp.wire_inside_mat.miller_value; + ild_thickness = g_tp.wire_inside_mat.ild_thickness; + break; + } + default: + { + aspect_ratio = g_tp.wire_local.aspect_ratio; + horiz_dielectric_constant = g_tp.wire_local.horiz_dielectric_constant; + vert_dielectric_constant = g_tp.wire_local.vert_dielectric_constant; + miller_value = g_tp.wire_local.miller_value; + ild_thickness = g_tp.wire_local.ild_thickness; + break; + } + } + + if (call_from_outside) + { + wire_width *= 1e-6; + wire_spacing *= 1e-6; + } + wire_height = wire_width/w_scale*aspect_ratio; + /* + * assuming height does not change. wire_width = width_original*w_scale + * So wire_height does not change as wire width increases + */ + +// capacitance between wires in the same level +// sidewall = 2*miller_value * horiz_dielectric_constant * (wire_height/wire_spacing) +// * epsilon0; + + sidewall = miller_value * horiz_dielectric_constant * (wire_height/wire_spacing) + * epsilon0; + + + // capacitance between wires in adjacent levels + //adj = miller_value * vert_dielectric_constant *w_scale * epsilon0; + //adj = 2*vert_dielectric_constant *wire_width/(ild_thickness*1e-6) * epsilon0; + + adj = miller_value *vert_dielectric_constant *wire_width/(ild_thickness*1e-6) * epsilon0; + //Change ild_thickness from micron to M + + //tot_cap = (sidewall + adj + (deviceType->C_fringe * 1e6)); //F/m + tot_cap = (sidewall + adj + (g_tp.fringe_cap * 1e6)); //F/m + + if (call_from_outside) + { + wire_width *= 1e6; + wire_spacing *= 1e6; + } + return (tot_cap*len); // (F) +} + + + double +Wire::wire_res (double len /*(in m)*/) +{ + + double aspect_ratio,alpha_scatter =1.05, dishing_thickness=0, barrier_thickness=0; + //TODO: this should be consistent with the wire_res in technology file + //The whole computation should be consistent with the wire_res in technology.cc too! + + switch (wire_placement) + { + case outside_mat: + { + aspect_ratio = g_tp.wire_outside_mat.aspect_ratio; + break; + } + case inside_mat : + { + aspect_ratio = g_tp.wire_inside_mat.aspect_ratio; + break; + } + default: + { + aspect_ratio = g_tp.wire_local.aspect_ratio; + break; + } + } + return (alpha_scatter * resistivity * 1e-6 * len/((aspect_ratio*wire_width/w_scale-dishing_thickness - barrier_thickness)* + (wire_width-2*barrier_thickness))); +} + +/* + * Calculates the delay, power and area of the transmitter circuit. + * + * The transmitter delay is the sum of nand gate delay, inverter delay + * low swing nmos delay, and the wire delay + * (ref: Technical report 6) + */ + void +Wire::low_swing_model() +{ + double len = wire_length; + double beta = pmos_to_nmos_sz_ratio(); + + + double inputrise = (in_rise_time == 0) ? signal_rise_time() : in_rise_time; + + /* Final nmos low swing driver size calculation: + * Try to size the driver such that the delay + * is less than 8FO4. + * If the driver size is greater than + * the max allowable size, assume max size for the driver. + * In either case, recalculate the delay using + * the final driver size assuming slow input with + * finite rise time instead of ideal step input + * + * (ref: Technical report 6) + */ + double cwire = wire_cap(len); /* load capacitance */ + double rwire = wire_res(len); + +#define RES_ADJ (8.6) // Increase in resistance due to low driving vol. + + double driver_res = (-8*g_tp.FO4/(log(0.5) * cwire))/RES_ADJ; + double nsize = R_to_w(driver_res, NCH); + + nsize = MIN(nsize, g_tp.max_w_nmos_); + nsize = MAX(nsize, g_tp.min_w_nmos_); + + if(rwire*cwire > 8*g_tp.FO4) + { + nsize = g_tp.max_w_nmos_; + } + + // size the inverter appropriately to minimize the transmitter delay + // Note - In order to minimize leakage, we are not adding a set of inverters to + // bring down delay. Instead, we are sizing the single gate + // based on the logical effort. + double st_eff = sqrt((2+beta/1+beta)*gate_C(nsize, 0)/(gate_C(2*g_tp.min_w_nmos_, 0) + + gate_C(2*min_w_pmos, 0))); + double req_cin = ((2+beta/1+beta)*gate_C(nsize, 0))/st_eff; + double inv_size = req_cin/(gate_C(min_w_pmos, 0) + gate_C(g_tp.min_w_nmos_, 0)); + inv_size = MAX(inv_size, 1); + + /* nand gate delay */ + double res_eq = (2 * tr_R_on(g_tp.min_w_nmos_, NCH, 1)); + double cap_eq = 2 * drain_C_(min_w_pmos, PCH, 1, 1, g_tp.cell_h_def) + + drain_C_(2*g_tp.min_w_nmos_, NCH, 1, 1, g_tp.cell_h_def) + + gate_C(inv_size*g_tp.min_w_nmos_, 0) + + gate_C(inv_size*min_w_pmos, 0); + + double timeconst = res_eq * cap_eq; + + delay = horowitz(inputrise, timeconst, deviceType->Vth/deviceType->Vdd, + deviceType->Vth/deviceType->Vdd, RISE); + double temp_power = cap_eq*deviceType->Vdd*deviceType->Vdd; + + inputrise = delay / (deviceType->Vdd - deviceType->Vth); /* for the next stage */ + + /* Inverter delay: + * The load capacitance of this inv depends on + * the gate capacitance of the final stage nmos + * transistor which in turn depends on nsize + */ + res_eq = tr_R_on(inv_size*min_w_pmos, PCH, 1); + cap_eq = drain_C_(inv_size*min_w_pmos, PCH, 1, 1, g_tp.cell_h_def) + + drain_C_(inv_size*g_tp.min_w_nmos_, NCH, 1, 1, g_tp.cell_h_def) + + gate_C(nsize, 0); + timeconst = res_eq * cap_eq; + + delay += horowitz(inputrise, timeconst, deviceType->Vth/deviceType->Vdd, + deviceType->Vth/deviceType->Vdd, FALL); + temp_power += cap_eq*deviceType->Vdd*deviceType->Vdd; + + + transmitter.delay = delay; + transmitter.power.readOp.dynamic = temp_power*2; /* since it is a diff. model*/ + transmitter.power.readOp.leakage = deviceType->Vdd * + (4 * cmos_Isub_leakage(g_tp.min_w_nmos_, min_w_pmos, 2, nand) + + 4 * cmos_Isub_leakage(g_tp.min_w_nmos_, min_w_pmos, 1, inv)); + + transmitter.power.readOp.gate_leakage = deviceType->Vdd * + (4 * cmos_Ig_leakage(g_tp.min_w_nmos_, min_w_pmos, 2, nand) + + 4 * cmos_Ig_leakage(g_tp.min_w_nmos_, min_w_pmos, 1, inv)); + + inputrise = delay / deviceType->Vth; + + /* nmos delay + wire delay */ + cap_eq = cwire + drain_C_(nsize, NCH, 1, 1, g_tp.cell_h_def)*2 + + nsense * sense_amp_input_cap(); //+receiver cap + /* + * NOTE: nmos is used as both pull up and pull down transistor + * in the transmitter. This is because for low voltage swing, drive + * resistance of nmos is less than pmos + * (for a detailed graph ref: On-Chip Wires: Scaling and Efficiency) + */ + timeconst = (tr_R_on(nsize, NCH, 1)*RES_ADJ) * (cwire + + drain_C_(nsize, NCH, 1, 1, g_tp.cell_h_def)*2) + + rwire*cwire/2 + + (tr_R_on(nsize, NCH, 1)*RES_ADJ + rwire) * + nsense * sense_amp_input_cap(); + + /* + * since we are pre-equalizing and overdriving the low + * swing wires, the net time constant is less + * than the actual value + */ + delay += horowitz(inputrise, timeconst, deviceType->Vth/deviceType->Vdd, .25, 0); +#define VOL_SWING .1 + temp_power += cap_eq*VOL_SWING*.400; /* .4v is the over drive voltage */ + temp_power *= 2; /* differential wire */ + + l_wire.delay = delay - transmitter.delay; + l_wire.power.readOp.dynamic = temp_power - transmitter.power.readOp.dynamic; + l_wire.power.readOp.leakage = deviceType->Vdd* + (4* cmos_Isub_leakage(nsize, 0, 1, nmos)); + + l_wire.power.readOp.gate_leakage = deviceType->Vdd* + (4* cmos_Ig_leakage(nsize, 0, 1, nmos)); + + //double rt = horowitz(inputrise, timeconst, deviceType->Vth/deviceType->Vdd, + // deviceType->Vth/deviceType->Vdd, RISE)/deviceType->Vth; + + delay += g_tp.sense_delay; + + sense_amp.delay = g_tp.sense_delay; + out_rise_time = g_tp.sense_delay/(deviceType->Vth); + sense_amp.power.readOp.dynamic = g_tp.sense_dy_power; + sense_amp.power.readOp.leakage = 0; //FIXME + sense_amp.power.readOp.gate_leakage = 0; + + power.readOp.dynamic = temp_power + sense_amp.power.readOp.dynamic; + power.readOp.leakage = transmitter.power.readOp.leakage + + l_wire.power.readOp.leakage + + sense_amp.power.readOp.leakage; + power.readOp.gate_leakage = transmitter.power.readOp.gate_leakage + + l_wire.power.readOp.gate_leakage + + sense_amp.power.readOp.gate_leakage; +} + + double +Wire::sense_amp_input_cap() +{ + return drain_C_(g_tp.w_iso, PCH, 1, 1, g_tp.cell_h_def) + + gate_C(g_tp.w_sense_en + g_tp.w_sense_n, 0) + + drain_C_(g_tp.w_sense_n, NCH, 1, 1, g_tp.cell_h_def) + + drain_C_(g_tp.w_sense_p, PCH, 1, 1, g_tp.cell_h_def); +} + + +void Wire::delay_optimal_wire () +{ + double len = wire_length; + //double min_wire_width = wire_width; //m + double beta = pmos_to_nmos_sz_ratio(); + double switching = 0; // switching energy + double short_ckt = 0; // short-circuit energy + double tc = 0; // time constant + // input cap of min sized driver + double input_cap = gate_C(g_tp.min_w_nmos_ + min_w_pmos, 0); + + // output parasitic capacitance of + // the min. sized driver + double out_cap = drain_C_(min_w_pmos, PCH, 1, 1, g_tp.cell_h_def) + + drain_C_(g_tp.min_w_nmos_, NCH, 1, 1, g_tp.cell_h_def); + // drive resistance + double out_res = (tr_R_on(g_tp.min_w_nmos_, NCH, 1) + + tr_R_on(min_w_pmos, PCH, 1))/2; + double wr = wire_res(len); //ohm + + // wire cap /m + double wc = wire_cap(len); + + // size the repeater such that the delay of the wire is minimum + double repeater_scaling = sqrt(out_res*wc/(wr*input_cap)); // len will cancel + + // calc the optimum spacing between the repeaters (m) + + repeater_spacing = sqrt(2 * out_res * (out_cap + input_cap)/ + ((wr/len)*(wc/len))); + repeater_size = repeater_scaling; + + switching = (repeater_scaling * (input_cap + out_cap) + + repeater_spacing * (wc/len)) * deviceType->Vdd * deviceType->Vdd; + + tc = out_res * (input_cap + out_cap) + + out_res * wc/len * repeater_spacing/repeater_scaling + + wr/len * repeater_spacing * input_cap * repeater_scaling + + 0.5 * (wr/len) * (wc/len)* repeater_spacing * repeater_spacing; + + delay = 0.693 * tc * len/repeater_spacing; + +#define Ishort_ckt 65e-6 /* across all tech Ref:Banerjee et al. {IEEE TED} */ + short_ckt = deviceType->Vdd * g_tp.min_w_nmos_ * Ishort_ckt * 1.0986 * + repeater_scaling * tc; + + area.set_area((len/repeater_spacing) * + compute_gate_area(INV, 1, min_w_pmos * repeater_scaling, + g_tp.min_w_nmos_ * repeater_scaling, g_tp.cell_h_def)); + power.readOp.dynamic = ((len/repeater_spacing)*(switching + short_ckt)); + power.readOp.leakage = ((len/repeater_spacing)* + deviceType->Vdd* + cmos_Isub_leakage(g_tp.min_w_nmos_*repeater_scaling, beta*g_tp.min_w_nmos_*repeater_scaling, 1, inv)); + power.readOp.gate_leakage = ((len/repeater_spacing)* + deviceType->Vdd* + cmos_Ig_leakage(g_tp.min_w_nmos_*repeater_scaling, beta*g_tp.min_w_nmos_*repeater_scaling, 1, inv)); +} + + + +// calculate power/delay values for wires with suboptimal repeater sizing/spacing +void +Wire::init_wire(){ + wire_length = 1; + delay_optimal_wire(); + double sp, si; + powerDef pow; + si = repeater_size; + sp = repeater_spacing; + sp *= 1e6; // in microns + + double i, j, del; + repeated_wire.push_back(Component()); + for (j=sp; j < 4*sp; j+=100) { + for (i = si; i > 1; i--) { + pow = wire_model(j*1e-6, i, &del); + if (j == sp && i == si) { + global.delay = del; + global.power = pow; + global.area.h = si; + global.area.w = sp*1e-6; // m + } +// cout << "Repeater size - "<< i << +// " Repeater spacing - " << j << +// " Delay - " << del << +// " PowerD - " << pow.readOp.dynamic << +// " PowerL - " << pow.readOp.leakage <delay; + low_swing.power = l_wire->power; + delete l_wire; +} + + + +void Wire::update_fullswing() +{ + + list::iterator citer; + double del[4]; + del[3] = this->global.delay + this->global.delay*.3; + del[2] = global.delay + global.delay*.2; + del[1] = global.delay + global.delay*.1; + del[0] = global.delay + global.delay*.05; + double threshold; + double ncost; + double cost; + int i = 4; + while (i>0) { + threshold = del[i-1]; + cost = BIGNUM; + for (citer = repeated_wire.begin(); citer != repeated_wire.end(); citer++) + { + if (citer->delay > threshold) { + citer = repeated_wire.erase(citer); + citer --; + } + else { + ncost = citer->power.readOp.dynamic/global.power.readOp.dynamic + + citer->power.readOp.leakage/global.power.readOp.leakage; + if(ncost < cost) + { + cost = ncost; + if (i == 4) { + global_30.delay = citer->delay; + global_30.power = citer->power; + global_30.area = citer->area; + } + else if (i==3) { + global_20.delay = citer->delay; + global_20.power = citer->power; + global_20.area = citer->area; + } + else if(i==2) { + global_10.delay = citer->delay; + global_10.power = citer->power; + global_10.area = citer->area; + } + else if(i==1) { + global_5.delay = citer->delay; + global_5.power = citer->power; + global_5.area = citer->area; + } + } + } + } + i--; + } +} + + + +powerDef Wire::wire_model (double space, double size, double *delay) +{ + powerDef ptemp; + double len = 1; + //double min_wire_width = wire_width; //m + double beta = pmos_to_nmos_sz_ratio(); + // switching energy + double switching = 0; + // short-circuit energy + double short_ckt = 0; + // time constant + double tc = 0; + // input cap of min sized driver + double input_cap = gate_C (g_tp.min_w_nmos_ + + min_w_pmos, 0); + + // output parasitic capacitance of + // the min. sized driver + double out_cap = drain_C_(min_w_pmos, PCH, 1, 1, g_tp.cell_h_def) + + drain_C_(g_tp.min_w_nmos_, NCH, 1, 1, g_tp.cell_h_def); + // drive resistance + double out_res = (tr_R_on(g_tp.min_w_nmos_, NCH, 1) + + tr_R_on(min_w_pmos, PCH, 1))/2; + double wr = wire_res(len); //ohm + + // wire cap /m + double wc = wire_cap(len); + + repeater_spacing = space; + repeater_size = size; + + switching = (repeater_size * (input_cap + out_cap) + + repeater_spacing * (wc/len)) * deviceType->Vdd * deviceType->Vdd; + + tc = out_res * (input_cap + out_cap) + + out_res * wc/len * repeater_spacing/repeater_size + + wr/len * repeater_spacing * out_cap * repeater_size + + 0.5 * (wr/len) * (wc/len)* repeater_spacing * repeater_spacing; + + *delay = 0.693 * tc * len/repeater_spacing; + +#define Ishort_ckt 65e-6 /* across all tech Ref:Banerjee et al. {IEEE TED} */ + short_ckt = deviceType->Vdd * g_tp.min_w_nmos_ * Ishort_ckt * 1.0986 * + repeater_size * tc; + + ptemp.readOp.dynamic = ((len/repeater_spacing)*(switching + short_ckt)); + ptemp.readOp.leakage = ((len/repeater_spacing)* + deviceType->Vdd* + cmos_Isub_leakage(g_tp.min_w_nmos_*repeater_size, beta*g_tp.min_w_nmos_*repeater_size, 1, inv)); + + ptemp.readOp.gate_leakage = ((len/repeater_spacing)* + deviceType->Vdd* + cmos_Ig_leakage(g_tp.min_w_nmos_*repeater_size, beta*g_tp.min_w_nmos_*repeater_size, 1, inv)); + + return ptemp; +} + +void +Wire::print_wire() +{ + + cout << "\nWire Properties:\n\n"; + cout << " Delay Optimal\n\tRepeater size - "<< global.area.h << + " \n\tRepeater spacing - " << global.area.w*1e3 << " (mm)" + " \n\tDelay - " << global.delay*1e6 << " (ns/mm)" + " \n\tPowerD - " << global.power.readOp.dynamic *1e6<< " (nJ/mm)" + " \n\tPowerL - " << global.power.readOp.leakage << " (mW/mm)" + " \n\tPowerLgate - " << global.power.readOp.gate_leakage << " (mW/mm)\n"; + cout << "\tWire width - " < +#include + +class Wire : public Component +{ + public: + Wire(enum Wire_type wire_model, double len /* in u*/, + int nsense = 1/* no. of sense amps connected to the low-swing wire */, + double width_scaling = 1, + double spacing_scaling = 1, + enum Wire_placement wire_placement = outside_mat, + double resistivity = CU_RESISTIVITY, + /*TechnologyParameter::*/DeviceType *dt = &(g_tp.peri_global)); + ~Wire(); + + Wire( double width_scaling = 1, + double spacing_scaling = 1, + enum Wire_placement wire_placement = outside_mat, + double resistivity = CU_RESISTIVITY, + /*TechnologyParameter::*/DeviceType *dt = &(g_tp.peri_global) + ); // should be used only once for initializing static members + void init_wire(); + + void calculate_wire_stats(); + void delay_optimal_wire(); + double wire_cap(double len, bool call_from_outside=false); + double wire_res(double len); + void low_swing_model(); + double signal_fall_time(); + double signal_rise_time(); + double sense_amp_input_cap(); + + enum Wire_type wt; + double wire_spacing; + double wire_width; + enum Wire_placement wire_placement; + double repeater_size; + double repeater_spacing; + double wire_length; + double in_rise_time, out_rise_time; + + void set_in_rise_time(double rt) + { + in_rise_time = rt; + } + static Component global; + static Component global_5; + static Component global_10; + static Component global_20; + static Component global_30; + static Component low_swing; + static double wire_width_init; + static double wire_spacing_init; + void print_wire(); + + private: + + int nsense; // no. of sense amps connected to a low-swing wire if it + // is broadcasting data to multiple destinations + // width and spacing scaling factor can be used + // to model low level wires or special + // fat wires + double w_scale, s_scale; + double resistivity; + powerDef wire_model (double space, double size, double *delay); + list repeated_wire; + void update_fullswing(); + static int initialized; + + + //low-swing + Component transmitter; + Component l_wire; + Component sense_amp; + + double min_w_pmos; + + /*TechnologyParameter::*/DeviceType *deviceType; + +}; + +#endif diff --git a/Project_FARSI/data_collection/collection_utils/home_settings.py b/Project_FARSI/data_collection/collection_utils/home_settings.py new file mode 100644 index 00000000..a3391dd5 --- /dev/null +++ b/Project_FARSI/data_collection/collection_utils/home_settings.py @@ -0,0 +1,9 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +import os +import sys +sys.path.append(os.path.abspath('../../../')) +home_dir = os.getcwd() +"/../../Project_FARSI/" +#home_dir = os.getcwd() +"/../../../" diff --git a/Project_FARSI/data_collection/collection_utils/replay/FARSI_replay.py b/Project_FARSI/data_collection/collection_utils/replay/FARSI_replay.py new file mode 100644 index 00000000..f8ae22c3 --- /dev/null +++ b/Project_FARSI/data_collection/collection_utils/replay/FARSI_replay.py @@ -0,0 +1,59 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +import os +import sys +sys.path.append(os.path.abspath('./../')) +import home_settings +from DSE_utils.design_space_exploration_handler import * +from specs.database_input import * +from replayer import * + +# This files shows examples of how to setup the replayer + +# ------------------------------ +# Functionality: +# list the subdirectories of a directory +# Variables: +# result_base_dir: parent directory +# prefix: prefix to add to the name of the subdirectories (to generate the address) +# ------------------------------ +def list_dirs(result_base_dir, prefix): + dirs = os.listdir(result_base_dir) + result = [] + for dir in dirs: + dir_full_addr = os.path.join(result_base_dir, dir) # add path to each file + if not os.path.isdir(dir_full_addr): + continue + result.append(os.path.join(prefix,dir, os.listdir(dir_full_addr)[0])) + return result + +replayer_obj = Replayer() + + +mode = "individual" # individual, entire_folder + + +# individual replay +# TODO: clean this up, and add it as an option +""" +# individual replay +#des_folder_name = "/two_three_mem/data/04-29_14-08_22_5/04-29_14-08_22_5" +#des_folder_name = "/two_three_mem/data/04-29_14-07_00_0" +#des_folder_name = "/replay/two_bus/data/04-29_22-14_01_3/04-29_22-14_01_3" +#des_folder_name = "/two_bus/data/04-29_22-14_01_3/04-29_22-14_01_3" +#des_folder_name = "/05-08_18-43_07_13/05-08_18-43_07_13_13" +des_folder_name = "bus_slave_to_master_connection_bug/05-08_20-30_41_40/05-08_20-30_41_40_40" +des_folder_list = [des_folder_name] +""" + +# batch replay +prefix = "DP_Stats_testing_Data/" +folder_to_look_into =config.replay_folder_base +des_folder_list = list_dirs(folder_to_look_into, prefix) + +# iterate through the folder and call replay on it +for des_folder_name in des_folder_list: + replayer_obj.replay(des_folder_name) + replayer_obj.gen_pa() diff --git a/Project_FARSI/data_collection/collection_utils/replay/replayer.py b/Project_FARSI/data_collection/collection_utils/replay/replayer.py new file mode 100644 index 00000000..e24ccf7e --- /dev/null +++ b/Project_FARSI/data_collection/collection_utils/replay/replayer.py @@ -0,0 +1,106 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +import pickle +from DSE_utils import hill_climbing +from specs.data_base import * +from design_utils.design import * +from visualization_utils import vis_sim +from visualization_utils import vis_hardware, vis_stats,plot +import importlib + +if config.simulation_method == "power_knobs": + from specs import database_input_powerKnobs as database_input +elif config.simulation_method == "performance": + from specs import database_input +else: + raise NameError("Simulation method unavailable") + +# This class allows for replaying the designs already generated by FARSI. +# This helps for debugging to have a closer look at the simulation step by step, for a design +# with a large error. +class Replayer: + def __init__(self): + + # TODO: needs to be update to work + print("needs to be update to work") + exit(0) + + self.database = DataBase(database_input.tasksL, database_input.blocksL, + database_input.pe_mapsL, database_input.pe_schedulesL, database_input.SOCsL) # hw/sw database + self.dse = hill_climbing.HillClimbing(self.database) # design space exploration to use + self.home_dir = config.home_dir # home directory of repo + self.data_collection_top_folder = self.home_dir + "/data_collection/data_already_collected" # directory with designs in it. + self.replay_top_folder = self.data_collection_top_folder + "/replay/" # directory that replayer dumps its results into. + self.pickled_file_name = "ex_dp_pickled.txt" # generated designs are stored in pickle format, that are then read by replayer + self.latest_ex_dp = None # design to look at (replay) + self.latest_sim_dp = None # latest simulated design + self.latest_stat_result = None # latest stats associated with the design + self.latest_des_folder_addr = None # latest design folder + self.name_ctr = 0 # name ctr (id for designs) + + # ------------------------------ + # Functionality: + # loading a design. + # Variables: + # des_folder_name: folder where the design resides in + # ------------------------------ + def load_design(self, des_folder_name): + # des_folder_name is with respect to data_collection + self.latest_des_folder_addr = self.data_collection_top_folder + "/" + des_folder_name + self.latest_top_replay_folder_addr = self.replay_top_folder + "/" + "/".join(des_folder_name.split("/")[:-1]) + pickled_file_addr = self.latest_des_folder_addr + "/" + "ex_dp_pickled.txt" + return pickle.load(open(pickled_file_addr, "rb")) + + # ------------------------------ + # Functionality: + # replay the design. + # Variables: + # des_folder_name: folder where the design resides in + # ------------------------------ + def replay(self, des_folder_name): + self.latest_ex_dp = self.load_design(des_folder_name) + self.latest_ex_dp = copy.deepcopy(self.latest_ex_dp) # need to do this to clear some memories + self.latest_sim_dp = self.dse.eval_design(self.latest_ex_dp, self.database) + self.latest_stat_result = self.latest_sim_dp.dp_stats + + # ------------------------------ + # Functionality: + # generate PA digestible files for the replayed design. + # ------------------------------ + def gen_pa(self): + import_ver = importlib.import_module("data_collection.FB_private.verification_utils.PA_generation.PA_generators") + # get PA genera + pa_ver_obj = import_ver.PAVerGen() + # make all the combinations + knobs_list, knob_order = pa_ver_obj.gen_all_PA_knob_combos(import_ver.PA_knobs_to_explore) + os.makedirs(self.latest_top_replay_folder_addr, exist_ok=True) + + # go through all the combinations and generate the corresponding the PA design. + for knobs in knobs_list: + self.latest_ex_dp.reset_PA_knobs() + self.latest_ex_dp.update_PA_knobs(knobs, knob_order) + PA_result_folder = os.path.join(self.latest_top_replay_folder_addr, str(self.latest_ex_dp.id)) + os.makedirs(PA_result_folder, exist_ok =True) + # visualize the hardware + vis_hardware.vis_hardware(self.latest_ex_dp, "block_extra", PA_result_folder, "system_image_block_extra.pdf") + vis_hardware.vis_hardware(self.latest_ex_dp, "block_task", PA_result_folder, "system_image_block_task.pdf") + self.latest_stat_result.dump_stats(PA_result_folder) + if config.VIS_SIM_PER_GEN: vis_sim.plot_sim_data(self.latest_stat_result, self.latest_ex_dp, PA_result_folder) + # generate PA + pa_obj = import_ver.PAGen(database_input.proj_name, self.latest_ex_dp, PA_result_folder, config.sw_model) + pa_obj.gen_all() + self.latest_ex_dp.dump_props(PA_result_folder) + + # pickle the result + ex_dp_pickled_file = open(os.path.join(PA_result_folder, "ex_dp_pickled.txt"), "wb") + pickle.dump(self.latest_ex_dp, ex_dp_pickled_file) + ex_dp_pickled_file.close() + + sim_dp_pickled_file = open(os.path.join(PA_result_folder, "sim_dp_pickled.txt"), "wb") + pickle.dump(self.latest_sim_dp, sim_dp_pickled_file) + sim_dp_pickled_file.close() + self.name_ctr += 1 + + diff --git a/Project_FARSI/data_collection/collection_utils/sim_run/exploration_time.png b/Project_FARSI/data_collection/collection_utils/sim_run/exploration_time.png new file mode 100644 index 00000000..6c050501 Binary files /dev/null and b/Project_FARSI/data_collection/collection_utils/sim_run/exploration_time.png differ diff --git a/Project_FARSI/data_collection/collection_utils/sim_run/sim_run_array_x.py b/Project_FARSI/data_collection/collection_utils/sim_run/sim_run_array_x.py new file mode 100644 index 00000000..d08e58bc --- /dev/null +++ b/Project_FARSI/data_collection/collection_utils/sim_run/sim_run_array_x.py @@ -0,0 +1,138 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +import sys +import os +sys.path.append(os.path.abspath('./../')) +import home_settings +from top.main_FARSI import run_FARSI_only_simulation +from settings import config +import os +import itertools +# main function +import numpy as np +from mpl_toolkits.mplot3d import Axes3D +from matplotlib import cm +import matplotlib.pyplot as plt +from visualization_utils.vis_hardware import * +import numpy as np +from specs.LW_cl import * +from specs.database_input import * +import math +import matplotlib.colors as colors +#import pandas +import matplotlib.colors as mcolors +import pandas as pd +import argparse, sys +import data_collection.collection_utils.what_ifs.FARSI_what_ifs as wf + + +# selecting the database based on the simulation method (power or performance) +if config.simulation_method == "power_knobs": + from specs import database_input_powerKnobs as database_input +elif config.simulation_method == "performance": + from specs import database_input +else: + raise NameError("Simulation method unavailable") + +if __name__ == "__main__": + case_study = "simple_sim_run" + file_prefix = config.FARSI_simple_sim_run_study_prefix + current_process_id = 0 + total_process_cnt = 1 + #starting_exploration_mode = config.exploration_mode + print('case study:' + case_study) + + # ------------------------------------------- + # set result folder + # ------------------------------------------- + result_home_dir_default = os.path.join(os.getcwd(), "data_collection/data/" + case_study) + result_home_dir = os.path.join(config.home_dir, "data_collection/data/" + case_study) + date_time = datetime.now().strftime('%m-%d_%H-%M_%S') + result_folder = os.path.join(result_home_dir, + date_time) + + # ------------------------------------------- + # set parameters + # ------------------------------------------- + experiment_repetition_cnt = 1 + reduction = "most_likely" + #workloads = {"audio_decoder", "edge_detection"} + #workloads = {"audio_decoder"} + #workloads = {"edge_detection"} + #workloads = {"hpvm_cava"} + workloads = {"partial_SOC_example_hard"} + workloads = {"SOC_example_8p"} + tech_node_SF = {"perf":1, "energy":{"non_gpp":.064, "gpp":1}, "area":{"non_mem":.0374 , "mem":.079, "gpp":1}} # technology node scaling factor + db_population_misc_knobs = {"ip_freq_correction_ratio": 1, "gpp_freq_correction_ratio": 1, + "tech_node_SF":tech_node_SF, + "base_budget_scaling":{"latency":.5, "power":1, "area":1}} + sw_hw_database_population = {"db_mode": "parse", "hw_graph_mode": "parse", + "workloads": workloads, "misc_knobs":db_population_misc_knobs} + + # ------------------------------------------- + # distribute the work + # ------------------------------------------- + work_per_process = math.ceil(experiment_repetition_cnt / total_process_cnt) + run_ctr = 0 + # ------------------------------------------- + # run the combination and collect the data + # ------------------------------------------- + # ------------------------------------------- + # collect the exact hw sampling + # ------------------------------------------- + accuracy_percentage = {} + accuracy_percentage["sram"] = accuracy_percentage["dram"] = accuracy_percentage["ic"] = accuracy_percentage["gpp"] = accuracy_percentage[ + "ip"] = \ + {"latency": 1, + "energy": 1, + "area": 1, + "one_over_area": 1} + hw_sampling = {"mode": "exact", "population_size": 1, "reduction": reduction, + "accuracy_percentage": accuracy_percentage} + + + + burst_sizes = [64, 128, 256] + queue_sizes = [1, 2, 4, 8] + for burst_size in burst_sizes: + for queue_size in queue_sizes: + + db_input = database_input_class(sw_hw_database_population) + print("hw_sampling:" + str(hw_sampling)) + print("budget set to:" + str(db_input.get_budget_dict("glass"))) + unique_suffix = str(total_process_cnt) + "_" + str(current_process_id) + "_" + str(run_ctr) + config.default_cmd_queue_size = queue_size + config.default_data_queue_size = queue_size + config.default_burst_size = burst_size + dse_hndlr = run_FARSI_only_simulation(result_folder, unique_suffix, db_input, hw_sampling, sw_hw_database_population["hw_graph_mode"]) + run_ctr += 1 + + # write the results in the general folder + result_dir_specific = os.path.join(result_folder, "result_summary") + reason_to_terminate = "simple_sim_run" + # wf.write_one_results(dse_hndlr.dse.so_far_best_sim_dp, dse_hndlr.dse, reason_to_terminate, case_study, result_dir_specific, + # unique_suffix, + # file_prefix + "_" + str(current_process_id) + "_" + str(total_process_cnt)) + + # write the results in the specific folder + result_folder_modified = result_folder + "/runs/" + str(run_ctr) + "/" + os.system("mkdir -p " + result_folder_modified) + + for key, val in dse_hndlr.dse.so_far_best_sim_dp.dp_stats.SOC_metric_dict["latency"]["glass"][0].items(): + print("lat is {} for {}".format(val, key)) + lat = val + burst_size = config.default_burst_size + queue_size = config.default_data_queue_size + print("burst size is {}".format(burst_size)) + print("queue size is {}".format(queue_size)) + + with open(result_folder+"/latency_in_us.txt", "a") as f: + f.write("burst_size={}, queue_size={}, farsi_simtime={}\n".format(burst_size, + queue_size, + lat*1000000)) + wf.copy_DSE_data(result_folder_modified) + # wf.write_one_results(dse_hndlr.dse.so_far_best_sim_dp, dse_hndlr.dse, reason_to_terminate, case_study, + # result_folder_modified, unique_suffix, + # file_prefix + "_" + str(current_process_id) + "_" + str(total_process_cnt)) diff --git a/Project_FARSI/data_collection/collection_utils/sim_run/simple_sim_run.py b/Project_FARSI/data_collection/collection_utils/sim_run/simple_sim_run.py new file mode 100644 index 00000000..05a27ef0 --- /dev/null +++ b/Project_FARSI/data_collection/collection_utils/sim_run/simple_sim_run.py @@ -0,0 +1,118 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +import sys +import os +sys.path.append(os.path.abspath('./../')) +import home_settings +from top.main_FARSI import run_FARSI_only_simulation +from settings import config +import os +import itertools +# main function +import numpy as np +from mpl_toolkits.mplot3d import Axes3D +from matplotlib import cm +import matplotlib.pyplot as plt +from visualization_utils.vis_hardware import * +import numpy as np +from specs.LW_cl import * +from specs.database_input import * +import math +import matplotlib.colors as colors +#import pandas +import matplotlib.colors as mcolors +import pandas as pd +import argparse, sys +import data_collection.collection_utils.what_ifs.FARSI_what_ifs as wf + + +# selecting the database based on the simulation method (power or performance) +if config.simulation_method == "power_knobs": + from specs import database_input_powerKnobs as database_input +elif config.simulation_method == "performance": + from specs import database_input +else: + raise NameError("Simulation method unavailable") + +if __name__ == "__main__": + case_study = "simple_sim_run" + file_prefix = config.FARSI_simple_sim_run_study_prefix + current_process_id = 0 + total_process_cnt = 1 + #starting_exploration_mode = config.exploration_mode + print('case study:' + case_study) + + # ------------------------------------------- + # set result folder + # ------------------------------------------- + result_home_dir_default = os.path.join(os.getcwd(), "data_collection/data/" + case_study) + result_home_dir = os.path.join(config.home_dir, "data_collection/data/" + case_study) + date_time = datetime.now().strftime('%m-%d_%H-%M_%S') + result_folder = os.path.join(result_home_dir, + date_time) + + # ------------------------------------------- + # set parameters + # ------------------------------------------- + experiment_repetition_cnt = 1 + reduction = "most_likely" + #workloads = {"audio_decoder", "edge_detection"} + #workloads = {"audio_decoder"} + #workloads = {"edge_detection"} + #workloads = {"hpvm_cava"} + workloads = {"partial_SOC_example_hard"} + workloads = {"SOC_example_1p_2r"} + tech_node_SF = {"perf":1, "energy":{"non_gpp":.064, "gpp":1}, "area":{"non_mem":.0374 , "mem":.079, "gpp":1}} # technology node scaling factor + db_population_misc_knobs = {"ip_freq_correction_ratio": 1, "gpp_freq_correction_ratio": 1, + "tech_node_SF":tech_node_SF, + "base_budget_scaling":{"latency":.5, "power":1, "area":1}} + sw_hw_database_population = {"db_mode": "parse", "hw_graph_mode": "parse", + "workloads": workloads, "misc_knobs":db_population_misc_knobs} + + # ------------------------------------------- + # distribute the work + # ------------------------------------------- + work_per_process = math.ceil(experiment_repetition_cnt / total_process_cnt) + run_ctr = 0 + # ------------------------------------------- + # run the combination and collect the data + # ------------------------------------------- + # ------------------------------------------- + # collect the exact hw sampling + # ------------------------------------------- + accuracy_percentage = {} + accuracy_percentage["sram"] = accuracy_percentage["dram"] = accuracy_percentage["ic"] = accuracy_percentage["gpp"] = accuracy_percentage[ + "ip"] = \ + {"latency": 1, + "energy": 1, + "area": 1, + "one_over_area": 1} + hw_sampling = {"mode": "exact", "population_size": 1, "reduction": reduction, + "accuracy_percentage": accuracy_percentage} + + + + + db_input = database_input_class(sw_hw_database_population) + print("hw_sampling:" + str(hw_sampling)) + print("budget set to:" + str(db_input.get_budget_dict("glass"))) + unique_suffix = str(total_process_cnt) + "_" + str(current_process_id) + "_" + str(run_ctr) + dse_hndlr = run_FARSI_only_simulation(result_folder, unique_suffix, db_input, hw_sampling, sw_hw_database_population["hw_graph_mode"]) + run_ctr += 1 + + # write the results in the general folder + result_dir_specific = os.path.join(result_folder, "result_summary") + reason_to_terminate = "simple_sim_run" + wf.write_one_results(dse_hndlr.dse.so_far_best_sim_dp, dse_hndlr.dse, reason_to_terminate, case_study, result_dir_specific, + unique_suffix, + file_prefix + "_" + str(current_process_id) + "_" + str(total_process_cnt)) + + # write the results in the specific folder + result_folder_modified = result_folder + "/runs/" + str(run_ctr) + "/" + os.system("mkdir -p " + result_folder_modified) + wf.copy_DSE_data(result_folder_modified) + wf.write_one_results(dse_hndlr.dse.so_far_best_sim_dp, dse_hndlr.dse, reason_to_terminate, case_study, + result_folder_modified, unique_suffix, + file_prefix + "_" + str(current_process_id) + "_" + str(total_process_cnt)) diff --git a/Project_FARSI/data_collection/collection_utils/what_ifs/FARSI_what_ifs.py b/Project_FARSI/data_collection/collection_utils/what_ifs/FARSI_what_ifs.py new file mode 100644 index 00000000..958299f4 --- /dev/null +++ b/Project_FARSI/data_collection/collection_utils/what_ifs/FARSI_what_ifs.py @@ -0,0 +1,893 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +import sys +import os +sys.path.append(os.path.abspath('./../')) +#import home_settings +from top.main_FARSI import run_FARSI +from top.main_FARSI import run_FARSI_only_simulation +from settings import config +import os +import itertools +# main function +import numpy as np +from mpl_toolkits.mplot3d import Axes3D +from matplotlib import cm +import matplotlib.pyplot as plt +from visualization_utils.vis_hardware import * +import numpy as np +from specs.LW_cl import * +from specs.database_input import * +import math +import matplotlib.colors as colors +#import pandas +import matplotlib.colors as mcolors +import pandas as pd +import argparse, sys + +# selecting the database based on the simulation method (power or performance) +if config.simulation_method == "power_knobs": + from specs import database_input_powerKnobs as database_input +elif config.simulation_method == "performance": + from specs import database_input +else: + raise NameError("Simulation method unavailable") + + +# ------------------------------ +# Functionality: +# show the the result of power/performance/area sweep +# Variables: +# full_dir_addr: name of the directory to get data from +# full_file_addr: name of the file to get data from +# ------------------------------ +def plot_3d_dist(full_dir_addr, full_file_addr, workloads): + # getting data + df = pd.read_csv(full_file_addr) + + # get avarages + grouped_multiple = df.groupby(['latency_budget', 'power_budget', "area_budget"]).agg( + {'latency': ['mean'], "power": ["mean"], "area": ["mean"], "cost":["mean"]}) + # the follow two lines is really usefull when we have mulitple aggregrations for each + # key above, e.g., latency:['mean, 'max'] + grouped_multiple.columns = ['latency_avg', 'power_avg', 'area_avg', 'cost_avg'] + grouped_multiple = grouped_multiple.reset_index() + + # calculate the distance to goal and insert into the df + dist_list = [] + for idx, row in grouped_multiple.iterrows(): + latency_dist = max(0, row['latency_avg'] - row['latency_budget'])/row['latency_budget'] + power_dist = max(0, row['power_avg'] - row['power_budget'])/row['power_budget'] + area_dist = max(0, row['area_avg'] - row['area_budget'])/row['area_budget'] + dist_list.append(latency_dist+ power_dist+ area_dist) + grouped_multiple.insert(2, "norm_dist", dist_list, True) + + # get the data + latency_budget = grouped_multiple["latency_budget"] + power_budget = grouped_multiple["power_budget"] + area_budget = grouped_multiple["area_budget"] + + color_values = [] + for el in grouped_multiple["norm_dist"]: + if el == 0: + color_values.append(0) + else: + color_values.append(el + max(grouped_multiple["norm_dist"])) + + #color_values = grouped_multiple["norm_dist"] + print("maximum distance" + str(max(color_values))) + X = latency_budget + Y = power_budget + Z = area_budget + + """ + X = [el/min(latency_budget) for el in latency_budget] + Y = [el/min(power_budget) for el in power_budget] + Z = [el/min(area_budget) for el in area_budget] + """ + + # 3D plot, with color + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + + bounds = np.array(list(np.arange(0, 1+.005, .01))) + norm = colors.BoundaryNorm(boundaries=bounds, ncolors=230) + + # 3D plot, with color + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + + color_values_norm = [col/max(max(color_values), .0000000000001) for col in color_values] + p = ax.scatter(X, Y, Z, norm=norm, c=color_values_norm, cmap=plt.get_cmap("jet")) + plt.colorbar(p) + + ax.set_xlabel("latency" + "(ms)") + ax.set_ylabel("power"+ "(mw)") + ax.set_zlabel("area" + "(mm2)") + ax.set_title('Budget Sweep for ' + list(workloads)[0] + '.\n Hotter col= higher dist to budget, max:' + + str(max(color_values)) + " min:" + str(min(color_values))) + fig.savefig(os.path.join(full_dir_addr, config.FARSI_cost_correlation_study_prefix +"_3d.png")) + plt.show() + + +# copy the DSE results to the result dir +def copy_DSE_data(result_dir): + #result_dir_specific = os.path.join(result_dirresult_summary") + os.system("cp " + config.latest_visualization+"/*" + " " + result_dir) + + +# ------------------------------ +# Functionality: +# write the results into a file +# Variables: +# sim_dp: design point simulation +# result_dir: result directory +# unique_number: a number to differentiate between designs +# file_name: output file name +# ------------------------------ +def write_one_results(sim_dp, dse, reason_to_terminate, case_study, result_dir_specific, unique_number, file_name): + """ + def convert_dict_to_parsable_csv(dict_): + list = [] + for k,v in dict_.items(): + list.append(str(k)+"="+str(v)) + return list + """ + def convert_tuple_list_to_parsable_csv(list_): + result = "" + for k, v in list_: + result +=str(k) + "=" + str(v) + "___" + return result + + def convert_dictionary_to_parsable_csv_with_semi_column(dict_): + result = "" + for k, v in dict_.items(): + result +=str(k) + "=" + str(v) + ";" + return result + + + + if not os.path.isdir(result_dir_specific): + os.makedirs(result_dir_specific) + + + compute_system_attrs = sim_dp.dp_stats.get_compute_system_attr() + bus_system_attrs = sim_dp.dp_stats.get_bus_system_attr() + memory_system_attrs = sim_dp.dp_stats.get_memory_system_attr() + speedup_dict, speedup_attrs = sim_dp.dp_stats.get_speedup_analysis(dse) + + + + output_file_minimal = os.path.join(result_dir_specific, file_name+ ".csv") + + base_budget_scaling = sim_dp.database.db_input.sw_hw_database_population["misc_knobs"]["base_budget_scaling"] + + # minimal output + if os.path.exists(output_file_minimal): + output_fh_minimal = open(output_file_minimal, "a") + else: + output_fh_minimal = open(output_file_minimal, "w") + for metric in config.all_metrics: + output_fh_minimal.write(metric + ",") + if metric in sim_dp.database.db_input.get_budget_dict("glass").keys(): + output_fh_minimal.write(metric+"_budget" + ",") + output_fh_minimal.write("sampling_mode,") + output_fh_minimal.write("sampling_reduction" +",") + for metric, accuracy_percentage in sim_dp.database.hw_sampling["accuracy_percentage"]["ip"].items(): + output_fh_minimal.write(metric+"_accuracy" + ",") # for now only write the latency accuracy as the other + for block_type, porting_effort in sim_dp.database.db_input.porting_effort.items(): + output_fh_minimal.write(block_type+"_effort" + ",") # for now only write the latency accuracy as the other + + output_fh_minimal.write("output_design_status"+ ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("case_study"+ ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("unique_number" + ",") # for now only write the latency accuracy as the other + + output_fh_minimal.write("SA_total_depth,") + output_fh_minimal.write("reason_to_terminate" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("population generation cnt" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("iteration cnt" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("workload_set" + ",") # for now only write the latency accuracy as the other + #output_fh_minimal.write("iterationxdepth number" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("simulation time" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("move generation time" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("kernel selection time" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("block selection time" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("transformation selection time" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("transformation_selection_mode" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("dist_to_goal_all" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("dist_to_goal_non_cost" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("system block count" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("system PE count" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("system bus count" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("system memory count" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("routing complexity" + ",") # for now only write the latency accuracy as the other + #output_fh_minimal.write("area_breakdown_subtype" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("block_impact_sorted" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("kernel_impact_sorted" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("metric_impact_sorted" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("move_metric" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("move_transformation_name" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("move_kernel" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("move_block_name" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("move_block_type" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("move_dir" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("comm_comp" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("high_level_optimization" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("architectural_principle" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("area_dram" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("area_non_dram" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write("channel_cnt" + ",") # for now only write the latency accuracy as the other + for key, val in compute_system_attrs.items(): + output_fh_minimal.write(str(key) + ",") + for key, val in bus_system_attrs.items(): + output_fh_minimal.write(str(key) + ",") + for key, val in memory_system_attrs.items(): + output_fh_minimal.write(str(key) + ",") + + for key, val in speedup_attrs.items(): + output_fh_minimal.write(str(key) + ",") + + for key, val in speedup_dict.items(): + output_fh_minimal.write(str(key)+"_speedup_analysis" + ",") + + for key,val in base_budget_scaling.items(): + output_fh_minimal.write("budget_scaling_"+str(key) + ",") + + + + output_fh_minimal.write("\n") + for metric in config.all_metrics: + data_ = sim_dp.dp_stats.get_system_complex_metric(metric) + if isinstance(data_, dict): + data__ = convert_dictionary_to_parsable_csv_with_semi_column(data_) + else: + data__ = data_ + + output_fh_minimal.write(str(data__) + ",") + + if metric in sim_dp.database.db_input.get_budget_dict("glass").keys(): + data_ = sim_dp.database.db_input.get_budget_dict("glass")[metric] + if isinstance(data_, dict): + data__ = convert_dictionary_to_parsable_csv_with_semi_column(data_) + else: + data__ = data_ + output_fh_minimal.write(str(data__) + ",") + + output_fh_minimal.write(sim_dp.database.hw_sampling["mode"] + ",") + output_fh_minimal.write(sim_dp.database.hw_sampling["reduction"] + ",") + for metric, accuracy_percentage in sim_dp.database.hw_sampling["accuracy_percentage"]["ip"].items(): + output_fh_minimal.write(str(accuracy_percentage) + ",") # for now only write the latency accuracy as the other + for block_type, porting_effort in sim_dp.database.db_input.porting_effort.items(): + output_fh_minimal.write(str(porting_effort)+ ",") # for now only write the latency accuracy as the other + + if sim_dp.dp_stats.fits_budget(1): + output_fh_minimal.write("budget_met"+ ",") # for now only write the latency accuracy as the other + else: + output_fh_minimal.write("budget_not_met" + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(case_study + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str(unique_number)+ ",") # for now only write the latency accuracy as the other + + output_fh_minimal.write(str(config.SA_depth)+ ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str(reason_to_terminate)+ ",") # for now only write the latency accuracy as the other + + ma = sim_dp.get_move_applied() # move applied + if not ma == None: + sorted_metrics = convert_tuple_list_to_parsable_csv([(el,val) for el,val in ma.sorted_metrics.items()]) + metric = ma.get_metric() + transformation_name = ma.get_transformation_name() + task_name = ma.get_kernel_ref().get_task_name() + block_type = ma.get_block_ref().type + dir = ma.get_dir() + generation_time = ma.get_generation_time() + sorted_blocks = convert_tuple_list_to_parsable_csv([(el.get_generic_instance_name(), val) for el,val in ma.sorted_blocks]) + sorted_kernels = convert_tuple_list_to_parsable_csv([(el.get_task_name(), val) for el,val in ma.sorted_kernels.items()]) + blk_instance_name = ma.get_block_ref().get_generic_instance_name() + blk_type = ma.get_block_ref().type + + comm_comp = (ma.get_system_improvement_log())["comm_comp"] + high_level_optimization = (ma.get_system_improvement_log())["high_level_optimization"] + exact_optimization = (ma.get_system_improvement_log())["exact_optimization"] + architectural_variable_to_improve = (ma.get_system_improvement_log())["architectural_principle"] + block_selection_time = ma.get_logs("block_selection_time") + kernel_selection_time = ma.get_logs("kernel_selection_time") + transformation_selection_time = ma.get_logs("transformation_selection_time") + else: # happens at the very fist iteration + sorted_metrics = "" + metric = "" + transformation_name = "" + task_name = "" + block_type = "" + dir = "" + generation_time = '' + sorted_blocks = '' + sorted_kernels = {} + blk_instance_name = '' + blk_type = '' + comm_comp = "" + high_level_optimization = "" + architectural_variable_to_improve = "" + block_selection_time = "" + kernel_selection_time = "" + transformation_selection_time = "" + + routing_complexity = sim_dp.dp_rep.get_hardware_graph().get_routing_complexity() + simple_topology = sim_dp.dp_rep.get_hardware_graph().get_simplified_topology_code() + blk_cnt = sum([int(el) for el in simple_topology.split("_")]) + bus_cnt = [int(el) for el in simple_topology.split("_")][0] + mem_cnt = [int(el) for el in simple_topology.split("_")][1] + pe_cnt = [int(el) for el in simple_topology.split("_")][2] + #itr_depth_multiplied = sim_dp.dp_rep.get_iteration_number()*config.SA_depth + sim_dp.dp_rep.get_depth_number() + + output_fh_minimal.write(str(sim_dp.dp_rep.get_population_generation_cnt())+ ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str(dse.get_total_iteration_cnt())+ ",") # for now only write the latency accuracy as the other + output_fh_minimal.write('_'.join(sim_dp.database.db_input.workload_tasks.keys()) +",") + #output_fh_minimal.write(str(itr_depth_multiplied)+ ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str(sim_dp.dp_rep.get_simulation_time())+ ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str(generation_time)+ ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str(kernel_selection_time)+ ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str(block_selection_time)+ ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str(transformation_selection_time)+ ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str(config.transformation_selection_mode)+ ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str(sim_dp.dp_stats.dist_to_goal(metrics_to_look_into = ["area", "latency", "power", "cost"], mode = "eliminate")) + ",") + output_fh_minimal.write(str(sim_dp.dp_stats.dist_to_goal(metrics_to_look_into = ["area", "latency", "power"], mode = "eliminate")) + ",") + output_fh_minimal.write(str(blk_cnt) + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str(pe_cnt) + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str(bus_cnt) + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str(mem_cnt) + ",") # for now only write the latency accuracy as the other + output_fh_minimal.write(str(routing_complexity) + ",") # for now only write the latency accuracy as the other + #output_fh_minimal.write(convert_dictionary_to_parsable_csv_with_semi_column(sim_dp.dp_stats.SOC_area_subtype_dict.keys()) + ",") + output_fh_minimal.write(str(sorted_blocks) + ",") + output_fh_minimal.write(str(sorted_kernels) + ",") + output_fh_minimal.write(str(sorted_metrics)+ ",") + output_fh_minimal.write(str(metric)+",") + output_fh_minimal.write(transformation_name+",") + output_fh_minimal.write(task_name+",") + output_fh_minimal.write(blk_instance_name+",") + output_fh_minimal.write(blk_type+",") + output_fh_minimal.write(str(dir)+",") + output_fh_minimal.write(str(comm_comp)+",") + output_fh_minimal.write(str(high_level_optimization)+",") + output_fh_minimal.write(str(architectural_variable_to_improve)+",") + output_fh_minimal.write(str(sim_dp.dp_stats.get_system_complex_area_stacked_dram()["dram"]) +",") + output_fh_minimal.write(str(sim_dp.dp_stats.get_system_complex_area_stacked_dram()["non_dram"]) +",") + output_fh_minimal.write(str(sim_dp.dp_rep.get_hardware_graph().get_number_of_channels()) +",") + for key, val in compute_system_attrs.items(): + output_fh_minimal.write(str(val) + ",") + for key, val in bus_system_attrs.items(): + output_fh_minimal.write(str(val) + ",") + for key, val in memory_system_attrs.items(): + output_fh_minimal.write(str(val) + ",") + + for key, val in speedup_attrs.items(): + output_fh_minimal.write(str(val) + ",") + + for key, val in speedup_dict.items(): + output_fh_minimal.write(convert_dictionary_to_parsable_csv_with_semi_column(val)+",") + + for key,val in base_budget_scaling.items(): + output_fh_minimal.write(str(val) + ",") + + output_fh_minimal.close() + + + +def simple_run_iterative(result_folder, sw_hw_database_population, system_workers=(1, 1)): + case_study = "simple_run_iterative" + current_process_id = system_workers[0] + total_process_cnt = system_workers[1] + starting_exploration_mode = "from_scratch" + print('cast study:' + case_study) + # ------------------------------------------- + # set parameters + # ------------------------------------------- + experiment_repetition_cnt = 1 + reduction = "most_likely" + + # ------------------------------------------- + # distribute the work + # ------------------------------------------- + work_per_process = math.ceil(experiment_repetition_cnt / total_process_cnt) + run_ctr = 0 + + # ------------------------------------------- + # run the combination and collect the data + # ------------------------------------------- + # get the budget, set them and run FARSI + for i in range(0, work_per_process): + # ------------------------------------------- + # collect the exact hw sampling + # ------------------------------------------- + accuracy_percentage = {} + accuracy_percentage["sram"] = accuracy_percentage["dram"] = accuracy_percentage["ic"] = accuracy_percentage[ + "gpp"] = accuracy_percentage["ip"] = \ + {"latency": 1, + "energy": 1, + "area": 1, + "one_over_area": 1} + hw_sampling = {"mode": "exact", "population_size": 1, "reduction": reduction, + "accuracy_percentage": accuracy_percentage} + + unique_suffix = str(total_process_cnt) + "_" + str(current_process_id) + "_" + str(run_ctr) + + #study = ["boundedness", "serial_parallel"] + study = "boundedness" + #study = "serial_parallel" + #study = "hop_NoC_studies" + + if study == "serial_parallel": + # for serial/parallel studies + #serial_sweep = list(range(0,20,8)) + #parallel_sweep = list(range(0,20, 8)) + #parallel_sweep = [0, 1,2, 4,8] + #serial_sweep = [0, 1, 2, 4,8] + + + parallel_sweep = [0,4,8] #, 4,8] + serial_sweep = [0, 4,8] + #parallel_sweep =[5] + + all_sweep = [serial_sweep, parallel_sweep] + parallel_task_type_list= ["audio_style", "edge_detection_style"] # "audio_style, edge_detection_style + parallel_task_type_list= ["audio_style"] #, "edge_detection_style"] # "audio_style, edge_detection_style + #parallel_task_type_list= ["edge_detection"] #, "edge_detection_style"] # "audio_style, edge_detection_style + for parallel_task_type in parallel_task_type_list: + combos = itertools.product(*all_sweep) + for serial_task_cnt, parallel_task_cnt in combos: + datamovement_scaling_ratio = 1 + sw_hw_database_population["misc_knobs"]['task_spawn']['serial_task_cnt'] = serial_task_cnt + sw_hw_database_population["misc_knobs"]['task_spawn']['parallel_task_cnt'] = parallel_task_cnt + sw_hw_database_population["misc_knobs"]['task_spawn']['parallel_task_type'] = parallel_task_type # can be audio or edge detection + sw_hw_database_population["misc_knobs"]['task_spawn']['boundedness'] = ["memory_intensive", datamovement_scaling_ratio, 1] + sw_hw_database_population["misc_knobs"]['num_of_hops'] = 4 + sw_hw_database_population["misc_knobs"]['num_of_NoCs'] = 4 + #sw_hw_database_population["hw_graph_mode"] = "star_mode" + sw_hw_database_population["hw_graph_mode"] = "hop_mode" + db_input = database_input_class(sw_hw_database_population) + print("hw_sampling:" + str(hw_sampling)) + print("budget set to:" + str(db_input.get_budget_dict("glass"))) + dse_hndlr = run_FARSI_only_simulation(result_folder, unique_suffix, db_input, hw_sampling, + sw_hw_database_population["hw_graph_mode"]) + elif study == "boundedness": + # for boundedness studies + boundedness_ratio_range = list(np.arange(0,1.1,.1)) + #boundedness_ratio_range = [0, .3,.6,1] + boundedness_ratio_range = [.3] + #datamovement_scaling_ratio_range = list(np.arange(.1,1,.2)) + datamovement_scaling_ratio_range = [1] + for boundedness_ratio in boundedness_ratio_range: + for datamovement_scaling_ratio in datamovement_scaling_ratio_range: + sw_hw_database_population["misc_knobs"]['task_spawn']['serial_task_cnt'] = 4 #6 + sw_hw_database_population["misc_knobs"]['task_spawn']['parallel_task_cnt'] = 0 #1 + sw_hw_database_population["misc_knobs"]['task_spawn']['parallel_task_type'] = "audio_style" # can be audio or edge detection + sw_hw_database_population["misc_knobs"]['task_spawn']['boundedness'] = ["memory_intensive", datamovement_scaling_ratio, boundedness_ratio] + sw_hw_database_population["misc_knobs"]['num_of_hops'] = 1 + sw_hw_database_population["misc_knobs"]['num_of_NoCs'] = 1 + db_input = database_input_class(sw_hw_database_population) + print("hw_sampling:" + str(hw_sampling)) + print("budget set to:" + str(db_input.get_budget_dict("glass"))) + dse_hndlr = run_FARSI_only_simulation(result_folder, unique_suffix, db_input, hw_sampling, + sw_hw_database_population["hw_graph_mode"]) + elif study == "hop_studies": + # for boundedness studies + boundedness_ratio_range = [1] + datamovement_scaling_ratio_range = [1] + datamovement_scaling_ratio = 1 + #num_of_hops_range = [1] + num_of_hops_range = [4] + boundedness_ratio = 1 + + #serial_sweep = [0] + #parallel_sweep = [0,2, 4, 8] #, 16] + + serial_sweep = [0] + parallel_sweep = [5,7, 9] #, 16] + + all_sweep = [serial_sweep, parallel_sweep] + parallel_task_type_list= ["audio_style"] #, "edge_detection_style"] # "audio_style, edge_detection_style + + for num_of_hops in num_of_hops_range: + combos = itertools.product(*all_sweep) + for serial_task_cnt, parallel_task_cnt in combos: + sw_hw_database_population["misc_knobs"]['task_spawn']['serial_task_cnt'] = serial_task_cnt + sw_hw_database_population["misc_knobs"]['task_spawn']['parallel_task_cnt'] = parallel_task_cnt + sw_hw_database_population["misc_knobs"]['task_spawn']['parallel_task_type'] = "audio_style" # can be audio or edge detection + sw_hw_database_population["misc_knobs"]['task_spawn']['boundedness'] = ["memory_intensive", datamovement_scaling_ratio, boundedness_ratio] + sw_hw_database_population["misc_knobs"]['num_of_hops'] = num_of_hops + sw_hw_database_population["misc_knobs"]['num_of_NoCs'] = num_of_hops # same as num_of_hops + sw_hw_database_population["hw_graph_mode"] = "hop_mode" + db_input = database_input_class(sw_hw_database_population) + print("hw_sampling:" + str(hw_sampling)) + print("budget set to:" + str(db_input.get_budget_dict("glass"))) + dse_hndlr = run_FARSI_only_simulation(result_folder, unique_suffix, db_input, hw_sampling, + sw_hw_database_population["hw_graph_mode"]) + elif study == "hop_NoC_studies": + # for boundedness studies + boundedness_ratio_range = [1] + datamovement_scaling_ratio_range = [1] + datamovement_scaling_ratio = 1 + #num_of_hops_range = [1] + #num_of_hop_NoCs_range =[[1,1], [2,2],[3,3], [2,3], [2,4],[3,4]] + #num_of_hop_NoCs_range =[[3,4]] + #num_of_hop_NoCs_range =[[1,1], [2,2], [3,3]] + num_of_hop_NoCs_range =[[1,1],[2,2],[2,3], [2,4], [3,3], [3,4], [4,4]] + boundedness_ratio = 1 + + #serial_sweep = [0] + #parallel_sweep = [0,2, 4, 8] #, 16] + + serial_sweep = [4] + parallel_sweep = [0] #, 16] + + all_sweep = [serial_sweep, parallel_sweep] + parallel_task_type_list= ["audio_style"] #, "edge_detection_style"] # "audio_style, edge_detection_style + + for num_of_hops, num_of_NoCs in num_of_hop_NoCs_range: + combos = itertools.product(*all_sweep) + for serial_task_cnt, parallel_task_cnt in combos: + sw_hw_database_population["misc_knobs"]['task_spawn']['serial_task_cnt'] = serial_task_cnt + sw_hw_database_population["misc_knobs"]['task_spawn']['parallel_task_cnt'] = parallel_task_cnt + sw_hw_database_population["misc_knobs"]['task_spawn']['parallel_task_type'] = "audio_style" # can be audio or edge detection + sw_hw_database_population["misc_knobs"]['task_spawn']['boundedness'] = ["memory_intensive", datamovement_scaling_ratio, boundedness_ratio] + sw_hw_database_population["misc_knobs"]['num_of_hops'] = num_of_hops + sw_hw_database_population["misc_knobs"]['num_of_NoCs'] = num_of_NoCs + sw_hw_database_population["hw_graph_mode"] = "hop_mode" + db_input = database_input_class(sw_hw_database_population) + print("hw_sampling:" + str(hw_sampling)) + print("budget set to:" + str(db_input.get_budget_dict("glass"))) + dse_hndlr = run_FARSI_only_simulation(result_folder, unique_suffix, db_input, hw_sampling, + sw_hw_database_population["hw_graph_mode"]) + + + run_ctr += 1 + # write the results in the general folder + result_dir_specific = os.path.join(result_folder, "result_summary") + write_one_results(dse_hndlr.dse.so_far_best_sim_dp, dse_hndlr.dse, dse_hndlr.dse.reason_to_terminate, case_study, + result_dir_specific, unique_suffix, + config.FARSI_simple_run_prefix + "_" + str(current_process_id) + "_" + str(total_process_cnt)) + dse_hndlr.dse.write_data_log(list(dse_hndlr.dse.get_log_data()), dse_hndlr.dse.reason_to_terminate, case_study, result_dir_specific, unique_suffix, + config.FARSI_simple_run_prefix + "_" + str(current_process_id) + "_" + str(total_process_cnt)) + + # write the results in the specific folder + result_folder_modified = result_folder+ "/runs/" + str(ctr) + "/" + os.system("mkdir -p " + result_folder_modified) + copy_DSE_data(result_folder_modified) + write_one_results(dse_hndlr.dse.so_far_best_sim_dp, dse_hndlr.dse, dse_hndlr.dse.reason_to_terminate, case_study, result_folder_modified, unique_suffix, + config.FARSI_simple_run_prefix + "_" + str(current_process_id) + "_" + str(total_process_cnt)) + + os.system("cp " + config.home_dir+"/settings/config.py"+ " "+ result_folder) + +# ------------------------------ +# Functionality: +# a simple run, where FARSI is run to meet certain budget +# Variables: +# system_workers: used for parallelizing the data collection: (current process id, total number workers) +# ------------------------------ +def simple_run(result_folder, sw_hw_database_population, system_workers=(1, 1)): + case_study = "simple_run" + current_process_id = system_workers[0] + total_process_cnt = system_workers[1] + starting_exploration_mode = "from_scratch" + print('cast study:' + case_study) + # ------------------------------------------- + # set parameters + # ------------------------------------------- + experiment_repetition_cnt = 1 + reduction = "most_likely" + + # ------------------------------------------- + # distribute the work + # ------------------------------------------- + work_per_process = math.ceil(experiment_repetition_cnt / total_process_cnt) + run_ctr = 0 + + # ------------------------------------------- + # run the combination and collect the data + # ------------------------------------------- + # get the budget, set them and run FARSI + for i in range(0, work_per_process): + # ------------------------------------------- + # collect the exact hw sampling + # ------------------------------------------- + accuracy_percentage = {} + accuracy_percentage["sram"] = accuracy_percentage["dram"] = accuracy_percentage["ic"] = accuracy_percentage[ + "gpp"] = accuracy_percentage["ip"] = \ + {"latency": 1, + "energy": 1, + "area": 1, + "one_over_area": 1} + hw_sampling = {"mode": "exact", "population_size": 1, "reduction": reduction, + "accuracy_percentage": accuracy_percentage} + db_input = database_input_class(sw_hw_database_population) + print("hw_sampling:" + str(hw_sampling)) + print("budget set to:" + str(db_input.get_budget_dict("glass"))) + unique_suffix = str(total_process_cnt) + "_" + str(current_process_id) + "_" + str(run_ctr) + + + # run FARSI + dse_hndlr = run_FARSI(result_folder, unique_suffix, case_study, db_input, hw_sampling, + sw_hw_database_population["hw_graph_mode"]) + #dse_hndlr = run_FARSI_only_simulation(result_folder, unique_suffix, db_input, hw_sampling, + # sw_hw_database_population["hw_graph_mode"]) + + + run_ctr += 1 + + return dse_hndlr + + + +# ------------------------------ +# Functionality: +# generate a range of values between a lower and upper bound +# Variables: +# cnt: number of points within the range to generate +# ------------------------------ +def gen_range(lower_bound, upper_bound, cnt): + if cnt == 1: + upper_bound = lower_bound + step_size = lower_bound + else: + step_size = (upper_bound - lower_bound) / cnt + + range_= list(np.arange(lower_bound, min(.9*lower_bound, (10**-9)*upper_bound) + upper_bound, step_size)) + range_formatted = [float("{:.9f}".format(el)) for el in range_] + if len(range_formatted) == 1: + return range_formatted + else: + return range_formatted[:-1] + + +# ------------------------------ +# Functionality: +# generate all the combinations of inputs +# Variables: +# args: consist list of arg values. Each arg specfies a range of values using (lower bound, upper bound, cnt) +# ------------------------------ +def gen_combinations(args): + list_of_ranges = [] + for arg in args: + list_of_ranges.append(gen_range(arg[0], arg[1], arg[2])) + all_combinations = [*list_of_ranges] + all_budget_combinations = itertools.product(*list_of_ranges) + all_budget_combinations_listified = [el for el in all_budget_combinations] + return all_budget_combinations_listified + + +# ------------------------------ +# Functionality: +# conduct a host of studies depending on whether input_error or input_cost are set. Here are the combinations: +# input_error = False, input_cost = False: conduct a cost-PPA study +# input_error = True , input_cost = False: conduct a study to figure out the impact of input error on cost +# input_error = True , input_cost = True : TBD +# Variables: +# system_workers: used for parallelizing the data collection: (current process id, total number workers) +# ------------------------------ +def input_error_output_cost_sensitivity_study(result_folder, sw_hw_database_population, system_workers=(1,1), input_error=False, input_cost=False): + current_process_id = system_workers[0] + total_process_cnt = system_workers[1] + + # ----------------------- + # set up the case study + # ----------------------- + case_study = "" + if not input_error and not input_cost: + case_study = "cost_PPA" + file_prefix = config.FARSI_cost_correlation_study_prefix + elif input_error and not input_cost: + case_study = "input_error_output_cost" + print("input error cost study") + file_prefix = config.FARSI_input_error_output_cost_sensitivity_study_prefix + elif input_error and input_cost: + case_study = "input_error_input_cost" + file_prefix = config.FARSI_input_error_input_cost_sensitivity_study_prefix + else: + print("this study is not supported") + exit(0) + + print("conducting " + case_study) + + # ----------------------- + # first extract the current budget + # ----------------------- + accuracy_percentage = {} + accuracy_percentage["sram"] = accuracy_percentage["dram"] = accuracy_percentage["ic"] = accuracy_percentage["gpp"] = \ + accuracy_percentage["ip"] = \ + {"latency": 1, + "energy": 1, + "area": 1, + "one_over_area": 1} + hw_sampling = {"mode": "exact", "population_size": 1, "reduction": "most_likey", + "accuracy_percentage": accuracy_percentage} + db_input = database_input_class(sw_hw_database_population) + budgets_dict = {} # set the reference budget + budgets_dict['latency'] = db_input.budgets_dict['glass']['latency'][list(workloads)[0]] + budgets_dict['power'] = db_input.budgets_dict['glass']['power'] + budgets_dict['area'] = db_input.budgets_dict['glass']['area'] + + #------------------------------------------- + # set sweeping parameters + #------------------------------------------- + experiment_repetition_cnt = 1 + budget_cnt = 3 + budget_upper_bound_factor = {} + budget_upper_bound_factor["perf"] = 10 + budget_upper_bound_factor["power"] = 10 + budget_upper_bound_factor["area"] = 100 + + if not input_error: + accuracy_lower_bound = accuracy_upper_bound = 1 + accuracy_cnt = 1 # number of accuracy values to use + else: + accuracy_lower_bound = .5 + accuracy_cnt = 3 # number of accuracy values to use + accuracy_upper_bound = 1 + + if not input_cost: + effort_lower_bound = effort_upper_bound = 100 + effort_cnt = 1 # number of accuracy values to use + else: + effort_lower_bound = 20 + effort_cnt = 3 # number of accuracy values to use + effort_upper_bound = 100 + + # ------------------------------------------- + # generate all the combinations of the budgets + # ------------------------------------------- + combination_input =[] + combination_input.append((budgets_dict["latency"], budget_upper_bound_factor["perf"]*budgets_dict["latency"], budget_cnt)) + combination_input.append((budgets_dict["power"], budget_upper_bound_factor["power"]*budgets_dict["power"], budget_cnt)) + combination_input.append((budgets_dict["area"], budget_upper_bound_factor["area"]*budgets_dict["area"], budget_cnt)) + + combination_input.append((accuracy_lower_bound, accuracy_upper_bound, accuracy_cnt)) + combination_input.append((effort_lower_bound, effort_upper_bound, effort_cnt)) + + all_combinations = gen_combinations(combination_input) + + #------------------------------------------- + # distribute the work + #------------------------------------------- + combo_cnt = len(list(all_combinations)) + work_per_process = math.ceil(combo_cnt/total_process_cnt) + run_ctr = 0 + + #------------------------------------------- + # run the combination and collect the data + #------------------------------------------- + # get the budget, set them and run FARSI + reduction = "most_likely_with_accuracy_percentage" + for i in range(0, experiment_repetition_cnt): + for latency, power, area, accuracy, effort in list(all_combinations)[current_process_id* work_per_process: min((current_process_id+ 1) * work_per_process, combo_cnt)]: + # iterate though metrics and set the budget + + accuracy_percentage = {} + accuracy_percentage["sram"] = accuracy_percentage["dram"] = accuracy_percentage["ic"] = accuracy_percentage["gpp"] = {"latency": 1, "energy": 1, + "area": 1, "one_over_area": 1} + accuracy_percentage["ip"] = {"latency": accuracy, "energy": 1 / pow(accuracy, 2), "area": 1, + "one_over_area": 1} + hw_sampling = {"mode": "exact", "population_size": 1, "reduction": reduction, + "accuracy_percentage": accuracy_percentage} + db_input = database_input_class(sw_hw_database_population) + + # set the budget + budgets_dict = {} + budgets_dict["glass"] = {} + budgets_dict["glass"]["latency"] = {list(workloads)[0]:latency} + budgets_dict["glass"]["power"] = power + budgets_dict["glass"]["area"] = area + db_input.set_budgets_dict_directly(budgets_dict) + db_input.set_porting_effort_for_block("ip", effort) # only playing with ip now + unique_suffix = str(total_process_cnt) + "_" + str(current_process_id) + "_" + str(run_ctr) + + print("hw_sampling:" + str(hw_sampling)) + print("budget set to:" + str(db_input.budgets_dict)) + dse_hndlr = run_FARSI(result_folder, unique_suffix, db_input, hw_sampling, sw_hw_database_population["hw_graph_mode"]) + run_ctr += 1 + # write the results in the general folder + result_dir_specific = os.path.join(result_folder, "result_summary") + write_one_results(dse_hndlr.dse.so_far_best_sim_dp, dse_hndlr.dse, dse_hndlr.dse.reason_to_terminate, case_study, result_dir_specific, unique_suffix, + file_prefix + "_" + str(current_process_id) + "_" + str(total_process_cnt)) + + # write the results in the specific folder + result_folder_modified = result_folder + "/runs/" + str(run_ctr) + "/" + os.system("mkdir -p " + result_folder_modified) + copy_DSE_data(result_folder_modified) + write_one_results(dse_hndlr.dse.so_far_best_sim_dp, dse_hndlr.dse, dse_hndlr.dse.reason_to_terminate, case_study, result_folder_modified, unique_suffix, + file_prefix + "_" + str(current_process_id) + "_" + str(total_process_cnt)) + + +if __name__ == "__main__": + # set the number of workers to be used (parallelism applied) + current_process_id = 0 + total_process_cnt = 1 + system_workers = (current_process_id, total_process_cnt) + + + # set the study type + #study_type = "cost_PPA" + study_type = "simple_run" + #study_subtype = "plot_3d_distance" + study_subtype = "run" + assert study_type in ["cost_PPA", "simple_run", "input_error_output_cost_sensitivity", "input_error_input_cost_sensitivity"] + assert study_subtype in ["run", "plot_3d_distance"] + + + # set the study parameters + # set the workload + + #workloads = {"edge_detection"} + #workloads = {"hpvm_cava"} + #workloads = {"audio_decoder"} + #workloads ={"edge_detection","hpvm_cava", "audio_decoder"} + workloads = {"edge_detection", "audio_decoder"} + workloads = {"SLAM"} + workloads = {"simple_multiple_hops"} + workloads = {"simple_serial_tg"} + #workloads = {"partial_SOC_example_hard"} + #workloads = {"simple_all_parallel"} + + workloads_first_letter = '_'.join(sorted([el[0] for el in workloads])) + # set result folder + result_home_dir_default = os.path.join(os.getcwd(), "data_collection/data/" + study_type) + result_home_dir = os.path.join(config.home_dir, "data_collection/data/" + study_type) + date_time = datetime.now().strftime('%m-%d_%H-%M_%S') + #config_obj = config + #config_obj_budg_dict = config_obj.budget_dict + budget_values = "pow_"+str(config.budget_dict["glass"]["power"]) + "__area_"+str(config.budget_dict["glass"]["area"]) + result_folder = os.path.join(result_home_dir, + date_time + "____"+ budget_values+"___workloads_"+workloads_first_letter) + + + # set the IP spawning params + ip_loop_unrolling = {"incr": 2, "max_spawn_ip": 17, "spawn_mode": "geometric"} + #ip_freq_range = {"incr":3, "upper_bound":8} + #mem_freq_range = {"incr":3, "upper_bound":6} + #ic_freq_range = {"incr":4, "upper_bound":6} + ip_freq_range = [1,4,6,8] + mem_freq_range = [1,4,6] + ic_freq_range = [1,4,6] + #tech_node_SF = {"perf":1, "energy":.064, "area":.079} # technology node scaling factor + tech_node_SF = {"perf":1, "energy":{"non_gpp":.064, "gpp":1}, "area":{"non_mem":.0374 , "mem":.07, "gpp":1}} # technology node scaling factor + db_population_misc_knobs = {"ip_freq_correction_ratio": 1, "gpp_freq_correction_ratio": 1, + "ip_spawn": {"ip_loop_unrolling": ip_loop_unrolling, "ip_freq_range": ip_freq_range}, + "mem_spawn": {"mem_freq_range":mem_freq_range}, + "ic_spawn": {"ic_freq_range":ic_freq_range}, + "tech_node_SF":tech_node_SF, + "base_budget_scaling":{"latency":.5, "power":1, "area":1}} + + # set software hardware database population + # for SLAM, and serial experiment + sw_hw_database_population = {"db_mode": "hardcoded", "hw_graph_mode": "generated_from_scratch", + "workloads": workloads, "misc_knobs": db_population_misc_knobs} + + #for multiple hops experiment + #sw_hw_database_population = {"db_mode": "hardcoded", "hw_graph_mode": "hardcoded", + # "workloads": workloads, "misc_knobs": db_population_misc_knobs} + # for paper workloads + #sw_hw_database_population = {"db_mode": "parse", "hw_graph_mode": "generated_from_scratch", + # "workloads": workloads, "misc_knobs": db_population_misc_knobs} + # for check pointed + #sw_hw_database_population = {"db_mode": "parse", "hw_graph_mode": "generated_from_check_point", + # "workloads": workloads, "misc_knobs": db_population_misc_knobs} + + + # depending on the study/substudy type, invoke the appropriate function + if study_type == "simple_run": + simple_run(result_folder, sw_hw_database_population, system_workers) + elif study_type == "cost_PPA" and study_subtype == "run": + input_error_output_cost_sensitivity_study(result_folder, sw_hw_database_population, system_workers, False, False) + elif study_type == "input_error_output_cost_sensitivity" and study_subtype == "run": + input_error_output_cost_sensitivity_study(result_folder, sw_hw_database_population, system_workers, True, False) + elif study_type == "input_error_input_cost_sensitivity" and study_subtype == "run": + input_error_output_cost_sensitivity_study(result_folder, sw_hw_database_population, system_workers,True, True) + elif study_type == "cost_PPA" and study_subtype == "plot_3d_distance": + result_folder = "05-28_18-46_40" # edge detection + result_folder = "05-28_18-47_33" # hpvm cava + result_folder = "05-28_18-47_03" + result_folder = "05-31_16-24_49" # hpvm cava (2, tighter constraints) + result_dir_addr= os.path.join(config.home_dir, 'data_collection/data/', study_type, result_folder, + "result_summary") + full_file_addr = os.path.join(result_dir_addr, + config.FARSI_cost_correlation_study_prefix + "_0_1.csv") + plot_3d_dist(result_dir_addr, full_file_addr, workloads) \ No newline at end of file diff --git a/Project_FARSI/data_collection/collection_utils/what_ifs/FARSI_what_ifs_simple.py b/Project_FARSI/data_collection/collection_utils/what_ifs/FARSI_what_ifs_simple.py new file mode 100644 index 00000000..cfa911d2 --- /dev/null +++ b/Project_FARSI/data_collection/collection_utils/what_ifs/FARSI_what_ifs_simple.py @@ -0,0 +1,60 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +import sys +import os +sys.path.append(os.path.abspath('./../')) +import home_settings +from top.main_FARSI import run_FARSI +from settings import config +import os +import itertools +# main function +import numpy as np +from mpl_toolkits.mplot3d import Axes3D +from matplotlib import cm +import matplotlib.pyplot as plt +from visualization_utils.vis_hardware import * +import numpy as np +from specs.LW_cl import * +from specs.database_input import * +import math +import matplotlib.colors as colors +#import pandas +import matplotlib.colors as mcolors +import pandas as pd +import argparse, sys +from specs import database_input + + +import FARSI_what_ifs + + +if __name__ == "__main__": + # set the number of workers to be used (parallelism applied) + current_process_id = 0 + total_process_cnt = 1 + system_workers = (current_process_id, total_process_cnt) + + # set the study type + study_type = "cost_PPA" + study_subtype = "run" + assert study_type in ["cost_PPA", "simple_run", "input_error_output_cost_sensitivity", "input_error_input_cost_sensitivity"] + assert study_subtype in ["run", "plot_3d_distance"] + + # set result folder according to the time and the study type + result_home_dir = os.path.join(config.home_dir, "data_collection/data/" + study_type) + date_time = datetime.now().strftime('%m-%d_%H-%M_%S') + result_folder = os.path.join(result_home_dir, + date_time) + + # set the workload + workloads = {"audio_decoder"} # select from {"audio_decoder", "edge_detection", "hpvm_cava"} + + # set software hardware database population (refer to database.py for more information} + sw_hw_database_population = {"db_mode": "parse", "hw_graph_mode": "generated_from_scratch", "workloads": workloads, + "misc_knobs":{}} + + # run FARSI (a simple exploration study) + FARSI_what_ifs.simple_run(result_folder, sw_hw_database_population, system_workers) \ No newline at end of file diff --git a/Project_FARSI/data_collection/collection_utils/what_ifs/FARSI_what_ifs_with_params.py b/Project_FARSI/data_collection/collection_utils/what_ifs/FARSI_what_ifs_with_params.py new file mode 100644 index 00000000..62b900fe --- /dev/null +++ b/Project_FARSI/data_collection/collection_utils/what_ifs/FARSI_what_ifs_with_params.py @@ -0,0 +1,352 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +import sys +import os +import shutil +import multiprocessing +import psutil +sys.path.append(os.path.abspath('./../')) +import home_settings +from top.main_FARSI import run_FARSI +from top.main_FARSI import run_FARSI +from settings import config +import os +import itertools +# main function +import numpy as np +from mpl_toolkits.mplot3d import Axes3D +from matplotlib import cm +import matplotlib.pyplot as plt +from visualization_utils.vis_hardware import * +import numpy as np +from specs.LW_cl import * +from specs.database_input import * +import math +import matplotlib.colors as colors +#import pandas +import matplotlib.colors as mcolors +import pandas as pd +import argparse, sys +from FARSI_what_ifs import * +import os.path + + + +# selecting the database based on the simulation method (power or performance) +if config.simulation_method == "power_knobs": + from specs import database_input_powerKnobs as database_input +elif config.simulation_method == "performance": + from specs import database_input +else: + raise NameError("Simulation method unavailable") + + + +def run_with_params(workloads, SA_depth, freq_range, base_budget_scaling, trans_sel_mode, study_type, workload_folder, date_time, check_points, ret_value): + config.transformation_selection_mode = trans_sel_mode + config.SA_depth = SA_depth + # set the number of workers to be used (parallelism applied) + current_process_id = 0 + total_process_cnt = 1 + system_workers = (current_process_id, total_process_cnt) + + # set the study type + #study_type = "cost_PPA" + + + workloads_first_letter = '_'.join(sorted([el[0] for el in workloads])) + budget_values = "lat_"+str(base_budget_scaling["latency"])+"__pow_"+str(base_budget_scaling["power"]) + "__area_"+str(base_budget_scaling["area"]) + + + # set result folder + if check_points["start"]: + append = check_points["folder"].split("/")[-2] + result_folder = os.path.join(workload_folder, append) + # copy the previous results + if config.memory_conscious and not check_points['prev_itr'] == "": + src = check_points["prev_itr"] + des = os.path.join(result_folder, "result_summary", "prev_iter") + os.makedirs(des, exist_ok=True) + des = os.path.join(result_folder, "result_summary", "prev_iter", "result_summary") + destination = shutil.copytree(src, des) + else: + result_folder = os.path.join(workload_folder, + date_time + "____"+ budget_values +"___workloads_"+workloads_first_letter) + # set the IP spawning params + ip_loop_unrolling = {"incr": 2, "max_spawn_ip": 17, "spawn_mode": "geometric"} + #ip_freq_range = {"incr":3, "upper_bound":8} + #mem_freq_range = {"incr":3, "upper_bound":6} + #ic_freq_range = {"incr":4, "upper_bound":6} + ip_freq_range = freq_range + mem_freq_range = freq_range + ic_freq_range = freq_range + tech_node_SF = {"perf":1, "energy":{"non_gpp":.064, "gpp":1}, "area":{"non_mem":.0374 , "mem":.07, "gpp":1}} # technology node scaling factor + db_population_misc_knobs = {"ip_freq_correction_ratio": 1, "gpp_freq_correction_ratio": 1, + "ip_spawn": {"ip_loop_unrolling": ip_loop_unrolling, "ip_freq_range": ip_freq_range}, + "mem_spawn": {"mem_freq_range":mem_freq_range}, + "ic_spawn": {"ic_freq_range":ic_freq_range}, + "tech_node_SF":tech_node_SF, + "base_budget_scaling":base_budget_scaling, + "queue_available_size":[1, 2, 4, 8, 16], + "burst_size_options":[1024], + "task_spawn":{"parallel_task_cnt":2, "serial_task_cnt":3}} + + # set software hardware database population + # for scalibility studies + #sw_hw_database_population = {"db_mode": "generate", "hw_graph_mode": "generated_from_scratch", + # "workloads": workloads, "misc_knobs": db_population_misc_knobs} + # for SLAM + #sw_hw_database_population = {"db_mode": "hardcoded", "hw_graph_mode": "generated_from_scratch", + # "workloads": workloads, "misc_knobs": db_population_misc_knobs} + # for paper workloads + sw_hw_database_population = {"db_mode": "parse", "hw_graph_mode": "generated_from_scratch", + "workloads": workloads, "misc_knobs": db_population_misc_knobs} + #sw_hw_database_population = {"db_mode": "parse", "hw_graph_mode": "generated_from_check_point", + # "workloads": workloads, "misc_knobs": db_population_misc_knobs} + # for check pointed + if check_points["start"]: + config.check_point_folder = check_points["folder"] + if not os.path.exists(config.check_point_folder) : + print("check point folder to start from doesn't exist") + print("either start from scratch or fix the folder address") + exit(0) + sw_hw_database_population = {"db_mode": "parse", "hw_graph_mode": "generated_from_check_point", + "workloads": workloads, "misc_knobs": db_population_misc_knobs} + + + + # depending on the study/substudy type, invoke the appropriate function + if study_type == "simple_run": + dse_hndler = simple_run(result_folder, sw_hw_database_population, system_workers) + if study_type == "simple_run_iterative": + dse_hndler = simple_run_iterative(result_folder, sw_hw_database_population, system_workers) + elif study_type == "cost_PPA" and study_subtype == "run": + input_error_output_cost_sensitivity_study(result_folder, sw_hw_database_population, system_workers, False, False) + elif study_type == "input_error_output_cost_sensitivity" and study_subtype == "run": + input_error_output_cost_sensitivity_study(result_folder, sw_hw_database_population, system_workers, True, False) + elif study_type == "input_error_input_cost_sensitivity" and study_subtype == "run": + input_error_output_cost_sensitivity_study(result_folder, sw_hw_database_population, system_workers,True, True) + elif study_type == "cost_PPA" and study_subtype == "plot_3d_distance": + result_folder = "05-28_18-46_40" # edge detection + result_folder = "05-28_18-47_33" # hpvm cava + result_folder = "05-28_18-47_03" + result_folder = "05-31_16-24_49" # hpvm cava (2, tighter constraints) + result_dir_addr= os.path.join(config.home_dir, 'data_collection/data/', study_type, result_folder, + "result_summary") + full_file_addr = os.path.join(result_dir_addr, + config.FARSI_cost_correlation_study_prefix + "_0_1.csv") + plot_3d_dist(result_dir_addr, full_file_addr, workloads) + + print("reason to terminate: " + dse_hndler.dse.reason_to_terminate) + ret_value.value = int(dse_hndler.dse.reason_to_terminate == "out_of_memory") + + + + +def run(check_points_start, check_points_top_folder, previous_results): + #study_type = "simple_run_iterative" + study_type = "simple_run" + #study_subtype = "plot_3d_distance" + study_subtype = "run" + assert study_type in ["cost_PPA", "simple_run", "input_error_output_cost_sensitivity", "input_error_input_cost_sensitivity", "simple_run_iterative"] + assert study_subtype in ["run", "plot_3d_distance"] + SA_depth = [10] + freq_range = [1, 4, 6, 8] + #freq_range = [1] #, 4, 6, 8] + + # fast run + workloads = [{"audio_decoder"}] + #workloads = [{"synthetic"}] + workloads = [{"hpvm_cava"}] + workloads = [{"edge_detection"}] + workloads = [ {"edge_detection_1"},{"edge_detection_1", "edge_detection_2"}, {"edge_detection_1", "edge_detection_2", "edge_detection_3"}, {"edge_detection_1", "edge_detection_2", "edge_detection_3", "edge_detection_4"} ]#, "edge_detection_4"}] + + #workloads = [{"edge_detection_1", "edge_detection_2"}] + #workloads = [{"SLAM"}] + + #workloads =[{"audio_decoder", "hpvm_cava"}] + + # each workload in isolation + #workloads =[{"audio_decoder"}, {"edge_detection"}, {"hpvm_cava"}] + + # all workloads together + #workloads =[{"audio_decoder", "edge_detection", "hpvm_cava"}] + + # entire workload set + #workloads = [{"hpvm_cava"}, {"audio_decoder"}, {"edge_detection"}, {"edge_detection", "audio_decoder"}, {"hpvm_cava", "audio_decoder"}, {"hpvm_cava", "edge_detection"} , {"audio_decoder", "edge_detection", "hpvm_cava"}] + + latency_scaling_range = [.8, 1, 1.2] + power_scaling_range = [.8,1,1.2] + area_scaling_range = [.8,1,1.2] + + # edge detection lower budget + latency_scaling_range = [1] + # for audio + #power_scaling_range = [.6,.5,.4,.3] + #area_scaling_range = [.6,.5,.5,.3] + + power_scaling_range = [1] + area_scaling_range = [1] + + result_home_dir_default = os.path.join(os.getcwd(), "data_collection/data/" + study_type) + result_folder = os.path.join(config.home_dir, "data_collection/data/" + study_type) + date_time = datetime.now().strftime('%m-%d_%H-%M_%S') + run_folder = os.path.join(result_folder, date_time) + os.mkdir(run_folder) + + #transformation_selection_mode_list = ["random", "arch-aware"] # choose from {random, arch-aware} + #transformation_selection_mode_list = ["random"] + transformation_selection_mode_list = ["arch-aware"] + + check_points_values = [] + if check_points_start: + if not os.path.exists(check_points_top_folder) : + print("check point folder to start from doesn't exist") + print("either start from scratch or fix the folder address") + exit(0) + + all_dirs = [x[0] for x in os.walk(check_points_top_folder)] + check_point_folders = [dir for dir in all_dirs if "check_points" in dir] + + for folder in check_point_folders: + check_points_values.append((True, folder)) + else: + check_points_values.append((False, "")) + + for check_point_el in check_points_values: + check_point = {"start":check_point_el[0], "folder":check_point_el[1], "prev_itr": previous_results} + for trans_sel_mode in transformation_selection_mode_list: + for w in workloads: + workloads_first_letter = '_'.join(sorted([el[0] for el in w])) +"__"+trans_sel_mode[0] + workload_folder = os.path.join(run_folder, workloads_first_letter) + if not os.path.exists(workload_folder): + os.mkdir(workload_folder) + for d in SA_depth: + for latency_scaling,power_scaling, area_scaling in itertools.product(latency_scaling_range, power_scaling_range, area_scaling_range): + base_budget_scaling = {"latency": latency_scaling, "power": power_scaling, "area": area_scaling} + if config.memory_conscious: + # use subprocess to free memory + ret_value = multiprocessing.Value("d", 0.0, lock=False) + p = multiprocessing.Process(target=run_with_params, args=[w, d, freq_range, base_budget_scaling, trans_sel_mode, study_type, workload_folder, date_time, check_point, ret_value]) + p.start() + p.join() + + # checking for memory issues + if ret_value.value == 1: + return "out_of_memory", run_folder + else: + dse_hndler = run_with_params(w, d, freq_range, base_budget_scaling, trans_sel_mode, study_type, workload_folder, date_time, check_point) + return "others", run_folder + +def create_final_folder(run_folder): + source = run_folder + destination_parts = run_folder.split("/") + destination_last_folder = "final_" + destination_parts[-1] + destination_parts[-1] = destination_last_folder + destination = "/".join(destination_parts) + os.rename(source, destination) + return destination + +def aggregate_results(run_folder): + all_dirs = [x[0] for x in os.walk(run_folder) if 'result_summary' == x[0].split("/")[-1]] + sorted_based_on_depth = sorted(all_dirs, reverse=True) + + # create a new file + most_recent_directory = sorted_based_on_depth[-1] + file_to_copy_to = os.path.join(most_recent_directory, "aggregate_all_results.csv") + with open(file_to_copy_to, 'w') as fp: + pass + + # iterate through all the folder and append to the new file + first = True + for dir in sorted_based_on_depth: + if "result_summary" in dir: + file_to_copy = os.path.join(dir, "FARSI_simple_run_0_1_all_reults.csv") + file = open(file_to_copy, "r") + data2 = file.read().splitlines(True) + file.close() + fout = open(file_to_copy_to, "a") + if first: + fout.writelines(data2[:]) + else: + fout.writelines(data2[1:]) + first = False + fout.close() + + previous_results = [dir for dir in all_dirs if "result_summary" in dir][0] + + +def run_batch(check_points_start, check_points_top_folder): + # check pointing information + """ + #check_points_start = False + # check_points_top_folder = "/Users/behzadboro/Project_FARSI_dir/Project_FARSI_with_channels/data_collection/data/simple_run/12-20_15-37_33/data_per_design/12-20_15-39_38_16/PA_knob_ctr_0/" + # "/media/reddi-rtx/KINGSTON/FARSI_results/scaling_of_1_2_4_across_all_budgets_07-31" + # check_points_top_folder = "/home/reddi-rtx/FARSI_related_stuff/Project_FARSI_TECS/Project_FARSI_6/data_collection/data/simple_run/02-28_17-00_03/a_e_h__r/02-28_17-00_03____lat_1__pow_1__area_1___workloads_a_e_h/check_points" + # check_points_top_folder = "/home/reddi-rtx/FARSI_related_stuff/Project_FARSI_TECS/Project_FARSI_6/data_collection/data/simple_run/02-28_17-52_30/a_e_h__r/02-28_17-52_30____lat_1__pow_1__area_1___workloads_a_e_h/check_points" + check_points_top_folder = "/home/reddi-rtx/FARSI_related_stuff/Project_FARSI_TECS/Project_FARSI_6/data_collection/data/simple_run/third_leg/a_e_h__r/02-28_17-52_30____lat_1__pow_1__area_1___workloads_a_e_h/check_points" + check_points_top_folder = "/home/reddi-rtx/FARSI_related_stuff/Project_FARSI_TECS/Project_FARSI_6/data_collection/data/simple_run/03-01_15-54_25/a_e_h__r/03-01_15-54_25____lat_1__pow_1__area_1___workloads_a_e_h/check_points" + check_points_top_folder = "/home/reddi-rtx/FARSI_related_stuff/Project_FARSI_TECS/Project_FARSI_6/data_collection/data/simple_run/03-01_15-54_25" + check_points_top_folder = "/home/reddi-rtx/FARSI_related_stuff/Project_FARSI_TECS/Project_FARSI_6/data_collection/data/simple_run/03-02_13-47_03" + check_points_top_folder = "/home/reddi-rtx/FARSI_related_stuff/Project_FARSI_TECS/Project_FARSI_6/data_collection/data/simple_run/03-03_08-17_32" + check_points_top_folder = "/home/reddi-rtx/FARSI_related_stuff/Project_FARSI_TECS/Project_FARSI_6/data_collection/data/simple_run/03-03_13-47_59" + #check_points_top_folder ="/home/reddi-rtx/FARSI_related_stuff/Project_FARSI_TECS/Project_FARSI_6/data_collection/data/simple_run/03-04_08-47_00" + #check_points_top_folder = "" + #previous_results = "" + """ + if check_points_start: + all_dirs = [x[0] for x in os.walk(check_points_top_folder)] + previous_results = [dir for dir in all_dirs if "result_summary" in dir][0] + else: + previous_results = "" + + ctr =0 + while True: + termination_cause, run_folder = run(check_points_start, check_points_top_folder, previous_results) + # to be backward compatible, + # we leave this scenario in + if not config.memory_conscious: + break + + # if out of memory, run again from the check point + if termination_cause == "out_of_memory": + ctr += 1 + check_points_start = True + check_points_top_folder = run_folder + all_dirs = [x[0] for x in os.walk(check_points_top_folder)] + previous_results = [dir for dir in all_dirs if "result_summary" in dir][0] + else: + # adjust the name so we would know which folder contains the final information + run_folder = create_final_folder(run_folder) + # aggregate the results (as they are spread out among multiple folders) + aggregate_results(run_folder) + break + + +def get_all_final_folders(check_points_start): + if not check_points_start: + return "" + result_folder = os.path.join(config.home_dir, "data_collection/data/simple_run/"+config.heuristic_type) + all_dirs = [os.path.join(result_folder, f) for f in os.listdir(result_folder)] + #all_dirs = [x[0] for x in os.walk(result_folder)] + check_point_folders = [dir for dir in all_dirs if "final" in dir] + return check_point_folders + +if __name__ == "__main__": + batch_count = 1 + check_points_top_folders = ["/home/reddi-rtx/FARSI_related_stuff/Project_FARSI_TECS/Project_FARSI_6/data_collection/data/simple_run/03-03_13-47_59"] + check_points_start = False + check_points_top_folders = get_all_final_folders(check_points_start) + + + if check_points_start: + for check_point_top_folder in check_points_top_folders: + #assert(batch_count == 1) + for batch_number in range(0, batch_count): + run_batch(check_points_start, check_point_top_folder) + else: + for batch_number in range(0, batch_count): + run_batch(check_points_start, "") + diff --git a/Project_FARSI/data_collection/collection_utils/what_ifs/autoWLandDep.py b/Project_FARSI/data_collection/collection_utils/what_ifs/autoWLandDep.py new file mode 100644 index 00000000..b7d26c55 --- /dev/null +++ b/Project_FARSI/data_collection/collection_utils/what_ifs/autoWLandDep.py @@ -0,0 +1,13 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +depths = [3, 5, 8, 10] +workloads = [{"edge_detection"}, {"hpvm_cava"}, {"audio_decoder"}, {"edge_detection", "hpvm_cava"}, {"edge_detection", "audio_decoder"}, {"hpvm_cava", "audio_decoder"}, {"audio_decoder", "edge_detection", "hpvm_cava"}] + +for d in depths: + for w in workloads: + print(d) + print(w) + print("\n") + print("\n") \ No newline at end of file diff --git a/Project_FARSI/design_utils/__init__.py b/Project_FARSI/design_utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Project_FARSI/design_utils/common_design_utils.py b/Project_FARSI/design_utils/common_design_utils.py new file mode 100644 index 00000000..d51c5bc8 --- /dev/null +++ b/Project_FARSI/design_utils/common_design_utils.py @@ -0,0 +1,34 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + + +from settings import config +# ------------------------------ +# Functionality: +# this function helps us to convert energy to power by slicing the energy collected in time to +# smaller phases (PCP) and calculate the power +# PCP: power collection period +# ------------------------------ +def slice_phases_with_PWP(sorted_phase_latency_dict): + lower_bound_idx, upper_bound_idx = 0, 0 + phase_bounds_list = [] + budget_before_next_collection = config.PCP + for phase, latency in sorted_phase_latency_dict.items(): + if latency == 0: + continue + if latency > budget_before_next_collection: + budget_before_next_collection = config.PCP + phase_bounds_list.append( + (lower_bound_idx, upper_bound_idx + 1)) # we increment by 1, cause this is used as the upper bound + lower_bound_idx = upper_bound_idx + 1 + else: + budget_before_next_collection -= latency + upper_bound_idx += 1 + + # add whatever wasn't included at the end + if not phase_bounds_list: + phase_bounds_list.append((0, len(list(sorted_phase_latency_dict.values())))) + elif not phase_bounds_list[-1][1] == len(list(sorted_phase_latency_dict.values())): + phase_bounds_list.append((phase_bounds_list[-1][1], len(list(sorted_phase_latency_dict.values())))) + return phase_bounds_list \ No newline at end of file diff --git a/Project_FARSI/design_utils/components/__init__.py b/Project_FARSI/design_utils/components/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Project_FARSI/design_utils/components/hardware.py b/Project_FARSI/design_utils/components/hardware.py new file mode 100644 index 00000000..c5ed7280 --- /dev/null +++ b/Project_FARSI/design_utils/components/hardware.py @@ -0,0 +1,1841 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +import json +import os +import itertools +from settings import config +from typing import List, Tuple +from design_utils.components.workload import * +import copy +import math +import time +from collections import deque + + +# This class emulates hardware queues. +# At the moment, we are using the pipe class to the same thing +# so not using this class +class HWQueue(): + def __init(self, max_size, buffer_size): + self.max_size = max_size + self.q_data = deque() + self.buffer_size = buffer_size + + def enqueue(self, data): + if self.is_full(): + return False + self.q_data.insert(0, data) + return True + + def dequeue(self): + if self.is_empty(): + return False + self.q_data.pop() + return True + + def peek(self): + if self.is_empty(): + return False + return self.q_data[-1] + + def size(self): + return len(self.q_data) + + def is_empty(self): + return (self.size() == 0) + + def is_full(self): + total_data = 0 + for front_ in self.q_data: + total_data += front_.total_work + return total_data >= self.max_size + + #return self.size() == self.max_size + + def __str__(self): + return str(self.q_data) + + +if config.simulation_method == "power_knobs": + from specs import database_input_powerKnobs as database_input +elif config.simulation_method == "performance": + from specs import database_input +else: + raise NameError("Simulation method unavailable") + + +# This class emulates the behavior of a hardware block +class Block: + id_counter = 0 + block_numbers_seen = [] + def __init__(self, db_input, hw_sampling, instance_name, type, subtype, + peak_work_rate_distribution, work_over_energy_distribution, work_over_area_distribution, + one_over_area_distribution, clock_freq, bus_width, loop_itr_cnt, loop_max_possible_itr_cnt, hop_latency, pipe_line_depth, + leakage_power="", power_knobs="", + SOC_type="", SOC_id=""): + self.db_input = db_input # data base input + self.__instance_name = instance_name # name of the block instance + self.subtype = subtype # sub type of the block (e.g, gpp, or ip) + + # specs + self.peak_work_rate_distribution = peak_work_rate_distribution # peak work rate (work over time). + # The concept of work rate + # is different depending on the block. + # concretely, for PE work rate is IPC + # and for memory/buses its Bandwidth + self.hop_latency = hop_latency + self.pipe_line_depth = pipe_line_depth + self.clock_freq = clock_freq + self.bus_width = bus_width + self.loop_itr_cnt = loop_itr_cnt + self.loop_max_possible_itr_cnt = loop_max_possible_itr_cnt + self.work_over_energy_distribution = work_over_energy_distribution # work over energy + self.work_over_area_distribution = work_over_area_distribution # work over area + self.one_over_area_distribution = one_over_area_distribution + self.set_rates(hw_sampling) # set the above rates based on the hardware sampling. Note that + # the above rate can vary if they are a distribution rather than + # one value + self.leakage_power = leakage_power + + # power knobs of the block + self.power_knobs = power_knobs + # Indices of power knobs that can be used for the turbo mode (increase performance) + self.turbo_power_knob_indices = [] + # Indices of power knobs that can be used for the slow-down mode (decrease performance) + self.slow_down_power_knob_indices = [] + # divide power knobs to turbo and slow down ones + #self.categorize_power_knobs() + + self.pipes = [] # these are the queues that connect different block + self.pipe_clusters = [] + self.type = type # type of the block (i.e., pe, mem, ic) + self.neighs: List[Block] = [] # neighbours, i.e., the connected blocks + self.__task_name_dir_list = [] # tasks on the block + + # task_dir is the tuple specifying which task is making a request to access memory and in which direction. + self.__tasks_dir_work_ratio: Dict[(Task, str): float] = {} # tasks on the block and their work ratio. + self.id = Block.id_counter + Block.id_counter += 1 + if Block.id_counter in Block.block_numbers_seen: + raise Exception("can not have two blocks with id:" + str(Block.id_counter)) + Block.block_numbers_seen.append(Block.id_counter) + self.area = 0 + self.SOC_type = SOC_type + self.SOC_id = SOC_id + + self.PA_prop_dict = {} # Props used for PA (platform architect) design generation. + self.PA_prop_auto_tuning_list = [] # list of variables to be auto tuned + + self.area_list = [0] # list of areas for different task calls + self.area_in_bytes_list = [0] # list of areas for different task calls + # only for memory + self.area_task_dir_list = [] + self.task_mem_map_dict = {} # memory map associated with different tasks for memory + self.system_bus_behavior = False # if true, the block is system bus + + def set_system_bus_behavior(self, qualifier): + system_bus_behavior = qualifier + + def is_system_bus(self): + return self.system_bus_behavior + + def get_block_bus_width(self): + return self.bus_width + + def get_loop_itr_cnt(self): + return self.loop_itr_cnt + + def get_loop_max_possible_itr_cnt(self): + return self.loop_max_possible_itr_cnt + + + def get_block_freq(self): + return self.clock_freq + + def get_hop_latency(self): + return self.hop_latency + + def get_pipe_line_depth(self): + return self.pipe_line_depth + + # --------------- + # Functionality: + # Return peak_work_rate. Note that work-rate definition varies based on the + # hardware type. work-rate for PE is IPS, where as work-rate for bus/mem means BW. + # --------------- + def get_peak_work_rate(self, pk_id=0): + if (pk_id == 0) or (self.type != "pe"): + return self.peak_work_rate + else: + (perf_change, dyn_power_change, leakage_power_change) = self.power_knobs[pk_id - 1] + return self.peak_work_rate * perf_change + + # --------------- + # Functionality: + # Return work_over_energy. Note that work definition varies based on the + # hardware type. work for PE is instruction, where as work for bus/mem means bytes. + # --------------- + def get_work_over_energy(self, pk_id=0): + if (pk_id == 0) or (self.type != "pe"): + return self.work_over_energy + else: + (perf_change, dyn_power_change, leakage_power_change) = self.power_knobs[pk_id - 1] + # work over energy is instruction per Joules + return (perf_change / dyn_power_change) * self.work_over_energy + + # --------------- + # Functionality: + # Return work_over_power. Note that work definition varies based on the + # hardware type. work for PE is instruction, where as work for bus/mem means bytes. + # --------------- + def get_work_over_power(self, pk_id=0): + # TODO fix this to actually use work_over_power + return self.get_work_over_energy(pk_id) + + # --------------- + # Functionality: + # Return work_over_area. Note that work definition varies based on the + # hardware type. work for PE is instruction, where as work for bus/mem means bytes. + # --------------- + def get_work_over_area(self, pk_id=0): + if (pk_id == 0) or (self.type != "pe"): + return self.work_over_area + else: + #print("Returning DVFS value") + (perf_change, dyn_power_change, leakage_power_change) = self.power_knobs[pk_id - 1] + # Work over area change; area is constant so only amount of work done is changed + return perf_change * self.work_over_area + + # Return the private variable leakage energy + # if a power_knob is used then return the leakage energy associated with that knob + def get_leakage_power(self, pk_id=0): + + if (pk_id == 0) or (self.type != "pe"): + return self.leakage_power + else: + (perf_change, dyn_power_change, leakage_power_change) = self.power_knobs[pk_id - 1] + # leakage power change + return leakage_power_change * self.leakage_power + + def get_peak_work_rate_distribution(self): + return self.peak_work_rate_distribution + + def get_work_over_energy_distribution(self): + return self.work_over_energy_distribution + + def get_work_over_area_distribution(self): + return self.work_over_area_distribution + + def get_one_over_area_distribution(self): + return self.one_over_area_distribution + + # get average of the distribution + def get_avg(self, distribution_dict): + avg = 0 + for key, value in distribution_dict.items(): + avg += key*value + return avg + + # --------------- + # Functionality: + # set rates, i.e., work_rate, work_over_energy, work_over_area + # Note that work definition varies based on the + # hardware type. work for PE is instruction, where as work for bus/mem means bytes. + # --------------- + def set_rates(self, hw_sampling): + mode = hw_sampling["reduction"] + accuracy_percentage = hw_sampling["accuracy_percentage"][self.subtype] + if mode in ["random", "most_likely", "min", "max", "most_likely_with_accuracy_percentage"]: + if mode == "random": + time.sleep(.005) + np.random.seed(datetime.now().microsecond) + # sample the peak_work_rate + work_rates = [work_rate for work_rate, work_rate_prob in self.get_peak_work_rate_distribution().items()] + work_rate_probs = [work_rate_prob for work_rate, work_rate_prob in self.get_peak_work_rate_distribution().items()] + if not(sum(work_rate_probs) == 1): + print("break point") + work_rate_selected = np.random.choice(work_rates, p=work_rate_probs) + # get the index and use it to collect other values (since there is a one to one correspondence) + work_rate_idx = list(self.get_peak_work_rate_distribution().keys()).index(work_rate_selected) + elif mode == "most_likely" or mode == "most_likely_with_accuracy_percentage": # used when we don't want to do statistical analysis + # sample the peak_work_rate + work_rates_sorted = collections.OrderedDict(sorted(self.get_peak_work_rate_distribution().items(), key=lambda kv: kv[1])) + work_rate_selected = list(work_rates_sorted.keys())[-1] + work_rate_idx = list(self.get_peak_work_rate_distribution().keys()).index(work_rate_selected) + elif mode in ["min", "max"]: # used when we don't want to do statistical analysis + # sort based on the value + work_rates_sorted = collections.OrderedDict(sorted(self.get_peak_work_rate_distribution().items(), key=lambda kv: kv[0])) + if mode == "min": # worse case design, hence smallest workrate + work_rate_selected = list(work_rates_sorted.keys())[0] + elif mode == 'max': # best case design + work_rate_selected = list(work_rates_sorted.keys())[-1] + work_rate_idx = list(self.get_peak_work_rate_distribution().keys()).index(work_rate_selected) + + self.peak_work_rate = list(self.get_peak_work_rate_distribution().keys())[work_rate_idx] + self.work_over_energy = list(self.get_work_over_energy_distribution().keys())[work_rate_idx] + self.one_over_power = self.work_over_energy*(1/self.peak_work_rate) + try: + self.work_over_area = list(self.get_work_over_area_distribution().keys())[work_rate_idx] + self.one_over_area = list(self.get_one_over_area_distribution().keys())[work_rate_idx] + except: + print("what ") + exit(0) + elif mode == "avg": + self.peak_work_rate = self.get_avg(self.get_peak_work_rate_distribution()) + self.work_over_energy = self.get_avg(self.get_work_over_energy_distribution()) + self.work_over_area = self.get_avg(self.get_work_over_area_distribution()) + self.one_over_area = self.get_avg(self.get_one_over_area_distribution()) + self.one_over_power = self.work_over_energy*(1/self.peak_work_rate) + else: + print("mode" + mode + " is not supported for block sampling") + exit(0) + + if mode == "most_likely_with_accuracy_percentage": + # use the error in + self.peak_work_rate *= accuracy_percentage['latency'] + self.work_over_energy *= accuracy_percentage['energy'] + self.work_over_area *= accuracy_percentage['area'] + self.one_over_area *= accuracy_percentage['one_over_area'] + self.one_over_power *= accuracy_percentage['energy'] + + # ------------------------------------------- + # power-knobs related functions + # ------------------------------------------- + # Get all the power_knob configurations available + # each power_knob : ([(performance_change(ips/Bps), active_power_change(ipj/Bpj), leakage_power_change)]) + # for example (1.2,1.2,1.2) shows that by 20% performance increase, dynamic and leakage powers increase by 20% + def get_power_knob_tuples(self): + return self.power_knobs + + # Looks into power knobs and adds the one giving performance improvement > 1 to turbo mode ones + # and adds the ones with perf improvement <= 1 to slow down ones + # Please note that it does add only the indices of the power_knob to either list + def categorize_power_knobs(self): + for power_knob_idx, power_knob in enumerate(self.power_knobs): + if power_knob[0] > 1: + self.turbo_power_knob_indices.append(power_knob_idx) + elif power_knob[0] < 1: + self.slow_down_power_knob_indices.append(power_knob_idx) + else: + self.turbo_power_knob_indices.append(power_knob_idx) + self.slow_down_power_knob_indices.append(power_knob_idx) + + # return the turbo mode indices in the power knobs list + def get_turbo_power_knob_indices(self): + return self.turbo_power_knob_indices + + # return the slow down mode indices in the power knobs list + def get_slow_down_power_knob_indices(self): + return self.slow_down_power_knob_indices + + # A wrapper that gets the mode of power knob and send the list of corresponding indices + def get_power_knob_indices(self, mode="slow_down"): + if mode == "slow_down": + return self.get_slow_down_power_knob_indices() + elif mode == "turbo": + return self.get_turbo_power_knob_indices() + else: + raise Exception("Power knob mode {} is not available!".format(mode)) + + # ------------------------------------------- + # setters + # ------------------------------------------- + # ----------- + # Functionality: + # resetting Platform architect props + # ----------- + def reset_PA_props(self): + self.PA_prop_dict = collections.OrderedDict() + + # ----------- + # Functionality: + # updating Platform architect props + # ----------- + def update_PA_props(self, PA_prop_dict): + self.PA_prop_dict.update(PA_prop_dict) + + def update_PA_auto_tunning_knob_list(self, prop_auto_tuning_list): + self.PA_prop_auto_tuning_list = prop_auto_tuning_list + + # ----------- + # Functionality: + # assign an SOC for the block + # Variables: + # SOC_type: type of SOC + # SOC_id: id of the SOC + # ----------- + def set_SOC(self, SOC_type, SOC_id): + self.SOC_type = SOC_type + self.SOC_id = SOC_id + + # --------------------------- + # Functionality: + # get the static area (for the blocks that are not sized on the run time, such as cores and NoCs. Note that memory would be size + # on the run time. + # convention is that work_over_area is 1/fix_area if you are statically sized + # -------------------------- + def get_static_size(self): + return float(1)/self.work_over_area + + # --------------------------- + # Functionality: + # get the area in terms of byte. only for memory. Then we can convert it to mm^2 easily + # -------------------------- + def get_area_in_bytes(self): + if not self.type == "mem": + print("area in byte for non-memory blocks are not defined") + return 0 + else: + #area = self.get_area() + #return math.ceil(area*self.work_over_area) + max_work = max(self.area_in_bytes_list) + return math.ceil(max_work) + #return area*self.db_input.misc_data["ref_mem_work_over_area"] + + # return the number of banks + def get_num_of_banks(self): + if self.type == "mem": + #max_work = max(self.area_list)*self.db_input.misc_data["ref_mem_work_over_area"] + max_work = max(self.area_in_bytes_list) + num_of_banks = math.ceil(max_work/self.db_input.misc_data["memory_block_size"]) + return num_of_banks + else: + print("asking number of banks for non-memory blocks does not make sense") + exit(0) + + """ + # get capacity of memories + def get_capacity(self): + if self.type == "mem": + max_work = max(self.area_list)*self.db_input.misc_data["ref_mem_work_over_area"] + max_work_rounded = math.ceil(max_work/self.db_input.memory_block_size)*self.db_input.memory_block_size + return max_work_rounded + else: + print("asking capacity of non-memory blocks does not make sense") + exit(0) + """ + + def set_area_directly(self, area): + self.area_list = [area] + self.area = area + + # used with cacti + def update_area_energy_power_rate(self, energy_per_byte, area_per_byte): + if self.type not in ["mem", "ic"]: + print("should not be updateing the block values") + exit(0) + self.work_over_area = 1/area_per_byte + self.work_over_energy = 1/max(energy_per_byte,.0000000001) + self.one_over_power = self.work_over_energy/self.peak_work_rate + self.one_over_area = self.work_over_area + + # --------------------------- + # Functionality: + # get the area associated with a block. + # -------------------------- + def get_area(self): + if self.type == "mem": + if not config.use_cacti: # if using cacti, the area is already calculated (note that at the moment, area list doesn't matter for cacti) + max_work = max(self.area_in_bytes_list) + max_work_rounded = math.ceil(max_work/self.db_input.misc_data["memory_block_size"])*self.db_input.misc_data["memory_block_size"] + self.area = max_work_rounded/self.work_over_area + else: + self.area = max(self.area_list) + return self.area + + def get_leakage_power_calculated_after(self): + if self.type == "mem": + max_work = max(self.area_list)*self.db_input.ref_mem_work_over_area + max_work_rounded = math.ceil(max_work/self.db_input.memory_block_size)*self.db_input.memory_block_size + self.leakage_power_calculated_after = max_work_rounded/self.db_input.ref_work_over_leakage + else: + self.leakage_power_calculated_after = 0 + return self.leakage_power_calculated_after + + + def set_instance_name(self, name): + self.__instance_name = name + + @property + def instance_name_without_id(self): + return self.__instance_name + + + @property + def instance_name(self): + return self.__instance_name + "_"+ str(self.id) + + @instance_name.setter + def instance_name(self, instance_name): + self.__instance_name = instance_name + + # --------------------------- + # Functionality: + # Return whether a block only contains a dummy task (i.e., source and sink tasks) + # Used for prunning the block later. + # -------------------------- + def only_dummy_tasks(self): + num_of_tasks = len(self.get_tasks_of_block()) + if num_of_tasks == 2: + a = [task.name for task in self.get_tasks_of_block()] + if any("souurce" in task.name for task in self.get_tasks_of_block()) and \ + any("siink" in task.name for task in self.get_tasks_of_block()): + return True + elif num_of_tasks == 1: + a = [task.name for task in self.get_tasks_of_block()] + if any("souurce" in task.name for task in self.get_tasks_of_block()) or \ + any("siink" in task.name for task in self.get_tasks_of_block()): + return True + else: + return False + + def update_work(self, work): + self.work_list + + # --------------------------- + # Functionality: + # updating the area when more tasks are assigned to a memory block. Only used for memory. + # Variables: + # area: current area + # task_requesting: the task is that is using the memory + # -------------------------- + def update_area(self, area, task_requesting): # note that bytes can be negative (if reading from the block) or positive + if self.type == "mem": + if area < 0: dir = 'read' + else: dir = 'write' + if (self.area_list[-1]+area) < -1*self.db_input.misc_data["area_error_margin"] and not config.use_cacti: + raise Exception("memory size can not go bellow the error margin") + #if dir == "write": + # self.task_mem_map_dict[task_requesting] = (hex(int(self.area_list[-1])), hex(int(self.area_list[-1]+area))) + self.area_list.append(self.area_list[-1]+area) + self.area_task_dir_list.append((task_requesting, dir)) + else: + if self.only_dummy_tasks(): + area = 0 + self.area_list.append(area) + + + def update_area_in_bytes(self, area_in_byte, task_requesting): # note that bytes can be negative (if reading from the block) or positive + if self.type == "mem": + if area_in_byte < 0: dir = 'read' + else: dir = 'write' + if (self.area_in_bytes_list[-1]+area_in_byte) < -1*self.db_input.misc_data["area_error_margin"]: + raise Exception("memory size can not go bellow the error margin") + #if dir == "write": + # self.task_mem_map_dict[task_requesting] = (hex(int(self.area_in_bytes_list[-1])), hex(int(self.area_in_bytes_list[-1]+area_in_byte))) + self.area_in_bytes_list.append(self.area_in_bytes_list[-1]+area_in_byte) + else: + if self.only_dummy_tasks(): + area = 0 + self.area_in_bytes_list.append(area_in_byte) + + + def __deepcopy__(self, memo): + #Block.id_counter = -1 + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + setattr(result, k, copy.deepcopy(v, memo)) + #Block.id_counter += 1 + result.area_list = [0] + result.area_in_bytes_list = [0] + result.area_task_dir_list = [] + result.task_mem_map_dict = {} + return result + + # each hardware block has a generic name, so it can be identified + def get_generic_instance_name(self): + return self.__instance_name + + # get the type name: {pe, mem, ic} + def get_block_type_name(self): + return self.type + # ------ + # comparisons + #------ + def __lt__(self, other): + return self.peak_work_rate < other.peak_work_rate + + # --------------------------- + # Functionality: + # get all task's parents/children and the work ratio between them. + # "dir" determines whether the family task is parent or child. + # "work_ratio" is the ratio between the work of the family member and the task itself. + # -------------------------- + def get_tasks_dir_work_ratio_for_printing(self): + temp_dict = [] + for task_dir, work_ratio in self.__tasks_dir_work_ratio.items(): + print((task_dir[0].name, task_dir[1]), work_ratio) + + # --------------------------- + # Functionality: + # get all task's parents/children and the work ratio between them. + # "dir" determines whether the family task is parent or child. + # "work_ratio" is the ratio between the work of the family member and the task itself. + # -------------------------- + def get_tasks_dir_work_ratio(self): + return self.__tasks_dir_work_ratio + + # --------------------------- + # Functionality: + # return the task if the name matches + # -------------------------- + def get_tasks_by_name(self, task_name): + tasks = self.get_tasks_of_block() + for task in tasks: + if (task.name == task_name): + return True, task + print("erroring out for block" + self.instance_name) + self.get_tasks_dir_work_ratio_for_printing() + print("task with the name of " + task_name + "is not loaded on block" + self.instance_name) + return False, "_" + + # --------------------------- + # Functionality: + # update the work_ratios. Used in jitter modeling, when a new task from the distribution is generated. + # -------------------------- + def update_all_tasks_dir_work_ratio(self): + for task_dir, family_tasks_name in self.__tasks_dir_work_ratio.items(): + task, dir = task_dir + for family_task_name in family_tasks_name.keys(): + work_ratio = task.get_work_ratio_by_family_task_name(family_task_name) + self.__tasks_dir_work_ratio[task_dir][family_task_name] = work_ratio # where work ratio is set + + self.__task_name_dir_list = [] + for task_dir, family_tasks_name in self.__tasks_dir_work_ratio.items(): + task, dir = task_dir + for family_task_name in family_tasks_name.keys(): + self.__task_name_dir_list.append((task.name, dir)) + + # the following copule of lines for debugging + if not len(self.__task_name_dir_list) == sum([len(el.values()) for el in self.__tasks_dir_work_ratio.values()]): + blah = len(self.__task_name_dir_list) + blah2 = sum([len(el.values()) for el in self.__tasks_dir_work_ratio.values()]) + print("this shoud not happend, so please debug") + + # get the fronts that will be/has been scheduled on the block + # Note that front is a stream of data/instruction. + # for PEs, it's just a bunch of instructions sliced up to chunks + # for memories and buses, it's a stream of data that needs to move from one block (bus, memory) to + # another block (bus, memory). This class is not used for the moment + def get_fronts(self, mode="task_name_dir"): + if mode == "task_name_dir": + return self.__task_name_dir_list + elif mode == 'task_dir_work_ratio': + result = [] + for el in self.__tasks_dir_work_ratio.values(): + result.extend(el.keys()) + return result + else: + print("mode:" + mode + " not supported") + exit(0) + # --------------------------- + # Functionality: + # get the task_dir + # -------------------------- + def get_task_dirs_of_block(self): + tasks_with_possible_duplicates = [task_dir for task_dir in self.__tasks_dir_work_ratio.keys()] + # the duplicates happens cause we can have one task that writes and reads into that block + return list(set(tasks_with_possible_duplicates)) + + # --------------------------- + # Functionality: + # get the tasks of a block + # -------------------------- + def get_tasks_of_block(self): + tasks_with_possible_duplicates = [task_dir[0] for task_dir in self.__tasks_dir_work_ratio.keys()] + #if len(tasks_with_possible_duplicates) == 0: + # print("what") + # the duplicates happens cause we can have one task that writes and reads into that block + results = list(set(tasks_with_possible_duplicates)) + return results + + def get_tasks_of_block_by_dir(self, dir_): + tasks_with_possible_duplicates = [task_dir[0] for task_dir in self.__tasks_dir_work_ratio.keys() if task_dir[1] == dir_] + #if len(tasks_with_possible_duplicates) == 0: + # print("what") + # the duplicates happens cause we can have one task that writes and reads into that block + results = list(set(tasks_with_possible_duplicates)) + return results + + + # --------------------------- + # Functionality: + # get the tasks work ratio by the family task. + # Variables: + # task: the family task + # -------------------------- + def get_task_s_work_ratio_by_task(self, task): + blah = self.__tasks_dir_work_ratio + work_ratio = [work_ratio for task_dir, work_ratio in self.__tasks_dir_work_ratio.items() if task_dir[0] == task] + assert(len(work_ratio) < 3), ("task:" + str(task.name) + " can only have one or two (i.e., read and/or write) ratio") + return work_ratio + + # --------------------------- + # Functionality: + # get the task to task work. + # Variables: + # task: the family task + # dir: the direction (read/write) of the family task. + # -------------------------- + def get_task_s_family_by_task_and_dir(self, task, dir): + res = [] + task_dir__work_ratio_up = [(task_dir, work_ratio) for task_dir, work_ratio in self.__tasks_dir_work_ratio.items() if task_dir[0] == task] + for task_dir, work_ratio in task_dir__work_ratio_up: + dir_ = task_dir[1] + if dir_ == dir: + for family_task_ratio in work_ratio: + family_task = family_task_ratio + res.append(family_task) + + return res + + # --------------------------- + # Functionality: + # get the task to task work ratio. + # Variables: + # task: the family task + # dir: the direction (read/write) of the family task. + # -------------------------- + def get_task_s_work_ratio_by_task_and_dir(self, task, dir): + res = [] + task_dir__work_ratio_tup = [(task_dir, work_ratio) for task_dir, work_ratio in self.__tasks_dir_work_ratio.items() if task_dir[0] == task] + for task_dir, work_ratio in task_dir__work_ratio_tup: + dir_ = task_dir[1] + if dir_ == dir: + #for task_, work_ratio_val in work_ratio.items() + res.append(work_ratio) + if (len(res) > 1): + raise Exception('can not happen. Delete this later. for debugging now') + return res[0] + + # --------------------------- + # Functionality: + # get the task relationship + # Variables: + # task: the family task + # -------------------------- + def get_task_dir_by_task_name(self, task): + task_dir = [task_dir for task_dir in self.__tasks_dir_work_ratio.keys() if task_dir[0].name == task.name] + assert(len(task_dir) < 3), ("task:" + str(task.name) + " can only have one or two (i.e., read and/or write) ratio") + return task_dir + + # --------------------------- + # Functionality: + # connecting a block to another block. + # Variables: + # neigh: neighbour, the block to connect to. + # -------------------------- + def connect(self, neigh): + if neigh not in self.neighs: + neigh.neighs.append(self) + self.neighs.append(neigh) + + # --------------------------- + # Functionality: + # disconnecting a block from another block. + # Variables: + # neigh: neighbour, the block to disconnect from. + # -------------------------- + def disconnect(self, neigh): + if neigh not in self.neighs: + return + self.neighs.remove(neigh) + neigh.neighs.remove(self) + + # --------------------------- + # Functionality: + # load (map) a task on a block. + # Variables: + # task: task to load (map) + # family_task: family task is a task that write/read to another task + # -------------------------- + def load_improved(self, task, family_task): + + # the following copule of lines for debugging + if not len(self.__task_name_dir_list) == sum([len(el.values()) for el in self.__tasks_dir_work_ratio.values()]): + blah = len(self.__task_name_dir_list) + blah2 = sum([len(el.values()) for el in self.__tasks_dir_work_ratio.values()]) + print("this shoud not happend, so please debug") + + # determine relationship between the tasks + relationship = task.get_relationship(family_task) + if relationship == "parent": dir= "read" + elif relationship == "child": dir = "write" + elif relationship == "self": dir = "loop_back" + else: + print("relationship between " + task.name + " and " + family_task.name + " is " + relationship) + exit(0) + task_dir = (task, dir) + work_ratio = task.get_work_ratio(family_task) + if task_dir in self.__tasks_dir_work_ratio.keys(): + if not family_task.name in self.__tasks_dir_work_ratio[task_dir].keys(): + self.__task_name_dir_list.append((task.name, dir)) + self.__tasks_dir_work_ratio[task_dir][family_task.name] = work_ratio# where work ratio is set + else: + self.__tasks_dir_work_ratio[task_dir] = {} + self.__tasks_dir_work_ratio[task_dir][family_task.name] = work_ratio # where work ratio is set + self.__task_name_dir_list.append((task.name, dir)) + + # the following copule of lines for debugging + if not len(self.__task_name_dir_list) == sum([len(el.values()) for el in self.__tasks_dir_work_ratio.values()]): + blah = len(self.__task_name_dir_list) + blah2 = sum([len(el.values()) for el in self.__tasks_dir_work_ratio.values()]) + print("this should not happen, so please debug") + + # add a pipe for the block. + def set_pipe(self, pipe_): + self.pipes.append(pipe_) + + def set_pipe_cluster(self, cluster): + self.pipe_clusters.append(cluster) + + def reset_pipes(self): + self.pipes = [] + + def reset_clusters(self): + self.pipe_clusters = [] + + def get_pipe_clusters(self, ): + return self.pipe_clusters + + def get_pipe_clusters_of_task(self, task): + clusters = [] + for pipe_cluster in self.pipe_clusters: + if pipe_cluster.is_task_present(task): + clusters.append(pipe_cluster) + return clusters + + + def get_pipes(self, channel_name): + block = self + result = [] + if channel_name == "same": + dirs = ["read", "write"] + else: + dirs = [channel_name] + + for dir_ in dirs: + pipes_of_channel = list(filter(lambda pipe: pipe.dir == dir_, self.pipes)) + for pipe in pipes_of_channel: + if block == pipe.get_slave(): + result.append(pipe) + + #if len(result) ==0: + # print("what") + #assert(len(result)>0) + return result + + # --------------------------- + # Functionality: + # get the family tasks of a task by the direction + # Variables: + # task_dir: task_dir is read / write corresponding to parent / child + # respectively) + # -------------------------- + def get_tasks_by_direction(self, task_dir): + assert (task_dir in ["read", "write"]) + return [task_dir for task_dir in self.__tasks_dir_work_ratio if task_dir[1] == task_dir] + + # --------------------------- + # Functionality: + # unload all the tasks that read from the block + # -------------------------- + def unload_read(self): + change = True + delete = [task_dir for task_dir in self.__tasks_dir_work_ratio if task_dir[1] == "read"] + for el in delete: del self.__tasks_dir_work_ratio[el] + + list_delete = [task_dir for task_dir in self.__task_name_dir_list if task_dir[1] == "read"] + for el in list_delete: self.__task_name_dir_list.remove(el) + + # the following copule of lines are for debugging. get rid of it + if not len(self.__task_name_dir_list) == sum([len(el.values()) for el in self.__tasks_dir_work_ratio.values()]): + blah= len(self.__task_name_dir_list) + blah2=sum([len(el.values()) for el in self.__tasks_dir_work_ratio.values()]) + print("this should not happen, so please debug") + + # --------------------------- + # Functionality: + # unload the task from the block with a certain direction (i.e., read/write) + # Variables: + # task_dir + # -------------------------- + def unload(self, task_dir): + task, dir = task_dir + if not (task.name, dir) in self.__task_name_dir_list: + print("what") + + while (task.name, dir) in self.__task_name_dir_list: + self.__task_name_dir_list.remove((task.name, dir)) + del self.__tasks_dir_work_ratio[task_dir] + + # the following couple of lines are for debugging. get rid of it + if not len(self.__task_name_dir_list) == sum([len(el.values()) for el in self.__tasks_dir_work_ratio.values()]): + blah = len(self.__task_name_dir_list) + blah2 = sum([len(el.values()) for el in self.__tasks_dir_work_ratio.values()]) + print("this shoud not happend, so please debug") + + # --------------------------- + # Functionality: + # unload all the tasks. + # -------------------------- + def unload_all(self): + self.__task_name_dir_list = [] + self.__tasks_dir_work_ratio = {} + + # --------------------------- + # Functionality: + # disconnect the block from all the blocks. + # -------------------------- + def disconnect_all(self): + neighs = self.neighs[:] + for neigh in neighs: + self.disconnect(neigh) + + # --------------------------- + # Functionality: + # get all the neighbouring blocks (i.e, all the blocks that are connected to a block) + # -------------------------- + def get_neighs(self): + return self.neighs + + def get_metric(self, metric): + if metric == "latency": + return self.get_peak_work_rate() + if metric == "energy": + return self.get_work_over_energy() + if metric == "power": + return self.get_work_over_power() + if metric == "area": + return self.get_work_over_area() + + +# traffic is a stream that moves between two pipes (from one task to another) +class traffic: + def __init__(self, parent, child, dir, work): + self.parent = parent # parent task + self.child = child # child task + self.work = work # work + self.dir = dir # direction (read/write/loop) + + + def get_child(self): + return self.child + + def get_parent(self): + return self.parent + + + +# physical channels inside the the router +class PipeCluster: + def __init__(self, ref_block, dir, outgoing_pipe, incoming_pipes, unique_name): + self.ref_block = ref_block + self.dir = dir + self.cluster_type = "regular" + self.pathlet_phase_work_rate = {} + self.pathlet_phase_latency = {} + if outgoing_pipe == None: + self.outgoing_pipe = None + else: + self.outgoing_pipe = outgoing_pipe + + if incoming_pipes == None: + self.incoming_pipes = [] + else: + self.incoming_pipes = incoming_pipes + self.unique_name = unique_name + + self.pathlets = [] + for in_pipe in incoming_pipes: + self.pathlets.append(pathlet(in_pipe, outgoing_pipe, self.dir)) + + def change_to_dummy(self, tasks): + self.cluster_type = "dummy" + self.set_dir("same") + self.dummy_tasks = tasks + + def set_dir(self, dir): + self.dir = dir + + def get_unique_name(self): + return self.unique_name + + # the block at the center of the cluster (with source (incoming) pipes and dest (outgoing) pipe) + def get_block_ref(self): + return self.ref_block + + def get_pathlets(self): + return self.pathlets + + # incoming pipes + def get_incoming_pipes(self): + assert not(self.cluster_type == "dummy") + return self.incoming_pipes + + def get_outgoing_pipe(self): + assert not(self.cluster_type == "dummy") + return self.outgoing_pipe + + def get_dir(self): + return self.dir + + def get_task_work_unit(self, task): + assert not(self.cluster_type == "dummy") + work_unit = 0 + for pipe in self.incoming_pipes: + work_unit += pipe.get_task_work_unit(task) + return work_unit + + def get_task_work(self, task): + assert not(self.cluster_type == "dummy") + work = 0 + for pipe in self.incoming_pipes: + work += pipe.get_task_work(task) + return work + + def is_task_present(self, task): + if self.get_block_ref().type == "pe": + return task.get_name() in [el.get_name() for el in self.dummy_tasks] + elif self.get_block_ref().type in ["ic"]: + return self.outgoing_pipe.is_task_present(task) + else: + return any([pipe.is_task_present(task ) for pipe in self.incoming_pipes]) + + def get_info(self): + incoming_pipes = [] + for el in self.incoming_pipes: + incoming_pipes.append(el.get_info()) + if self.outgoing_pipe == None: + outgoing = " None " + #outgoing_tasks = "None" + else: + outgoing = self.outgoing_pipe.get_info() + + + return "block:" + self.get_block_ref().instance_name + " incoming_pipes:" +str(incoming_pipes) + " outgoing_pipes:"+str(outgoing) + + # for path's within the pipe cluster set the work rate + def set_pathlet_phase_work_rate(self, pathlet, phase_num, work_rate): + in_pipe = pathlet.get_in_pipe() + out_pipe = pathlet.get_out_pipe() + if in_pipe not in self.incoming_pipes or out_pipe not in [self.outgoing_pipe]: + print("pipe should already exist") + exit(0) + else: + if pathlet not in self.pathlet_phase_work_rate.keys(): + self.pathlet_phase_work_rate[pathlet] = {} + if phase_num not in self.pathlet_phase_work_rate[pathlet].keys(): + self.pathlet_phase_work_rate[pathlet][phase_num] = 0 + self.pathlet_phase_work_rate[pathlet][phase_num] += work_rate + + + + def get_pathlet_phase_work_rate(self): + return self.pathlet_phase_work_rate + + def get_pathlet_last_phase_work_rate(self): + pathlet_last_phase_work_rate = {} + for pathlet, phase_work_rate in self.get_pathlet_phase_work_rate().items(): + last_phase = sorted(phase_work_rate.keys())[-1] + pathlet_last_phase_work_rate[pathlet] = phase_work_rate[last_phase] + return pathlet_last_phase_work_rate, last_phase + + def set_pathlet_latency(self, pathlet, phase_num, latency_dict): + if pathlet not in self.pathlet_phase_latency.keys(): + self.pathlet_phase_latency[pathlet] = {} + for latency_typ, trv_dir_val in latency_dict.items(): + for trv_dir, val in trv_dir_val.items(): + if trv_dir not in self.pathlet_phase_latency[pathlet].keys(): + self.pathlet_phase_latency[pathlet][trv_dir] = {} + if phase_num not in self.pathlet_phase_latency[pathlet][trv_dir].keys(): + self.pathlet_phase_latency[pathlet][trv_dir][phase_num] = 0 + self.pathlet_phase_latency[pathlet][trv_dir][phase_num] += val + return self.pathlet_phase_latency[pathlet] + + + def get_pathlet_phase_latency(self): + return self.pathlet_phase_latency + +class pathlet: + def __init__(self, in_pipe, out_pipe, dir_): + self.in_pipe = in_pipe + self.out_pipe = out_pipe + self.dir = dir_ + + def get_out_pipe(self): + return self.out_pipe + + def get_in_pipe(self): + return self.in_pipe + + def get_dir(self): + return self.dir + + +# A pipe is a queue that connects blocks +class pipe: + def __init__(self, master, slave, dir_, number, cmd_queue_size, data_queue_size): + self.traffics = [] # traffic on the pipe + self.dir = dir_ # direction of the traffic + self.master = master # master block + self.slave = slave # slave block + self.number = number # an assigned id + self.cmd_queue_size = cmd_queue_size + self.data_queue_size = data_queue_size + + def get_cmd_queue_size(self): + return self.cmd_queue_size + + def set_cmd_queue_size(self, size): + self.size = size + + def get_data_queue_size(self): + return self.data_queue_size + + def set_data_queue_size(self, size): + self.data_queue_size = size + + def get_dir(self): + return self.dir + + def get_master(self): + return self.master + + def get_slave(self): + return self.slave + + def update_traffic_and_dir(self, parent , child, work, dir): + if dir not in self.dir: + print("this should not happen. Same direction should be assigned all the time") + exit(0) + + self.traffics.append(traffic(parent, child, dir, work)) + + def get_task_work_unit(self, task): + traffics_ = [] + for traffic_ in self.traffics: + if traffic_.dir == "read" and task.name == traffic_.child.name: + traffics_.append(traffic_) + elif traffic_.dir == "write" and task.name == traffic_.parent.name: + traffics_.append(traffic_) + + work = 0 + for traffic_ in traffics_: + if traffic_.dir == "read": + work += task.get_self_to_family_task_work_unit(traffic_.parent) + elif traffic_.dir == "write": + work += task.get_self_to_family_task_work_unit(traffic_.child) + + return work + + def get_task_work(self, task): + traffics_ = [] + for traffic_ in self.traffics: + if traffic_.dir == "read" and task.name == traffic_.child.name: + traffics_.append(traffic_) + elif traffic_.dir == "write" and task.name == traffic_.parent.name: + traffics_.append(traffic_) + + work = 0 + for traffic_ in traffics_: + if traffic_.dir == "read": + work += task.get_self_to_family_task_work(traffic_.parent) + elif traffic_.dir == "write": + work += task.get_self_to_family_task_work(traffic_.child) + + return work + + # is task on the pipe + def is_task_present(self, task): + result = False + if "write" in self.dir: + result = result or (task.name in (set([el.parent.name for el in self.traffics]))) + if "read" in self.dir: + result = result or (task.name in (set([el.child.name for el in self.traffics]))) + + return result + + def get_tasks(self): + if "write" in self.dir: + results = set([el.parent.name for el in self.traffics]) + if "read" in self.dir: + results = set([el.child.name for el in self.traffics]) + return results + + def get_traffic(self): + return self.traffics + + + def get_traffic_names(self): + result = [] + for el in self.traffics: + result.append(el.parent.name+ " " + el.child.name+ " - ") + return result + + def get_info(self): + return "m:"+self.master.instance_name + " s:"+self.slave.instance_name + +# This class is the graph with nodes denoting the blocks and the edges denoting +# the relationship (parent/child) between the nodes. +# generation_mode either tool_generated or user_generated. Different checks are applied +# for different scenarios. +class HardwareGraph: + def __init__(self, block_to_prime_with:Block, generation_mode="tool_generated"): + self.pipes = [] + self.pipe_clusters = [] + self.last_pipe_assigned_number = 0 # this is used for setting up the pipes. + self.last_cluster_assigned_number = 0 + self.blocks = self.traverse_neighs_recursively(block_to_prime_with, []) # all the blocks in the graph + self.task_graph = TaskGraph(self.get_all_tasks()) # the task graph + self.config_code = str(-1) # used to differentiat between systems (focuses on type/number of elements) + self.SOC_design_code= "" # used to differentiat between systems (focuses on type/number of elements) + self.simplified_topology_code = str(-1) # used to differential between systems (focuses on number of elements) + self.pipe_design() # set up the pipes + self.generation_mode = generation_mode + + + + def get_blocks(self): + return self.blocks + + # get a pipe, given it's master and slave + def get_pipe_with_master_slave(self, master_, slave_, dir_): + for pipe in self.pipes: + if pipe.get_master() == master_ and pipe.get_slave() == slave_ and dir_ in pipe.get_dir(): + return pipe + + print("this pipe was not found. something wrong") + master_to_slave_path = self.get_path_between_two_vertecies(master_, slave_) + exit(0) + + # do sanity check on the pipe + def pipe_is_sane(self, pipe_): + if pipe_.master.type == "mem": + return False + if pipe_.slave.type == "pe": + return False + return True + + # traverse the hardware graph and assign the pipes between them + def traverse_and_assign_pipes(self, block, blocks_visited): + if block in blocks_visited: + return None + blocks_visited.append(block) + for neigh in block.neighs: + dirs = [["read"], ["write"]] + for dir_ in dirs: + pipe_ = pipe(block, neigh, dir_, self.last_pipe_assigned_number) + if self.pipe_is_sane(pipe_): + self.last_pipe_assigned_number += 1 + block.set_pipe(pipe_) + neigh.set_pipe(pipe_) + self.pipes.append(pipe_) + self.traverse_and_assign_pipes(neigh, blocks_visited) + return blocks_visited + + # --------------------------- + # Functionality: + # traverse all the neighbours of a block recursively. + # Variables: + # block: block to get neighbours for + # blocks_visited: blocks already visited in the depth first search. Prevent double counting certain blocks. + # -------------------------- + def traverse_neighs_recursively(self, block, blocks_visited): # depth first search + if block in blocks_visited: + return None + blocks_visited.append(block) + for neigh in block.neighs: + self.traverse_neighs_recursively(neigh, blocks_visited) + return blocks_visited + + # a node is unnecessary when no task lives on it. + # this happens when we apply moves (e.g., migrate tasks around) + def prune_unnecessary_nodes(self, block_to_prime_with): # depth first search + blocks = self.traverse_neighs_recursively(block_to_prime_with, []) + for block in blocks: + if block.type == "ic": + connectd_pes = [block_ for block_ in block.get_neighs() if block_.type == "pe"] + connectd_mems = [block_ for block_ in block.get_neighs() if block_.type == "mem"] + connected_ics = [block_ for block_ in block.get_neighs() if block_.type == "ic"] + #only_one_pe_one_ic = len(connectd_mems) == 0 and len(connectd_pes) == 1 and len(connected_ics) == 1 + no_mem_one_ic = len(connectd_mems) == 0 and len(connected_ics) == 1 + no_pe_one_ic_no_system_bus = len(connectd_pes) == 0 and len(connected_ics) == 1 and (True or not block.is_system_ic()) + no_pe_no_mem = (len(connectd_pes) == 0 and len(connectd_mems) == 0) + + if no_mem_one_ic: + ic_to_connect_to = connected_ics[0] + for pe in connectd_pes: + pe.connect(ic_to_connect_to) + pe.disconnect(block) + elif no_pe_one_ic_no_system_bus: + ic_to_connect_to = connected_ics[0] + for mem in connectd_mems: + mem.connect(ic_to_connect_to) + mem.disconnect(block) + # if either of above true, take care of ics as well + if no_mem_one_ic or no_pe_one_ic_no_system_bus or no_pe_no_mem: + if len(connected_ics) == 1: + block.disconnect(connected_ics[0]) + else: + ic_to_connect_to = connected_ics[0] + block.disconnect(connected_ics[0]) + for ic in connected_ics[1:]: + ic.connect(ic_to_connect_to) + block.disconnect(ic) + + # --------------------------- + # Functionality: + # sample a task (from the task distribution) in the task graph. Used for jitter modeling. + # -------------------------- + def sample(self, hw_sampling): + # sample tasksz + for task_ in self.get_all_tasks(): + task_.update_task_work(task_.sample_self_task_work()) + for child_task in task_.get_children(): + task_.update_task_to_child_work(child_task, task_.sample_self_to_child_task_work(child_task)) + + # update blocks with the sampled tasks + blocks = self.traverse_neighs_recursively(self.get_root(), []) + for block in blocks: + block.update_all_tasks_dir_work_ratio() + + # sample blocks + for block in blocks: + if hw_sampling["mode"] in ["error_integration", "exact"]: # including the error + block.set_rates(hw_sampling) + else: + print("hw_sampling_mode" + hw_sampling["mode"] +" is not defined") + exit(0) + # --------------------------- + # Functionality: + # get all the tasks associated with the hardware graph (basically all the blocks tasks) + # -------------------------- + def get_all_tasks(self): + all_tasks = [] + for block in self.blocks: + all_tasks.extend(block.get_tasks_of_block()) + all_tasks = list(set(all_tasks)) # get rid of duplicates + return all_tasks + + def get_blocks_by_type(self, type_): + return [block for block in self.blocks if block.type == type_] + + # get all the blocks that hos the task + def get_blocks_of_task_by_name(self, task_name): + all_blocks = [] + for block in self.blocks: + task_names = [el.get_name() for el in block.get_tasks_of_block()] + if task_name in task_names: + all_blocks.append(block) + + return all_blocks + + + # get all the blocks that hos the task + def get_blocks_of_task(self, task): + all_blocks = [] + for block in self.blocks: + if task in block.get_tasks_of_block() : + all_blocks.append(block) + + return all_blocks + + # this is just a number that helps us encode a design topology/block type + def set_config_code(self): + pes = str(sorted(['_'.join(blck.instance_name.split("_")[:-1]) for blck in self.get_blocks_by_type("pe")])) + mems = str(sorted(['_'.join(blck.instance_name.split("_")[:-1]) for blck in self.get_blocks_by_type("mem")])) + ics = str(sorted(['_'.join(blck.instance_name.split("_")[:-1]) for blck in self.get_blocks_by_type("ic")])) + + self.config_code = str(len(self.get_blocks_by_type("ic"))) + "_" + \ + str(len(self.get_blocks_by_type("mem"))) + "_" + \ + str(len(self.get_blocks_by_type("pe"))) + self.config_code = pes+"__"+mems +"__" + ics + + # this code (value) uniquely specifies a design. + # Usage: prevent regenerating/reevaluting the design for example. + def set_SOC_design_code(self): + # sort based on the blk name and task names tuples + def sort_blk_tasks(blks): + blk_tasks = [] + for blk in blks: + blk_tasks_sorted = str([el.name for el in sorted(blk.get_tasks_of_block(), key=lambda x: x.name)]) + blk_name_stripped = '_'.join(blk.instance_name.split("_")[:-1]) + blk_tasks.append((blk, blk_tasks_sorted + "__tasks_one_blk_" + blk_name_stripped)) + blk_sorted = sorted(blk_tasks, key=lambda x: x[1]) + return blk_sorted + + self.SOC_design_code= "" # used to differentiat between systems (focuses on type/number of elements) + hg_string = "" + + # sort the PEs + pes_sorted = sort_blk_tasks(self.get_blocks_by_type("pe")) + # iterate through PEs + for pe, string_ in pes_sorted: + hg_string += string_ + # sort the neighbours + neighs_sorted = sort_blk_tasks(pe.get_neighs()) + # iterate through neighbours + for neigh, string_ in neighs_sorted: + hg_string += string_ + neigh_s_neighs_sorted = sort_blk_tasks(neigh.get_neighs()) + for neigh_s_neigh, string_ in neigh_s_neighs_sorted: + if not neigh_s_neigh.type == "mem": + continue + hg_string += string_ + + # task graph based id + TG = sorted([tsk.name for tsk in self.get_task_graph().get_all_tasks()]) + for tsk_name in TG: + task = self.get_task_graph().get_task_by_name(tsk_name) + blks_hosting_task = self.get_blocks_of_task(task) + blks_hosting_task_sorted = str(sorted(['_'.join(blk.instance_name.split("_")[:-1]) for blk in blks_hosting_task])) + self.SOC_design_code += tsk_name + "_" + blks_hosting_task_sorted + "___" + self.SOC_design_code += hg_string + + # this is just a number that helps us encode a design topology/block type + def get_SOC_design_code(self): + return self.SOC_design_code + + # just a string to specify the simplified_topology + def set_simplified_topology_code(self): + self.simplified_topology_code = str(len(self.get_blocks_by_type("ic"))) + "_" + \ + str(len(self.get_blocks_by_type("mem"))) + "_" + \ + str(len(self.get_blocks_by_type("pe"))) + + def get_simplified_topology_code(self): + if self.simplified_topology_code == "-1": + self.set_simplified_topology_code() + return self.simplified_topology_code + + def get_number_of_channels(self): + ics = self.get_blocks_by_type("ic") + total_number_channels = 0 + for blk in ics: + total_number_channels+= len(blk.get_pipe_clusters()) + + return total_number_channels + + + def get_routing_complexity(self): + pes = self.get_blocks_by_type("pe") + mems = self.get_blocks_by_type("mem") + ics = self.get_blocks_by_type("ic") + + # a measure of how hard it is to rout, + # which depends on how many different paths that can be taken between master and slaves + complexity = 0 + + for pe in pes: + for mem in mems: + all_paths = self.get_all_paths_between_two_vertecies(pe, mem) + complexity += len(all_paths) + + # normalized to the number of master slaves + complexity = complexity/(len(pes)*len(mems)) + return complexity + + def get_config_code(self): + if self.config_code == "-1": + self.set_config_code() + return self.config_code + + # --------------------------- + # Functionality: + # update the graph without prunning. No pruning policy allows for hardware graphs to be directly + # (without modifications) absorbed from the input when requeste + # -------------------------- + def update_graph_without_prunning(self, block_to_prime_with=None): + if not block_to_prime_with: + block_to_prime_with = self.get_root() + self.blocks = self.traverse_neighs_recursively(block_to_prime_with, []) + self.set_config_code() + self.set_SOC_design_code() + + # --------------------------- + # Functionality: + # update the graph. Used for jitter modeling after a new task was sampled from the task distribution. + # -------------------------- + def update_graph(self, block_to_prime_with=None): + if not block_to_prime_with: + block_to_prime_with = self.get_root() + elif block_to_prime_with not in self.get_blocks(): + for blck in self.get_blocks(): + if blck.instance_name == block_to_prime_with.instance_name: + block_to_prime_with = blck + break + self.prune_unnecessary_nodes(block_to_prime_with) + self.blocks = self.traverse_neighs_recursively(block_to_prime_with, []) + self.set_config_code() + self.set_SOC_design_code() + #self.assign_pipes() # rehuild pipes from scratch + # re assigning pipes + + # assign tasks to the pipes + def task_the_pipes(self, task, pipes, dir): + for pipe_ in pipes: + if dir == "read": + block = pipe_.get_slave() + for parent in task.get_parents(): # [par.name for par in task.get_parents()]: + parent_names = block.get_task_s_family_by_task_and_dir(task, "read") + if parent.name in parent_names: + work = parent.get_self_to_family_task_work(task) + pipe_.update_traffic_and_dir(parent, task, work, "read") + + elif dir == "write": + block = pipe_.get_slave() + for child in task.get_children(): + children_names = block.get_task_s_family_by_task_and_dir(task, "write") + if child.name in children_names: + work = task.get_self_to_family_task_work(child) + pipe_.update_traffic_and_dir(task, child, work, "write") + + def get_pipes_between_two_blocks(self, blck_1, blck_2, dir_): + pipes = [] + # get blocks along the way + master_to_slave_path = self.get_path_between_two_vertecies(blck_1, blck_2) + # get pipes along the way + for idx in range(0, len(master_to_slave_path) - 1): + block_master = master_to_slave_path[idx] + block_slave = master_to_slave_path[idx + 1] + pipes.append(self.get_pipe_with_master_slave(block_master, block_slave, dir_)) + return pipes + + def filter_empty_pipes(self): + empty_pipes = [] + for pipe in self.pipes: + if pipe.traffics == []: + empty_pipes.append(pipe) + + for pipe in empty_pipes: + self.pipes.remove(pipe) + + # assign task to pipes + def task_all_the_pipes(self): + def get_blocks_of_task(task): + blocks = [] + for block in self.blocks: + if task in block.get_tasks_of_block(): + blocks.append(block) + return blocks + + # assign tasks to pipes + self.task_pipes = {} + all_tasks = self.get_all_tasks() + for task in all_tasks: + pe = [block for block in get_blocks_of_task(task) if block.type == "pe" ][0] + mem_reads = [block for block in get_blocks_of_task(task) if block.type == "mem" and (task, "read") in block.get_tasks_dir_work_ratio().keys()] + mem_writes = [block for block in get_blocks_of_task(task) if block.type == "mem" and (task, "write") in block.get_tasks_dir_work_ratio().keys()] + + # get all the paths leading from mem reads to pe + seen_pipes = [] + for mem in mem_reads: + pipes = self.get_pipes_between_two_blocks(pe, mem, "read") + pipes_to_consider = [] + for pipe in pipes: + if pipe not in seen_pipes: + pipes_to_consider.append(pipe) + seen_pipes.append(pipe) + if len(pipes_to_consider) == 0: + continue + else: + self.task_the_pipes(task, pipes_to_consider, "read") + + # get all the paths leading from mem reads to pe + seen_pipes = [] + for mem in mem_writes: + pipes = self.get_pipes_between_two_blocks(pe, mem, "write") + pipes_to_consider = [] + for pipe in pipes: + if pipe not in seen_pipes: + pipes_to_consider.append(pipe) + seen_pipes.append(pipe) + if len(pipes_to_consider) == 0: + continue + else: + self.task_the_pipes(task, pipes_to_consider, "write") + + + def get_pipe_dir(self, block): + #if block.type == "pe": + # return ["same"] + #else: + return ["write", "read"] + + def get_pipe_clusters(self): + return self.pipe_clusters + + def cluster_pipes(self): + self.pipe_clusters = [] + for block in self.blocks: + block.reset_clusters() + + def traffic_overlaps(blck_pipe, neigh_pipe): + blck_pipe_traffic = blck_pipe.get_traffic_names() + neigh_pipe_traffic = neigh_pipe.get_traffic_names() + traffic_non_overlap = list(set(blck_pipe_traffic) - set(neigh_pipe_traffic)) + return len(traffic_non_overlap) < len(blck_pipe_traffic) + + if (len(self.pipes) == 0): + print("something is wrong") + + assert(len(self.pipes) > 0), "you need to assign pipes first" + pipe_cluster_dict = {} + + # iterate through blocks and neighbours to generate clusters + for block in self.blocks: + pipe_cluster_dict[block] = {} + for neigh in block.get_neighs(): + if neigh.type == "pe": continue + for dir in self.get_pipe_dir(block): + if dir not in pipe_cluster_dict[block]: pipe_cluster_dict[block][dir] = {} + + # get the pipes + block_pipes = block.get_pipes(dir) + neigh_pipes = neigh.get_pipes(dir) + + if block_pipes == [] and block.type == "pe": + pipe_cluster_dict[block][dir] = {} + for neigh_pipe in neigh_pipes: + if neigh_pipe.master == block and neigh_pipe.dir == dir: + pipe_cluster_dict[block][dir][neigh_pipe] = [] + elif block.type == "mem": + pipe_cluster_dict[block][dir] = {} + pipe_cluster_dict[block][dir][None] = block_pipes + elif block.type == "ic": + if dir not in pipe_cluster_dict[block]: + pipe_cluster_dict[block][dir] = {} + for blck_pipe in block_pipes: + for neigh_pipe in neigh_pipes: + if not blck_pipe.slave == neigh_pipe.master: + continue + if not(blck_pipe.dir == neigh_pipe.dir): + continue + if traffic_overlaps(blck_pipe, neigh_pipe): + if neigh_pipe not in pipe_cluster_dict[block][dir].keys(): + pipe_cluster_dict[block][dir][neigh_pipe] = [] + if blck_pipe not in pipe_cluster_dict[block][dir][neigh_pipe]: + pipe_cluster_dict[block][dir][neigh_pipe].append(blck_pipe) + + # now generates the clusters + for block, dir_outgoing_incoming_pipes in pipe_cluster_dict.items(): + for dir, outgoing_pipe_incoming_pipes in dir_outgoing_incoming_pipes.items(): + for outgoing_pipe, incoming_pipes in outgoing_pipe_incoming_pipes.items(): + pipe_cluster_ = PipeCluster(block, dir, outgoing_pipe, incoming_pipes, self.last_cluster_assigned_number) + if outgoing_pipe: # for pe and ic + if outgoing_pipe.master.type == "pe": # only push once + pipe_cluster_.change_to_dummy(outgoing_pipe.master.get_tasks_of_block()) + if len(outgoing_pipe.master.get_pipe_clusters()) == 0: + outgoing_pipe.master.set_pipe_cluster(pipe_cluster_) + else: + outgoing_pipe.master.set_pipe_cluster(pipe_cluster_) + elif incoming_pipes: # for mem + incoming_pipes[0].slave.set_pipe_cluster(pipe_cluster_) + self.last_cluster_assigned_number += 1 + self.pipe_clusters.append(pipe_cluster_) + + pass + + def size_queues(self): + # set to the default value + for pipe in self.pipes: + # by default set the cmd/data size to master queue size + pipe.set_cmd_queue_size(config.default_data_queue_size) + pipe.set_data_queue_size(config.default_data_queue_size) + + # actually size the queues + """ + for pipe in self.pipes: + # by default set the cmd/data size to master queue size + pipe_slave = pipe.get_slave() + # ignore PEs + if pipe_slave.type == "pe": + continue + pipe_line_depth = pipe_slave.get_pipe_line_depth() + pipe.set_cmd_queue_size(pipe_line_depth) + pipe.set_data_queue_size(pipe_line_depth) + pipe_tasks = pipe.get_tasks() + """ + + + def generate_pipes(self): + # assign number to pipes + self.last_pipe_assigned_number = 0 + pes = self.get_blocks_by_type("pe") + mems = self.get_blocks_by_type("mem") + ics = self.get_blocks_by_type("ic") + def seen_pipe(pipe__): + for pipe in self.pipes: + if pipe__.master == pipe.master and pipe__.slave == pipe.slave and pipe__.dir == pipe.dir: + return True + return False + + # iterate through all the blocks and specify their pipes + for pe in pes: + for mem in mems: + master_to_slave_path = self.get_path_between_two_vertecies(pe, mem) + if len(master_to_slave_path) > len(ics)+2: # two is for pe and memory + print('something has gone wrong with the path calculation') + exit(0) + # get pipes along the way + for idx in range(0, len(master_to_slave_path) - 1): + block_master = master_to_slave_path[idx] + block_slave = master_to_slave_path[idx + 1] + for dir_ in ["write", "read"]: + pipe_ = pipe(block_master, block_slave, dir_, self.last_pipe_assigned_number, 1, 1) + if not seen_pipe(pipe_): + self.pipes.append(pipe_) + self.last_pipe_assigned_number +=1 + + def connect_pipes_to_blocks(self): + for pipe in self.pipes: + master_block = pipe.get_master() + slave_block = pipe.get_slave() + master_block.set_pipe(pipe) + slave_block.set_pipe(pipe) + + # assign pipes to different blocks (depending on which blocks the pipes are connected to) + def pipe_design(self): + for block in self.blocks: + block.reset_pipes() + self.pipes = [] + self.pipe_clusters = [] + + # generate pipes everywhere + self.generate_pipes() + # assign tasks + self.task_all_the_pipes() + # filter pipes without tasks + self.filter_empty_pipes() + self.connect_pipes_to_blocks() + self.cluster_pipes() + self.size_queues() + + # --------------------------- + # Functionality: + # finding all the paths (set of edges) that connect two blocks (nodes) in the hardware graph. + # Variables: + # vertex: source vertex + # v_des: destination vertex + # vertecies_visited: vertices visited already (avoid circular traversal of the graph) + # path: the accumulated path so far. At the end, this will contain the total path. + # -------------------------- + def get_all_paths(self, vertex, v_des, vertecies_neigh_visited, path): + paths = self.get_shortest_path_helper(vertex, v_des, vertecies_neigh_visited, path) + #sorted_paths = sorted(paths, key=len) + return paths + + + # --------------------------- + # Functionality: + # finding the path (set of edges) that connect two blocks (nodes) in the hardware graph. + # Variables: + # vertex: source vertex + # v_des: destination vertex + # vertecies_visited: vertices visited already (avoid circular traversal of the graph) + # path: the accumulated path so far. At the end, this will contain the total path. + # -------------------------- + def get_shortest_path(self, vertex, v_des, vertecies_neigh_visited, path): + paths = self.get_shortest_path_helper(vertex, v_des, vertecies_neigh_visited, path) + sorted_paths = sorted(paths, key=len) + return sorted_paths[0] + + def get_shortest_path_helper(self, vertex, v_des, vertecies_neigh_visited, path): + neighs = vertex.get_neighs() + path.append(vertex) + + # iterate through neighbours and remove the ones that you have already visited + neighs_to_ignore = [] + for neigh in neighs: + if (vertex,neigh) in vertecies_neigh_visited: + neighs_to_ignore.append(neigh) + neighs_to_look_at = list(set(neighs) - set(neighs_to_ignore)) + + if vertex == v_des: + return [path] + elif len(neighs_to_look_at) == 0: + return [] + else: + for neigh in neighs_to_look_at: + vertecies_neigh_visited.append((neigh, vertex)) + + paths = [] + for vertex_ in neighs_to_look_at: + paths_ = self.get_shortest_path_helper(vertex_, v_des, vertecies_neigh_visited[:], path[:]) + for path_ in paths_: + if len(path) == 0: + continue + paths.append(path_) + + return paths + + + # --------------------------- + # Functionality: + # finding the path (set of edges) that connect two blocks (nodes) in the hardware graph. + # Variables: + # vertex: source vertex + # v_des: destination vertex + # vertecies_visited: vertices visited already (avoid circular traversal of the graph) + # path: the accumulated path so far. At the end, this will contain the total path. + # -------------------------- + def get_path_helper(self, vertex, v_des, vertecies_visited, path): + path.append(vertex) + if vertex in vertecies_visited: + return [] + if vertex == v_des: + return path + else: + vertecies_visited.append(vertex) + paths = [self.get_path_helper(vertex_, v_des, vertecies_visited[:], path[:]) for vertex_ in vertex.neighs] + flatten_path = list(itertools.chain(*paths)) + return flatten_path + + # --------------------------- + # Functionality: + # finding the path (set of edges) that connet two blocks (nodes) in the hardware graph. + # Variables: + # v1: source vertex + # v2: destination vertex + # -------------------------- + def get_path_between_two_vertecies(self, v1, v2): + #path = self.get_path_helper(v1, v2, [], []) + #if (len(path)) <= 0: + # print("catch this error") + #assert(len(path) > 0), "no path between the two nodes" + shortest_path = self.get_shortest_path(v1, v2, [],[]) + #if not shortest_path == path: + # print("something gone wrong with path calculation fix this") + return shortest_path + + # --------------------------- + # Functionality: + # finding all the paths (set of edges) that connet two blocks (nodes) in the hardware graph. + # Variables: + # v1: source vertex + # v2: destination vertex + # -------------------------- + def get_all_paths_between_two_vertecies(self, v1, v2): + all_paths = self.get_all_paths(v1, v2, [],[]) + return all_paths + + + + + # --------------------------- + # Functionality: + # get root of the hardware graph. + # -------------------------- + def get_root(self): + return self.blocks[0] + + # --------------------------- + # Functionality: + # get task's graph + # -------------------------- + def get_task_graph(self): + return self.task_graph \ No newline at end of file diff --git a/Project_FARSI/design_utils/components/krnel.py b/Project_FARSI/design_utils/components/krnel.py new file mode 100644 index 00000000..dd4d46ad --- /dev/null +++ b/Project_FARSI/design_utils/components/krnel.py @@ -0,0 +1,1932 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. +import math + +import numpy as np +import copy +from design_utils.components.hardware import * +from design_utils.components.workload import * +from design_utils.components.mapping import * +from design_utils.components.scheduling import * +import operator +from collections import OrderedDict +from collections import defaultdict +from design_utils.common_design_utils import * +import warnings +import queue +from collections import deque + + +# This is a proprietary Queue class. Not used at the moment, but possible will be +# used for Queue modeling +class MyQueue(): + def __init(self, max_size): + self.max_size = max_size + self.q_data = deque() + + def enqueue(self, data): + if self.is_full(): + return False + self.q_data.insert(0, data) + return True + + def dequeue(self): + if self.is_empty(): + return False + self.q_data.pop() + return True + + def peek(self): + if self.is_empty(): + return False + return self.q_data[-1] + + def size(self): + return len(self.q_data) + + def is_empty(self): + return (self.size() == 0) + + def is_full(self): + return self.size() == self.max_size + + def __str__(self): + return str(self.q_data) + + +# This class emulates the concept of waves, or rather a kernel that has multiple fronts +# where each front is only a burst of data. We are not using this class any more and +# at the moment, the concept of front means +class Wave: + def __init__(self, block, channel_name, task_name, total_work, work_unit_size, id): + self.task_name = task_name + self.channel_name = channel_name + self.block = block + self.total_work = total_work + self.id = id + + # set up the fronts + total_work_to_distribute = total_work + fq_size = math.ceil(total_work/work_unit_size) # front queue size + self.front_queue = MyQueue(fq_size) + for i in range(0, fq_size): + if total_work_to_distribute < work_unit_size: + front_work = total_work_to_distribute + else: + front_work = work_unit_size + self.front_queue(Front(block, channel_name, task_name, front_work, i)) + total_work_to_distribute += work_unit_size + + +# This class emulates a stream of data/instruction. +# for PEs, it's just a bunch of instructions sliced up to chunks +# for memories and buses, it's a stream of data that needs to move from one block (bus, memory) to another block (bus, memory). +# This class is not used for the moment +class Front: + def __init__(self, host_block, src_block, dest_block, channel_name, main_task, family_task, total_work, work_unit): + self.main_task = main_task # the task that the stream is doing work for + self.family_task = family_task # the task that the stream reads/writes to. In case of instructions, family task is yourself + self.channel_name = channel_name # read/write or same + self.host_block = host_block # the block that + # only one of the following is populated depending on read (src) or write(dest) + self.src_block = src_block + self.dest_block = dest_block + self.total_work = total_work + self.work_unit = work_unit + self.work_left = total_work + self.state = "in_active" # ["waiting", "to_be_waiting", "running"] + + def update_work_left(self, work_done): + work_left_copy = self.work_left + self.work_left -= work_done + if self.work_left < -.001: + print("this should not happen") + exit(0) + + def update_state(self, state): + assert(state in ["waiting", "to_be_waiting", "running", "done"]) + self.state = state + + +# class for getting relevant metrics (latency, energy, ...) for the task +class KernelStats: + def __init__(self): + self.latency = 0 + self.area= 0 + self.power = 0 + self.energy= 0 + self.throughput = None + self.blck_pwr = {} # (blck and power associated with it after execution) + self.blck_energy = {} # (blck and area associated with it after execution) + self.blck_area = {} # (block, area of block) + self.phase_latency_dict = {} # duration consumed per phase + self.phase_energy_dict = {} # energy consumed per phase + self.phase_leakage_energy_dict = {} # leakage energy consumed per phase + self.phase_area_dict = {} # area consumed per phase (note that we will be double counting the statically size + # blocks + + self.phase_bytes_dict = {} + self.design_cost = None + self.phase_block_duration_bottleneck:Dict[int, (Block, float)] = {} # dictionary containing phases and how different + # blocks for different durations become the kernel + # bottleneck + self.block_phase_energy_dict = {} + self.starting_time = 0 # when a kernel starts + self.completion_time = 0 # when a kernel completes + self.latency_till_earliest_child_start = 0 + + # ------------------------------ + # Functionality + # get the block bottleneck for the task from the latency perspective + # Variables: + # phase: the phase which the bottleneck qurried for. By default, the querry is for all the phases. + # ------------------------------ + def get_block_latency_bottleneck(self, phase="all"): + # sanity checks + if not self.phase_block_duration_bottleneck: + raise Exception('block_bottleneck is not identified yet.') + elif isinstance(phase, int): + raise Exception("bottleneck for specific phases are not supported yet.") + else: + block_duration_bottleneck: Dict[Block, float] = {} + # iterate and accumulate the duration that blocks are the bottleneck for the kernel + for phase, block_duration in self.phase_block_duration_bottleneck.items(): + block = block_duration[0] + duration = block_duration[1] + if block in block_duration_bottleneck: + block_duration_bottleneck[block] += duration + else: + block_duration_bottleneck[block] = duration + + # sort from worst to best and return the worst (the bottleneck) + sorted_block_duration_bottleneck = OrderedDict( + sorted(block_duration_bottleneck.items(), key=operator.itemgetter(1))) + + return list(sorted_block_duration_bottleneck.keys())[-1] + + # get the block bottleneck from power perspective + # phase: Simulation phases + def get_block_power_bottleneck(self, phase): + #index = random.choice([-3,-2,-1]) + index = random.choice([-1]) + return (sorted(self.blck_pwr.items(), key=operator.itemgetter(1))[index])[0] + + def get_block_area(self): + return self.blck_area + + # get the block bottleneck from area perspective + # phase: Simulation phases + def get_block_area_bottleneck(self, phase): + index = random.choice([-3,-2,-1]) + index = random.choice([-1]) + return (sorted(self.blck_area.items(), key=operator.itemgetter(1))[index])[0] + + # get the block bottleneck from energy perspective + # phase: Simulation phases + def get_block_energy_bottleneck(self, phase): + return (sorted(self.blck_energy.items(), key=operator.itemgetter(1))[-1])[0] + + # get the block bottleneck from cost perspective + # phase: Simulation phases + def get_block_cost_bottleneck(self, phase): + return (sorted(self.blck_cost.items(), key=operator.itemgetter(1))[-1])[0] + + # ------------------------------ + # Functionality + # get the block bottleneck for the task + # Variables: + # phase: the phase which the bottleneck querried for. By default, the query is for all the phases. + # ------------------------------ + def get_block_latency_sorted(self, phase="all"): + # sanity checks + if not self.phase_block_duration_bottleneck: + raise Exception('block_bottleneck is not identified yet.') + elif isinstance(phase, int): + raise Exception("bottleneck for specific phases are not supported yet.") + else: + block_duration_bottleneck: Dict[Block, float] = {} + for phase, block_duration in self.phase_block_duration_bottleneck.items(): + block = block_duration[0] + duration = block_duration[1] + if block in block_duration_bottleneck: + block_duration_bottleneck[block] += duration + else: + block_duration_bottleneck[block] = duration + + sorted_block_duration_bottleneck = OrderedDict( + sorted(block_duration_bottleneck.items(), key=operator.itemgetter(1))) + sorted_normalized = [(key, 100*value / sum(sorted_block_duration_bottleneck.values())) for key, value in sorted_block_duration_bottleneck.items()] + + # for latency, we need to zero out the rest since there is only one bottleneck through out each phase + non_bottleneck_blocks =[] + for block in self.blck_area.keys(): + if block.instance_name == sorted_normalized[0][0]: + continue + non_bottleneck_blocks.append(block) + for block in non_bottleneck_blocks: + sorted_normalized.append((block, 0)) + return sorted_normalized + + # get the block bottlenecks (from power perspective) sorted + # phase: Simulation phases + def get_block_power_sorted(self, phase): + sorted_list = sorted(self.blck_pwr.items(), key=operator.itemgetter(1)) + values = [tuple_[1] for tuple_ in sorted_list] + sorted_normalized = [(key, 100*value/max(sum(values), .00000001)) for key, value in sorted_list] + return sorted_normalized + + # get the block bottlenecks (from area perspective) sorted + # phase: Simulation phases + def get_block_area_sorted(self, phase): + sorted_list = sorted(self.blck_area.items(), key=operator.itemgetter(1)) + values = [tuple_[1] for tuple_ in sorted_list] + sorted_normalized = [(key, 100*value/sum(values)) for key, value in sorted_list] + return sorted_normalized + + # get the block bottlenecks (from energy perspective) sorted + # phase: Simulation phases + def get_block_energy_sorted(self, phase): + sorted_list = sorted(self.blck_energy.items(), key=operator.itemgetter(1)) + values = [tuple_[1] for tuple_ in sorted_list] + sorted_normalized = [(key, 100*value/sum(values)) for key, value in sorted_list] + return sorted_normalized + + # get the block bottlenecks (from cost perspective) sorted + # phase: Simulation phases + def get_block_cost_sorted(self, phase): + sorted_list = sorted(self.blck_cost.items(), key=operator.itemgetter(1)) + values = [tuple_[1] for tuple_ in sorted_list] + sorted_normalized = [(key, 100*value/max(sum(values),.00000001)) for key, value in sorted_list] + return sorted_normalized + + # get the block bottlenecks (from the metric of interest perspective) sorted + # phase: Simulation phases + def get_block_sorted(self, metric="latency", phase="all"): + if metric == "latency": + return self.get_block_latency_sorted(phase) + elif metric == "power": + return self.get_block_power_sorted(phase) + elif metric == "area": + return self.get_block_area_sorted(phase) + elif metric == "energy": + return self.get_block_energy_sorted(phase) + elif metric == "cost": + return self.get_block_cost_sorted(phase) + + # get the block bottlenecks from the metric of interest perspective + # phase: Simulation phases + # metric: metric of interest to pick the bottleneck for + def get_block_bottleneck(self, metric="latency", phase="all"): + if metric == "latency": + return self.get_block_latency_bottleneck(phase) + elif metric == "power": + return self.get_block_power_bottleneck(phase) + elif metric == "area": + return self.get_block_area_bottleneck(phase) + elif metric == "energy": + return self.get_block_energy_bottleneck(phase) + elif metric == "cost": + return self.get_block_cost_bottleneck(phase) + + # ----- + # setter + # ----- + def set_stats(self): + for metric in config.all_metrics: + if metric == "latency": + self.set_latency() + if metric == "energy": + self.set_energy() + if metric == "power": + self.set_power() + if metric == "cost": + return + # already set before + if metric == "area": + # already set before + return + + def set_latency(self): + # already + return + + # sets power for the entire kernel and also per blocks hosting the kernel + def set_power(self): + # get energy first + sorted_listified_phase_latency_dict = sorted(self.phase_latency_dict.items(), key=operator.itemgetter(0)) + sorted_durations = [duration for phase, duration in sorted_listified_phase_latency_dict] + sorted_phase_latency_dict = collections.OrderedDict(sorted_listified_phase_latency_dict) + sorted_listified_phase_energy_dict = sorted(self.phase_energy_dict.items(), key=operator.itemgetter(0)) + sorted_phase_energy_dict = collections.OrderedDict(sorted_listified_phase_energy_dict) + phase_bounds_lists = slice_phases_with_PWP(sorted_phase_latency_dict) + + # calculate power + power_list = [] # list of power values collected based on the power collection freq + for lower_bnd, upper_bnd in phase_bounds_lists: + if sum(sorted_durations[lower_bnd:upper_bnd]) > 0: + power_list.append( + sum(list(sorted_phase_energy_dict.values())[lower_bnd:upper_bnd]) / sum(sorted_durations[lower_bnd:upper_bnd])) + else: + power_list.append(0) + self.power = max(power_list) + + # now calculate the above per block + blck_pwr_list = defaultdict(list) + for block, phase_energy_dict in self.block_phase_energy_dict.items(): + sorted_listified_phase_energy_dict = sorted(phase_energy_dict.items(), key=operator.itemgetter(0)) + sorted_phase_energy_dict = collections.OrderedDict(sorted_listified_phase_energy_dict) + for lower_bnd, upper_bnd in phase_bounds_lists: + if sum(sorted_durations[lower_bnd:upper_bnd]) > 0: + blck_pwr_list[block].append( + sum(list(sorted_phase_energy_dict.values())[lower_bnd:upper_bnd]) / sum(sorted_durations[lower_bnd:upper_bnd])) + else: + blck_pwr_list[block].append(0) + + for blck in blck_pwr_list.keys() : + self.blck_pwr[blck] = max(blck_pwr_list[blck]) + + def set_area(self, area): + self.area = area + + def set_block_area(self, area_dict): + self.blck_area = area_dict + self.set_area(sum(list(area_dict.values()))) + + def set_cost(self, cost): + self.cost = cost + + def set_block_cost(self, cost_dict): + self.blck_cost = cost_dict + self.set_cost(sum(list(cost_dict.values()))) + + def set_energy(self): + for block, phase_energy_dict in self.block_phase_energy_dict.items(): + self.blck_energy[block] = sum(list(phase_energy_dict.values())) + self.energy = sum(self.phase_energy_dict.values()) + + def set_stats_directly(self, metric_name, metric_value): + if (metric_name == "latency"): + self.latency = metric_value + elif (metric_name == "energy"): + self.energy = metric_value + elif (metric_name == "power"): + self.power = metric_value + elif (metric_name == "area"): + self.area = metric_value + elif (metric_name == "cost"): + self.cost = metric_value + else: + print("metric:" + metric_name + " is not supported in the stats") + exit(0) + + # -------- + # getters + # ------ + def get_latency(self): + return self.latency + + def get_cost(self): + return self.cost + + def get_power(self): + return self.power + + def get_area(self): + return self.area + + def get_energy(self): + return self.energy + + def get_metric(self, metric): + if metric == "latency": + return self.get_latency() + if metric == "energy": + return self.get_energy() + if metric == "power": + return self.get_power() + if metric == "area": + return self.get_area() + if metric == "cost": + return self.get_cost() + + +# This class emulates the a task within the workload. +# The difference between kernel and task is that kernel is a simulation construct containing timing/energy/power information. +class Kernel: + def __init__(self, task_to_blocks_map: TaskToBlocksMap): #, task_to_pe_block_schedule: TaskToPEBlockSchedule): + # constructor argument vars, any changes to these need to initiate a reset + self.__task_to_blocks_map = task_to_blocks_map # mapping of the task to blocks + self.kernel_total_work = {} + #self.kernel_total_work = self.__task_to_blocks_map.task.get_self_task_work() # work (number of instructions for the task) + self.kernel_total_work["execute"] = self.__task_to_blocks_map.task.get_self_total_work("execute") + self.kernel_total_work["read"] = self.__task_to_blocks_map.task.get_self_total_work("read") + self.kernel_total_work["write"] =self.__task_to_blocks_map.task.get_self_total_work("write") + self.data_work_left = {} + self.max_iteration_ctr = self.iteration_ctr = self.__task_to_blocks_map.task.iteration_ctr + # status vars + self.cur_phase_bottleneck = "" # which phase does the bottleneck occurs + self.block_att_work_rate_dict = defaultdict(dict) # for the kernel, block and their attainable work rate. + # attainable work rate is peak work rate (BW or IPC) of the block + # but attenuated as it is being shared among multiple kernels/fronts + + self.type = self.__task_to_blocks_map.task.get_type() + self.throughput_info = self.__task_to_blocks_map.task.get_throughput_info() + self.operating_state = "none" + self.block_path_dir_phase_latency = {} + self.path_dir_phase_latency = {} + self.pathlet_phase_latency_dict = {} + self.stats = KernelStats() + self.workload_pe_total_work, self.workload_fraction, self.pe_s_work_left, self.progress, self.status = [None]*5 + self.block_dir_work_left = defaultdict(dict) # block and the direction (write/read or loop) and how much work is left for the + # the kernel on this block + self.phase_num = -1 # since the very first time, doesn't count + self.block_phase_work_dict = defaultdict(dict) # how much work per block and phase is done + self.block_phase_read_dict = defaultdict(dict) # how much work per block and phase is done + self.block_phase_write_dict = defaultdict(dict) # how much work per block and phase is done + self.block_phase_energy_dict = defaultdict(dict) # how much energy phase block and phase is consumed + # how much leakage energy phase block and phase is consumed (PE and mem) + self.block_phase_leakage_energy_dict = defaultdict(dict) + self.block_phase_area_dict = defaultdict(dict) # how much area phase block and phase is consumed + self.SOC_type = "" + self.SOC_id = "" + self.set_SOC() + self.task_name = self.__task_to_blocks_map.task.name + + # The variable shows what power_knob is being used! + # 0 means baseline, and any other number is a DVFS/power_knob mode + self.power_knob_id = 0 + self.starting_time = 0 + self.completion_time = 0 + + # Shows what block is the bottleneck at every phase of the kernel execution + self.kernel_phase_bottleneck_blocks_dict = defaultdict(dict) + self.block_num_shared_blocks_dict = {} # how many other kernels shared this block + #self.work_unit_dict = self.__task_to_blocks_map.task.__task_to_family_task_work_unit # determines for each burst how much work needs + # to be done. + self.path_structural_latency = {} + + self.firing_time_to_meet_throughput = {} + self.firing_work_to_meet_throughput = {} + self.data_work_left_to_meet_throughput = {} + + + # This function is used for the power knobs simulator; After each run the stats + # gathered for each kernel is removed to avoid conflicting with past simulations + def reset_sim_stats(self): + # status vars + self.cur_phase_bottleneck = "" + self.block_att_work_rate_dict = defaultdict(dict) + self.stats = KernelStats() + self.workload_pe_total_work, self.workload_fraction, self.pe_s_work_left, self.progress, self.status = [None]*5 + self.phase_num = -1 # since the very first time, doesn't count + self.block_normalized_work_rate = {} # blocks and their work_rate (not normalized) including sharing + self.block_phase_work_dict = defaultdict(dict) # how much work per block and phase is done + self.block_phase_read_dict = defaultdict(dict) # how much work per block and phase is done + self.block_phase_write_dict = defaultdict(dict) # how much work per block and phase is done + self.block_phase_energy_dict = defaultdict(dict) # how much energy phase block and phase is consumed + self.block_phase_leakage_energy_dict = defaultdict(dict) + self.block_phase_area_dict = defaultdict(dict) # how much area phase block and phase is consumed + self.SOC_type = "" + self.SOC_id = "" + self.set_SOC() + self.task_name = self.__task_to_blocks_map.task.name + self.work_unit_dict = self.__task_to_blocks_map.task.__task_to_family_task_work_unit # determines for each burst how much work needs + self.power_knob_id = 0 + self.starting_time = 0 + self.completion_time = 0 + self.kernel_phase_bottleneck_blocks_dict = defaultdict(dict) + self.block_num_shared_blocks_dict = {} + self.operating_state = "none" + + # populate the statistics + def set_stats(self): + self.stats.set_block_area(self.calc_area_used_per_block()) + self.stats.set_block_cost(self.calc_cost_per_block()) + self.stats.set_stats() + + # set the SOC for the kernel + def set_SOC(self): + SOC_list = [(block.SOC_type, block.SOC_id) for block in self.__task_to_blocks_map.get_blocks()] + if len(set(SOC_list)) > 1: + raise Exception("kernel can not be resident in more than 1 SOC") + SOC = SOC_list[0] + self.SOC_type = SOC[0] + self.SOC_id = SOC[1] + + + def update_krnl_iteration_ctr(self): + if self.iteration_ctr == -1: + pass + else: + self.iteration_ctr = max(0, self.iteration_ctr -1) + + # -------------- + # getters + # -------------- + def get_task(self): + return self.__task_to_blocks_map.task + + def get_task_name(self): + return self.__task_to_blocks_map.task.name + + # work here is PE's work, so specified in terms of number of instructions + def get_total_work(self): + return self.kernel_total_work["execute"] + + # get the list of blocks the kernel uses + def get_block_list_names(self): + return [block.instance_name for block in self.__task_to_blocks_map.get_blocks()] + + def get_blocks(self): + return self.__task_to_blocks_map.get_blocks() + + # get the reference block for the kernel (reference block is what work rate is calculated + # based off of. + def get_ref_block(self): + for block in self.__task_to_blocks_map.get_blocks(): + if block.type == "pe": + return block + + # get kernel's memory blocks + # dir: direction of interest (read/write) + def get_kernel_s_mems(self, dir): + mems = [] + for block in self.__task_to_blocks_map.get_blocks(): + if block.type == "mem": + task_dir = block.get_task_dir_by_task_name(self.get_task()) + for task, dir_ in task_dir: + if dir_ == dir: + mems.append(block) + break + + if "souurce" in self.get_task_name() and dir == "read": + if not len(mems) == 0: + raise Exception(" convention is that no read memory for souurce task") + else: + return [] + elif "siink" in self.get_task_name() and dir == "write": + if not len(mems) == 0: + raise Exception(" convention is that no write memory for siink task") + else: + return [] + else: + return mems + + def set_power_knob_id(self, pk_id): + self.power_knob_id = pk_id + return + + def get_power_knob_id(self): + return self.power_knob_id + + # return the a dictionary containing the kernel bottleneck across different phases of execution + def get_bottleneck_dict(self): + return self.kernel_phase_bottleneck_blocks_dict + + # calculate the area + def calc_area_used(self): + total_area = 0 + for my_block in self.__task_to_blocks_map.get_blocks(): + if my_block.get_block_type_name() in ["ic", "pe"]: + total_area += my_block.get_area() + elif my_block.get_block_type_name() in ["mem"]: + mem_work_ratio_read = self.__task_to_blocks_map.get_workRatio_by_block_name_and_dir(my_block.instance_name, "read") + mem_work_ratio_write = self.__task_to_blocks_map.get_workRatio_by_block_name_and_dir(my_block.instance_name, "write") + total_area += self.get_total_work()*mem_work_ratio_read/my_block.get_work_over_area() + total_area += self.get_total_work()*mem_work_ratio_write/my_block.get_work_over_area() + return total_area + + # calculate area per block + def calc_area_used_per_block(self): + area_dict = {} + total_area = 0 + for my_block in self.__task_to_blocks_map.get_blocks(): + if my_block.get_block_type_name() in ["ic", "pe"]: + area_dict[my_block] = my_block.get_area() + elif my_block.get_block_type_name() in ["mem"]: + mem_work_ratio_read = self.__task_to_blocks_map.get_workRatio_by_block_name_and_dir(my_block.instance_name, "read") + mem_work_ratio_write = self.__task_to_blocks_map.get_workRatio_by_block_name_and_dir(my_block.instance_name, "write") + area_dict[my_block] = self.get_total_work()*mem_work_ratio_read/my_block.get_work_over_area() + area_dict[my_block] += self.get_total_work()*mem_work_ratio_write/my_block.get_work_over_area() + return area_dict + + # calculate area per block + def calc_traffic_per_block(self, my_block): + traffic = 0 + if my_block.get_block_type_name() in ["mem"]: + mem_work_ratio_read = self.__task_to_blocks_map.get_workRatio_by_block_name_and_dir(my_block.instance_name, "read") + mem_work_ratio_write = self.__task_to_blocks_map.get_workRatio_by_block_name_and_dir(my_block.instance_name, "write") + traffic += self.get_total_work()*mem_work_ratio_write + traffic += self.get_total_work()*mem_work_ratio_read + + return traffic + + + + + # calculate the cost per block + def calc_cost_per_block(self): + cost_dict = {} + for my_block in self.__task_to_blocks_map.get_blocks(): + cost_dict[my_block] = 0 + return cost_dict + + # return the kernels that currently are on the block and channel of the block. + def kernel_currently_uses_the_block_pipe_cluster(self, block, task, pipe_cluster): + return pipe_cluster.is_task_present(task) + + # how many kernels using the block + def get_block_num_shared_krnels(self, block, pipe_cluster, scheduled_kernels): + krnl_using_block_channl = list(filter(lambda krnl: krnl.kernel_currently_uses_the_block_pipe_cluster(block, krnl.get_task(), pipe_cluster), scheduled_kernels)) + return len(krnl_using_block_channl) + + # creates a list of the blocks that are shared with other tasks in the current phase + # Stores the numbers of tasks using the same resources to divide them among all later + def get_blocks_num_shared_krnels(self, scheduled_kernels): + #block_num_shared_blocks_dict = defaultdict(defaultdict(defaultdict)) + block_num_shared_blocks_dict = {} + + blocks = self.get_blocks() + for block in blocks: + for pipe_cluster in block.get_pipe_clusters_of_task(self.get_task()): + dir = pipe_cluster.get_dir() + cluster_UN = pipe_cluster.get_unique_name() # cluster unique name + if block not in block_num_shared_blocks_dict.keys(): + block_num_shared_blocks_dict[block] = {} + if dir not in block_num_shared_blocks_dict[block].keys(): + block_num_shared_blocks_dict[block][dir] = {} + block_num_shared_blocks_dict[block][dir][cluster_UN] = self.get_block_num_shared_krnels(block, pipe_cluster, scheduled_kernels) + return block_num_shared_blocks_dict + + # return the pipes that have kernels on them + def filter_in_active_pipes(self, incoming_pipes, outcoming_pipe, scheduled_kernels): + active_pipes_with_duplicates = [] + for in_pipe_ in incoming_pipes: + for krnl in scheduled_kernels: + if outcoming_pipe == None: # for memory + task_present_on_outcoming_pipe = True + else: + task_present_on_outcoming_pipe = outcoming_pipe.is_task_present(krnl.__task_to_blocks_map.task) + + if in_pipe_.is_task_present(krnl.__task_to_blocks_map.task) and task_present_on_outcoming_pipe: + active_pipes_with_duplicates.append(in_pipe_) + + active_pipes = list(set(active_pipes_with_duplicates)) + assert(len(active_pipes) <= len(incoming_pipes)) + return active_pipes + + # calculate what the work rate (BW or IPC) of each block is (for the kernel at hand) + # This method assumes that each pipe gets an equal portion of the bandwidth, in other words, + # equal pipe arbitration policy + def alotted_work_rate_equal_pipe(self, pipe_cluster, scheduled_kernels): + block = pipe_cluster.get_block_ref() + dir = pipe_cluster.get_dir() + # helper function + def pipes_serial_work_ratio(pipe, scheduled_kernels, mode = "equal_per_kernel"): + mode = "proportional_to_kernel" + #mode = "equal_per_kernel" + if mode == "equal_per_kernel": + num_tasks_present = 0 + num_tasks_present = sum([int(pipe.is_task_present(krnel.__task_to_blocks_map.task)) for krnel in scheduled_kernels]) + task_present = int(pipe.is_task_present(self.__task_to_blocks_map.task)) + if num_tasks_present == 0: # this scenario happens when we have schedulued only one task (specifically souurce or siink on a processor) + num_tasks_present = 1 + serial_work_rate = task_present / num_tasks_present + elif mode == "proportional_to_kernel": + all_kernels_work = sum([(pipe.get_task_work_unit(krnel.__task_to_blocks_map.task)) for krnel in scheduled_kernels]) + own_kernel_work = pipe.get_task_work_unit(self.__task_to_blocks_map.task) + if all_kernels_work == 0: # this scenario happens when we have schedulued only one task (sepcifically souurce or siink on a processor) + all_kernels_work = 1 + serial_work_rate = own_kernel_work/all_kernels_work + return serial_work_rate + + if block.type == "pe": # pipes are not important for PEs + allotted_work_rate = 1/self.block_num_shared_blocks_dict[block][dir][pipe_cluster.get_unique_name()] + else: + # get the pipes that kernel is running on and use bottleneck analysis + # to find the work rate + incoming_pipes = pipe_cluster.get_incoming_pipes() + outgoing_pipe = pipe_cluster.get_outgoing_pipe() + pipes_with_traffic = self.filter_in_active_pipes(incoming_pipes,outgoing_pipe, scheduled_kernels) + allotted_work_rate = 0 + for pipe in pipes_with_traffic: + pipe_serial_work_rate = pipes_serial_work_ratio(pipe, scheduled_kernels) + allotted_work_rate += (1/len(pipes_with_traffic)) * pipe_serial_work_rate + return allotted_work_rate + + # calculate the work rate (BW or IPC depending on the hardware block) of each kernel, while considering + # sharing of the block across live kernels + def calc_allotted_work_rate_relative_to_other_kernles(self, mode, pipe_cluster, scheduled_kernels): + assert(mode in ["equal_rate_per_kernel", "equal_rate_per_pipe"]) + if mode == "equal_rate_per_kernel": + return float(1./self.block_num_shared_blocks_dict[pipe_cluster.get_ref_block()][pipe_cluster.dir][pipe_cluster.get_unique_name()]) + elif mode == "equal_rate_per_pipe": + return self.alotted_work_rate_equal_pipe(pipe_cluster, scheduled_kernels) + + def get_block_family_tasks_in_use(self, block): + blocks_family_members = self.__task_to_blocks_map.get_block_family_members_allocated(block.instance_name) + return blocks_family_members + + + def get_queue_impact_simplified(self, block, pipe_cluster, schedulued_krnels): + def get_flit_count_on_pipe(block, pipe, schedulued_krnels): + work_unit_total = 0 + for krnl in schedulued_krnels: + if pipe.is_task_present(krnl.get_task()) and krnl.get_task_name() == self.get_task_name(): + work_unit_total += pipe.get_task_work_unit(krnl.get_task()) + + flit_cnt = math.ceil(work_unit_total/block.get_block_bus_width()) + return flit_cnt + + def get_flit_count_on_pipe_cluster(block, schedulued_kernels, mode): + incoming_pipes = pipe_cluster.get_incoming_pipes() + own_pipe = "NA" + other_pipes = [] + for pipe_ in incoming_pipes: + if pipe_.is_task_present(self.get_task()): + own_pipe = pipe_ + + for pipe_ in incoming_pipes: + if pipe_.is_task_present(self.get_task()): + continue + other_pipes.append(pipe_) + + if mode == "own": + work_unit_total = own_pipe.get_task_work_unit(self.get_task()) + + if mode == "serial": + for krnl in schedulued_krnels: + if own_pipe.is_task_present(krnl.get_task()) and not krnl.get_task_name() == self.get_task_name(): + work_unit_total += own_pipe.get_task_work_unit(krnl.get_task()) + + if mode == "parallel": + for krnl in schedulued_krnels: + for pipe_ in other_pipes: + if pipe_.is_task_present(krnl.get_task()) and not krnl.get_task_name() == self.get_task_name(): + work_unit_total += pipe_.get_task_work_unit(krnl.get_task()) + + + return math.ceil(work_unit_total / block.get_block_bus_width()) + + + # calculate the queue impact + queue_impact = 1 + if block.type == "pe": + queue_impact = 1 + else: + incoming_pipes = pipe_cluster.get_incoming_pipes() + # use a random pipe for now. TODO: fix later + + + return queue_impact + + + def get_queue_impact(self, block, pipe_cluster, schedulued_krnels): + def get_flit_count_on_pipe(block, pipe, schedulued_krnels): + work_unit_total = 0 + for krnl in schedulued_krnels: + if pipe.is_task_present(krnl.get_task()): + work_unit_total += pipe.get_task_work_unit(krnl.get_task()) + + flit_cnt = math.ceil(work_unit_total/block.get_block_bus_width()) + return flit_cnt + + + + def get_flit_count_on_pipe_by_type(block, pipe, schedulued_krnels): + work_unit_total = 0 + for krnl in schedulued_krnels: + if pipe.is_task_present(krnl.get_task()): + work_unit_total += pipe.get_task_work_unit(krnl.get_task()) + + flit_cnt = math.ceil(work_unit_total/block.get_block_bus_width()) + queue_size = pipe.get_data_queue_size() + + flit_cnt_to_prime_with = min(queue_size, flit_cnt) + if flit_cnt > queue_size: + flit_cnt_after_priming = (math.floor(flit_cnt/queue_size) - 1)*queue_size + flit_cnt_for_draining = flit_cnt - flit_cnt_to_prime_with - flit_cnt_after_priming + else: + flit_cnt_after_priming = 0 + flit_cnt_for_draining = 0 + + assert(flit_cnt_for_draining >= 0), "flit cnt draining needs to be greater or equal to zero" + assert(flit_cnt_after_priming >= 0), "flit count after priming needs to be greater or equal to zero" + assert(flit_cnt_to_prime_with >= 0), "flit count to prime with needs to be greater or equal to zero" + + flit_cnt_type = {"to_prime_with": flit_cnt_to_prime_with, "after_priming":flit_cnt_after_priming, "draining":flit_cnt_for_draining} + return flit_cnt_type + + + # calculate the queue impact + queue_impact = 1 + if block.type == "pe": + queue_impact = 1 + else: + incoming_pipes = pipe_cluster.get_incoming_pipes() + # use a random pipe for now. TODO: fix later + + bus_width = block.get_block_bus_width() + block_pipe_line_depth = block.get_pipe_line_depth() + + # select an incoming pipe + for pipe_ in incoming_pipes: + for krnl in schedulued_krnels: + if pipe_.is_task_present(krnl.get_task()): + default_pipe = pipe_ + #default_pipe = incoming_pipes[0] + pipe = default_pipe + + queue_size = pipe.get_data_queue_size() + flit_cnt = get_flit_count_on_pipe(block, pipe, schedulued_krnels) + + + flit_cnt_by_type = get_flit_count_on_pipe_by_type(block, pipe, schedulued_krnels) + + # spend all the flits to 3, + # flits that prime the pipeline, flits after priming, flits for draining + total_cycles_spent_on_all_flits = 0 + # what portion of bandwidth would be curbed due to queuing impact + # while the pipeline is being primed + modeling_quanta = flit_cnt_by_type["to_prime_with"] + number_of_quanta = 1 + quanta_over_all_percentage = (modeling_quanta*number_of_quanta)/flit_cnt + if not quanta_over_all_percentage == 0: + cycles_spent_on_quanta = max(block_pipe_line_depth - (queue_size - 1),1) + (modeling_quanta - 1) + #cycles_spent_on_quanta = (modeling_quanta - 1) + 1 + quanta_curbing_coeff = modeling_quanta/cycles_spent_on_quanta # quanta is "to_prime_with" + flits_to_prime_with_impact = quanta_over_all_percentage * quanta_curbing_coeff + total_cycles_spent_on_all_flits+= number_of_quanta*cycles_spent_on_quanta + else: + flits_to_prime_with_impact = 0 + + # after the pipeline is primed + modeling_quanta = queue_size + number_of_quanta = (flit_cnt_by_type["after_priming"]/queue_size) + quanta_over_all_percentage = (modeling_quanta*number_of_quanta)/flit_cnt + if not quanta_over_all_percentage == 0: + cycles_spent_on_quanta = max(block_pipe_line_depth - (queue_size - 1),1) + (modeling_quanta - 1) # quanta is queue size + quanta_curbing_coeff = modeling_quanta/cycles_spent_on_quanta + flits_after_priming_impact = quanta_over_all_percentage * quanta_curbing_coeff + total_cycles_spent_on_all_flits += number_of_quanta*cycles_spent_on_quanta + else: + flits_after_priming_impact = 0 + + # pipeline drainage + modeling_quanta = flit_cnt_by_type["draining"] + number_of_quanta = 1 + quanta_over_all_percentage = (modeling_quanta*number_of_quanta)/flit_cnt + if not quanta_over_all_percentage == 0: + cycles_spent_on_quanta = max(block_pipe_line_depth - (queue_size - 1),1) + (modeling_quanta - 1) # quanta is queue size + quanta_curbing_coeff = modeling_quanta / cycles_spent_on_quanta + flits_for_draining_impact = quanta_over_all_percentage* quanta_curbing_coeff + total_cycles_spent_on_all_flits += number_of_quanta*cycles_spent_on_quanta + else: + flits_for_draining_impact = 0 + + + # calculate queue impact + """ + queue_occupancy = min(queue_size, flit_cnt) # measured in number of occupied cells + pipe_line_utilization = queue_occupancy/block_pipe_line_depth + pipe_line_utilization = min(pipe_line_utilization, 1) # can't be above one + queue_impact = pipe_line_utilization + #queue_impact = flits_after_priming_impact + flits_to_prime_with_impact + flits_for_draining_impact + """ + + fw_latency = 2*block_pipe_line_depth + (queue_size - 1) + bw_latency = 3 * block_pipe_line_depth + total_hop_latency = fw_latency + bw_latency + + # add the hop latency + if len(schedulued_krnels) > 1 : + krnls_running_cnt = len(schedulued_krnels) + #krnls_running_cnt = 8 + unhidden_latency = total_hop_latency - (krnls_running_cnt - 1) * total_cycles_spent_on_all_flits + if unhidden_latency > 0: + unhidden_latency = (unhidden_latency)/krnls_running_cnt + else: + unhidden_latency = 0 + else: + unhidden_latency = total_hop_latency + + total_cycles_spent_on_all_flits += unhidden_latency + queue_impact = flit_cnt/total_cycles_spent_on_all_flits + + return queue_impact + + + + # get each blocks work-rate while considering sharing the block across the active kernels + # Normalization is the process of normalizing the work_rate of each block with respect of the + # reference work (work done by the PE). This then allows us to easily find the bottleneck + # for the block as we have already normalized the data. + def calc_all_block_normalized_work_rate(self, scheduled_kernels): + self.block_num_shared_blocks_dict = self.get_blocks_num_shared_krnels(scheduled_kernels) + #mode = "equal_rate_per_kernel" + mode = "equal_rate_per_pipe" + block_work_rate_norm_dict = defaultdict(defaultdict) + + # iterate through each block, channel. + # (1) calculate the share of each kernel for the channel. (2) get their work ratio to normalize to + # (3) use peak rate, share of each kernel and work ratio to generate final results + for block in self.get_blocks(): + for pipe_cluster in block.get_pipe_clusters_of_task(self.get_task()): + if self.get_task().is_task_dummy(): + block_work_rate_norm_dict[block][pipe_cluster] = 1 + continue + # calculate share of each kernel + allocated_work_rate_relative_to_other_kernels = self.calc_allotted_work_rate_relative_to_other_kernles(mode, pipe_cluster, scheduled_kernels) + if allocated_work_rate_relative_to_other_kernels == 0: # when the channel in the block is not being used + continue + + # get work ratio (so you can normalize to it) + dir = pipe_cluster.get_dir() + + if config.sw_model == "sequential": + # filter blocks based on the stage + if block.type == "pe" and not self.operating_state == "execute": + continue + elif not block.type == "pe" and not dir == self.operating_state: + continue + work_ratio = 1 + else: + work_ratio = self.__task_to_blocks_map.get_workRatio_by_block_name_and_family_member_names_and_channel_eliminating_fake( + block.instance_name, self.get_block_family_tasks_in_use(block), dir) + + if work_ratio == 0: + print("this should be looked at") + work_ratio = self.__task_to_blocks_map.get_workRatio_by_block_name_and_family_member_names_and_channel_eliminating_fake( + block.instance_name, self.get_block_family_tasks_in_use(block), dir) + + + # queue impact + #queue_impact = self.get_queue_impact(block, pipe_cluster, scheduled_kernels) + # TODO: change later (for now set to 1). + queue_impact = 1 + + work_rate = queue_impact*float(block.get_peak_work_rate(self.get_power_knob_id()))*allocated_work_rate_relative_to_other_kernels/work_ratio + block_work_rate_norm_dict[block][pipe_cluster] = work_rate + + return block_work_rate_norm_dict + + # simply go through all the block work rate and pick the smallest + def calc_block_s_bottleneck(self, block_work_rate_norm_dict): + # only if work unit is left + block_bottleneck = {"write": None, "read":None} + bottleneck_work_rate = {"write": np.Inf, "read":np.Inf} + # iterate through all the blocks/channels and ge the minimum work rate. Since + # the data is normalized, minimum is the bottleneck + for block, pipe_cluster_work_rate in block_work_rate_norm_dict.items(): + for pipe_cluster, work_rate in pipe_cluster_work_rate.items(): + dir_ = pipe_cluster.get_dir() + if dir_ == "same": # same is for PEs. In that case, we will apply it for both the read and write path + dirs = ["write", "read"] + else: + dirs = [dir_] + for dir__ in dirs: + if not block_bottleneck[dir__]: + block_bottleneck[dir__] = block + bottleneck_work_rate[dir__] = work_rate + else: + if work_rate < bottleneck_work_rate[dir__]: + bottleneck_work_rate[dir__] = work_rate + block_bottleneck[dir__] = block + + return block_bottleneck, bottleneck_work_rate + + + # calculate the unnormalized work rate. + # Normalization is the process of normalizing the work_rate of each block with respect of the + # reference work (work done by the PE). This then allows us to easily find the bottleneck + # for the block as we have already normalized the data. Unnormalization is the reverse + # process + def calc_unnormalize_work_rate_by_dir(self, block_work_rate_norm_dict, bottleneck_work_rate): + block_dir_att_work_rate_dict = defaultdict(dict) + for block, pipe_cluster_work_rate in block_work_rate_norm_dict.items(): + if block.type == "pe": + continue + for pipe_cluster, work_rate in pipe_cluster_work_rate.items(): + dir = pipe_cluster.get_dir() + if dir == "same": + dirs = ["write", "read"] + else: + dirs = [dir] + for dir_ in dirs: + work_ratio = self.__task_to_blocks_map.get_workRatio_by_block_name_and_family_member_names_and_channel_eliminating_fake( + block.instance_name, self.get_block_family_tasks_in_use(block), dir_) + if "souurce" in self.get_task_name() or "siink" in self.get_task_name(): + work_ratio = 1 + block_dir_att_work_rate_dict[block][pipe_cluster] = bottleneck_work_rate[dir_]*work_ratio + #self.update_pipe_cluster_work_rate(pipe_cluster, bottleneck_work_rate) + + return block_dir_att_work_rate_dict + + + # calculate the unnormalized work rate. + # Normalization is the process of normalizing the work_rate of each block with respect of the + # reference work (work done by the PE). This then allows us to easily find the bottleneck + # for the block as we have already normalized the data. Unnormalization is the reverse + # process + def calc_unnormalize_work_rate(self, block_work_rate_norm_dict, bottleneck_work_rate): + block_att_work_rate_dict = defaultdict(dict) + for block, pipe_cluster_work_rate in block_work_rate_norm_dict.items(): + for pipe_cluster, work_rate in pipe_cluster_work_rate.items(): + dir = pipe_cluster.get_dir() + work_ratio = self.__task_to_blocks_map.get_workRatio_by_block_name_and_family_member_names_and_channel_eliminating_fake( + block.instance_name, self.get_block_family_tasks_in_use(block), dir) + + if "souurce" in self.get_task_name() or "siink" in self.get_task_name() or config.sw_model == "sequential": + work_ratio = 1 + block_att_work_rate_dict[block][pipe_cluster] = bottleneck_work_rate*work_ratio + #self.update_pipe_cluster_work_rate(pipe_cluster, bottleneck_work_rate) + + return block_att_work_rate_dict + + # update paths (inpipe-outpipe) work rate + def update_pipe_clusters_pathlet_work_rate(self): + for block in self.get_blocks(): + for pipe_cluster in block.get_pipe_clusters(): + self.update_pipe_cluster_pathlet_work_rate(pipe_cluster, self.cur_phase_bottleneck_work_rate) + + + + # latency taken for an ic to wait for the traffic ahead + def get_traffic_latency(self, pathlet_, pipe_cluster, scheduled_kernels): + traversal_dir_latency = {} # forward/backward latency + block_ref = pipe_cluster.get_block_ref() + if not block_ref.type == "ic" : + traversal_dir_latency["backward"] = 0 + traversal_dir_latency["forward"] = 0 + return traversal_dir_latency + + traffic_latency = 0 + dir_ = pathlet_.get_dir() + pipes_with_traffic = self.filter_in_active_pipes(pipe_cluster.get_incoming_pipes(), + pipe_cluster.get_outgoing_pipe(), scheduled_kernels) + + pathlets_work_rate, last_phase = pipe_cluster.get_pathlet_last_phase_work_rate() # last phase work_rate + sum_work_rate = max(sum(pathlets_work_rate.values()), .000000000000000000001) # max is there to avoid division by 0 + + for pathlet__, work_rate in pathlets_work_rate.items(): + if pathlet__ == pathlet_: + traffic_latency += 64/max(work_rate,.0000000000000001) + #traffic_latency = sum_work_rate/max(work_rate,.0000000000000001) + + # determine the forward/backward path latency + if dir_ == "write": + traversal_dir_latency["backward"] = 0 + traversal_dir_latency["forward"] = traffic_latency + if dir_ == "read": + traversal_dir_latency["backward"] = 0 + traversal_dir_latency["forward"] = traffic_latency + + return traversal_dir_latency + + # latency taken for an ic to arbiterate between different requests + def get_arbiteration_latency(self, pathlet_, pipe_cluster, scheduled_kernels): + traversal_dir_latency = {} # forward/backward latency + block_ref = pipe_cluster.get_block_ref() + if not block_ref.type == "ic" : + traversal_dir_latency["backward"] = 0 + traversal_dir_latency["forward"] = 0 + return traversal_dir_latency + + arbiteration_latency = 0 + dir_ = pathlet_.get_dir() + pipes_with_traffic = self.filter_in_active_pipes(pipe_cluster.get_incoming_pipes(), + pipe_cluster.get_outgoing_pipe(), scheduled_kernels) + + pathlets_work_rate, last_phase = pipe_cluster.get_pathlet_last_phase_work_rate() # last phase work_rate + # aribiteration latency is total_number of working (pathlets - 1) as IC needs to iterate through them all + occupided_pathlets = [pathlet for pathlet, work_rate in pathlets_work_rate.items() if work_rate>0] + if dir_ == "write": + traversal_dir_latency["forward"] = len(occupided_pathlets) - 1 + traversal_dir_latency["backward"] = 0 + elif dir_ == "read": + traversal_dir_latency["forward"] = 0 + traversal_dir_latency["backward"] = len(occupided_pathlets) - 1 + + return traversal_dir_latency + + # latency taken for an ic to generate the fleets of a request + def get_pathlet_flee_cnt(self, block_ref, pathlet_): + dir_ = pathlet_.get_dir() + family_tasks = self.get_family_tasks_on_the_pipe_cluster(dir_) + work = 0 + # iterate and add work_unit for all the pathlets + in_pipe = pathlet_.get_in_pipe() + #TODO: for now only pe to ic fleet generation is accounted for, + # but I believe if there is bus width descrpeency between adjacent blocks, + # fleets might be generated + #if in_pipe.get_slave().type == "mem": + # print("ok") + if not ((in_pipe.get_master().type == "pe" and dir_ == "write") or (in_pipe.get_slave().type == "mem" and dir_ == "read")): + return 0 + master_tasks = self.get_masters_relevant_tasks_on_the_pipe_cluster(in_pipe, dir_) + # calculate work ratio + for family_task in family_tasks: + if family_task in master_tasks: + work += self.get_task().get_self_to_family_task_work_unit(family_task) + + fleet_cnt = (work/block_ref.get_block_bus_width()) + return fleet_cnt + + def get_fleet_generation_latency(self, pathlet_, pipe_cluster, schedulued_kernels): + #if self.task_name in self. + + traversal_dir_latency = {} # forward/backward latency + dir_ = pathlet_.get_dir() + block_ref = pipe_cluster.get_block_ref() + if not ((block_ref.type == "pe" and dir_ == "write") or (block_ref.type == "mem" and dir_ == "read")): + traversal_dir_latency["backward"] = 0 + traversal_dir_latency["forward"] = 0 + return traversal_dir_latency + + fleet_generation_latency = 0 + for pathlet_ in pipe_cluster.get_pathlets(): + fleet_generation_latency += self.get_pathlet_flee_cnt(block_ref, pathlet_) + + # determine the forward/backward path latency + if dir_ == "write": + traversal_dir_latency["backward"] = 0 + traversal_dir_latency["forward"] = fleet_generation_latency + if dir_ == "read": + traversal_dir_latency["backward"] = 0 + traversal_dir_latency["forward"] = 0 + + return traversal_dir_latency + + def get_hop_latency(self, pipe_cluster): + block_ref = pipe_cluster.get_block_ref() + dir_ = pipe_cluster.get_dir() + + # determine the forward/backward path latency + traversal_dir_latency = {} # forward/backward latency + if dir_ == "write": + traversal_dir_latency["backward"] = 0 + traversal_dir_latency["forward"] = 1 + if dir_ == "read": + traversal_dir_latency["backward"] = 0 # at the moment, our back ward hop is zero matching platform architect's default + traversal_dir_latency["forward"] = 1 + for k in traversal_dir_latency.keys(): + traversal_dir_latency[k] = traversal_dir_latency[k]/block_ref.get_block_freq() + + return traversal_dir_latency + + def update_path_latency_1(self): + for pathlet, trv_dir_phase_num_val in self.pathlet_phase_latency_dict.items(): + for trv_dir, phase_num_val in trv_dir_phase_num_val.items(): + for dir_ in ["write", "read"]: + if dir_ not in self.path_dir_phase_latency.keys(): + self.path_dir_phase_latency[dir_] = {} + for phase_num, val in phase_num_val.items(): + if phase_num not in self.path_dir_phase_latency[dir_].keys(): + self.path_dir_phase_latency[dir_][phase_num] = 0 + + if self.get_task().is_task_dummy(): + continue + #blk_freq = pathlet.get_in_pipe().get_slave().get_block_freq() + #self.path_dir_phase_latency[dir_][phase_num] += val + block = pathlet.get_in_pipe().get_slave() + pipe_cluster = pathlet.getma + block_att_work_rate_dict[block][pipe_cluster] + self.path_dir_phase_latency[dir_][phase_num] = 64/self.block_att_work_rate_dict[pathlet.get_in_pipe().get_slave()] + + # update paths (inpipe-outpipe) work rate + def update_path_latency(self): + if self.get_task().is_task_dummy(): + return + + for block, pipe_cluster_work_rate in self.block_dir_att_work_rate_dict.items(): + if block.type == "pe": + continue + if block not in self.block_path_dir_phase_latency.keys(): + self.block_path_dir_phase_latency[block] = {} + + for pipe_cluster, work_rate in pipe_cluster_work_rate.items(): + if pipe_cluster not in self.block_path_dir_phase_latency[block].keys(): + self.block_path_dir_phase_latency[block][pipe_cluster] = {} + + dir_ = pipe_cluster.get_dir() + if dir_ == "same": + dirs = ["write", "read"] + else: + dirs = [dir_] + for dir in dirs: + if dir not in self.block_path_dir_phase_latency[block][pipe_cluster].keys(): + self.block_path_dir_phase_latency[block][pipe_cluster][dir] = {} + _, last_phase = pipe_cluster.get_pathlet_last_phase_work_rate() # last phase work_rate + #last_phase = self.phase_num + if last_phase not in self.block_path_dir_phase_latency[block][pipe_cluster][dir].keys(): + self.block_path_dir_phase_latency[block][pipe_cluster][dir][last_phase] = 0 + + work_unit = self.get_task().get_smallest_task_work_unit_by_dir(dir) + self.block_path_dir_phase_latency[block][pipe_cluster][dir][last_phase] = work_unit/ work_rate + + # not considering congestion in the system + # calculate the path latency for a kernel + def update_path_structural_latency(self, design): + if self.get_task().is_task_dummy(): + return + blocks = self.get_blocks() + pes = [blk for blk in blocks if blk.type == "pe"] + mems = [blk for blk in blocks if blk.type == "mem"] + buses = [blk for blk in blocks if blk.type == "mem"] + read_mems = [block for block in mems if block.get_task_dir_by_task_name(self.get_task())[0][1] == "write"] + write_mems = [block for block in mems if block.get_task_dir_by_task_name(self.get_task())[0][1] == "write"] + + # iterate through the paths (from pe to memory) and add up hop and fleet generation latency + for pe in pes: + for mem in read_mems: + dir = "read" + path = design.get_hardware_graph().get_path_between_two_vertecies(pe, mem) + hop_latency = 0 + # hop latency + for blk in path: + hop_latency += 1/blk.get_block_freq() + + # fleet generation latency + # at the moment, we only do it for a represantative bus. + # we might need to add extra latency if there is a mismatch between the bus width of the adjacent buses + work_unit = self.get_task().get_smallest_task_work_unit_by_dir("read") + fleet_cnt = (work_unit/ mem.get_block_bus_width()) + fleet_generation_latency = fleet_cnt/buses[0].get_block_freq() # any bus would work + latency = hop_latency + fleet_generation_latency + self.path_structural_latency[(pe, mem, dir)] = latency + + for pe in pes: + for mem in read_mems: + dir = "write" + path = design.get_hardware_graph().get_path_between_two_vertecies(pe, mem) + hop_latency = 0 + # hop latency + for blk in path: + hop_latency += 1/blk.get_block_freq() + + # fleet generation latency + # at the moment, we only do it for a represantative bus. + # we might need to add extra latency if there is a mismatch between the bus width of the adjacent buses + work_unit = self.get_task().get_smallest_task_work_unit_by_dir("write") + fleet_cnt = (work_unit/ mem.get_block_bus_width()) + fleet_generation_latency = fleet_cnt/buses[0].get_block_freq() # any bus would work + latency = hop_latency + fleet_generation_latency + self.path_structural_latency[(pe, mem, dir)] = latency + + + # update paths (inpipe-outpipe) work rate + # We are not using the following as it was not usefull (and with verification didn't give us good results) + def update_pipe_clusters_pathlet_latency(self, scheduled_kernels): + if "souurce" in self.get_task_name() or "siink" in self.get_task_name() or "dummy_last" in self.get_task_name(): + return + for block in self.get_blocks(): + for pipe_cluster in block.get_pipe_clusters(): + if pipe_cluster.cluster_type == "dummy": + return + for pathlet_ in pipe_cluster.get_pathlets(): + hop_latency = self.get_hop_latency(pipe_cluster) + #arbiteration_latency = self.get_arbiteration_latency(pathlet_, pipe_cluster, scheduled_kernels) + #fleet_generation_latency = self.get_fleet_generation_latency(pathlet_, pipe_cluster, scheduled_kernels) + traffic_latency = self.get_traffic_latency(pathlet_, pipe_cluster, scheduled_kernels) + #arbiteration_latency, fleet_generation_latency = 0 + #latency_in_cycles_dict = {"hop": hop_latency, "arbiteration": arbiteration_latency, "fleet_generation": fleet_generation_latency, "traffic_latency":traffic_latency} + latency_in_cycles_dict = {"hop": hop_latency, "traffic_latency":traffic_latency} + _, last_phase = pipe_cluster.get_pathlet_last_phase_work_rate() # last phase work_rate + pathlet_latency = pipe_cluster.set_pathlet_latency(pathlet_, last_phase, latency_in_cycles_dict) + + if pathlet_.get_in_pipe().is_task_present(self.get_task()): + self.pathlet_phase_latency_dict[pathlet_] = pathlet_latency + + def get_family_tasks_on_the_pipe_cluster(self, dir_): + if dir_ == "read": + family_tasks = self.get_task().get_parents() + elif dir_ == "write": + family_tasks = self.get_task().get_children() + else: + print("this mode is no suppported") + return family_tasks + + def get_masters_relevant_tasks_on_the_pipe_cluster(self, in_pipe, dir_): + if dir_ == "write": + relevant_tasks = [el.get_child() for el in in_pipe.get_traffic() if + el.get_parent().name == self.get_task_name()] + else: + relevant_tasks = [el.get_parent() for el in in_pipe.get_traffic() if + el.get_child().name == self.get_task_name()] + return relevant_tasks + + # update each path's (inpipe-outpipe) workrate + def update_pipe_cluster_pathlet_work_rate(self, pipe_cluster, bottleneck_work_rate): + def get_pathlet_work_rate_and_phase(path, work_ratio, bottleneck_work_rate): + if "souurce" in self.get_task_name(): + return 0, self.phase_num + 1 + else: + if self.phase_num == -1: + phase_num = self.phase_num + 2 + else: + phase_num = self.phase_num +1 + if "siink" in self.get_task_name(): + return 0, phase_num + else: + return work_ratio*bottleneck_work_rate, phase_num + + # don't need to deal with dummy clusters + if pipe_cluster.cluster_type == "dummy": + return + + # get cluster info + dir_ = pipe_cluster.get_dir() + family_tasks = self.get_family_tasks_on_the_pipe_cluster(dir_) + pipe_ref_block = pipe_cluster.get_block_ref() + + # iterate through incoming pipes, calculate work ratio and update workrate + for pathlet_ in pipe_cluster.get_pathlets(): + in_pipe = pathlet_.get_in_pipe() + out_pipe = pathlet_.get_in_pipe() + work_ratio =0 + #pipe_master = in_pipe.get_master() + relevant_tasks_on_pipe = self.get_masters_relevant_tasks_on_the_pipe_cluster(in_pipe, dir_) + # calculate work ratio + for family_task in family_tasks: + if family_task in relevant_tasks_on_pipe: + work_ratio += self.__task_to_blocks_map.get_workRatio_by_block_name_and_family_member_names_and_channel_eliminating_fake( + pipe_ref_block.instance_name, [(family_task.name, dir_)], dir_) + + #update work rate + work_rate, phase_number = get_pathlet_work_rate_and_phase(pathlet_, work_ratio, bottleneck_work_rate) + pipe_cluster.set_pathlet_phase_work_rate(pathlet_, phase_number, work_rate) + + def read_latency_per_request(self, mem, pe): + pass + + def write_latency_per_request(self, mem, pe): + pass + + # consolidate read/write channels in order to emulate DMA read/write serialization. + # Consolidation is essentially the process of ensuring that we can closely emulate the fact + # that DMA might serialize read/writes. + def consolidate_channels(self, block_normalized_work_rate): + block_normalized_work_rate_consolidated = defaultdict(dict) + """ + assert config.DMA_mode in ["serialized_read_write", "parallelized_read_write"] + if config.DMA_mode == "serialized_read_write": + for block, pipe_cluster_work_rate in block_normalized_work_rate.items(): + for pipe_cluster, work_rate in pipe_cluster_work_rate.items(): + block_normalized_work_rate_consolidated[block][pipe_cluster] = 1/sum([1/norm_wr for norm_wr in pipe_cluster_work_rate.values()]) + elif config.DMA_mode == "parallelized_read_write": + """ + block_normalized_work_rate_consolidated = block_normalized_work_rate + + return block_normalized_work_rate_consolidated + + + def get_latency_if_krnel_run_in_isolation(self): + if self.get_task().is_task_dummy(): + return 0 + if config.sw_model == "sequential": + print("can not do kerel calculation in isolation for sequential mode yet") + return .1 + block_att_work_rate_dict = self.update_block_att_work_rate_in_isolation() + time_to_completion_in_isolation = self.get_total_work()/block_att_work_rate_dict[self.get_ref_block()][self.get_ref_block().get_pipe_clusters()[0]] + return time_to_completion_in_isolation + + # only for one krnel + def update_block_att_work_rate_in_isolation(self): + scheduled_kernels = [self] + block_normalized_work_rate_unconsolidated = self.calc_all_block_normalized_work_rate(scheduled_kernels) + + # consolidate read/write channels since DMA serializes read/writes + block_normalized_work_rate = self.consolidate_channels(block_normalized_work_rate_unconsolidated) + + # identify the block bottleneck + cur_phase_dir_bottleneck, cur_phase_dir_bottleneck_work_rate = self.calc_block_s_bottleneck(block_normalized_work_rate) + cur_phase_bottleneck = list(cur_phase_dir_bottleneck.values())[0] + cur_phase_bottleneck_work_rate = list(cur_phase_dir_bottleneck_work_rate.values())[0] + for dir, work_rate in cur_phase_dir_bottleneck_work_rate.items(): + if work_rate < cur_phase_bottleneck_work_rate: + cur_phase_bottleneck = cur_phase_dir_bottleneck[dir] + cur_phase_bottleneck_work_rate = work_rate + + #.kernel_phase_bottleneck_blocks_dict[self.phase_num] = self.cur_phase_bottleneck + ref_block = self.get_ref_block() + # unnormalized the results (unnormalizing means that actually provide the work rate as opposed + # to normalizing it to the ref block (which is usally PE) work rate + block_att_work_rate_dict = self.calc_unnormalize_work_rate(block_normalized_work_rate, cur_phase_bottleneck_work_rate) + block_dir_att_work_rate_dict = self.calc_unnormalize_work_rate_by_dir(block_normalized_work_rate, cur_phase_dir_bottleneck_work_rate) + + return block_att_work_rate_dict + + + # calculate the attainable work rate of the block + def update_block_att_work_rate(self, scheduled_kernels): + + # + scheduled_kernels_tmp = [] + if config.sw_model == "sequential": + krnls_operating_state = self.operating_state + scheduled_kernels_tmp = [krnl_ for krnl_ in scheduled_kernels if krnl_.operating_state == krnls_operating_state] + scheduled_kernels = scheduled_kernels_tmp + + # get block work rate. In this step we calculate the normalized work rate. + # normalized work rate is actual work rate normalized to the work rate of + # the ref block (which is usally PE) work rate. Normalizing allows us + # to identify the bottleneck easily since every one has the same unit and reference + self.block_normalized_work_rate_unconsolidated = self.calc_all_block_normalized_work_rate(scheduled_kernels) + + # consolidate read/write channels since DMA serializes read/writes + self.block_normalized_work_rate = self.consolidate_channels(self.block_normalized_work_rate_unconsolidated) + + # identify the block bottleneck + cur_phase_dir_bottleneck, cur_phase_dir_bottleneck_work_rate = self.calc_block_s_bottleneck(self.block_normalized_work_rate) + self.cur_phase_bottleneck = list(cur_phase_dir_bottleneck.values())[0] + self.cur_phase_bottleneck_work_rate = list(cur_phase_dir_bottleneck_work_rate.values())[0] + for dir, work_rate in cur_phase_dir_bottleneck_work_rate.items(): + if work_rate < self.cur_phase_bottleneck_work_rate: + self.cur_phase_bottleneck = cur_phase_dir_bottleneck[dir] + self.cur_phase_bottleneck_work_rate = work_rate + + self.kernel_phase_bottleneck_blocks_dict[self.phase_num] = self.cur_phase_bottleneck + ref_block = self.get_ref_block() + # unnormalized the results (unnormalizing means that actually provide the work rate as opposed + # to normalizing it to the ref block (which is usally PE) work rate + self.block_att_work_rate_dict = self.calc_unnormalize_work_rate(self.block_normalized_work_rate, self.cur_phase_bottleneck_work_rate) + self.block_dir_att_work_rate_dict = self.calc_unnormalize_work_rate_by_dir(self.block_normalized_work_rate, cur_phase_dir_bottleneck_work_rate) + + + def get_completion_time(self): + return self.completion_time + + def time_to_meet_throughput_for_a_pipe(self, operating_state, mem, mem_pipe_cluster): + return (self.data_work_left[self.operating_state] - self.firing_work_to_meet_throughput[self.operating_state][0]) / self.block_att_work_rate_dict[mem][ + mem_pipe_cluster] + """ + if len(self.data_work_left_to_meet_throughput[self.operating_state]) == 1: + return self.data_work_left_to_meet_throughput[self.operating_state][0] / self.block_att_work_rate_dict[mem][ + mem_pipe_cluster] + else: + return (self.data_work_left_to_meet_throughput[self.operating_state][0] - self.data_work_left_to_meet_throughput[self.operating_state][1])/ self.block_att_work_rate_dict[mem][ + mem_pipe_cluster] + """ + # calculate the completion time for the kernel + def calc_kernel_completion_time(self): + if config.sw_model == "sequential": + if self.operating_state == "execute": + return self.pe_s_work_left / self.block_att_work_rate_dict[self.get_ref_block()][self.get_ref_block().get_pipe_clusters()[0]] + else: + if self.get_task().is_task_dummy(): + return 0 + mem = self.get_a_mem_by_dir(self.operating_state) + mem_pipe_cluster = None + for pipe_cluster in mem.get_pipe_clusters(): + if pipe_cluster.get_dir() == self.operating_state: + mem_pipe_cluster = pipe_cluster + break + if pipe_cluster == None: + print("something went wrong since there is not pipe cluster associated with this memory") + exit(0) + if self.get_type() == "throughput_based": + return self.time_to_meet_throughput_for_a_pipe(self.operating_state, mem, mem_pipe_cluster) + else: + return self.data_work_left[self.operating_state]/self.block_att_work_rate_dict[mem][mem_pipe_cluster] + else: + return self.pe_s_work_left/self.block_att_work_rate_dict[self.get_ref_block()][self.get_ref_block().get_pipe_clusters()[0]] + + # launch the kernel + # Variables: + # cur_time: current time (s) + def launch(self, cur_time): + self.pe_s_work_left = self.kernel_total_work["execute"] + self.data_work_left["read"] = self.kernel_total_work["read"] + self.data_work_left["write"] = self.kernel_total_work["write"] + self.operating_state = "read" + + if self.get_type() == "throughput_based": + self.populate_throughput_data(self.operating_state, cur_time) + + # keeping track of how much work left for every block + for block_dir_work_ratio in self.__task_to_blocks_map.block_dir_workRatio_dict.keys(): + if block_dir_work_ratio not in self.block_dir_work_left.keys(): + self.block_dir_work_left[block_dir_work_ratio] = {} + for task, ratio in self.__task_to_blocks_map.block_dir_workRatio_dict[block_dir_work_ratio].items(): + self.block_dir_work_left[block_dir_work_ratio][task] = (self.pe_s_work_left*ratio) + + self.status = "in_progress" + if self.iteration_ctr == self.max_iteration_ctr: + self.starting_time = cur_time + self.update_krnl_iteration_ctr() + + # has kernel completed already + def has_completed(self): + return self.status == "completed" + + """ + # has kernel started already + def has_started(self): + return self.status == "in_progress" + """ + + def get_type(self): + return self.type + + def get_throughput_info(self): + return self.throughput_info + + def throughput_time_trigger_achieved(self, clock): + # for execute, throughput doesn't make sense + if self.operating_state == "execute": + return False + + error = .0002 + for el in self.firing_time_to_meet_throughput[self.operating_state]: + if math.fabs(clock - el)/el < error: + return True + return False + + def throughput_work_achieved(self): + error = 2 + # for execute, throughput doesn't make sense + if self.operating_state == "execute": + return False + """ + for el in self.firing_work_to_meet_throughput[self.operating_state]: + if math.fabs(self.data_work_left_to_meet_throughput[self.operating_state][0] - el) < error: + return True + return False + """ + if math.fabs(self.firing_work_to_meet_throughput[self.operating_state][0] - self.data_work_left[self.operating_state]) < error: + return True + return False + """ + for el in self.firing_work_to_meet_throughput[self.operating_state]: + if math.fabs(self.data_work_left[self.operating_state] - el) < error: + return True + return False + """ + + def populate_throughput_data(self, operating_state, clock_time): + self.firing_time_to_meet_throughput[operating_state] = [] + self.firing_work_to_meet_throughput[operating_state] = [] + self.data_work_left_to_meet_throughput[operating_state] = [] + if operating_state == "execute": + return + + + # delete this later + throughput__ = 4*min([el.clock_freq for el in self.get_blocks()] ) + self.throughput_info["read"] = self.throughput_info["write"] = throughput__ + print("we are over writing throughput. Delete later") + + + total_work = self.data_work_left[operating_state] + assert(self.throughput_info["clock_period"]*(10**-9)> 0) + unit_time_to_meet_throughput = self.throughput_info["clock_period"]*(10**-9) + unit_work_to_meet_throughput = self.throughput_info[operating_state]*unit_time_to_meet_throughput + + + if self.get_task().is_task_dummy(): + self.firing_work_to_meet_throughput[operating_state].append(0) + self.data_work_left_to_meet_throughput[operating_state].append(0) + self.firing_time_to_meet_throughput[operating_state].append(0) + return + + while (True): + total_work -= unit_work_to_meet_throughput + #total_work = max(0, total_work) + if total_work <= 0: + break + clock_time += unit_time_to_meet_throughput + self.firing_work_to_meet_throughput[operating_state].append(total_work) + self.data_work_left_to_meet_throughput[operating_state].append(total_work) + self.firing_time_to_meet_throughput[operating_state].append(clock_time) + + self.firing_work_to_meet_throughput[operating_state].append(0) # adding zero to make sure we finish the task + print("number of phases added"+ str(len(self.firing_work_to_meet_throughput["read"]))) + + def get_operating_state(self): + return self.operating_state + + # update the status of the kernel to specify whether it's done or not + def update_status(self, time_step_size, clock=0): + if config.sw_model == "sequential": + if self.operating_state == "read" and self.data_work_left["read"] <.001: + self.operating_state = "execute" + self.progress = .5 + elif self.operating_state == "execute" and self.pe_s_work_left <.001: + self.operating_state = "write" + self.progress= .5 + if self.type == "throughput_based": + self.populate_throughput_data(self.operating_state, clock) + elif self.operating_state == "write" and self.data_work_left["write"] < .001: + self.progress = 1 + else: + self.progress = .3 + else: + if self.kernel_total_work["execute"] == 0: self.progress = 1 # for dummy tasks (with the suffix of souurce and siink) + else: self.progress = 1 - float(self.pe_s_work_left/self.kernel_total_work['execute']) + + self.stats.latency += time_step_size + + if self.progress >= .99: + self.status = "completed" + #self.completion_time = self.stats.latency + self.starting_time + self.completion_time = clock + elif self.progress == 0: + self.status = "not_scheduled" + else: + self.status = "in_progress" + + self.update_stats(time_step_size) + + + + + def get_a_mem_by_dir(self, dir_): + for block in self.get_blocks(): + for pipe_cluster in block.get_pipe_clusters_of_task(self.get_task()): + dir = pipe_cluster.get_dir() + if dir == dir_ and block.type == "mem": + return block + + # given the time (time_step_size) of the tick, calculate how much work has the + # kernel accomplished. Note that work concept varies depending on the hardware block, i.e., + # work = bytes for memory/uses and its instructions for PEs + def calc_work_consumed(self, time_step_size): + + # iterate through each blocks attainable work rate and calculate + # initialize + for block in self.get_blocks(): + self.block_phase_work_dict[block][self.phase_num] = 0 + + # how much work it can do for the kernel of interest + for block, pipe_clusters_work_rate in self.block_att_work_rate_dict.items(): + read_work = write_work = 0 + for pipe_cluster, work_rate in pipe_clusters_work_rate.items(): + # update work + if self.phase_num in self.block_phase_work_dict[block].keys(): + self.block_phase_work_dict[block][self.phase_num] += work_rate* time_step_size + else: + self.block_phase_work_dict[block][self.phase_num] = work_rate* time_step_size + + if pipe_cluster.dir == "read": + read_work += work_rate * time_step_size + if pipe_cluster.dir == "write": + write_work += work_rate * time_step_size + + if self.get_type() == "throughput_based": + unit_time_to_meet_throughput = self.throughput_info["clock_period"] * (10 ** -9) + read_unit_of_work = self.throughput_info["read"] * unit_time_to_meet_throughput + write_unit_of_work = self.throughput_info["write"] * unit_time_to_meet_throughput + read_work = min(read_work, read_unit_of_work) + write_work = min(write_work, write_unit_of_work) + + # update read specifically + if self.phase_num in self.block_phase_read_dict[block].keys(): + self.block_phase_read_dict[block][self.phase_num] += read_work + else: + self.block_phase_read_dict[block][self.phase_num] = read_work + + if self.phase_num in self.block_phase_write_dict[block].keys(): + self.block_phase_write_dict[block][self.phase_num] += write_work + else: + self.block_phase_write_dict[block][self.phase_num] = write_work + + + # Calculates the leakage power of the phase for PE and IC + # memory leakage power should be accumulated for the whole execution time + # since we cannot turn off the memory but the rest can be in cut-off (C7) mode + def calc_leakage_energy_consumed(self, time_step_size): + for block, work in self.block_phase_work_dict.items(): + # taking care of dummy corner case + if "souurce" in self.get_task_name() or "siink" in self.get_task_name(): + self.block_phase_leakage_energy_dict[block][self.phase_num] = 0 + else: + if block.get_block_type_name() == "mem": + self.block_phase_leakage_energy_dict[block][self.phase_num] = 0 + else: + # changed to get by Hadi + self.block_phase_leakage_energy_dict[block][self.phase_num] = \ + block.get_leakage_power(self.get_power_knob_id()) * time_step_size + + # calculate energy consumed + def calc_energy_consumed(self): + for block, work in self.block_phase_work_dict.items(): + # Dynamic energy consumption + if "souurce" in self.get_task_name() or "siink" in self.get_task_name(): # taking care of dummy corner case + self.block_phase_energy_dict[block][self.phase_num] = 0 + else: + # changed to get by Hadi + this_phase_energy = self.block_phase_work_dict[block][self.phase_num] / block.get_work_over_energy(self.get_power_knob_id()) + if this_phase_energy < 0: + print("energy can't be a negative value") + block.get_work_over_energy(self.get_power_knob_id()) + exit(0) + + self.block_phase_energy_dict[block][self.phase_num] = this_phase_energy + + pass + + # for read, we release memory (the entire input worth of data) once the kernel is done + # for write, we assign memory (the entire output worth of data) once the kernel starts + # if a sibling of a task depends on the same data (that resides in the memory), we can't let + # go of that till the sibling is done + # coeff is gonna determine whether to retract or expand the memory + # Todo: include the case where there are multiple siblings + def update_mem_size(self, coef): + if "souurce" in self.get_task_name() and coef == -1: return + elif "siink" in self.get_task_name() and coef == 1: return + + dir_ = "write" + mems = self.get_kernel_s_mems(dir=dir_) + if "souurce" in self.get_task_name(): + if dir_ == "write": + #memory_total_work = config.souurce_memory_work + for mem in mems: + # mem_work_ratio = self.__task_to_blocks_map.get_workRatio_by_block_name_and_dir(mem.instance_name, dir_) + tasks_name = self.__task_to_blocks_map.get_tasks_of_block_dir( + mem.instance_name, dir_) + # Sends the memory size that needs to be occupied (positive for write and negative for read) + # then updates the memory mapping in the memory block to know what is the in use capacity + # changed to get by Hadi + for tsk in tasks_name: + memory_total_work = config.souurce_memory_work[tsk] + mem.update_area(coef*memory_total_work/mem.get_work_over_area(self.get_power_knob_id()), self.get_task_name()) + mem.update_area_in_bytes(coef*memory_total_work, self.get_task_name()) + # mem.update_area(coef*memory_total_work/mem.get_work_over_area(self.get_power_knob_id()), self.get_task_name()) + else: memory_total_work = 0 + elif "siink" in self.get_task_name(): + memory_total_work = 0 + else: + pe_s_total_work = self.kernel_total_work["execute"] + for mem in mems: + #mem_work_ratio = self.__task_to_blocks_map.get_workRatio_by_block_name_and_dir(mem.instance_name, dir_) + mem_work_ratio = self.__task_to_blocks_map.get_workRatio_by_block_name_and_dir_eliminating_fake(mem.instance_name, dir_) + memory_total_work = pe_s_total_work * mem_work_ratio + # changed to get by Hadi + mem.update_area_in_bytes(coef*memory_total_work, self.get_task_name()) + mem.update_area(coef*memory_total_work/mem.get_work_over_area(self.get_power_knob_id()), self.get_task_name()) + + mem_work_ratio = self.__task_to_blocks_map.get_workRatio_by_block_name_and_dir(mem.instance_name, dir_) + + # update pe allocation -> allocate a part of pe quantum for current task + def update_pe_size(self): + pe = self.get_ref_block() + # the convention is ot provide 1/fixe_area for these blocks + # DSPs and Processors are among statically sized blocks while accelerators are not among the list + if pe.subtype in config.statically_sized_blocks: work = 1 + else: work = self.kernel_total_work["execute"] + pe.update_area(work/pe.get_work_over_area(self.get_power_knob_id()), self.get_task_name()) + + # TODO: need to include ic as well + def update_ic_size(self): + return 0 + + # calculate how much of the work is left for the kernel. + # Note that work concept varies depending on the hardware block, i.e., + # work = bytes for memory/uses and its instructions for PEs + def calc_work_left(self): + if not config.sw_model == "sequential": + self.pe_s_work_left -= self.block_phase_work_dict[self.get_ref_block()][self.phase_num] + else: + # use the following for phase based simulation + if config.sw_model == "sequential": + if self.operating_state == "execute": + self.pe_s_work_left -= self.block_phase_work_dict[self.get_ref_block()][self.phase_num] + else: + if self.get_task().is_task_dummy(): + self.data_work_left[self.operating_state] = 0 + else: + mem = self.get_a_mem_by_dir(self.operating_state) + if self.operating_state == "read": + self.data_work_left[self.operating_state] -= self.block_phase_read_dict[mem][self.phase_num] + #if self.get_type() == "throughput_based": + # self.data_work_left_to_meet_throughput[self.operating_state][0] -= self.block_phase_read_dict[mem][self.phase_num] + elif self.operating_state == "write": + self.data_work_left[self.operating_state] -= self.block_phase_write_dict[mem][self.phase_num] + #if self.get_type() == "throughput_based": + # self.data_work_left_to_meet_throughput[self.operating_state][0] -= \ + # self.block_phase_write_dict[mem][self.phase_num] + for block_dir_work_ratio in self.__task_to_blocks_map.block_dir_workRatio_dict.keys(): + if block_dir_work_ratio not in self.block_dir_work_left.keys(): + self.block_dir_work_left[block_dir_work_ratio] = {} + for task, ratio in self.__task_to_blocks_map.block_dir_workRatio_dict[block_dir_work_ratio].items(): + self.block_dir_work_left[block_dir_work_ratio][task] = (self.pe_s_work_left * ratio) + else: + self.pe_s_work_left -= self.block_phase_work_dict[self.get_ref_block()][self.phase_num] + for block_dir_work_ratio in self.__task_to_blocks_map.block_dir_workRatio_dict.keys(): + if block_dir_work_ratio not in self.block_dir_work_left.keys(): + self.block_dir_work_left[block_dir_work_ratio] = {} + for task, ratio in self.__task_to_blocks_map.block_dir_workRatio_dict[block_dir_work_ratio].items(): + self.block_dir_work_left[block_dir_work_ratio][task] = (self.pe_s_work_left*ratio) + + # accumulate how much area has been used for a phase of execution + def aggregate_area_of_phase(self): + total_area_consumed = 0 + for block, phase_area_dict in self.block_phase_area_dict.items(): + total_area_consumed += sum(list(phase_area_dict.values())) + return total_area_consumed + + + # aggregate the energy consumed for all the blocks for a specific phase + def aggregate_area_of_for_every_phase(self): + aggregate_phase_area = {} + for block, phase_area_dict in self.block_phase_area_dict.items(): + for phase, area in phase_area_dict.items(): + this_phase_area = phase_area_dict[phase] + if phase not in aggregate_phase_area.keys(): + aggregate_phase_area [phase] = 0 + aggregate_phase_area[phase] += this_phase_area + return aggregate_phase_area + + + # aggregate the energy consumed for all the blocks for a specific phase + def aggregate_energy_of_for_every_phase(self): + aggregate_phase_energy = {} + for block, phase_energy_dict in self.block_phase_energy_dict.items(): + for phase, energy in phase_energy_dict.items(): + this_phase_energy = phase_energy_dict[phase] + if phase not in aggregate_phase_energy.keys(): + aggregate_phase_energy[phase] = 0 + aggregate_phase_energy[phase] += this_phase_energy + return aggregate_phase_energy + + + # aggregate the energy consumed for all the blocks for a specific phase + def aggregate_energy_of_phase(self): + total_energy_consumed = 0 + for block, phase_energy_dict in self.block_phase_energy_dict.items(): + this_phase_energy = phase_energy_dict[self.phase_num] + total_energy_consumed += this_phase_energy + return total_energy_consumed + + def aggregate_leakage_energy_of_phase(self): + total_leakage_energy_consumed = 0 + for block, phase_leakage_energy_dict in self.block_phase_leakage_energy_dict.items(): + this_phase_leakage_energy = phase_leakage_energy_dict[self.phase_num] + total_leakage_energy_consumed += this_phase_leakage_energy + return total_leakage_energy_consumed + + # Checks if there was memory bounded phases in the kernel execution + def is_kernel_memory_bounded(self): + blocks = self.kernel_phase_bottleneck_blocks_dict.values() + for block in blocks: + if block.get_block_type_name() == "mem": + return True + return False + + # Checks if there was any compute intensive phases in the kernel execution + def is_kernel_processing_bounded(self): + blocks = self.kernel_phase_bottleneck_blocks_dict.values() + for block in blocks: + if block.get_block_type_name() == "pe": + return True + return False + + # Checks if there was any IC intensive phases in the kernel execution + def is_kernel_interconnects_bounded(self): + blocks = self.kernel_phase_bottleneck_blocks_dict.values() + for block in blocks: + if block.get_block_type_name() == "ic": + return True + return False + + # update the progress of kernel (i.e., how much work is left) + def update_progress(self, time_step_size): + # calculate the metric consumed for each phase + self.calc_work_consumed(time_step_size) + self.calc_work_left() + self.calc_energy_consumed() + #self.calc_leakage_energy_consumed(time_step_size) + + # Calculates the leakage power of the phase for PE and IC + # memory leakage power should be accumulated for the whole execution time + # since we cannot turn off the memory but the rest can be in cut-off (C7) mode + if config.simulation_method == "power_knobs": + self.calc_leakage_energy_consumed(time_step_size) + + # update the stats for a kernel (e.g., energy, start time, ...) + def update_stats(self, time_step_size): + self.stats.phase_block_duration_bottleneck[self.phase_num] = (self.cur_phase_bottleneck, time_step_size) + self.stats.phase_energy_dict[self.phase_num] = self.aggregate_energy_of_phase() + self.stats.phase_latency_dict[self.phase_num] = time_step_size + self.stats.block_phase_energy_dict = self.block_phase_energy_dict + if config.simulation_method == "power_knobs": + # aggregate the energy consumed among all the blocks corresponding to the task in current phase + self.stats.phase_leakage_energy_dict[self.phase_num] = \ + self.aggregate_leakage_energy_of_phase() + + # Update the starting and completion time of the kernel -> used for power knob simulator + self.stats.starting_time = self.starting_time + self.stats.completion_time = self.completion_time + + # step the kernel progress forward + # Variables: + # phase_num: phase number + def step(self, time_step_size, phase_num): + self.phase_num = phase_num + # update the amount of work remaining per block + self.update_progress(time_step_size) + + + def get_schedule(self): + return self.__schedule diff --git a/Project_FARSI/design_utils/components/mapping.py b/Project_FARSI/design_utils/components/mapping.py new file mode 100644 index 00000000..ce0139a2 --- /dev/null +++ b/Project_FARSI/design_utils/components/mapping.py @@ -0,0 +1,388 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +import json +import os +from settings import config +from typing import Dict +import warnings +from design_utils.components.hardware import * +from design_utils.components.workload import * +from typing import List + + +# This task maps a task to a processing block +class TaskToBlocksMap: #which block(s) the task is mapped to. a task is usually mapped to a PE, interconnect and memory + def __init__(self, task:Task, blocks_dir_work_ratio:Dict[Tuple[Block, str], float]={}): + self.task = task # list of tasks + self.block_dir_workRatio_dict: Dict[Tuple[Block, str], float] = blocks_dir_work_ratio # block work ratio with its direction (read/write), + # direction specifies whether a bock is used for reading or writing for this task + # work_ratio specifies the task's computation load for the block. + # find a task (object) by its name + def find_task_by_name(self, task_name): + for task in self.tasks: + if task.name == task_name: + return task + + # return the task's work. Work can be specified in terms of bytes or instructions + # We use instructions as the reference (instead of bytes) + def get_ref_task_work(self): + return self.task.get_task_work_distribution() + + # return tuples containing tasks and their direction for all tasks that the block hosts. + # note that direction can be write/read (for tasks on memory/bus) and loop (for PE). + def get_block_family_members_allocated(self, block_name): + result = [] + for block_dir, work_ratio_ in self.block_dir_workRatio_dict.items(): + if block_name == block_dir[0].instance_name: + for task_name in work_ratio_.keys(): + result.append((task_name, block_dir[1])) + + return result + + # ------------------------------ + # Functionality + # get work ratio associated of the task and its direction (read/write) + # Variables: + # block_name: block name of interest + # dir_: direction of interest, i.e., read or write + # ------------------------------ + def get_workRatio_by_block_name_and_dir(self, block_name, dir_): + blocks = [block_dir[0] for block_dir in self.block_dir_workRatio_dict.keys()] + work_ratio = 0 + for block_dir, work_ratio_ in self.block_dir_workRatio_dict.items(): + if block_name == block_dir[0].instance_name and block_dir[1] == dir_: + for family_member, work_ratio_value in work_ratio_.items(): + if not (family_member in self.task.get_fake_children_name()): + work_ratio += work_ratio_value + return work_ratio + + # ------------------------------ + # Functionality + # get work ratio associated of the task and its channel name. + # channel can be read/write (for buses) or "same" as in only one channel + # for memory and ips. + # fake means that there are tasks that the current + # task doesn't need to recalculate the result for, rather the result + # has been already calculated for another task, and only need to be copied + # Variables: + # block_name: block name of interest + # ------------------------------ + def get_workRatio_by_block_name_and_channel_eliminating_fake(self, block_name, channel_name): + if channel_name == "same": + return self.get_workRatio_by_block_name(block_name) + else: + return self.get_workRatio_by_block_name_and_dir_eliminating_fake(block_name, channel_name) + + + def get_workRatio_by_block_name_and_family_member_names_and_channel_eliminating_fake(self, block_name, family_members, channel_name): + if channel_name == "same": + return self.get_workRatio_by_block_name_and_family_member_names(block_name, [el[0] for el in family_members]) + else: + return self.get_workRatio_by_block_name_and_family_member_names_and_dir_eliminating_fake(block_name, + [el[0] for el in family_members], channel_name) + + # given a block, how much work (measured in terms of instructions or bytes depending on the block type) + # does it need to deliver + # channel_name: read/write for memory and buses and "same" for PE + # block_name: name of the block of interest. + # Note that if we have already decided that read/write channels are the same for buses + # and memory, then same will be used + def get_task_total_work_for_block(self, block_name, channel_name, task_name): + work_ratio = None + if channel_name == "same": + work_ratio = self.get_workRatio_by_block_name_and_family_member_names(block_name, [task_name]) + else: + work_ratio = self.get_workRatio_by_block_name_and_family_member_names_and_dir_eliminating_fake(block_name, + [task_name], channel_name) + + ref_task_work = self.get_ref_task_work()[0][0] + return work_ratio*ref_task_work + + # get all the tasks that a block hosts + def get_tasks_of_block(self, block_name): + tasks = [] + for block_dir, work_ratio_ in self.block_dir_workRatio_dict.items(): + if block_name == block_dir[0].instance_name: + for family_member in work_ratio_.keys(): + if not (family_member in self.task.get_fake_children_name()): + tasks.append(family_member) + + return tasks + + # get all the tasks that the block hosts with a certain direction + # dir: read/write for memory and buses and loop for PE + def get_tasks_of_block_dir(self, block_name, dir): + tasks = [] + for block_dir, work_ratio_ in self.block_dir_workRatio_dict.items(): + if block_name == block_dir[0].instance_name and dir == block_dir[1]: + for family_member in work_ratio_.keys(): + if not (family_member in self.task.get_fake_children_name()): + tasks.append(family_member) + + return tasks + + def get_tasks_of_block_with_src_dest(self, block_name): + print("this is not supported at the moment") + exit(0) + pass + tasks = [] + for block_dir, work_ratio_ in self.block_dir_workRatio_dict.items(): + if block_name == block_dir[0].instance_name: + for family_member in work_ratio_.keys(): + if not (family_member in self.task.get_fake_children_name()): + tasks.append(family_member) + + return tasks + + # get tasks of a block with a certain channel + # channel can be read/write (for buses) or "same" as in only one channel + def get_tasks_of_block_channel(self, block_name, channel_name): + if channel_name == "same": + return self.get_tasks_of_block_with_src_dest(block_name) + else: + return self.get_tasks_of_block_dir(block_name, channel_name) + + # ------------------------------ + # Functionality + # get work ratio associated of the task and its dir_ (read/write) or loop (for PEs. + # fake means that there are tasks that the current + # task doesn't need to recalculate the resul for, rather the result + # has been already calculated for another task, and only need to be copied + # Variables: + # block_name: block name of interest + # dir_: read/write loop + # ------------------------------ + def get_workRatio_by_block_name_and_dir_eliminating_fake(self, block_name, dir_): + blocks = [block_dir[0] for block_dir in self.block_dir_workRatio_dict.keys()] + work_ratio = 0 + for block_dir, work_ratio_ in self.block_dir_workRatio_dict.items(): + if block_name == block_dir[0].instance_name and block_dir[1] == dir_: + for family_member, work_ratio_value in work_ratio_.items(): + if not (family_member in self.task.get_fake_family_name()): + work_ratio += work_ratio_value + return work_ratio + + # ------------------------------ + # Functionality + # get work ratio associated of the task and its direction (read/write) using the block name + # Variables: + # block_name: block name of interest + # ------------------------------ + def get_workRatio_by_block_name(self, block_name): + work_ratio = 0 + for block_dir, work_ratio_ in self.block_dir_workRatio_dict.items(): + if block_name == block_dir[0].instance_name: + for family_member, work_ratio_value in work_ratio_.items(): + if not (family_member in self.task.get_fake_children_name()): + work_ratio += work_ratio_value + if work_ratio == 0: + if config.WARN: + warnings.warn("workratio for block" + str(block_name) + "is zero") + return work_ratio + + # only return work ratio for certain family members + def get_workRatio_by_block_name_and_family_member_names(self, block_name, family_member_names): + work_ratio = 0 + for block_dir, work_ratio_ in self.block_dir_workRatio_dict.items(): + if block_name == block_dir[0].instance_name: + for family_member, work_ratio_value in work_ratio_.items(): + if not (family_member in self.task.get_fake_children_name()) and (family_member in family_member_names): + work_ratio += work_ratio_value + if work_ratio == 0: + if config.WARN: + warnings.warn("workratio for block" + str(block_name) + "is zero") + return work_ratio + + def get_workRatio_by_block_name_and_family_member_names_and_dir_eliminating_fake(self, block_name, family_member_names, dir_): + blocks = [block_dir[0] for block_dir in self.block_dir_workRatio_dict.keys()] + work_ratio = 0 + for block_dir, work_ratio_ in self.block_dir_workRatio_dict.items(): + if block_name == block_dir[0].instance_name and block_dir[1] == dir_: + for family_member, work_ratio_value in work_ratio_.items(): + if not (family_member in self.task.get_fake_family_name()) and (family_member in family_member_names): + work_ratio += work_ratio_value + return work_ratio + + # get direction for a block + def get_dir_of_block(self, block): + for block_dir, work_ratio_ in self.block_dir_workRatio_dict.items(): + if block.instance_name == block_dir[0].instance_name: + return block_dir[1] + + # get all the blocks that a task is mapped to . + def get_blocks(self): + return list(set([block_dir[0] for block_dir in (self.block_dir_workRatio_dict.keys())])) + + def channel_blocks(self): + self.blocks_with_channels = [] + + def get_blocks_with_channel(self): + return self.blocks_with_channels + + + def get_blocks_with_channel(self): + exit(0) + pass + results = [] + for block_dir in (self.block_dir_workRatio_dict.keys()): + block = block_dir[0] + dir_ = block_dir[1] + if block_dir[0].type == "ic": + if (block, dir_) not in results: + results.append((block, dir_)) + elif block_dir[0].type == "mem": + if (block, dir_) not in results: + results.append((block, dir_)) + elif block_dir[0].type == "pe": + if (block, dir_) not in results: + results.append((block, dir_)) + #if (block, "same") not in results: + # results.append((block, "same")) + + return results + + # ------------------------------ + # Functionality + # get the PE (processing element, i.e., ip or general purpose processor or dsp) associated with the task + # ------------------------------ + def get_pe_name(self): + for block_dir in self.block_dir_workRatio_dict.keys(): + if self.block_dir_workRatio_dict[block_dir] == 1: + return block_dir[0].block_instance_name + + # ------------------------------ + # Functionality + # print the task name and its corresponding blocks. Used for debugging purposes. + # ------------------------------ + def print(self): + print(self.task.name + str(list(map(lambda block: block.name, self.block_workRatio_dict.keys())))) + + +# This task maps a all the tasks within the workload to a processing blocks +class WorkloadToHardwareMap: + def __init__(self, input_file="scheduling.json", mode=""): + self.input_file = input_file # input file to read the mapping from. TODO: this is not supported yet. + + # vars keeping status + self.tasks_to_blocks_map_list:List[TaskToBlocksMap] = [] # list of all the tasks and the blocks they have mapped to. + + if mode == "from_json": # TODO: this is not supported yet. + self.populate_tasks_to_blocks_map_list_from_json() + + # ------------------------------ + # Functionality + # map the tasks to the blocks using information from an input json file. TODO: not supported yet. + # ------------------------------ + def populate_tasks_to_blocks_map_list_from_json(self): + raise Exception("deprecated function") + with open(os.path.join(config.design_input_folder, self.input_file), 'r') as json_file: + data = json.load(json_file) + + for task_data in data["workload_hardware_mapping"]: + task_to_blocks_map = TaskToBlocksMap(task_data["task_name"], task_data["task_fraction"], \ + task_data["task_fraction_index"]) #single task to blocks map + for block in task_data["blocks"]: + task_to_blocks_map.block_workRatio_dict[block] = block["work_ratio"] + self.tasks_to_blocks_map_list.append(task_to_blocks_map) + + # ------------------------------ + # Functionality + # get a task within the workload. + # Variables: + # task: task of interest. + # ------------------------------ + def get_by_task(self, task: Task): + for el in self.tasks_to_blocks_map_list: + if el.task == task: + return el + return None + + # ------------------------------ + # Functionality + # get a task within the workload by its name. + # Variables: + # task: task of interest. + # ------------------------------ + def get_by_task_name(self, task_name: str): # we go with the task name for task name + for el in self.tasks_to_blocks_map_list: + if el.task.name == task_name: + return el + raise Exception("can have multiple or none task_to_blocks_maps for a single task") + + # ------------------------------ + # Functionality + # find out whether a task is already mapped. + # Variables: + # task: task of interest. + # ------------------------------ + def find_task(self, task:Task): + for task_mapped in self.tasks_to_blocks_map_list: + if task == task_mapped.task: + return task + + raise Exception("couldn't find a task with the name " + str(task.name)) + + # ------------------------------ + # Functionality + # get blocks associated with a task. + # Variables: + # task: task of interest. + # ------------------------------ + def get_blocks_associated_with_task(self, task:Task): + return task.get_blocks() + + # ------------------------------ + # Functionality + # get all the tasks within the workload. + # ------------------------------ + def get_tasks(self): + tasks = [] + for el in self.tasks_to_blocks_map_list: + tasks.append(el.task) + return tasks + + # ------------------------------ + # Functionality + # get all the blocks used in mapping given their type + # Variables: + # type: "type of the block (pe, mem, ic)" + # ------------------------------ + def get_blocks_by_type(self, type): + blocks = self.get_blocks() + return list(filter(lambda x: x.type == type, blocks)) + + # ------------------------------ + # Functionality + # get all the blocks used in mapping. + # ------------------------------ + def get_blocks(self): + blocks = [] + for task_to_blocks_map in self.tasks_to_blocks_map_list: + task_to_blocks_map_blocks = task_to_blocks_map.get_blocks() + for block in task_to_blocks_map_blocks: + if block not in blocks: + blocks.append(block) + return blocks + + # ------------------------------ + # Functionality + # get all tasks that are mapped to a block + # Variables: + # block: block of interest. + # ------------------------------ + def get_tasks_associated_with_block(self, block:Block): + tasks = [] + for task_mapped in self.tasks_to_blocks_map_list: + if block in task_mapped.get_blocks(): + tasks.append(task_mapped.task) + return tasks + + # ------------------------------ + # Functionality + # get all the information in the mapping class. Used mainly for debugging purposes. + # ------------------------------ + def print(self): + for el in self.tasks_to_blocks_map_list: + el.print() \ No newline at end of file diff --git a/Project_FARSI/design_utils/components/scheduling.py b/Project_FARSI/design_utils/components/scheduling.py new file mode 100644 index 00000000..024b0495 --- /dev/null +++ b/Project_FARSI/design_utils/components/scheduling.py @@ -0,0 +1,81 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +import json +import os +from settings import config +from design_utils.components.workload import * + + +# This class sets schedules for task. +class TaskToPEBlockSchedule: #information about one task's schedule + def __init__(self, task, starting_time): + self.task = task + self.starting_time = starting_time + + # ------------------------------ + # Functionality + # set the starting time of the task + # Variables: + # starting_time: starting time of the task + # ------------------------------ + def reschedule(self, starting_time): + self.starting_time = starting_time + + +# this class sets schedules for the whole workload +class WorkloadToPEBlockSchedule: + def __init__(self, input_file="scheduling.json", mode=""): + self.input_file = input_file # input file to read the schedules from + self.task_to_pe_block_schedule_list_sorted = [] # list of tasks sorted based on their starting time + if mode == "from_json": # TODO: this is not supported yet + self.populate_task_to_pe_block_schedule_list() + + # ------------------------------ + # Functionality + # set the schedule for the whole workload. + # ------------------------------ + def populate_task_to_pe_block_schedule_list(self): + raise Exception("not supporting this anymore") + with open(os.path.join(config.design_input_folder, self.input_file), 'r') as json_file: + data = json.load(json_file) + + task_to_pe_block_schedule_list_unsorted = [] # containing the unsorted scheduling units + + # ------------------------------ + # Functionality + # schedule out a task and schedule in another task on to a PE (processing element) + # Variables: + # old_task_to_pe_block_schedule: task to schedule out + # new_task_to_pe_block_schedule: task to schedule in + # ------------------------------ + def swap_task_to_pe_block_schedule(self, old_task_to_pe_block_schedule, new_task_to_pe_block_schedule): + self.task_to_pe_block_schedule_list_sorted.remove(old_task_to_pe_block_schedule) + self.task_to_pe_block_schedule_list_sorted.append(new_task_to_pe_block_schedule) + self.task_to_pe_block_schedule_list_sorted = sorted(self.task_to_pe_block_schedule_list_sorted, + key=lambda schedule: schedule.starting_time) + + # ------------------------------ + # Functionality + # get a task from the list of tasks + # Variables: + # task: task of interest + # ------------------------------ + def get_by_task(self, task: Task): # we go with the task name for task name + for el in self.task_to_pe_block_schedule_list_sorted: + if task == el.task: + return el + raise Exception("too many or none tasks scheduled with " + task.name + "name") + + # ------------------------------ + # Functionality + # get a task from the list of tasks by its name + # Variables: + # task: task of interest + # ------------------------------ + def get_by_task_name(self, task: Task): # we go with the task name for task name + for el in self.task_to_pe_block_schedule_list_sorted: + if task.name == el.task.name: + return el + raise Exception("too many or none tasks scheduled with " + task.name + "name") \ No newline at end of file diff --git a/Project_FARSI/design_utils/components/workload.py b/Project_FARSI/design_utils/components/workload.py new file mode 100644 index 00000000..855cb54a --- /dev/null +++ b/Project_FARSI/design_utils/components/workload.py @@ -0,0 +1,597 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +import json +import os +from settings import config +from typing import Dict, List +import collections +import random +from datetime import datetime +import numpy as np +import time +import math + + +# This class is to model a task, that is the smallest software execution unit. +class Task: + task_id_for_debugging_static = 0 + def __init__(self, name, work, iteration_ctr =1, type="latency_based", throughput_info = {}): + self.iteration_ctr = iteration_ctr + self.name = name + self.progress = 0 # progress percentage (how much of the task has been finished) + self.task_fraction = 1 # TODO: right now always set to 1 + self.task_fraction_index = 0 # TODO: always shunned to zero for now + self.__children = [] # task children, i.e., the tasks that read from it + self.__parents = [] # task parents, i.e., the tasks that write to the task + self.__siblings = [] # tasks siblings, i.e., tasks with the same parents. + self.PA_prop_dict = collections.OrderedDict() # all the props used PA + self.PA_prop_auto_tuning_list = [] # list of auto tuned variables + self.__task_to_family_task_work = {} # task to parent work-ratio + self.__task_to_family_task_work[self] = work # task self work + self.task_id_for_debugging = self.task_id_for_debugging_static + self.updated_task_work_for_debug = False + self.__task_work_distribution = [] # Todo: double check and deprecate later + self.__task_to_child_work_distribution = {} # double check and deprecate later + self.__fake_children = [] # from the parent perspective, there is no new data transferred to children + # but the reuse of the old generate data. So this is not used in + # work ratio calculation + + self.__fake_parents = [] # from the parent perspective, there is no new data transferred to children + # but the reuse of the old generate data. So this is not used in + # work ratio calculation + + self.__task_to_family_task_work_unit = {} # task to family task unit of work. For example + # work unit from bus and memory perspective is the burst size + # (in bytes) + self.burst_size = config.default_burst_size + + self.type = type + self.throughput_info = throughput_info + + def get_throughput_info(self): + return self.throughput_info + + def get_iteration_cnt(self): + return self.iteration_ctr + + def set_burst_size(self, burst_size): + self.burst_size = burst_size + + def get_burst_size(self): + return self.burst_size + + def get_name(self): + return self.name + # --------------- + # Functionality: + # get a task's family tasks (tasks that it reads from). + # Variables: + # task_name: name of the task. + # --------------- + def get_family_task_by_name(self, task_name): + for task in self.__parents + self.__siblings + self.__children + [self]: + if task.name == task_name: + return task + + # --------------- + # Functionality: + # resetting Platform Architect (PA) props. Used in PA design generation. + # task_name: name of the task. + # --------------- + def reset_PA_props(self): + self.PA_prop_dict = collections.OrderedDict() + + # --------------- + # Functionality: + # update Platform Architect (PA) props. + # Variables: + # PA_prop_dict: dictionary containing all the PA props + # --------------- + def update_PA_props(self, PA_prop_dict): + self.PA_prop_dict.update(PA_prop_dict) + + def update_PA_auto_tunning_knob_list(self, prop_auto_tuning_list): + self.PA_prop_auto_tuning_list = prop_auto_tuning_list + + # --------------- + # Functionality: + # pick one off the children at random. + # --------------- + def sample_child(self): + random.seed(datetime.now().microsecond) + return random.choice(self.get_children()) + + # --------------- + # Functionality: + # sample the task distribution work. Used for jitter modeling/incorporation. + # --------------- + def sample_self_task_work(self): + time.sleep(.00005) + np.random.seed(datetime.now().microsecond) + task_work = [task_work for task_work, work_prob in self.get_task_work_distribution()] + work_prob = [work_prob for task_work, work_prob in self.get_task_work_distribution()] + return np.random.choice(task_work, p=work_prob) + + # --------------- + # Functionality: + # sample the task to child work (how much data gets writen into the child task) distribution. + # Used for jitter modeling/incorporation. + # Variables: + # child: task's child + # --------------- + def sample_self_to_child_task_work(self, child): + np.random.seed(datetime.now().microsecond) + task_to_child_work = [task_work for task_work, work_prob in self.get_task_to_child_work_distribution(child)] + work_prob = [work_prob for task_work, work_prob in self.get_task_to_child_work_distribution(child)] + return np.random.choice(task_to_child_work, p=work_prob) + + # --------------- + # Functionality: + # update the task work (used in jitter modeling after a new work is assigned to the task). + # Variables: + # self_work: work to assign to the task. + # --------------- + def update_task_work(self, self_work): + delete_later_ = self.get_self_task_work() + self.updated_task_work_for_debug = True + self.__task_to_family_task_work[self] = self_work + delete_later = self.get_self_task_work() + a = delete_later + + # --------------- + # Functionality: + # update the task to child work (used in jitter modeling after a new work is assigned to the task). + # Variables: + # child: tasks's child. + # child_work: tasks's child work. + # --------------- + def update_task_to_child_work(self, child, child_work): + self.__task_to_family_task_work[child] = child_work + child.__task_to_family_task_work[self] = child_work + if self.updated_task_work_for_debug: + self.update_task_work_for_debug = False + self.task_id_for_debugging +=1 + + # --------------- + # Functionality: + # add task to child work to the distribution of tasks. Used for jitter modeling. + # Variables: + # child: tasks's child. + # work: work to add to the distribution. + # --------------- + def add_task_to_child_work_distribution(self, child, work): + self.__task_to_child_work_distribution[child] = work + + # --------------- + # Functionality: + # add a parent(a task that it reads data from) for the task. + # Variables: + # child: tasks's child. + # work: works of the child. + + # --------------- + def add_parent(self, parent,child_nature = "real"): + self.__parents.append(parent) + if child_nature == "fake": + self.__fake_parents.append(parent) + + # add a child (a task that it writes to). + # nature determines whether the work is real or fake. real means real generation of the data (which needs to be pass + # along to the child) whereas fake means that the data has already been generated and just needs to be passed along. + def add_child(self, child, work, child_nature="real"): + self.__task_to_family_task_work[child] = work + for other_child in self.__children: + other_child.add_sibling(child) + child.add_sibling(other_child) + self.__children.append(child) + child.add_parent(self, child_nature) + child.__task_to_family_task_work[self] = work + assert(child_nature in ["fake", "real"]), "child nature can only be fake or real but " + child_nature + " was given" + if child_nature == "fake": + self.__fake_children.append(child) + + # --------------- + # Functionality: + # fake children are children that we need to pass data to, but we don't need to generate the data + # since it is already generate. This situation happens when two children use the same exact data + # --------------- + def get_fake_children(self): + return self.__fake_children + + def get_fake_children_name(self): + return [task.name for task in self.__fake_children] + + # --------------- + # Functionality: + # fake parents are parents that pass data to, but don't need to generate the data + # since it is already generate. + # --------------- + def get_fake_parent_name(self): + return [task.name for task in self.__fake_parents] + + def get_fake_family_name(self): + return self.get_fake_children_name() + self.get_fake_parent_name() + + # --------------- + # Functionality: + # remove a child (a task that it writes data to) for the task. + # Variables: + # child: tasks's child. + # --------------- + def remove_child(self, child): + for other_child in self.__children: + other_child.remove_sibling(child) + child.remove_sibling(other_child) + self.__children.remove(child) + del self.__task_to_family_task_work[child] + child.__parents.remove(self) + del child.__task_to_family_task_work[self] + + + def remove_parent(self, parent): + self.__parents.remove(parent) + del self.__task_to_family_task_work[parent] + parent.__children.remove(self) + del parent.__task_to_family_task_work[self] + + + # --------------- + # Functionality: + # add sibling (a task with the same parent) for the task. + # Variables: + # task: sibling task. + # --------------- + def add_sibling(self, task): + if task not in self.__siblings: + self.__siblings.append(task) + + # --------------- + # Functionality: + # removing sibling (a task with the same parent) for the task. + # Variables: + # task: sibling task. + # --------------- + def remove_sibling(self, task): + if task in self.__siblings: + self.__siblings.remove(task) + + # --------------- + # Functionality: + # get the relationship of the task with the input task. + # Variables: + # task_: the task to find the relationship for. + # --------------- + def get_relationship(self, task_): + if any([task__.name == task_.name for task__ in self.__children]): + return "child" + elif any([task__.name == task_.name for task__ in self.__parents]): + return "parent" + elif any([task__.name == task_.name for task__ in self.__siblings]): + return "sibling" + elif task_.name == self.name: + return "self" + else: + return "no relationship" + + # --------------- + # Functionality: + # get tasks's work + # --------------- + def get_self_task_work(self): + return self.__task_to_family_task_work[self] + + # --------------- + # Functionality: + # get self to task family work (how much data is passed from/to the family task). + # Variables: + # family_task: family task. + # --------------- + def get_self_to_family_task_work(self, family_task): + if family_task in self.get_children(): + return self.__task_to_family_task_work[family_task] + elif family_task in self.get_parents(): + return family_task.get_self_to_family_task_work(self) + elif family_task == self: + return self.get_self_task_work() + else: + print(family_task.name + " is not a family task of " + self.name) + exit(0) + + def get_type(self): + return self.type + + def get_self_total_work(self, mode): + total_work = 0 + if mode == "execute": + total_work = self.__task_to_family_task_work[self] + if mode == "read": + for family_task in self.get_parents(): + total_work += family_task.get_self_to_family_task_work(self) + if mode == "write": + for family_task in self.get_children(): + total_work += self.__task_to_family_task_work[family_task] + return total_work + + + # return self to family task unit of work. For example + # work unit from bus and memory perspective is the burst size + # (in bytes) + def get_self_to_family_task_work_unit(self, family_task): + return self.__task_to_family_task_work_unit[family_task] + + # determines what the dicing "grain" should be such that that + # work unit (e.g., burst size) is respected. + # Note that we ensure that smallest "read" (just as a convention) will respect the + # burst-size. Everything else is adjusted accordingly + def set_dice_factor(self, block_size): + smallest_read = self.get_smallest_task_work_by_dir("read") + smallest_write = self.get_smallest_task_work_by_dir("write") + smallest_instructions = self.get_smallest_task_work_by_dir("loop") + + dice_factor = math.floor(smallest_read/block_size) # use read,# this is just decided. Doesn't have to be this. Just had to pick something + + if dice_factor == 0: + dice_factor = 1 + else: + smallest_read_scaled = math.floor(smallest_read/dice_factor) + smallest_write_scaled = math.floor(smallest_write/dice_factor) + task_instructions_scaled = math.floor(smallest_instructions/dice_factor) + + if smallest_write_scaled == 0 or task_instructions_scaled == 0: + dice_factor = 1 + return dice_factor + + # based on the some reference work unit (same as block_size) determine the rest of the + # work units. + def calc_work_unit(self): + dice_factor = self.set_dice_factor(self.burst_size) + for family in self.get_family(): + self.__task_to_family_task_work_unit[family] = int(self.get_self_to_family_task_work(family)/dice_factor) + assert(self.get_self_to_family_task_work(family)/dice_factor > .1) + + def get_smallest_task_work_by_dir(self, dir): + if dir == "write": + family_tasks = self.get_children() + elif dir == "read": + family_tasks = self.get_parents() + elif dir == "loop": + family_tasks = [self] + + if "souurce" in self.name: + return 0 + if "siink" in self.name: + return 0 + + if len(family_tasks) == 0: + print("what") + + return min([self.get_self_to_family_task_work(task_) for task_ in family_tasks]) + + def get_biggest_task_work_by_dir(self, dir): + if dir == "write": + family_tasks = self.get_children() + elif dir == "read": + family_tasks = self.get_parents() + elif dir == "loop": + family_tasks = [self] + + if "souurce" in self.name: + return 0 + if "siink" in self.name: + return 0 + return max([self.get_self_to_family_task_work(task_) for task_ in family_tasks]) + + + + def get_smallest_task_work_unit_by_dir(self, dir): + if dir == "write": + family_tasks = self.get_children() + elif dir == "read": + family_tasks = self.get_parents() + elif dir == "loop": + family_tasks = [self] + + if "souurce" in self.name: + return 0 + if "siink" in self.name: + return 0 + return min([self.get_self_to_family_task_work_unit(task_) for task_ in family_tasks]) + + + + # --------------- + # Functionality: + # add task's work to the distribution work. + # Variables: + # work: new work to add to the distribution. + # --------------- + def add_task_work_distribution(self, work): + self.task_work_distribution = work + + # --------------- + # Functionality: + # get task's work distribution. + # --------------- + def get_task_work_distribution(self): + return self.task_work_distribution + + # --------------- + # Functionality: + # get task's to child work distribution. + # --------------- + def get_task_to_child_work_distribution(self, child): + return self.__task_to_child_work_distribution[child] + + # --------------- + # Functionality: + # get the work ratio (how much data is written to/ read from) for the family task. + # Variables: + # family_task_name: name of the family (parent/child) task. + # --------------- + def get_work_ratio_by_family_task_name(self, family_task_name): + family_task = self.get_family_task_by_name(family_task_name) + return self.get_work_ratio(family_task) + + # --------------- + # Functionality: + # get the work ratio (how much data is written to/ read from) for the family task. + # Variables: + # family_task: name of the family (parent/child) task. + # --------------- + def get_work_ratio(self, family_task): + """ + if not (self.task_id_for_debugging == self.task_id_for_debugging_static): + print("debugging not matching") + exit(0) + """ + if self.get_self_task_work() == 0: # dummy tasks + return 1 + return self.get_self_to_family_task_work(family_task)/self.get_self_task_work() + + # --------------- + # Functionality: + # getters + # --------------- + def get_children(self): + return self.__children + + def get_parents(self): + return self.__parents + + def get_siblings(self): + return self.__siblings + + def get_family(self): + return self.__parents + self.__children + + def is_task_dummy(self): + return "souurce" in self.name or "siink" in self.name or "dummy_last" in self.name + + +# Task Graph for the workload. +class TaskGraph: + def __init__(self, tasks): + self.__tasks = tasks + _ = [task_.calc_work_unit() for task_ in self.__tasks] + + # ----------- + # Functionality: + # get the root of the task graph. + # ----------- + def get_root(self): + roots = [] + for task in self.__tasks: + if not task.get_parents(): + roots.append(task) + if not(len(roots)== 1): + print("weird") + assert(len(roots) == 1), "must have only one task at the top of the dep graph. added a dummy otherwise to do this" + return roots[0] + + def get_all_tasks(self): + return self.__tasks + + + def get_task_by_name(self, name): + for tsk in self.__tasks: + if tsk.name == name: + return tsk + + print("task with the name " + name + " does not exist") + exit(0) + + # ----------- + # Functionality: + # get task's parents (task that it reads from) + # ----------- + def get_task_s_parents(self, task): + return task.get_parents() + + # ----------- + # Functionality: + # get task's parents name + # Variables: + # task_name: name of the task + # ----------- + def get_task_s_parent_by_name(self, task_name): + for task in self.__tasks: + if task.name == task_name: + return task.get_parents() + raise Exception("task:" + task.name + " not in the task graph") + + # ----------- + # Functionality: + # get task's children + # Variables: + # task_name: name of the task + # ----------- + def get_task_s_children(self, task): + return task.get_children() + + # ----------- + # Functionality: + # get task's children by task name + # Variables: + # task_name: name of the task + # ----------- + def get_task_s_children_by_task_name(self, task_name): + for task in self.__tasks: + if task.name == task_name: + return task.get_children() + raise Exception("task:" + task.name + " not in the task graph") + + # determine whether two tasks can run in parallel or not + def task_can_run_in_parallel_helper(self, task_1, task_2, task_to_look_at, root): + if task_to_look_at == task_2: + return False + elif task_to_look_at == root: + return True + + result = [] + for task in task_to_look_at.get_parents(): + result.append(self.task_can_run_in_parallel_helper(task_1, task_2, task, root)) + if not all(result): + return False + return True + + # --------------- + # Functionality: + # establish if tasks can run in parallel or not . + # --------------- + def tasks_can_run_in_parallel(self, task_1, task_2): + root = self.get_root() + return (self.task_can_run_in_parallel_helper(task_1, task_2, task_1, root) and + self.task_can_run_in_parallel_helper(task_2, task_1, task_2, root)) + + +# This class emulates the software workload containing the task set. +class Workload: + def __init__(self, input_file="workload.json", mode=""): + self.tasks = [] # set of tasks. + self.input_file = input_file # if task reads from a file to be populated. Not supported yet. + if mode == "from_json": + self.populate_tasks() + + # ----------- + # Functionality: + # populate the tasks from a file. To be finished in the next round. + # ----------- + def populate_tasks(self): + raise Exception('not supported any more') + with open(os.path.join(config.design_input_folder, self.input_file), 'r') as json_file: + data = json.load(json_file) + + for datum in data["workload"]: + self.tasks.append(Task(datum["task_name"], datum["work"])) + + # ----------- + # Functionality: + # get a task that has the similar name as the input task name. Used when we duplicate a design. + # ----------- + def get_task_by_name(self, task): + for task_ in self.tasks: + if task.name == task_.name: + return task_ + raise Exception("task instance with name:" + str(task.name) + " doesn't exist in this workload") \ No newline at end of file diff --git a/Project_FARSI/design_utils/des_handler.py b/Project_FARSI/design_utils/des_handler.py new file mode 100644 index 00000000..b3fd27b3 --- /dev/null +++ b/Project_FARSI/design_utils/des_handler.py @@ -0,0 +1,2539 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +from design_utils.design import * +from settings import config +from specs.data_base import * +from specs.LW_cl import * +import time +import random +from datetime import datetime +from typing import Dict, List, Tuple +from design_utils.components.hardware import * +from error_handling.custom_error import * +import importlib +from DSE_utils.exhaustive_DSE import * +from visualization_utils.vis_hardware import * +import _pickle as cPickle + +# This class allows us to modify the design. Each design is applied +# a move to get transformed to another. +# Move at the moment has 4 different parts (metric, kernel, block, transformation) that needs to be +# set +class move: + def __init__(self, transformation_name, transformation_sub_name, batch_mode, dir, metric, blck, krnel, krnl_prob_dict_sorted): + self.workload = "" + self.transformation_name = transformation_name + self.parallelism_type = [] + self.locality_type = [] + self.transformation_sub_name = transformation_sub_name + self.customization_type = "" + self.metric = metric + self.batch_mode = batch_mode + self.dir = dir + self.blck = blck + self.krnel = krnel + self.tasks = "_" + self.dest_block = "_" + self.valid = True + self.breadth = 0 + self.depth = 0 + self.mini_breadth = 0 + self.sorted_kernels = [] + self.sorted_blocks = [] + self.kernel_rnk_to_consider = 0 + self.sorted_metrics = [] + self.ref_des_dist_to_goal_all = 0 + self.ref_des_dist_to_goal_non_cost = 0 + self.cost = 0 + self.sorted_metric_dir = {} + self.pre_move_ex = None + self.moved_ex = None + self.validity_meta_data = "" + self.krnel_prob_dict_sorted = krnl_prob_dict_sorted + self.generation_time = 0 + self.system_improvement_dict = {} + #self.populate_system_improvement_log() + + self.transformation_selection_time = 0 + self.block_selection_time = 0 + self.kernel_selection_time = 0 + self.pickling_time = 0 + + self.design_space_size = {"pe_mapping": 0 , "ic_mapping":0, "mem_mapping":0, + "pe_allocation": 0, "ic_allocation": 0, "mem_allocation": 0, + "transfer":0, "routing":0, "hardening":0, "softening":0, + "bus_width_modulation":0, "mem_frequency_modulation":0, "ic_frequency_modulation":0, "pe_frequency_modulation":0, "identity":0, + "loop_iteration_modulation":0 + } + + + def get_design_space_size(self): + return self.design_space_size + + def get_system_improvement_log(self): + return self.system_improvement_dict + + def populate_system_improvement_log(self): + # communication vs computation + if self.get_block_ref().type in ["ic","mem"]: + comm_comp = "comm" + else: + comm_comp = "comp" + + # which exact optimization targgeted: topology/mapping/tunning + if self.get_transformation_name() in ["swap"]: + exact_optimization = self.get_customization_type(self.get_block_ref(), self.get_des_block()) + elif self.get_transformation_name() in ["split_swap"]: + if self.get_block_ref() == "pe": + exact_optimization = self.get_customization_type(self.get_block_ref(), self.get_des_block()) +";" +self.get_block_ref().type +"_"+"allocation" + else: + exact_optimization = self.get_customization_type(self.get_block_ref(), self.get_des_block()) + elif self.get_transformation_name() in ["migrate"]: + exact_optimization = self.get_block_ref().type +"_"+"mapping" + elif self.get_transformation_name() in ["split_swap", "split", "transfer", "routing"]: + exact_optimization = self.get_block_ref().type +"_"+"allocation" + elif self.get_transformation_name() in ["transfer", "routing"]: + exact_optimization = self.get_block_ref().type +"_"+self.get_transformation_name() + elif self.get_transformation_name() in ["cost"]: + exact_optimization = "cost" + elif self.get_transformation_name() in ["dram_fix"]: + exact_optimization = "dram_fix" + elif self.get_transformation_name() in ["identity"]: + exact_optimization = "identity" + else: + print(self.get_transformation_name() + " high level optimization is not specified") + exit(0) + + + # which high level optimization targgeted: topology/mapping/tunning + if self.get_transformation_name() in ["swap", "split_swap"]: + high_level_optimization = "hardware_tunning" + elif self.get_transformation_name() in ["split_swap"]: + high_level_optimization = "hardware_tunning;topology" + elif self.get_transformation_name() in ["migrate"]: + high_level_optimization = "mapping" + elif self.get_transformation_name() in ["split_swap", "split", "transfer","routing"]: + high_level_optimization = "topology" + elif self.get_transformation_name() in ["cost"]: + high_level_optimization = "cost" + elif self.get_transformation_name() in ["dram_fix"]: + high_level_optimization = "dram_fix" + elif self.get_transformation_name() in ["identity"]: + high_level_optimization = "identity" + else: + print(self.get_transformation_name() + " high level optimization is not specified") + exit(0) + + # which architectural variable targgeted: topology/mapping/tunning + if self.get_transformation_name() in ["split"]: + architectural_principle = "" + for el in self.parallelism_type: + architectural_principle += el+";" + architectural_principle = architectural_principle[:-1] + elif self.get_transformation_name() in ["migrate"]: + architectural_principle = "" + for el in self.parallelism_type: + architectural_principle += el+";" + for el in self.locality_type: + architectural_principle += el+";" + architectural_principle = architectural_principle[:-1] + elif self.get_transformation_name() in ["split_swap", "swap"]: + if "loop_iteration_modulation" in exact_optimization: + architectural_principle = "loop_level_parallelism" + else: + architectural_principle = "customization" + elif self.get_transformation_name() in ["transfer","routing"]: + architectural_principle = "spatial_locality" + elif self.get_transformation_name() in ["cost"]: + architectural_principle = "cost" + elif self.get_transformation_name() in ["dram_fix"]: + architectural_principle = "dram_fix" + elif self.get_transformation_name() in ["identity"]: + architectural_principle = "identity" + else: + print(self.get_transformation_name() + " high level optimization is not specified") + exit(0) + + self.system_improvement_dict["comm_comp"] = comm_comp + self.system_improvement_dict["high_level_optimization"] = high_level_optimization + self.system_improvement_dict["exact_optimization"] = exact_optimization + self.system_improvement_dict["architectural_principle"] = architectural_principle + + + def set_parallelism_type(self, parallelism_type): + self.parallelism_type = parallelism_type + + def set_locality_type(self, locality_type): + self.locality_type = locality_type + + + def get_transformation_sub_name(self): + return self.transformation_sub_name + + def set_krnel_ref(self, krnel): + self.krnel = krnel + + # how long did it take to come up with the move + def set_generation_time(self, generation_time): + self.generation_time = generation_time + + def get_generation_time(self): + return self.generation_time + + def set_logs(self, data, type_): + if type_ == "cost": + self.cost = data + if type_ == "workload": + self.workload = data + if type_ == "pickling_time": + self.pickling_time = data + if type_ == "metric_selection_time": + self.metric_selection_time = data + if type_ == "dir_selection_time": + self.dir_selection_time = data + if type_ == "kernel_selection_time": + self.kernel_selection_time = data + if type_ == "block_selection_time": + self.block_selection_time = data + if type_ == "transformation_selection_time": + self.transformation_selection_time = data + if type_ == "kernels": + self.sorted_kernels = data + if type_ == "blocks": + self.sorted_blocks = data + if type_ == "metrics": + self.sorted_metrics = data + if type_ == "kernel_rnk_to_consider": + self.kernel_rnk_to_consider = data + if type_ == "ref_des_dist_to_goal_all": + self.ref_des_dist_to_goal_all = data + if type_ == "ref_des_dist_to_goal_non_cost": + self.ref_des_dist_to_goal_non_cost = data + + def get_transformation_batch(self): + return self.batch_mode + + def get_logs(self, type_): + if type_ == "workload": + return self.workload + if type_ == "cost": + return self.cost + if type_ == "kernels": + return self.sorted_kernels + if type_ == "blocks": + return self.sorted_blocks + if type_ == "metrics": + return self.sorted_metrics + if type_ == "kernel_rnk_to_consider": + return self.kernel_rnk_to_consider + if type_ == "ref_des_dist_to_goal_all": + return self.ref_des_dist_to_goal_all + if type_ == "ref_des_dist_to_goal_non_cost": + return self.ref_des_dist_to_goal_non_cost + if type_ == "pickling_time": + return self.pickling_time + if type_ == "metric_selection_time": + return self.metric_selection_time + if type_ == "dir_selection_time": + return self.dir_selection_time + if type_ == "kernel_selection_time": + return self.kernel_selection_time + if type_ == "block_selection_time": + return self.block_selection_time + if type_ == "transformation_selection_time": + return self.transformation_selection_time + + + # depth and breadth determine how many designs to generate around (breadth) and + # chain from (breadth) from the current design + def set_breadth_depth(self, breadth, depth, mini_breadth): + self.breadth = breadth + self.depth = depth + self.mini_breadth = mini_breadth + + def get_depth(self): + return self.depth + + def get_mini_breadth(self): + return self.mini_breadth + + def get_breadth(self): + return self.breadth + + def set_metric(self, metric): + self.metric = metric + + # the block to target in the move + def set_ref_block(self, blck): + self.blck = blck + + # the task to target in the move + def set_tasks(self, tasks_): + self.tasks = tasks_ + + # the transformation to target in the move + def set_transformation(self, trans): + self.transformation_name = trans + + # the transformation to target in the move + def set_transformation_sub_name(self, trans_sub_name): + self.transformation_sub_name = trans_sub_name + + # set this to specify that the move's parameters are fully set + # and hence the move is ready to be applied + def set_validity(self, validity, reason=""): + self.valid = validity + self.validity_meta_data = reason + + def is_valid(self): + return self.valid + + # set the block that we will change the ref_block to. + def set_dest_block(self, block_): + self.dest_block = block_ + + + def get_customization_type(self, ref_block, imm_block): + return self.customization_type + + # set the block that we will change the ref_block to. + def set_customization_type(self, ref_block, imm_block): + if not ref_block.subtype == imm_block.subtype: + # type difference + if ref_block.type == "pe": + if ref_block.subtype == "gpp" and imm_block.subtype =="ip": + self.customization_type = "hardening" + elif ref_block.subtype == "ip" and imm_block.subtype =="gpp": + self.customization_type = "softening" + else: + self.customization_type = "unknown" + #print("we should have coverred all the customizations. what is missing then (1)") + #exit(0) + elif ref_block.type == "mem": + #self.customization_type = "memory_cell_ref_block.subtype +"_to_" + imm_block.subtype + self.customization_type = "mem_allocation" + else: + self.customization_type = "unknown" + # print("we should have coverred all the customizations. what is missing then (1)") + # exit(0) + else: + if not ref_block.get_loop_itr_cnt() == imm_block.get_loop_itr_cnt(): + self.customization_type = "loop_iteration_modulation" + elif not ref_block.get_block_freq() == imm_block.get_block_freq(): + self.customization_type = ref_block.type+"_"+"frequency_modulation" + elif not ref_block.get_block_bus_width() == imm_block.get_block_bus_width(): + self.customization_type = "bus_width_modulation" + else: + self.customization_type = "unknown" + # print("we should have coverred all the customizations. what is missing then (1)") + # exit(0) + + + def get_block_attr(self, selected_metric): + if selected_metric == "latency": + selected_metric_to_sort = 'peak_work_rate' + elif selected_metric == "power": + #selected_metric_to_sort = 'work_over_energy' + selected_metric_to_sort = 'one_over_power' + elif selected_metric == "area": + selected_metric_to_sort = 'one_over_area' + else: + print("selected_selected_metric: " + selected_metric + " is not defined") + return selected_metric_to_sort + + + # -------------------------------------- + # getters. + # PS: look in to the equivalent setters to understand the parameters + # -------------------------------------- + def get_tasks(self): + return self.tasks + + def get_des_block(self): + return self.dest_block + + def get_des_block_name(self): + if self.dest_block == "_": + return "_" + else: + return self.dest_block.instance_name + + def get_dir(self): + assert(not(self.dir == "deadbeef")), "dir is not initialized" + return self.dir + + def get_metric(self): + assert(not(self.metric == "deadbeef")), "metric is not initialized" + return self.metric + + def set_sorted_metric_dir(self, sorted_metric_dir): + self.sorted_metric_dir = sorted_metric_dir + + def get_sorted_metric_dir(self): + return self.sorted_metric_dir + + def get_transformation_name(self): + assert(not(self.transformation_name == "deadbeef")), "name is not initialized" + return self.transformation_name + + def get_block_ref(self): + assert(not(self.blck=="deadbeef")), "block is not initialized" + return self.blck + + def get_kernel_ref(self): + assert (not (self.krnel == "deadbeef")), "block is not initialized" + return self.krnel + + def print_info(self, mode="all"): + if mode == "all": + print("info:" + " tp::" + self.get_transformation_name() + ", mtrc::" + self.get_metric() + ", blck_ref::" + + self.get_block_ref().instance_name + ", block_des:" + self.get_des_block_name()+ + ", tsk:" + self.get_kernel_ref().get_task().name) + else: + print("mode:" + mode + " is not supported for move printing") + + # see if you can apply the move + # this is different than sanity check which checks whether the applied move messed up the design + def safety_check(self, ex): + return True + + + # Check the validity of the move. + def validity_check(self): + if not self.is_valid(): + if self.validity_meta_data == "NoMigrantException": + raise NoMigrantException + elif self.validity_meta_data == "ICMigrationException": + raise ICMigrationException + elif self.validity_meta_data == "CostPairingException": + raise CostPairingException + elif self.validity_meta_data == "NoAbsorbee(er)Exception": + raise NoAbException + elif self.validity_meta_data == "TransferException": + raise TransferException + elif self.validity_meta_data == "RoutingException": + raise RoutingException + elif self.validity_meta_data == "IPSplitException": + raise IPSplitException + elif self.validity_meta_data == "NoValidTransformationException": + raise NoValidTransformationException + elif self.validity_meta_data == "NoParallelTaskException": + raise NoParallelTaskException + else: + print("this invalidity reason is not supported" + self.validity_meta_data) + exit(0) + + # Log the design before and after for post-processing + def set_before_after_designs(self, pre_moved_ex, moved_ex): + self.pre_move_ex = pre_moved_ex + self.moved_ex = moved_ex + + # note that since we make a copy of hte design (and call it prev_des) + # the instance_names are not gonna be exactly the same + # Variables: + # pre_moved_ex: example design pre moving (transformation) + # moved_ex: example design after moving (transformation) + # mode = {"pre_application", "after_application"} # pre means before applying the move + def sanity_check(self): + pre_moved_ex = self.pre_move_ex + moved_ex = self.moved_ex + insanity_list = [] + + # --------------------- + # pre/post design are not specified. This is an indicator that the move was not really applied + # which is caused by the errors when apply the move + # --------------------- + if moved_ex is None or pre_moved_ex is None: + insanity = Insanity("_", "_", "no_design_provided") + print(insanity.gen_msg()) + self.set_validity(False) + raise MoveNoDesignException + #return False + + # --------------------- + # number of fronts sanity check + # --------------------- + pre_mvd_fronts_1 = sum([len(block.get_fronts("task_name_dir")) for block in pre_moved_ex.get_blocks()]) + pre_mvd_fronts_2 = sum([len(block.get_fronts("task_dir_work_ratio")) for block in pre_moved_ex.get_blocks()]) + if not pre_mvd_fronts_1 == pre_mvd_fronts_2: + pre_mvd_fronts_1 = [block.get_fronts("task_name_dir") for block in pre_moved_ex.get_blocks()] + pre_mvd_fronts_2 = [block.get_fronts("task_dir_work_ratio") for block in pre_moved_ex.get_blocks()] + raise UnEqualFrontsError + + mvd_fronts_1 = sum([len(block.get_fronts("task_name_dir")) for block in moved_ex.get_blocks()]) + mvd_fronts_2 = sum([len(block.get_fronts("task_dir_work_ratio")) for block in moved_ex.get_blocks()]) + if not mvd_fronts_1 == mvd_fronts_2: + mvd_fronts_1 = [block.get_fronts("task_name_dir") for block in moved_ex.get_blocks()] + mvd_fronts_2 = [block.get_fronts("task_dir_work_ratio") for block in moved_ex.get_blocks()] + raise UnEqualFrontsError + + # --------------------- + # block count sanity checks + # --------------------- + if self.get_transformation_name() == "split": + if self.get_block_ref().type == "ic": + if not (len(moved_ex.get_blocks()) in [len(pre_moved_ex.get_blocks()), + # when can't succesfully split + len(pre_moved_ex.get_blocks()) + 1, + len(pre_moved_ex.get_blocks()) + 2, + len(pre_moved_ex.get_blocks()) + 3]): + insanity = Insanity("_", "_", "block_count_deviation") + insanity_list.append(insanity) + print("previous block count:" + str( + len(pre_moved_ex.get_blocks())) + " moved_ex block count:" + str( + len(moved_ex.get_blocks()))) + print(insanity.gen_msg()) + self.set_validity(False) + raise BlockCountDeviationError + else: + if not (len(moved_ex.get_blocks()) in [len(pre_moved_ex.get_blocks()), + # when can't successfully split + len(pre_moved_ex.get_blocks()) + 1]): + insanity = Insanity("_", "_", "block_count_deviation") + insanity_list.append(insanity) + print("previous block count:" + str( + len(pre_moved_ex.get_blocks())) + " moved_ex block count:" + str( + len(moved_ex.get_blocks()))) + print(insanity.gen_msg()) + self.set_validity(False) + raise BlockCountDeviationError + elif self.get_transformation_name() in ["migrate"]: + if not (len(moved_ex.get_blocks()) in [len(pre_moved_ex.get_blocks()), + len(pre_moved_ex.get_blocks()) - 1]): + insanity = Insanity("_", "_", "block_count_deviation") + insanity_list.append(insanity) + print( + "previous block count:" + str(len(pre_moved_ex.get_blocks())) + " moved_ex block count:" + str( + len(moved_ex.get_blocks()))) + print(insanity.gen_msg()) + self.set_validity(False) + raise BlockCountDeviationError + elif self.get_transformation_name() in ["swap"]: + if not (len(pre_moved_ex.get_blocks()) == len(moved_ex.get_blocks())): + insanity = Insanity("_", "_", "block_count_deviation") + insanity_list.append(insanity) + print( + "previous block count:" + str(len(pre_moved_ex.get_blocks())) + " moved_ex block count:" + str( + len(moved_ex.get_blocks()))) + print(insanity.gen_msg()) + self.set_validity(False) + raise BlockCountDeviationError + + # --------------------- + # disconnection check + # --------------------- + if self.get_transformation_name() == "swap": + if len(self.get_block_ref().neighs) > 0: + insanity = Insanity("_", "_", "incomplete_swap") + insanity_list.append(insanity) + print("block" + move.get_block_ref().instance_name + " wasn't completely disconnected:") + print(insanity.gen_msg()) + self.set_validity(False) + raise IncompleteSwapError + + return + + +# This class takes care of instantiating, sampling, modifying the design objects. +# Used within by the design exploration framework to generate neighbouring design points. +class DesignHandler: + def __init__(self, database): + # instantiate a database object (this object contains all the information in the data) + self.database = database # hardware/software database used for design selection. + self.__tasks = database.get_tasks() # software tasks to include in the design. + self.pruned_one = True # used for moves that include adding a NOC which requires reconnecting + # memory and processing elements (reconnect = prune and connnect) + self.boost_SOC = False # for multie SOC designs. Not activated yet. + self.DMA_task_ctr = 0 # number of DMA engines used + if config.FARSI_performance == "fast": + self.get_immediate_block = self.get_immediate_block_fast + self.get_immediate_block_multi_metric = self.get_immediate_block_multi_metric_fast + self.get_equal_immediate_blocks_present = self.get_equal_immediate_blocks_present_fast + else: + self.get_immediate_block = self.get_immediate_block_slow + self.get_equal_immediate_block_present = self.get_equal_immediate_block_present_slow + + # ------------------------------------------- + # Functionality: + # loading (Mapping) and unloading tasks to the blocks. + # Variables: + # pe: processing element to load. + # mem: Memory element to load. + # tasks: tasks set (within the workload) to load PE and MEM with. + # ----------------------------------------- + # only have one pe and write mem, so easy. TODO: how about multiple ones + def load_tasks_to_pe_and_write_mem(self, pe, mem, tasks): # is used only for initializing + get_work_ratio = self.database.get_block_work_ratio_by_task_dir + _ = [pe.load_improved(task, task) for task in tasks] # load PEs with all the tasks + for task in tasks: + for task_child in task.get_children(): + mem.load_improved(task, task_child) # load memory with tasks + + # only have one pe and write mem, so easy. TODO: how about multiple ones + def load_tasks_to_pe_and_write_mem_for_hops(self, pes, mems, tasks, num_of_hops): # is used only for initializing + get_work_ratio = self.database.get_block_work_ratio_by_task_dir + + # souurce/siink + task = tasks[0] + pes[0].load_improved(task, task) + for task_child in task.get_children(): + mems[0].load_improved(task, task_child) # load memory with tasks + task = tasks[-1] + pes[-1].load_improved(task, task) + for task_child in task.get_children(): + mems[-1].load_improved(task, task_child) # load memory with tasks + + + last_task = 1 + for i in range(1, len(tasks) - (num_of_hops-2) -1): + task = tasks[i] + if (i % 2) == 0: + pe_mem_idx =0 + else: + pe_mem_idx = len(pes) -1 + + pes[pe_mem_idx].load_improved(task, task) + for task_child in task.get_children(): + mems[pe_mem_idx].load_improved(task, task_child) # load memory with tasks + last_task +=1 + + # dummy tasks in the middle + idx = 1 + for i in range(last_task, len(tasks)-1): + task = tasks[i] + pes[idx].load_improved(task, task) + for task_child in task.get_children(): + mems[idx].load_improved(task, task_child) # load memory with tasks + idx +=1 + + + # ------------------------------ + # Functionality: + # loading (Mapping) and unloading tasks to reading memory/ICs blocks. + # Assigning read mem means to find the parent task's writing mem + # routing means assigning (loading a bus with a task) (the already existing) bus to a task such that + # for it's read rout its the fastest rout from pe to read_mem and + # for it's write rout its the fastest rout from pe to write_mem. + # note that these two conditions are independently considered + # Variables: + # ex_dp: example design. + # ------------------------------ + def load_tasks_to_read_mem_and_ic(self, ex_dp): + # assign buses (for both read and write) and mem for read + self.load_read_mem_and_ic_recursive(ex_dp, [], ex_dp.get_hardware_graph().get_task_graph().get_root(), [], None) + # prune whatever ic connection where there is no traffic on it + self.disconnect_ics_with_no_shared_task(ex_dp) + + + def disconnect_ics_with_no_shared_task(self, ex_dp): + design_ics = [blck for blck in ex_dp.get_blocks() if blck.type == "ic"] + for ic in design_ics: + ic_tasks = [el.get_name() for el in ic.get_tasks_of_block()] + ic_neighs = [neigh for neigh in ic.get_neighs() if neigh.type=="ic"] + for neigh in ic_neighs: + neigh_tasks = [el.get_name() for el in neigh.get_tasks_of_block()] + no_task_shared = len(ic_tasks) == len(list(set(ic_tasks) - set(neigh_tasks))) + if no_task_shared: + neigh.disconnect(ic) + print("=========== ==========") + print("=========== disconnected a path==========") + print("=========== ==========") + + + + # check if there is another task (on the block that can run in parallel with the task of interest + def find_parallel_tasks_of_task_in_block(self, ex_dp, sim_dp, task, block): + task_synced = [task__ for task__ in block.get_tasks_of_block() if task__.name == task.name][0] + parallel_tasks = [] + tasks_of_block = [task_ for task_ in block.get_tasks_of_block() if (not ("souurce" in task_.name) or not ("siink" in task_.name))] + + if config.parallelism_analysis == "static": + for task_ in tasks_of_block: + if ex_dp.get_hardware_graph().get_task_graph().tasks_can_run_in_parallel(task_, task_synced): + parallel_tasks.append(task_) + elif config.parallelism_analysis == "dynamic": + parallel_tasks_names_ = sim_dp.get_dp_rep().get_tasks_parallel_task_dynamically(task) + tasks_using_diff_pipe_cluster = sim_dp.get_dp_rep().get_tasks_using_the_different_pipe_cluster(task, block) + + parallel_tasks_names= list(set(parallel_tasks_names_) - set(tasks_using_diff_pipe_cluster)) + for task_ in tasks_of_block: + if task_.get_name() in parallel_tasks_names: + parallel_tasks.append(task_) + + cluster_one = parallel_tasks + cluster_two = list(set(tasks_of_block) - set(cluster_one)) + return [cluster_one, cluster_two] + + # can any other task on the block run in parallel with that ref task + def can_any_task_on_block_run_in_parallel(self, ex_dp, sim_dp, ref_task, block): + clusters_run_in_parallel = self.find_parallel_tasks_of_task_in_block(ex_dp, sim_dp, ref_task, block) + if len(clusters_run_in_parallel[0]) > 0: + return True + else: + return False + + # ------------------------------ + # Functionality: + # load a single memory with a task + # Variables: + # ex_dp: example design to use. + # mem: memory to load. + # task: task that will occupy memory + # dir_: direction (read/write) that the task will use memory + # family_task: parent/children task to read/write from to. + # ------------------------------ + def load_single_mem(self, ex_dp, mem, task, dir_, family_task): + #get_work_ratio = self.database.get_block_work_ratio_by_task_dir + if (len(ex_dp.get_blocks_of_task_by_block_type(task, "pe")) ==0): + print("This should not happen. Something wen't wrong") + raise NoPEError + pe = ex_dp.get_blocks_of_task_by_block_type(task, "pe")[0] # get pe blocks associated with the task + mem.load_improved(task, family_task) # load the the memory with the task. Use family ask for work ratio. + + # ------------------------------ + # Functionality: + # load a single bus with a task + # Variables: + # ex_dp: example design to use. + # mem: memory connected to the bust. + # task: task that will occupy bus. + # dir_: direction (read/write) that the task will use the bus for + # father_task: parent task to read from. + # ------------------------------ + def load_single_bus(self, ex_dp, mem, task, dir_, family_task): + #get_work_ratio = self.database.get_block_work_ratio_by_task_dir + pe_list = ex_dp.get_blocks_of_task_by_block_type(task, "pe") + if len(pe_list) == 0: + print("This should not happen. Something went wrong") + raise NoPEError + pe = pe_list[0] + get_work_ratio = self.database.get_block_work_ratio_by_task_dir + buses = ex_dp.hardware_graph.get_path_between_two_vertecies(pe, mem)[1:-1] + for bus in buses: + #bus.load((task, dir_), get_work_ratio(task, pe, dir_), father_task) + bus.load_improved(task, family_task) + + # ------------------------------ + # Functionality: + # load a read buses and memories recursively. Iterate through the hardware graph + # and fill out all the memory and buses for read. Note that we populate the read + # elements after write. You need to one, and then the other will get filled accordingly. + # ex_dp: example design to use. + # read_mem: memory to read from. + # task: task that will occupy bus. + # tasks_seen: a task list to prevent consider a task twice. + # father_task: parent task to read from. + # ------------------------------ + def load_read_mem_and_ic_recursive(self, ex_dp, read_mem, task, tasks_seen, father_task): + mem_blocks = ex_dp.get_blocks_of_task_by_block_type(task, "mem") # all the memory blocks of the task + if "souurce" in task.name: + write_mem = ex_dp.get_blocks_of_task_by_block_type(task, "mem") + elif "siink" in task.name: + write_mem= None + else: + write_mem = ex_dp.get_blocks_of_task_by_block_type_and_task_dir(task, "mem", "write") + + # get the tasks that this task writes to and the memory within which the transaction happens + if write_mem: + write_mem_tasks = [(mem, ex_dp.get_write_mem_tasks(task,mem)) for mem in write_mem] + else: + write_mem_tasks = None + + # Add read Memory, buses + if read_mem: + self.load_single_mem(ex_dp, read_mem, task, "read", father_task) + self.load_single_bus(ex_dp, read_mem, task, "read", father_task) + + if task in tasks_seen: return # terminate + else: tasks_seen.append(task) + + + if not(len(mem_blocks) == 0) and write_mem: + for mem, child_tasks in write_mem_tasks: + for child in child_tasks: + self.load_single_bus(ex_dp, mem, task, "write", child) # we make an assumption that the task + + """ + # Add write buses + if not(len(mem_blocks) == 0) and write_mem: + if (len(task.get_children()) == 0): + print("what") + raise TaskNoChildrenError + for child in task.get_children(): + self.load_single_bus(ex_dp, write_mem, task, "write", child) # we make an assumption that the task + #self.load_single_bus(ex_dp, write_mem, task, "write", task.get_children()[0]) # we make an assumption that the task + # even if having multiple children, it will be + # writing its results in the same memory + """ + if write_mem_tasks: + for mem, child_tasks in write_mem_tasks: + for child in child_tasks: + self.load_read_mem_and_ic_recursive(ex_dp, mem, child, tasks_seen, task) + """ + # recurse down + for task_ in task.get_children(): + if len(mem_blocks) == 0: + mem_blocks_ = ex_dp.get_blocks_of_task_by_block_type(task, "mem") # all the memory blocks of the task + print("what") + self.load_read_mem_and_ic_recursive(ex_dp, write_mem, task_, tasks_seen, task) + """ + # ------------------------------ + # Functionality: + # unload read buses. Need to do this first, to prepare the design for the next iteration. + # Variables: + # ex_dp: example_design + # ------------------------------ + def unload_read_buses(self, ex_dp): + busses = ex_dp.get_blocks_by_type("ic") + _ = [bus.unload_read() for bus in busses] + + # ------------------------------ + # Functionality: + # unload all buses. Need to do this first, to prepare the design for the next iteration. + # Variables: + # ex_dp: example_design + # ------------------------------ + def unload_buses(self, ex_dp): + busses = ex_dp.get_blocks_by_type("ic") + _ = [bus.unload_all() for bus in busses] + + # ------------------------------ + # Functionality: + # unload read memories. Need to do this first, to prepare the design for the next iteration. + # Variables: + # ex_dp: example_design + # ------------------------------ + def unload_read_mem(self, ex_dp): + mems = ex_dp.get_blocks_by_type("mem") + _ = [mem.unload_read() for mem in mems] + + # ------------------------------ + # Functionality: + # unload read memories and buses. Need to do this first, to prepare the design for the next iteration. + # Variables: + # ex_dp: example_design + # ------------------------------ + def unload_read_mem_and_ic(self, ex_dp): + self.unload_buses(ex_dp) + self.unload_read_mem(ex_dp) + + # ------------------------------ + # Functionality: + # find out whether a task needs a DMA (read and write memory are different) + # Variables: + # ex_dp: example_design + # task: task to consider + # ------------------------------ + def find_task_s_DMA_needs(self, ex_dp, task): + mem_blocks = ex_dp.get_blocks_of_task_by_block_type(task, "mem") # get the memory blocks of task + dir_mem_dict = {} + dir_mem_dict["read"] =[] + dir_mem_dict["write"] =[] + + # iterate through memory blocks and get their directions (read/write) + for mem_block in mem_blocks: + task_dir_list = mem_block.get_task_dir_by_task_name(task) + for task, dir in task_dir_list: + dir_mem_dict[dir].append(mem_block) + + # find out whether read/write memories are different + if len(dir_mem_dict["write"]) > 1: + raise Exception("a tas can only write to one memory") + # calc diff to see if write and read mems are diff + set_diff = set(dir_mem_dict["read"]) - set(dir_mem_dict["write"]) + DMA_src_dest_list = [] + for mem in list(set_diff): + DMA_src_dest_list.append((mem, dir_mem_dict["write"][0])) + return DMA_src_dest_list + + # ------------------------------ + # Functionality: + # Add a DMA block for the task. TODO: add code. + # ------------------------------ + def inject_DMA_blocks(self): + return 0 + + # possible implementations of DMA injection + # at the moment, it's comment out. + # TODO: uncomment and ensure it's correctness + """ + def inject_DMA_task_for_a_single_task(self, ex_dp:ExDesignPoint, task): + if config.VIS_GR_PER_GEN: vis_hardware(ex_dp) + task_s_DMA_s_src_dest_list = self.find_task_s_DMA_needs(ex_dp, task) + task_s_pe = ex_dp.get_blocks_of_task_by_block_type("pe") + bus_task_unloaded_list = [] #to keep track of what you already unloaded to avoid unloading again + for task_s_mem_src, task_s_mem_dest in task_s_DMA_s_src_dest_list: + if self.pe_hanging_off_of_src_block(task_s_pe, task_s_mem_src): + ic_neigh_of_src = [neigh for neigh in task_s_mem_src.neighs if neigh.type == "ic"][0] + DMA_block = self.database.sample_DMA_blocks() + DMA_block = self.database.copy_SOC(DMA_block, task_s_mem_src) + DMA_block.connect(ic_neigh_of_src) + self.copy_task(task) + + + src.load((task,"write")) + src.unload((task, "read")) + ic_neigh_of_src.unload((task, "read")) + + ic_neigh_of_src = [neigh for neigh in src.neighs if neigh.type == "ic"][0] + ic_neigh_of_des = [neigh for neigh in dest.neighs if neigh.type == "ic"][0] + if + + + + reads_work_ratio = src.get_task_s_work_ratio_by_task_and_dir(task, "read") + # unload task from the src memory and the immediate bus connected to it + + buses = ex_dp.hardware_graph.get_path_between_two_vertecies(src, dest)[1:-1] + for bus in buses: + if (bus, task) not in bus_task_unloaded_list: + bus.unload((task, "write")) + bus_task_unloaded_list.append((bus,task)) + + DMA_block = self.database.sample_DMA_blocks() + DMA_block = self.database.copy_SOC(DMA_block, src) + DMA_block.connect(ic_neigh_of_src) + + for task_to_read_from, work_ratio_value in reads_work_ratio.items(): + task_s_bytes_to_transfer = task.work * work_ratio_value + DMA_task = Task("DMA_from_" + task_to_read_from + "_to_" + task.name , task_s_bytes_to_transfer) + # load DMA task to the appropriate mem and buses + DMA_block.load((DMA_task,"loop_back"), {DMA_task.name:1}) + src.load((DMA_task, "read"), {task_to_read_from:1}) + ic_neigh_of_src.load((DMA_task, "read"), {task_to_read_from:1}) + for bus in buses: bus.load((DMA_task, "write"), {task.name:1}) + dest.load((DMA_task, "write"), {task.name: 1}) + # load destination and it's immediate bus with the task + dest.load((task, "read"), {DMA_task.name: work_ratio_value}) + ic_neigh_of_des.load((task, "read"), {DMA_task.name: work_ratio_value}) + + DMA_task.add_child(task) + parent = ex_dp.get_task_by_name(task_to_read_from) + parent.remove_child(task) + parent.add_child(DMA_task) + self.DMA_task_ctr += 1 + ex_dp.hardware_graph.update_graph(block_to_prime_with=ex_dp.get_blocks()[0]) + if config.VIS_GR_PER_GEN: vis_hardware(ex_dp) + # self.unload_read_mem_and_ic(ex_dp) + # self.load_tasks_to_read_mem_and_ic(ex_dp) + + if config.VIS_GR_PER_GEN: vis_hardware(ex_dp) + """ + """ + def inject_DMA_tasks_for_all_tasks(self, ex_dp:ExDesignPoint): + tasks = [task for task in ex_dp.get_tasks() if not self.task_dummy(task)] + for task in tasks: + self.inject_DMA_task_for_a_single_task(ex_dp, task) + task_s_DMA_s_src_dest_dict = {} + for task in tasks: + task_s_DMA_s_src_dest_dict[task] = self.find_task_s_DMA_if_necessary(ex_dp, task) + self.inject_DMA_task(task_s_DMA_s_src_dest_dict) + self.inject_DMA_blocks(task_s_DMA_s_src_dest_dict) + """ + + # ------------------------------ + # Functionality: + # find all whether pe is connected to another block + # Varaibles: + # pe: processing element to query + # src: the src block to see if pe is connected to + # ------------------------------ + def pe_hanging_of_of_src_block(self, pe, src): + return pe in src.neighs + + # ------------------------------ + # Functionality: + # add DMA for all the task that need it (their read/write memory is not the same) + # Varaibles: + # ex_dp: example design + # ------------------------------ + def inject_DMA_tasks_for_all_tasks(self, ex_dp: ExDesignPoint): + tasks = [task for task in ex_dp.get_tasks() if not task.is_task_dummy()] + for task in tasks: + self.inject_DMA_task_for_a_single_task(ex_dp, task) + self.inject_DMA_blocks() + + # generate and populate the queues that connect different hardware block, e.g., bus and memory, pe and bus, ... + def pipe_design(self, ex_dp, workload_to_blocks_map): + ex_dp.get_hardware_graph().pipe_design() + + # ------------------------------------------- + # Functionality: + # converts the example design point (ex_dp) to a simulatable design point (sim_dp) + # ex_dp and sim_dp are kept separate since sim_dp contains information about the simulation such + # as latency/power/area ... + # Variables: + # ex_dp: example design + # ------------------------------------------ + def convert_to_sim_des_point(self, ex_dp): + # generate simulation semantics such as task to block map (workload_to_blocks_map(), + # and workload to block schedule (tasks schedule) + + # generate task to hardware mapping + workload_to_blocks_map = self.gen_workload_to_blocks_from_blocks(ex_dp.get_blocks()) + self.pipe_design(ex_dp, workload_to_blocks_map) # populate the queues + + # generate task schedules + workload_to_pe_block_schedule = WorkloadToPEBlockSchedule() + for task in ex_dp.get_tasks(): + workload_to_pe_block_schedule.task_to_pe_block_schedule_list_sorted.append(TaskToPEBlockSchedule(task, 0)) + + # convert to teh sim design + return SimDesignPoint(ex_dp.get_hardware_graph(), workload_to_blocks_map, workload_to_pe_block_schedule) + + # ------------------------------------------- + # Functionality: + # generate task to block mapping (use by the simulator) + # Variables: + # blocks: hardware blocks to consider for mapping + # ------------------------------------------- + def gen_workload_to_blocks_from_blocks(self, blocks): + workload_to_blocks_map = WorkloadToHardwareMap() + # make task_to_block out of blocks + for block in blocks: + for task_dir, work_ratio in block.get_tasks_dir_work_ratio().items(): # get the task to task work ratio (gables) + task, dir = task_dir + task_to_blocks = workload_to_blocks_map.get_by_task(task) + if task_to_blocks: + task_to_blocks.block_dir_workRatio_dict[(block, dir)] = work_ratio + else: + task_to_blocks = TaskToBlocksMap(task, {(block, dir): work_ratio}) + workload_to_blocks_map.tasks_to_blocks_map_list.append(task_to_blocks) + return workload_to_blocks_map + + # generate light systems (i.e., systems that are not FARSI compatible, but have all the characterizations necessary + # to generate FARSI systems from). This steps allow us to quickly generate all the systems + # without worrying about generating FARSI systems + def light_system_gen_exhaustively(self, system_workers, database): + all_systems = exhaustive_system_generation(system_workers, database.db_input.gen_config) # not farsi compatible + return all_systems + + # generate FARSI systems from the light systems inputted. This is used for exhaustive search + # comparison + def FARSI_system_gen_exhaustively(self, light_systems, system_workers): + mapping_process_cnt = system_workers[0] + mapping_process_id = system_workers[1] + FARSI_gen_process_cnt = system_workers[2] + FARSI_gen_process_id = system_workers[3] + + FARSI_systems = [] # generating FARSI compatible systems from all systems + + # upper_bound and lower_bound used for allowing for parallel execution + #light_systems_lower_bound = 0 + #light_systems_upper_bound = len(light_systems) + light_systems_lower_bound = int(FARSI_gen_process_id*(len(light_systems) /FARSI_gen_process_cnt)) + light_systems_upper_bound = int(min((FARSI_gen_process_id + 1) * (len(light_systems) / FARSI_gen_process_cnt), len(light_systems))) + + num_of_sys_to_gen = light_systems_upper_bound - light_systems_lower_bound + + # adding and loading the pes + for idx, system in enumerate(light_systems[light_systems_lower_bound:light_systems_upper_bound]): + + if (idx % max(int(num_of_sys_to_gen/10.0),1) == 0): # debugging + print("---------" + str(idx/num_of_sys_to_gen) + "% of systems generated for process " + + str(mapping_process_id) + "_" + str(FARSI_gen_process_id)) + + buses_name = system.get_BUS_list() + pe_primer = None + pe_idx_global = 0 + mem_idx_global = 0 + for bus_idx, bus_name in enumerate(buses_name): + prev_bus_block = None + if not (bus_idx == 0): + prev_bus_block = bus_block + bus_block = self.database.sample_similar_block(self.database.get_block_by_name(bus_name)) + if not (prev_bus_block is None): + bus_block.connect(prev_bus_block) + + pe_neigh_names = system.get_bus_s_pe_neighbours(bus_idx) + for pe_idx, pe_name in enumerate(pe_neigh_names): + pe_block = self.database.sample_similar_block(self.database.get_block_by_name(pe_name)) + pe_primer = pe_block # can be any pe block. used to generate the hardware graph + pe_block.connect(bus_block) + tasks = [self.database.get_task_by_name(task_) for task_ in system.get_pe_task_set()[pe_idx_global]] + for task in tasks: + pe_block.load_improved(task, task) + pe_idx_global += 1 + + """ + # adding and loading the memories + for bus_idx, bus_name in enumerate(buses_name): + prev_bus_block = None + if not (bus_idx == 0): + prev_bus_block = bus_block + bus_block = self.database.sample_similar_block(self.database.get_block_by_name(bus_name)) + if not (prev_bus_block is None): + bus_block.connect(prev_bus_block) + """ + mem_neigh_names = system.get_bus_s_mem_neighbours(bus_idx) + for mem_idx, mem_name in enumerate(mem_neigh_names): + mem_block = self.database.sample_similar_block(self.database.get_block_by_name(mem_name)) # must sample similar block otherwise, we always get the same exact block + mem_block.connect(bus_block) + #if mem_idx_global >= len(system.get_mem_task_set()): + # print("what") + tasks = [self.database.get_task_by_name(task_) for task_ in system.get_mem_task_set()[mem_idx_global]] + for task in tasks: + for child in task.get_children(): + mem_block.load_improved(task, child) + mem_idx_global += 1 + + # generate a hardware graph and load read mem and ic + hardware_graph = HardwareGraph(pe_primer) + ex_dp = ExDesignPoint(hardware_graph) + self.load_tasks_to_read_mem_and_ic(ex_dp) + ex_dp.sanity_check() + FARSI_systems.append(ex_dp) + + #for farsi_system in FARSI_systems: + + print("----- all system generated for process: " + str(mapping_process_id) + "_" + str(FARSI_gen_process_id)) + return FARSI_systems + + # ------------------------------ + # Functionality: + # generate initial design from the specified parsed design + # ------------------------------ + def gen_specific_parsed_ex_dp(self, database): + # the hw_g parsed from the csv has double connections, i.e,. mem to pe and pe to me. However, we only + # need to specify it one way (since connecting pe to mem, means automatically connecting mem to pe as well). + # This function get rids of these duplications. This is not optional. + def filter_duplicates(hw_g): + seen_connections = [] # keep track of connections already seen to avoid double counting + for blk, neighs in hw_g.items(): + for neigh in neighs: + if not (blk, neigh) in seen_connections: + seen_connections.append((neigh, blk)) + + for el in seen_connections: + key = el[0] + value = el[1] + del hw_g[key][value] + + # get the (light, i.e, in sting format) hardware_graph and task_to_hw_mappings from csv + hw_g = database.db_input.get_parsed_hardware_graph() + # get rid of mutual connection, as we already connect both ways with "connect" APIs + filter_duplicates(hw_g) + task_to_hw_mapping = database.db_input.get_parsed_task_to_hw_mapping() + + # use the hw_g to generate the topology (connect blocks together) + block_seen = {} # dictionary with key:blk_name, value of block object + for blk_name, children_names in hw_g.items(): + if blk_name not in block_seen.keys(): # generate and memoize + blk = self.database.gen_one_block_by_name(blk_name) + block_seen[blk_name] = blk + else: + blk = block_seen[blk_name] + for child_name in children_names: + if child_name not in block_seen.keys(): # generate and memoize + child = self.database.gen_one_block_by_name(child_name) + block_seen[child_name] = child + else: + child = block_seen[child_name] + blk.connect(child) + # get a block to prime with + if blk.type == "pe": + pe_primer = blk + + # load the blocks with tasks + for blk_name, task_names in task_to_hw_mapping.items(): + blk = block_seen[blk_name] + for task in task_names: + task_parent_name = task[0] + task_child_name = task[1] + task_parent = self.database.get_task_by_name(task_parent_name) + task_child = self.database.get_task_by_name(task_child_name) + blk.load_improved(task_parent, task_child) + + # generate a hardware graph and load read mem and ic + hardware_graph = HardwareGraph(pe_primer, "user_generated") # user_generated will deactivate certain checks + # noBUSerror that the tool can/must not generate + # but are ok if inputed by user + ex_dp = ExDesignPoint(hardware_graph) + self.load_tasks_to_read_mem_and_ic(ex_dp) + ex_dp.sanity_check() + ex_dp.hardware_graph.update_graph_without_prunning() + ex_dp.hardware_graph.pipe_design() + return ex_dp + + # ------------------------------ + # Functionality: + # generate initial design from the hardcoded specified design + # ------------------------------ + def gen_specific_hardcoded_ex_dp(self, database): + lib_relative_addr = config.database_data_dir.replace(config.home_dir, "") + lib_relative_addr_pythony_fied = lib_relative_addr.replace("/", ".") + # only supporting SLAM at the moment -> Iulian: starting to support any hardcoded design + workload_name = list(database.db_input.sw_hw_database_population["workloads"])[0] #supporting only one hardcoded workload + files_to_import = [lib_relative_addr_pythony_fied + ".hardcoded." + workload + ".input" for workload in [workload_name]] + imported_databases = [importlib.import_module(el) for el in files_to_import][0] + ex_dp = imported_databases.gen_hardcoded_design(database) + self.load_tasks_to_read_mem_and_ic(ex_dp) + return ex_dp + + def get_most_inferior_block(self, block, tasks): + return self.database.sample_most_inferior_blocks_by_type(block_type=block.type, tasks=self.__tasks) + + + def get_most_inferior_block_before_unrolling(self, block, tasks): + return self.database.sample_most_inferior_blocks_before_unrolling_by_type(block_type=block.type, tasks=block.get_tasks_of_block(), block=block) + + + def gen_specific_design_with_a_star_noc(self, database): + + num_of_hops = database.db_input.sw_hw_database_population["misc_knobs"]["num_of_hops"] # supporting only one hardcoded workload + #if database.db_input.parallel_task_names == {}: + if len(list(database.db_input.parallel_task_names.values())) == 0: + max_parallelism = 0 + else: + max_parallelism = max([len(el) for el in database.db_input.parallel_task_names.values()]) + pes = [] + mems = [] + ics = [] + parallel_task_names =database.db_input.parallel_task_names + for el in range (0, max(max_parallelism,1)): + pe = self.database.sample_most_inferior_blocks_by_type(block_type="pe", tasks=self.__tasks) + mem = self.database.sample_most_inferior_blocks_by_type(block_type="mem", tasks=self.__tasks) + if el == 0: + ic = self.database.sample_most_inferior_blocks_by_type(block_type="ic", tasks=self.__tasks) + if el == 0: + ic = self.database.sample_most_inferior_SOC(ic, "power") + ics.append(ic) + + pe = self.database.sample_most_inferior_SOC(pe, config.sorting_SOC_metric) + mem = self.database.sample_most_inferior_SOC(mem, config.sorting_SOC_metric) + pes.append(pe) + mems.append(mem) + + for pe,mem in zip(pes, mems): + pe.connect(ics[0]) + ics[0].connect(mem) + + serial_task_names = set([el.name for el in self.__tasks]) + for parallel_task_names in database.db_input.parallel_task_names.values(): + parallel_tasks = [database.get_task_by_name(tsk_name) for tsk_name in parallel_task_names] + for idx, task in enumerate(parallel_tasks): + serial_task_names = serial_task_names - set([task.name]) + pes[idx].load_improved(task, task) + for task_child in task.get_children(): + mems[idx].load_improved(task, task_child) # load memory with tasks + + serial_tasks = [database.get_task_by_name(tsk_name) for tsk_name in serial_task_names] + for task in serial_tasks: + pes[0].load_improved(task, task) + for task_child in task.get_children(): + mems[0].load_improved(task, task_child) # load memory with tasks + + # generate a hardware graph and load read mem and ic + """ + for pe in pes: + for tsk in pe.get_tasks_of_block(): + if "souurce" in tsk.name: + pe_ = pe + """ + + hardware_graph = HardwareGraph(pes[0]) + ex_dp = ExDesignPoint(hardware_graph) + self.load_tasks_to_read_mem_and_ic(ex_dp) + ex_dp.hardware_graph.update_graph() + ex_dp.hardware_graph.pipe_design() + return ex_dp + + + def gen_specific_design_with_hops_and_stars(self, database): + num_of_hops = database.db_input.sw_hw_database_population["misc_knobs"]["num_of_hops"] # supporting only one hardcoded workload + num_of_NoCs = database.db_input.sw_hw_database_population["misc_knobs"]["num_of_NoCs"] # supporting only one hardcoded workload + if num_of_hops > num_of_NoCs: + print("number of hops can't be greater than number of NoCs") + exit(0) + + #if database.db_input.parallel_task_names == {}: + if len(list(database.db_input.parallel_task_names.values())) == 0: + max_parallelism = 0 + else: + max_parallelism = max([len(el) for el in database.db_input.parallel_task_names.values()]) + pes = [] + mems = [] + ics = [] + parallel_task_names =database.db_input.parallel_task_names + for el in range (0, max(max_parallelism,1)): + pe = self.database.sample_most_inferior_blocks_by_type(block_type="pe", tasks=self.__tasks) + mem = self.database.sample_most_inferior_blocks_by_type(block_type="mem", tasks=self.__tasks) + if el == 0: + ic = self.database.sample_most_inferior_blocks_by_type(block_type="ic", tasks=self.__tasks) + if el == 0: + ic = self.database.sample_most_inferior_SOC(ic, "power") + ics.append(ic) + + pe = self.database.sample_most_inferior_SOC(pe, config.sorting_SOC_metric) + mem = self.database.sample_most_inferior_SOC(mem, config.sorting_SOC_metric) + pes.append(pe) + mems.append(mem) + + for pe,mem in zip(pes, mems): + pe.connect(ics[0]) + ics[0].connect(mem) + + serial_task_names = set([el.name for el in self.__tasks]) + for parallel_task_names in database.db_input.parallel_task_names.values(): + parallel_tasks = [database.get_task_by_name(tsk_name) for tsk_name in parallel_task_names] + for idx, task in enumerate(parallel_tasks): + serial_task_names = serial_task_names - set([task.name]) + pes[idx].load_improved(task, task) + for task_child in task.get_children(): + mems[idx].load_improved(task, task_child) # load memory with tasks + + listified_serial_task_names = list(serial_task_names) + listified_serial_task_names.sort() + serial_task_names = listified_serial_task_names + + added_tasks = max(0, num_of_hops-2) + serial_tasks = [database.get_task_by_name(tsk_name) for tsk_name in serial_task_names] + # dummy tasks in the middle + hoppy_task_names =database.db_input.hoppy_task_names + + last_non_hoppy_task = list(set(serial_task_names) - set(hoppy_task_names) - set(["synthetic_siink"]) - set(["synthetic_souurce"])) + last_non_hoppy_task.sort() + last_non_hoppy_task = last_non_hoppy_task[-1] + + for el in hoppy_task_names: + pe = self.database.sample_most_inferior_blocks_by_type(block_type="pe", tasks=self.__tasks) + mem = self.database.sample_most_inferior_blocks_by_type(block_type="mem", tasks=self.__tasks) + pe = self.database.sample_most_inferior_SOC(pe, config.sorting_SOC_metric) + mem = self.database.sample_most_inferior_SOC(mem, config.sorting_SOC_metric) + pes.append(pe) + mems.append(mem) + ic = self.database.sample_most_inferior_blocks_by_type(block_type="ic", tasks=self.__tasks) + ic = self.database.sample_most_inferior_SOC(ic, "power") + ics.append(ic) + pe.connect(ic) + ic.connect(mem) + + hoppy_tasks = [database.get_task_by_name(tsk_name) for tsk_name in hoppy_task_names] + for idx, task in enumerate(hoppy_tasks): + pes[-1*(1+idx)].load_improved(task, task) + for task_child in task.get_children(): + mems[-1*(1+idx)].load_improved(task, task_child) # load memory with tasks + + + # serial tasks + if num_of_hops > 1: + pe = self.database.sample_most_inferior_blocks_by_type(block_type="pe", tasks=self.__tasks) + mem = self.database.sample_most_inferior_blocks_by_type(block_type="mem", tasks=self.__tasks) + pe = self.database.sample_most_inferior_SOC(pe, config.sorting_SOC_metric) + mem = self.database.sample_most_inferior_SOC(mem, config.sorting_SOC_metric) + pes.append(pe) + mems.append(mem) + ic = self.database.sample_most_inferior_blocks_by_type(block_type="ic", tasks=self.__tasks) + ic = self.database.sample_most_inferior_SOC(ic, "power") + ics.append(ic) + pe.connect(ic) + ic.connect(mem) + + for idx, task in enumerate(serial_tasks): + if task.name in hoppy_task_names: + continue + if "souurce" in task.name: + idx_ =0 + elif "synthetic_0" in task.name: + idx_ = -1 + else: + idx_ = -1*(idx%2 ) + pes[idx_].load_improved(task, task) + for task_child in task.get_children(): + mems[idx_].load_improved(task, task_child) # load memory with tasks + + for idx, ic in enumerate(ics): + if idx == 0: + continue + ic.connect(ics[idx - 1]) + + if num_of_hops < num_of_NoCs: + if num_of_hops == 2: + ics[-1].connect(ics[0]) + if num_of_hops == 3: + ics[-2].connect(ics[0]) + #ics[1].connect(ics[-1]) + + + hardware_graph = HardwareGraph(pes[0]) + ex_dp = ExDesignPoint(hardware_graph) + self.load_tasks_to_read_mem_and_ic(ex_dp) + ex_dp.hardware_graph.update_graph() + ex_dp.hardware_graph.pipe_design() + return ex_dp + + + def gen_specific_design_with_hops(self, database): + num_of_hops = database.db_input.sw_hw_database_population["misc_knobs"]["num_of_hops"] # supporting only one hardcoded workload + pes = [] + mems = [] + ics = [] + + for el in range (0, num_of_hops): + pe = self.database.sample_most_inferior_blocks_by_type(block_type="pe", tasks=self.__tasks) + mem = self.database.sample_most_inferior_blocks_by_type(block_type="mem", tasks=self.__tasks) + ic = self.database.sample_most_inferior_blocks_by_type(block_type="ic", tasks=self.__tasks) + pe = self.database.sample_most_inferior_SOC(pe, config.sorting_SOC_metric) + mem = self.database.sample_most_inferior_SOC(mem, config.sorting_SOC_metric) + ic = self.database.sample_most_inferior_SOC(ic, "power") + pes.append(pe) + mems.append(mem) + ics.append(ic) + + for pe,mem, ic in zip(pes, mems, ics): + pe.connect(ic) + ic.connect(mem) + + for idx, ic in enumerate(ics): + if idx == 0: + continue + ic.connect(ics[idx-1]) + + + self.load_tasks_to_pe_and_write_mem_for_hops(pes, mems, self.__tasks, num_of_hops) + + # generate a hardware graph and load read mem and ic + hardware_graph = HardwareGraph(pe) + ex_dp = ExDesignPoint(hardware_graph) + self.load_tasks_to_read_mem_and_ic(ex_dp) + ex_dp.hardware_graph.update_graph() + ex_dp.hardware_graph.pipe_design() + return ex_dp + + + + #print("ok") + + + # ------------------------------ + # Functionality: + # generate initial design. Used to boot strap the exploration + # ------------------------------ + def gen_init_des(self): + pe = self.database.sample_most_inferior_blocks_by_type(block_type="pe", tasks=self.__tasks) + mem = self.database.sample_most_inferior_blocks_by_type(block_type="mem", tasks=self.__tasks) + ic = self.database.sample_most_inferior_blocks_by_type(block_type="ic", tasks=self.__tasks) + pe = self.database.sample_most_inferior_SOC(pe, config.sorting_SOC_metric) + mem = self.database.sample_most_inferior_SOC(mem, config.sorting_SOC_metric) + ic = self.database.sample_most_inferior_SOC(ic, "power") + + # connect blocks together + pe.connect(ic) + ic.connect(mem) + + self.load_tasks_to_pe_and_write_mem(pe, mem, self.__tasks) + + # generate a hardware graph and load read mem and ic + hardware_graph = HardwareGraph(pe) + ex_dp = ExDesignPoint(hardware_graph) + self.load_tasks_to_read_mem_and_ic(ex_dp) + ex_dp.hardware_graph.update_graph() + ex_dp.hardware_graph.pipe_design() + return ex_dp + + # ------------------------------ + # Functionality: + # find the hot blocks' block bottleneck. This means we find the bottleneck associated + # with the hottest (longest latency) block. + # Variables: + # ex_dp: example design + # hot_kernel_block_bottlneck: block bottlneck + # ------------------------------ + def find_cores_hot_kernel_blck_bottlneck(self, ex_dp:ExDesignPoint, hot_kernel_blck_bottleneck:Block): + # iterate through the blocks and compare agains the name + for block in ex_dp.get_blocks(): + if block.instance_name == hot_kernel_blck_bottleneck.instance_name: + return block + raise Exception("did not find a corresponding block in the ex_dp with the name:" + + str(hot_kernel_blck_bottleneck.instance_name)) + + # ------------------------------ + # Functionality: + # for debugging purposes. Making sure that each ic (noc) has at least one connected block. + # Variables: + # ex_dp: example design + # ------------------------------ + def sanity_check_ic(self, ex_dp:ExDesignPoint): + ic_blocks = ex_dp.get_blocks_by_type("ic") + for block in ic_blocks: + pe_neighs = [block for block in block.neighs if block.type == "pe"] + if len(pe_neighs) == 0 : + ex_dp.hardware_graph.update_graph(block_to_prime_with=block) + if config.VIS_GR_PER_GEN: vis_hardware(ex_dp) + vis_hardware(ex_dp) + raise Exception("this design is not valid") + + # Get the immediate supperior/inferior block according to the metric (and direction). + # This version is the fast version (by the use of caching), however, it's unintuive + # Variables: + # metric_dir: -1 (increase) and 1 (decrease) + # block: the block to improve/de-improve + def get_immediate_block_fast(self, block, metric, metric_dir, tasks): + imm_blck_non_unique = self.database.up_sample_down_sample_block_fast(block, metric, metric_dir, tasks)[0] # get the first value + blkL = self.database.cached_block_to_blockL[imm_blck_non_unique] + imm_blck = self.database.cast(blkL) + return self.database.copy_SOC(imm_blck, block) + + def get_immediate_block_multi_metric_fast(self, block, metric, sorted_metric_dir, tasks): + imm_blck_non_unique = self.database.up_sample_down_sample_block_multi_metric_fast(block, sorted_metric_dir, tasks)[0] # get the first value + blkL = self.database.cached_block_to_blockL[imm_blck_non_unique] + imm_blck = self.database.cast(blkL) + return self.database.copy_SOC(imm_blck, block) + + + + # Get the immediate supperior/inferior block according to the metric (and direction) + # Variables: + # metric_dir: -1 (increase) and 1 (decrease) + # block: the block to improve/de-improve + def get_immediate_block_slow(self, block, metric, metric_dir, tasks): + imm_blck = self.database.up_sample_down_sample_block(block, metric, metric_dir, tasks)[0] # get the first value + return self.database.copy_SOC(imm_blck, block) + + # Get the immediate superior/inferior block according to the metric (and direction) that already + # exist in the design + # Variables: + # metric_dir: -1 (increase) and 1 (decrease) + # block: the block to improve/de-improve + def get_equal_immediate_block_present_slow(self, ex_dp, block, metric, metric_dir, tasks): + imm_blcks = self.database.equal_sample_up_sample_down_sample_block(block, metric, metric_dir, tasks) # get the first value + des_blocks = ex_dp.get_blocks() + for block_ in imm_blcks: + for des_block in des_blocks: + if block_.get_generic_instance_name() == des_block.get_generic_instance_name() and \ + not (block.instance_name == des_block.instance_name): + return des_block + + return block + + def get_all_compatible_blocks_of_certain_char(self, ex_dp, block, metric, metric_dir, tasks, mode): + metric_dir = -1 + imm_blcks_non_unique_1 = [el for el in self.database.equal_sample_up_sample_down_sample_block_fast(block, metric, metric_dir, + tasks) if not isinstance(el,str)] # get the first value + + + metric_dir = 1 + imm_blcks_non_unique_2 = [el for el in self.database.equal_sample_up_sample_down_sample_block_fast(block, metric, metric_dir, + tasks) if not isinstance(el,str)] # get the first value + + results = [] + if mode =="frequency_modulation": + for blk in imm_blcks_non_unique_1 + imm_blcks_non_unique_2: + if not blk.get_block_freq() == block.get_block_freq() and (blk.subtype == block.subtype) and (blk.get_block_bus_width() == block.get_block_bus_width()) and blk.get_loop_itr_cnt() == block.get_loop_itr_cnt(): + results.append(blk) + elif mode =="allocation": + for blk in imm_blcks_non_unique_1 + imm_blcks_non_unique_2: + if not blk.subtype == block.subtype and blk.get_block_freq() == block.get_block_freq() and (blk.get_block_bus_width() == block.get_block_bus_width()) and blk.get_loop_itr_cnt() == block.get_loop_itr_cnt(): + results.append(blk) + elif mode == "bus_width_modulation": + for blk in imm_blcks_non_unique_1 + imm_blcks_non_unique_2: + if not (blk.get_block_bus_width() == block.get_block_bus_width() and blk.subtype == block.subtype and blk.get_block_freq() == block.get_block_freq() and blk.get_loop_itr_cnt() == block.get_loop_itr_cnt()): + results.append(blk) + elif mode == "loop_iteration_modulation": + for blk in imm_blcks_non_unique_1 + imm_blcks_non_unique_2: + if not blk.get_loop_itr_cnt() == block.get_loop_itr_cnt() and blk.subtype == block.subtype and blk.get_block_freq() == block.get_block_freq() and (blk.get_block_bus_width() == block.get_block_bus_width()): + results.append(blk) + else: + print("this mode:" + mode + "not supported") + exit(0) + + return results + + + def get_equal_immediate_blocks_present_fast(self, ex_dp, block, metric, metric_dir, tasks): + imm_blcks_non_unique = self.database.equal_sample_up_sample_down_sample_block_fast(block, metric, metric_dir, + tasks) # get the first value + + imm_blcks_names = [blck.get_generic_instance_name() for blck in imm_blcks_non_unique] + #all_compatible_blocks = [blck.get_generic_instance_name() for blck in self.database.find_all_compatible_blocks_fast(block.type, tasks)] + + blocks_present = ex_dp.get_blocks() + result_blocks = [] + for block_present in blocks_present: + if not (block.instance_name == block_present.instance_name) and block_present.get_generic_instance_name() in imm_blcks_names: + result_blocks.append(block_present) + + if len(result_blocks) == 0: + result_blocks = [block] + return result_blocks + + def get_equal_immediate_block_present_multi_metric_fast(self, ex_dp, block, metric, sorted_metric_dir, tasks): + imm_blcks_non_unique = self.database.equal_sample_up_sample_down_sample_block_multi_metric_fast(block, metric, sorted_metric_dir, + tasks) # get the first value + des_blocks = ex_dp.get_blocks() + for block_ in imm_blcks_non_unique: + for des_block in des_blocks: + if block_.get_generic_instance_name() == des_block.get_generic_instance_name() and \ + not (block.instance_name == des_block.instance_name): + return des_block + + return block + + + # ------------------------------ + # Functionality: + # transforming the current design (by either applying a swap, for block improvement or a split, for reducing block contention) + # Variables: + # move_name: type of the move to apply (currently only supporting swap or split) + # sup_block: (swaper block) block to swap with + # hot_bloc: block to be swapped + # des_tup: (design tuple) containing ex_dp, sim_dp + # mode: not used any more. TODO: get rid of this + # hot_kernel_pos: position of the hot kenrel. Used for finding the hot kernel. + # ------------------------------ + def move_to(self,move_name , sup_block, hot_block, des_tup, mode, hot_kernel_pos): + if move_name == "swap": + if not hot_block.type == "ic": + self.unload_buses(des_tup[0]) # unload buses + else: + self.unload_read_buses(des_tup[0]) # unload buses + self.swap_block(hot_block, sup_block) # swap + self.mig_cur_tasks_of_src_to_dest(hot_block, sup_block) # migrate tasks over + des_tup[0].hardware_graph.update_graph(block_to_prime_with=sup_block) # update the hardware graph + self.unload_buses(des_tup[0]) # unload buses + self.unload_read_mem(des_tup[0]) # unload memories + if config.VIS_GR_PER_GEN: vis_hardware(des_tup[0]) + elif move_name == "split": + self.unload_buses(des_tup[0]) # unload buss + self.reduce_contention(des_tup, mode, hot_kernel_pos) # reduce contention by allocating an extra block + else: + raise Exception("move:" + move_name + " is not supported") + + + """ + # we assume that DRAM can be hanging only from one router. + # This check ensures that we only have dram present around one router + def improve_locality(self, ex_dp, move_to_apply): + # find all the drams and their ics + src_block = move_to_apply.get_block_ref() # memory to move + src_block_ic = [el for el in src_block.get_neights() if el.subtype == "ic"] + dest_block = move_to_apply.get_des_block() # destination Ic + src_block.disconnect(src_block_ic) + dest_block.connect(src_block) + """ + + # we assume that DRAM can be hanging only from one router. + # This check ensures that we only have dram present around one router + def fix_dram(self, ex_dp): + # find all the drams and their ics + all_drams = [el for el in ex_dp.get_hardware_graph().get_blocks() if el.subtype == "dram"] + dram_ics = [] + for dram in all_drams: + for neigh in dram.get_neighs(): + if neigh.type == "ic": + dram_ics.append(neigh) + + for ic in dram_ics: + neighs_has_no_pe = len([neigh for neigh in ic.get_neighs() if neigh.type == "pe"]) == 0 + if neighs_has_no_pe: + system_ic = ic + break + + # find or allocatd (if none exist) the system IC (which DRAM hangs from + system_ic = None + for ic in dram_ics: + ic_has_no_pe = len([neigh for neigh in ic.get_neighs() if neigh.type =="pe"]) == 0 + if ic_has_no_pe: + system_ic = ic + break + if system_ic == None: + dram_ics_sorted = sorted(dram_ics, key=attrgetter("peak_work_rate"), reverse=True) # + system_ic_alike = dram_ics_sorted[0] + system_ic = self.allocate_similar_block(system_ic_alike, []) + + # iterate through drams and connect them to he system ic + for dram_ic in dram_ics: + if dram_ic not in system_ic.get_neighs(): + dram_ic.connect(system_ic) + hanging_drams = [el for el in dram_ic.get_neighs() if el.subtype == "dram"] + for hanging_dram in hanging_drams: + hanging_dram.disconnect(dram_ic) + hanging_dram.connect(system_ic) + + + + + + # ------------------------------ + # Functionality: + # By applying the move, the initial design is transformed to a new design + # Variables: + # des_tup is the design tup, concretely (design, simulated design) + # ------------------------------ + def apply_move(self, des_tup, move_to_apply): + ex_dp, sim_dp = des_tup + blck_ref = move_to_apply.get_block_ref() + #print("applying move " + move.name + " -----" ) + #pre_moved_ex = copy.deepcopy(ex_dp) # this is just for move sanity checking + gc.disable() + pre_moved_ex = cPickle.loads(cPickle.dumps(ex_dp, -1)) + gc.enable() + + if move_to_apply.get_transformation_name() == "identity": + return ex_dp, True + if move_to_apply.get_transformation_name() == "swap": + if not blck_ref.type == "ic": self.unload_buses(ex_dp) # unload buses + else: self.unload_read_buses(ex_dp) # unload buses + succeeded = self.swap_block(blck_ref, move_to_apply.get_des_block()) + #succeeded = self.mig_cur_tasks_of_src_to_dest(move_to_apply.get_block_ref(), move_to_apply.get_des_block()) # migrate tasks over + succeeded = self.mig_tasks_of_src_to_dest(ex_dp, blck_ref, + move_to_apply.get_des_block(), move_to_apply.get_tasks()) + self.unload_buses(ex_dp) # unload buses + self.unload_read_mem(ex_dp) # unload memories + ex_dp.hardware_graph.update_graph(block_to_prime_with=move_to_apply.get_des_block()) # update the hardware graph + if config.DEBUG_SANITY:ex_dp.sanity_check() # sanity check + elif move_to_apply.get_transformation_name() == "split": + self.unload_buses(ex_dp) # unload buss + if blck_ref.type == "ic": + succeeded,_ = self.fork_bus(ex_dp, blck_ref, move_to_apply.get_tasks()) + else: + succeeded,_ = self.fork_block(ex_dp, blck_ref, move_to_apply.get_tasks()) + if config.DEBUG_SANITY:ex_dp.sanity_check() # sanity check + elif move_to_apply.get_transformation_name() == "split_swap": + # first split + previous_designs_blocks = ex_dp.get_blocks() + self.unload_buses(ex_dp) # unload buss + if blck_ref.type == "ic": + succeeded, new_block = self.fork_bus(ex_dp, blck_ref, move_to_apply.get_tasks()) + else: + succeeded, new_block = self.fork_block(ex_dp, blck_ref, move_to_apply.get_tasks()) + ex_dp.hardware_graph.update_graph(block_to_prime_with=blck_ref) # update the hardware graph + if succeeded: + """ + current_blocks = ex_dp.get_blocks() + new_block = list(set(current_blocks) - set(previous_designs_blocks))[0] + + # we need to do this, because sometimes the migrant tasks are swap and hence the new block gets the + # other migrant tasks + if len(new_block.get_tasks_of_block()) == 1: + block_to_swap = new_block + else: + block_to_swap =blck_ref + """ + block_to_swap = new_block + if config.DEBUG_SANITY:ex_dp.sanity_check() # sanity check + succeeded = self.swap_block(block_to_swap, move_to_apply.get_des_block()) + succeeded = self.mig_tasks_of_src_to_dest(ex_dp, block_to_swap, + move_to_apply.get_des_block(), move_to_apply.get_tasks()) + self.unload_buses(ex_dp) # unload buses + self.unload_read_mem(ex_dp) # unload memories + ex_dp.hardware_graph.update_graph(block_to_prime_with=move_to_apply.get_des_block()) # update the hardware graph + else: + print("something went wrong with split swap") + + if config.DEBUG_SANITY:ex_dp.sanity_check() # sanity check + elif move_to_apply.get_transformation_name() == "cleanup": + self.unload_buses(ex_dp) # unload buses + self.unload_read_mem(ex_dp) # unload memories + if move_to_apply.get_transformation_sub_name() == "absorb": + ref_block_neighs = blck_ref.get_neighs() + for block in ref_block_neighs: + block.disconnect(blck_ref) + block.connect(move_to_apply.get_des_block()) + ex_dp.hardware_graph.update_graph_without_prunning(block_to_prime_with=move_to_apply.get_des_block()) # update the hardware graph + succeeded = True + else: + if not blck_ref.type == "ic": # ic migration is not supported + succeeded = self.mig_tasks_of_src_to_dest(ex_dp, blck_ref, move_to_apply.get_des_block(), move_to_apply.get_tasks()) + + ex_dp.hardware_graph.update_graph(block_to_prime_with=move_to_apply.get_des_block()) # update the hardware graph + else: + succeeded = False + if config.DEBUG_SANITY:ex_dp.sanity_check() # sanity check + elif move_to_apply.get_transformation_name() == "migrate" or move_to_apply.get_transformation_name() == "cleanup": + self.unload_buses(ex_dp) # unload buses + self.unload_read_mem(ex_dp) # unload memories + if not blck_ref.type == "ic": # ic migration is not supported + succeeded = self.mig_tasks_of_src_to_dest(ex_dp, blck_ref, move_to_apply.get_des_block(), move_to_apply.get_tasks()) + + ex_dp.hardware_graph.update_graph(block_to_prime_with=move_to_apply.get_des_block()) # update the hardware graph + else: + succeeded = False + if config.DEBUG_SANITY:ex_dp.sanity_check() # sanity check + elif move_to_apply.get_transformation_name() == "dram_fix": + self.unload_buses(ex_dp) # unload buses + self.unload_read_mem(ex_dp) # unload memories + self.fix_dram(ex_dp) + ex_dp.hardware_graph.update_graph_without_prunning() # update the hardware graph + succeeded = True + elif move_to_apply.get_transformation_name() == "transfer": + self.unload_buses(ex_dp) # unload buses + self.unload_read_mem(ex_dp) # unload memories + src_block = blck_ref # memory to move + src_block_ic = [el for el in src_block.get_neighs() if el.subtype == "ic"][0] + dest_block = move_to_apply.get_des_block() # destination Ic + src_block.disconnect(src_block_ic) + dest_block_ic = [el for el in dest_block.get_neighs() if el.subtype == "ic"][0] + dest_block_ic.connect(src_block) + ex_dp.hardware_graph.update_graph(src_block) # update the hardware graph + succeeded = True + elif move_to_apply.get_transformation_name() == "routing": + self.unload_buses(ex_dp) # unload buses + self.unload_read_mem(ex_dp) # unload memories + src_block = blck_ref # memory to move + dest_block = move_to_apply.get_des_block() # destination Ic + src_block.connect(dest_block) + ex_dp.hardware_graph.update_graph(dest_block) # update the hardware graph + succeeded = True + else: + raise Exception("transformation :" + move_to_apply.get_transformation_name() + " is not supported") + + ex_dp.hardware_graph.pipe_design() + return ex_dp, succeeded + + # ------------------------------ + # Functionality: + # Relax the bottleneck either by: 1.using an improved block (swap) from the DB 2.reducing contention (splitrr) + # Variables: + # des_tup: design tuple containing (ex_dp, sim_dp) + # mode: whether to use hot_kernel or not for deciding the move + # hot_kernel_pos: position of the hot kernel + # ------------------------------ + def relax_bottleneck(self, des_tup, mode, hot_kernel_pos = 0): + split_coeff = 5 # determines the probability to pick split (directing the hill-climber to pick a move) + swap_coeff = 1 # determines the probability to pick swap + if mode == "hot_kernel": + self.unload_read_mem(des_tup[0]) + + # determine if swap is beneficial (if there is a better hardware block in the data base + # that can improve the current designs performance) + swap_beneficial, sup_block, hot_block = self.block_is_improvable(des_tup, hot_kernel_pos) + if config.DEBUG_FIX: random.seed(0) + else: time.sleep(.00001), random.seed(datetime.now().microsecond) + + # if swap is beneficial, (possibly) give swap a shot + if swap_beneficial: + print(" block to improve is of type" + hot_block.type) + # favoring swap for non pe's and split for pes (saving financial cost) + + if not hot_block.type == "pe": + split_coeff = 1 + swap_coeff = 5 + else: + # favoring swap over swap to avoid + # unnecessary customizations. + split_coeff = 5 + swap_coeff = 1 + move_to_choose = random.choice(swap_coeff*["swap"]+ split_coeff*["split"]) + else: move_to_choose = "split" + self.move_to(move_to_choose, sup_block, hot_block, des_tup, mode, hot_kernel_pos) + else: + raise Exception("mode:" + mode + " not defined" ) + + # ------------------------------ + # Functionality: + # swaping a block with an improved one to improve the performance + # Variables: + # swapee: the block to swap + # swaper: the block to swap with + # ------------------------------ + def swap_block(self, swapee, swaper): + # find and attache a similar block + neighs = swapee.neighs[:] + for neigh in neighs: + neigh.disconnect(swapee) + for neigh in neighs: + neigh.connect(swaper) + + # ------------------------------ + # Functionality: + # determine wether there is block in the database (of the same kind) that is superior + # Variables: + # des_tup: design tuple containing ex_dp, sim_dp + # hot_kernel_post: position of the hottest kernel. Helps locating the kernel. + # ------------------------------ + def block_is_improvable(self, des_tup, hot_kernel_pos): + ex_dp, sim_dp = des_tup + hot_blck = self.sim_dp.get_dp_stats().get_hot_block_system_complex(des_tup, hot_kernel_pos) + hot_blck_synced = self.find_cores_hot_kernel_blck_bottlneck(ex_dp, hot_blck) + new_block = self.database.up_sample_blocks(hot_blck_synced, "immediate_superior", hot_blck.get_tasks_of_block()) # up_sample = finding a superior block + if self.boost_SOC: new_block = self.database.up_sample_SOC(new_block, config.sorting_SOC_metric) + else: block = self.database.copy_SOC(new_block, hot_blck_synced) + if block.get_generic_instance_name() != hot_blck_synced.get_generic_instance_name(): + return True, new_block, hot_blck_synced + return False, "_", "_" + + # ------------------------------ + # Functionality: + # clustering tasks bases on their dependencies, i.e, having the same child or parent (this is to improve locality) + # Variables: + # task: task of interest to look at. + # num_of_tasks_to_migrate: how many task to migrate to a new block (if we decide to split) + # residing_tasks_on_pe: tasks that are currently occupied the processing element of interest + # clusters: clusters already formed (This is because the helper is called recursively) + # ------------------------------ + def cluster_tasks_based_on_dependency_helper(self, task, num_of_tasks_to_migrate, residing_tasks_on_pe, residing_tasks_on_pe_copy, clusters): + if num_of_tasks_to_migrate == 0: + return 0 + else: + # go through the children and if they are on the pe and not already selected in the cluster + task_s_children_on_pe = [] + for child in task.get_children(): + if child in residing_tasks_on_pe and child not in clusters[0]: + task_s_children_on_pe.append(child) + task_children_queue = [] # list to keep the childs for breadth first search + + # no children that satisfies the requirement (including no children at all) + if not task_s_children_on_pe: + task = random.choice(residing_tasks_on_pe_copy) + clusters[0].append(task) + residing_tasks_on_pe_copy.remove(task) + num_of_tasks_to_migrate -= 1 + return self.cluster_tasks_based_on_dependency_helper(task, num_of_tasks_to_migrate, residing_tasks_on_pe, residing_tasks_on_pe_copy, clusters) + + # iterate through children and add them to teh cluster + for child in task_s_children_on_pe: + if num_of_tasks_to_migrate == 0: + return 0 + task_children_queue.append(child) + clusters[0].append(child) + residing_tasks_on_pe_copy.remove(child) + num_of_tasks_to_migrate -= 1 + + # generate tasks to migrate + # and recursively call the helper + for child in task_children_queue: + num_of_tasks_to_migrate = self.cluster_tasks_based_on_dependency_helper(child, num_of_tasks_to_migrate, residing_tasks_on_pe, + residing_tasks_on_pe_copy, clusters) + if num_of_tasks_to_migrate == 0: + return 0 + + + def cluster_tasks_based_on_tasks_parallelism(self, ex_dp, sim_dp, task, block): + return self.find_parallel_tasks_of_task_in_block(ex_dp, sim_dp, task, block) + + def cluster_tasks_based_on_tasks_serialism(self, ex_dp, sim_dp, task, block): + return self.find_serial_tasks_of_task_in_block(ex_dp, sim_dp, task, block) + + # ------------------------------ + # Functionality: + # clustering tasks bases on their dependencies, i.e, having the same child or parent (this is to improve locality) + # Variables: + # residing_task_on_block: tasks that are already occupying the block (that we want to split + # num_clusters: how many clusters to generate for migration. + # ------------------------------ + def cluster_tasks_based_on_tasks_dependencies(self, task_ref, residing_tasks_on_block, num_clusters): + cluster_0 = [] + clusters_length = int(len(residing_tasks_on_block)/num_clusters) + residing_tasks_copy = residing_tasks_on_block[:] + for tsk in residing_tasks_copy: + if tsk.name == task_ref.name: + ref_task = tsk + break + + #ref_task = random.choice(residing_tasks_on_block) + + cluster_0.append(ref_task) + residing_tasks_copy.remove(ref_task) + tasks_to_add_pool = [ref_task] # list containing all the tasks that have been added so far. + # we sample from it to pick a ref block with some condition + + if config.DEBUG_FIX: random.seed(0) + else: time.sleep(.00001); random.seed(datetime.now().microsecond) + + break_completely = False + # generate clusters + while len(cluster_0) < clusters_length: + ref_tasks_family = ref_task.get_family() + family_on_block = set(ref_tasks_family).intersection(set(residing_tasks_copy)) + family_on_block_not_already_in_cluster = family_on_block - set(cluster_0) + # find tasks family to migrate (task family is used for clustering to improve locality) + tasks_to_add = list(family_on_block_not_already_in_cluster) + for task in tasks_to_add: + if len(cluster_0) < clusters_length: + cluster_0.append(task) + residing_tasks_copy.remove(task) + else: + break_completely = True + break + if break_completely: + break + if tasks_to_add: ref_task = random.choice(tasks_to_add) + else: + if tasks_to_add_pool: + ref_task = random.choice(tasks_to_add_pool) + tasks_to_add_pool.remove(ref_task) + else: + ref_task = random.choice(residing_tasks_copy) + tasks_to_add_pool.append(ref_task) + tasks_to_add_pool = list(set(tasks_to_add_pool + tasks_to_add)) + + return [cluster_0, residing_tasks_copy] + + # ------------------------------ + # Functionality: + # clustering tasks if they share input/outputs to ensure high locality. + # Variables: + # residing_task_on_block: tasks that are already occupying the block (that we want to split + # num_clusters: how many clusters to generate for migration. + # ------------------------------ + def cluster_tasks_based_on_data_sharing(self, task_ref, residing_tasks_on_block, num_clusters): + if (config.tasks_clustering_data_sharing_method == "task_dep"): + return self.cluster_tasks_based_on_tasks_dependencies(task_ref, residing_tasks_on_block, num_clusters) + else: + raise Exception("tasks_clustering_data_sharing_method:" + config.tasks_clustering_data_sharing_method + "not defined") + + # ------------------------------ + # Functionality: + # random clustering of the tasks (for split). To introduce stochasticity in the system. + # Variables: + # residing_task_on_block: tasks that are already occupying the block (that we want to split + # num_clusters: how many clusters to generate for migration. + # ------------------------------ + def cluster_tasks_randomly(self, residing_tasks_on_pe, num_clusters=2): + clusters:List[List[Task]] = [[] for i in range(num_clusters)] + residing_tasks_on_pe_copy = residing_tasks_on_pe[:] + + if (config.DEBUG_FIX): + random.seed(0) + else: + time.sleep(.00001) + random.seed(datetime.now().microsecond) + + # pick some random number of tasks to migrate + num_of_tasks_to_migrate = random.choice(range(1, len(residing_tasks_on_pe_copy))) + # assign them to the first cluster + for _ in range(num_of_tasks_to_migrate): + task = random.choice(residing_tasks_on_pe_copy) + clusters[0].append(task) + residing_tasks_on_pe_copy.remove(task) + + # assign the rest of tasks to cluster 2 + for task in residing_tasks_on_pe_copy: + clusters[1].append(task) + + # get rid of the empty clusters (happens if num of clusters is one less than the total number of tasks) + return [cluster for cluster in clusters if cluster] + + # ------------------------------ + # Functionality: + # Migrate all the tasks, from the known src to known destination block + # Variables: + # dp: design + # dest_blck: destination block. Block to migrate the tasks to. + # src_blck: source block, where task currently lives in. + # tasks: the tasks to migrate + # ------------------------------ + def mig_tasks_of_src_to_dest(self, dp: ExDesignPoint, src_blck, dest_blck, tasks): + + # sanity check + for task in tasks: + matched_block = False # if we there is an equality between src_block and current_src_blck (these need to be the same ensuring that nothing has gone wrong) + # check if task is only a read task on ic or memory. If so raise the exception + # since we have already unloaded it. + task_dirs = [task_dir[1] for task_dir in src_blck.get_tasks_dir_work_ratio()] + if "read" in task_dirs and len(task_dirs) == 1: + raise NoMigrantException # this scenario is an exception, but non read scenarios are errors + + cur_src_blocks = dp.get_blocks_of_task(task) + for cur_src_blck in cur_src_blocks: + if dest_blck.type == cur_src_blck.type: # only pay attention to the block of the type similar to the one that you migrate to + if cur_src_blck == src_blck: + matched_block = True + if not matched_block: + print("task does not exist int the block") + raise NoMigrantError + for task in tasks: + self.mig_one_task(dest_blck, src_blck, task) + + # ------------------------------ + # Functionality: + # Migrate one task, from the known src to known destination block + # Variables: + # dest_blck: destination block. Block to migrate the tasks to. + # src_blck: source block, where task currently lives in. + # task: the task to migrate + # ------------------------------ + def mig_one_task(self, dest_blck, src_blck, task): + # prevent migrating to yourself + # random bugs would pop up if we do so + if src_blck.instance_name == dest_blck.instance_name: + return + #work_ratio = src_blck.get_task_s_work_ratio_by_task(task) + + if (src_blck.type == "pe"): + #print("blah blah the taks to migrate is" + task.name + " from " + src_blck.instance_name + "to" + dest_blck.instance_name) + + dest_blck.load_improved(task, task) + dir_ = "loop_back" + else: # write + family_tasks_on_block = src_blck.get_task_s_family_by_task_and_dir(task, "write") # just the names + #assert (len(work_ratio) == 1), "only can have write work ratio at this stage" + #print("blah blah the taks to migrate is" + task.name + " from " + src_blck.instance_name + "to" + dest_blck.instance_name) + if len(task.get_children()) ==0: + print("This should not be happening") + raise TaskNoChildrenError + + for family in family_tasks_on_block: + family_task = [child for child in task.get_children() if family== child.name][0] + dest_blck.load_improved(task, family_task) + dir_ = "write" + src_blck.unload((task, dir_)) + + # delete this later + tasks_left = [] + for task in src_blck.get_tasks_of_block(): + tasks_left.append(task.name) + #print("blah blah tasks left on src block is " + str(tasks_left)) + + if len(src_blck.get_tasks_of_block()) == 0: # prunning out the block + src_blck.disconnect_all() + + # ------------------------------ + # Functionality: + # Migrate all of the tasks from a known src to a known destination + # Variables: + # dest_blck: destination block. Block to migrate the tasks to. + # src_blck: source block, where task currently lives in. + # ------------------------------ + def mig_cur_tasks_of_src_to_dest(self, src_blck, dest_blck): + tasks = src_blck.get_tasks_of_block() + _ = [self.mig_one_task(dest_blck, src_blck, task) for task in tasks] + + # used when forking buses. + # a special procedure needs to be followed + # just as deciding on disconnecting and connecting the neighbouring PEs/MEMS and buses + def attach_alloc_block_to_bus(self, ex_dp, block_to_mimic, mimicee_block, migrant_tasks): + pe_neighs = [neigh for neigh in block_to_mimic.neighs if neigh.type == "pe"] + mem_neighs = [neigh for neigh in block_to_mimic.neighs if neigh.type == "mem"] + ic_neighs = [neigh for neigh in block_to_mimic.neighs if neigh.type == "ic"] + migrant_tasks_names = [el.get_name() for el in migrant_tasks] + for neigh in pe_neighs + mem_neighs: + neigh_tasks = [el.get_name() for el in neigh.get_tasks_of_block()] + # if no overlap skip + if len(list(set(migrant_tasks_names) - set(neigh_tasks) )) == len(migrant_tasks_names): + continue + else: + mimicee_block.connect(neigh) + block_to_mimic.disconnect(neigh) + + # randomly connect the ics. + # we can do better here + ctr = 0 + for ic in ic_neighs: + if (ctr %2) == 1: + ic.disconnect(block_to_mimic) + ic.connect(mimicee_block) + ctr +=1 + mimicee_block.connect(block_to_mimic) + + + + + # ------------------------------ + # attach through the shared bus + # block_to_mimic: this is the block which we use as a template for connections + # mimicee_block: this the block that needs to mimic the block_to_mimic block's behavior (by attaching to its blocks) + # ------------------------------ + def attach_alloc_block(self, ex_dp, block_to_mimic, mimicee_block): + if(block_to_mimic.type == "ic"): # connect to all the mems and pes of the block_to_mimic + pe_neighs = [neigh for neigh in block_to_mimic.neighs if neigh.type == "pe"] + mem_neighs = [neigh for neigh in block_to_mimic.neighs if neigh.type == "mem"] + _ = [mimicee_block.connect(neigh) for neigh in pe_neighs] + _ = [mimicee_block.connect(neigh) for neigh in mem_neighs] + #connect = block_to_mimic.connect(mimicee_block) + else: + bus_neighs = [neigh for neigh in block_to_mimic.neighs if neigh.type == "ic"] + assert(len(bus_neighs) == 1), "can only have one bus neighbour" + #print("attaching block to " + bus_neighs[0].instance_name) + mimicee_block.connect(bus_neighs[0]) + + ex_dp.hardware_graph.update_graph(block_to_prime_with=block_to_mimic) + #self.sanity_check_ic(ex_dp) + + # ------------------------------ + # Functionality: + # allocating similar block for split purposes. We find a block of the same type and also superior. + # Variables: + # old_blck: block to improve on. + # tasks: tasks residing on the old block. We need this because the new block should support all the + # these tasks + # ------------------------------ + def allocate_similar_block(self, old_block, tasks): + new_block = self.database.sample_similar_block(old_block) + new_block = self.database.copy_SOC(new_block, old_block) + return new_block + + # ------------------------------ + # Functionality: + # split the tasks between two clusters, one cluster having only one + # task (found based on the selected kernel) and the other, the rest + # Variables: + # tasks_of_block: tasks resident on the block to choose from + # selected_kernel: the selected kernel to separate + # ------------------------------ + def separate_a_task(self, tasks_of_block, selected_kernel): + clusters = [[],[]] + for task in tasks_of_block: + if selected_kernel.get_task().name == task.name: + clusters[0].append(task) + else: + clusters[1].append(task) + return clusters + + # ------------------------------ + # Functionality: + # cluster tasks. This decides which task to migrate together. Used in split. + # Variables: + # block: block where tasks resides in. + # num_clusters: how many clusters do we want to have. + # ------------------------------ + def cluster_tasks(self, ex_dp, sim_dp, block, selected_kernel, selection_mode): + if selection_mode == "random": + return self.cluster_tasks_randomly(block.get_tasks_of_block()) + elif selection_mode == "tasks_dependency": + return self.cluster_tasks_based_on_data_sharing(selected_kernel.get_task(), block.get_tasks_of_block(), 2) + elif selection_mode == "single": + return self.separate_a_task(block.get_tasks_of_block(), selected_kernel) + elif selection_mode == "single_serialism": + return [[selected_kernel.get_task()]] + #elif selection_mode == "batch": + # return self.cluster_tasks_based_on_data_sharing(selected_kernel.get_task(), block.get_tasks_of_block(), 2) + elif selection_mode == "batch": + return self.cluster_tasks_based_on_tasks_parallelism(ex_dp, sim_dp, selected_kernel.get_task(), block) + #elif selection_mode == "batch_nonparallel": + # return self.cluster_tasks_based_on_tasks_serialism(ex_dp, sim_dp, selected_kernel.get_task(), block) + else: + raise Exception("migrant clustering policy:" + selection_mode + "not supported") + + # ------------------------------ + # Functionality: + # Which one of the clusters to migrate. + # Variables: + # block: block where tasks resides in. + # ------------------------------ + def migrant_selection(self, ex_dp, sim_dp, block_after_unload, block_before_unload, selected_kernel, selection_mode): + if config.DEBUG_FIX: random.seed(0) + else: time.sleep(.00001), random.seed(datetime.now().microsecond) + try: + clustered_tasks = self.cluster_tasks(ex_dp,sim_dp, block_after_unload, selected_kernel, selection_mode) + except: + print("migrant selection went wrong. This needs to be fixed. Most likely occurs with random (As opposed to arch-aware) transformation_selection_mode") + return [] + return clustered_tasks[0] + + # ------------------------------ + # Functionality: + # see if we can split a block (i.e., there are more than one tasks on it) + # Variables: + # ex_dp: example design + # block: block of interest + # mode: depracated. TODO: get rid of it. + # ------------------------------ + def block_forkable(self, ex_dp, block): + if len(block.get_tasks_of_block()) < config.num_clusters: + return False + else: + return True + + def task_in_block(self, block, task_): + return (task_.name in [task.name for task in block.get_tasks_of_block()]) + + # ------------------------------ + # Functionality: + # finds another block similar to the input block and attaches itself similarly + # Variables: + # ex_dp: example design + # block: block of interest + # mode: deprecated. TODO: get rid of it. + # ------------------------------ + def fork_block(self, ex_dp, block, migrant_tasks_non_filtered): + + migrant_tasks = [] # filter the tasks that don' exist on the block. This usually happens because we might unload the bus/memory + # transformation gaurds + if len(block.get_tasks_of_block()) < config.num_clusters: + return False,"" + else: + for task__ in migrant_tasks_non_filtered: + # if tasks to migrate does not exist on the src block + if not(task__.name in [task.name for task in block.get_tasks_of_block()]): # this only should occur for reads, + # since we unload the reads + continue + else: + migrant_tasks.append(task__) + + if len(migrant_tasks) == 0: + return False,"" + + # find and attach a similar block + alloc_block = self.allocate_similar_block(block, migrant_tasks) + #print("allocate block name" + alloc_block.instance_name) + self.attach_alloc_block(ex_dp, block, alloc_block) + # migrate tasks + #self.mig_tasks_of_diff_blocks(ex_dp, migrant_tasks, alloc_block) + self.mig_tasks_of_src_to_dest(ex_dp, block, alloc_block, migrant_tasks) + ex_dp.hardware_graph.update_graph(block_to_prime_with=alloc_block) + if config.VIS_GR_PER_GEN: vis_hardware(ex_dp) + + ex_dp.check_mem_fronts_sanity() + return True, alloc_block + + # ------------------------------ + # Functionality: + # find out how many tasks do two blocks share. Used for split move. + # Variables: + # ex_dp: example design + # block: block of interest + # mode: depracated. TODO: get rid of it. + # ------------------------------ + def calc_data_sharing_amount(self, pe, mem): + pe_tasks_name = [task.name for task in pe.get_tasks_of_block()] + mem_tasks_name = [task.name for task in mem.get_tasks_of_block()] + sharing_ctr = 0 + + for pe_task_name in pe_tasks_name: + for mem_task_name in mem_tasks_name: + if pe_task_name == mem_task_name: + sharing_ctr +=1 + return sharing_ctr + + # ------------------------------ + # Functionality: + # Cluster blocks based on the number of tasks they share. This is used when deciding which block to split. + # Variables: + # ex_dp: example design + # block: block to prime the selection with. + # ------------------------------ + def cluster_blocks_based_on_data_sharing(self, ex_dp, block): + self.pruned_one = False + ex_dp.hardware_graph.update_graph(block_to_prime_with=block) + if config.VIS_GR_PER_GEN: vis_hardware(ex_dp) + + pe_neighs = [neigh for neigh in block.neighs if neigh.type == "pe"] + mem_neighs = [neigh for neigh in block.neighs if neigh.type == "mem"] + + sharing_matrix = {} + from collections import defaultdict + sharing_dict = {} + for pe in pe_neighs: + for mem in mem_neighs: + sharing_dict[(pe, mem)] = self.calc_data_sharing_amount(pe, mem) + + sorted_sharing_list = sorted(sharing_dict.items(), key=operator.itemgetter(1)) + return sorted_sharing_list + + # ------------------------------ + # Functionality: + # fork (split) a bus. + # Variables: + # ex_dp: example design + # block: bus to split + # mode: how to split. Right now, only uses hot kernel. + # ------------------------------ + def fork_bus(self, ex_dp, block, migrant_tasks): + + self.pruned_one = False + ex_dp.hardware_graph.update_graph(block_to_prime_with=block) + if config.VIS_GR_PER_GEN: vis_hardware(ex_dp) + + pe_neighs = [neigh for neigh in block.neighs if neigh.type == "pe"] # find pe neighbours + mem_neighs = [neigh for neigh in block.neighs if neigh.type == "mem"] # find memory neighbours + + pe_forked = False + pe_good_to_go = False + mem_forked = False + mem_good_to_go = False + + # first check forkability of the neighbouring pe and mem (before attempting to fork either) + # if either (mem, pe) needs to be forked + if len(pe_neighs) == 1 or len(mem_neighs) == 1: + pe_forkability = True + mem_forkability = True + # see if you can fork + if len(pe_neighs) == 1: + for task_ in migrant_tasks: + if not (self.block_forkable(ex_dp, pe_neighs[0]) or self.task_in_block(block, + task_)): + pe_forkability = False + break + # see if you can fork + if len(mem_neighs) == 1: + for task_ in migrant_tasks: + if not (self.block_forkable(ex_dp, mem_neighs[0]) or self.task_in_block(mem_neighs[0], + task_)): + mem_forkability = False + + if not (mem_forkability and pe_forkability): + return False,"" + + # now fork the neighbours if necessary + if len(pe_neighs) == 1: + pe_forked,_ = self.fork_block(ex_dp, pe_neighs[0], migrant_tasks) + ex_dp.hardware_graph.update_graph(block_to_prime_with=block) + else: + pe_good_to_go = True + + if len(mem_neighs) == 1: + mem_forked,_= self.fork_block(ex_dp, mem_neighs[0], migrant_tasks) + ex_dp.hardware_graph.update_graph(block_to_prime_with=block) + else: + mem_good_to_go = True + + if not((pe_forked or pe_good_to_go) and (mem_forked or mem_good_to_go)): + return False, "" + + # allocate and attach a similar bus + alloc_block = self.allocate_similar_block(block, []) + + + + # at the moment, we don't support the smarter version of attaching the blocks for buses + if config.RUN_VERIFICATION_PER_GEN or config.RUN_VERIFICATION_PER_IMPROVMENT or config.RUN_VERIFICATION_PER_NEW_CONFIG: + self.attach_alloc_block(ex_dp, block, alloc_block) + ex_dp.hardware_graph.update_graph(block_to_prime_with=block) + #if config.VIS_GR_PER_GEN: vis_hardware(ex_dp) + + # prune blocks (memory and processing elements) from the previous pus + self.prune(ex_dp, block, alloc_block) + block.connect(alloc_block) + else: + self.attach_alloc_block_to_bus(ex_dp, block, alloc_block, migrant_tasks) + + + ex_dp.hardware_graph.update_graph(block_to_prime_with=block) + return True, alloc_block + + # ------------------------------ + # Functionality: + # reducing contention on the bottleneck block. Used specifically for split. + # Variables: + # des_tup: design tuple containing ex_dp, sim_dp + # mode: mode used for splitting. At the moment, we only use hottest kernel as the deciding factor. + # hot_kernel_pos: position of the hot kernel. Helps locating the hot kernel. + # ------------------------------ + def reduce_contention(self, des_tup, mode, hot_kernel_pos): + ex_dp, sim_dp = des_tup + if mode == "hot_kernel": + hot_block = self.sim_dp.get_dp_stats().get_hot_block_system_complex(des_tup, hot_kernel_pos) + hot_blck_synced = self.find_cores_hot_kernel_blck_bottlneck(ex_dp, hot_blck) + if hot_block.type == "ic": + self.fork_bus(ex_dp, hot_blck_synced) + else: + self.fork_block(ex_dp, hot_blck_synced) + + # ------------------------------ + # Functionality: + # allocate a new block, connect it to the current design and migrate some of tasks to it. + # Variables: + # dp: design point + # blck_bottlenck: block bottleneck (this block is forked) + # ------------------------------ + def alloc_mig(self, dp, blck_bottleneck): + + # find tasks to migrate + clustered_tasks = [] + if blck_bottleneck.type == "ic": + pe_neighs = [neigh for neigh in blck_bottleneck.neighs if neigh.type == "pe"] + mem_neighs = [neigh for neigh in blck_bottleneck.neighs if neigh.type == "mem"] + + if len(pe_neighs) == 1: + self.alloc_mig(dp, pe_neighs[0]) + + if len(pe_neighs) == 1: + self.alloc_mig(dp, mem_neighs[0]) + else: + clustered_tasks = self.naively_cluster_tasks(blck_bottleneck.get_tasks_of_block()) + if config.DEBUG_FIX: + random.seed(0) + else: + time.sleep(.00001) + random.seed(datetime.now().microsecond) + tasks_to_migrate = clustered_tasks[random.choice(range(0, len(clustered_tasks)))] # grab a random cluster + + alloc_block = self.allocate_similar_block(blck_bottleneck, clustered_tasks) + self.attach_alloc_block(dp, blck_bottleneck, alloc_block) + if not(blck_bottleneck == "ic"): + self.mig_tasks(dp, tasks_to_migrate, alloc_block) + + # ------------------------------ + # Functionality: + # used in pruning an bus(ic) from the blocks that are connected (through the fork process) to the new bus. + # Variables: + # dp: current design point. + # original_ic: ic to prune + # new_ic: new ic that has inherited some of the old ic neighbours (pc, mem) + # ------------------------------ + def prune(self, dp:ExDesignPoint, original_ic, new_ic): + # recursively prone the blocks in the dp + # start with any block and traverse (since blocks + self.prune_smartly(original_ic, dp, 2, new_ic) + dp.hardware_graph.update_graph(block_to_prime_with=new_ic) + + # ------------------------------ + # Functionality: + # find memories that share the most number of tasks with a pe. Used for pruning smartly. + # Variables: + # mem_clusters: cluster of memories under question. + # pe: pe to measure closeness with. + # ------------------------------ + def find_closest_mem_cluster(self, mem_clusters, pe): + def calc_mem_similarity(pe, mem): + pe_tasks_name = [task.name for task in pe.get_tasks_of_block()] + mem_tasks_name = [task.name for task in mem.get_tasks_of_block()] + sharing_ctr = 0 + + for pe_task_name in pe_tasks_name: + for mem_task_name in mem_tasks_name: + if pe_task_name == mem_task_name: + sharing_ctr += 1 + return sharing_ctr + + def calc_cluster_similarity(pe, mem_cluster): + similarity = 0 + for mem in mem_cluster: + similarity += calc_mem_similarity(pe, mem) + return similarity + + cluster_similarity= [] + for mem_cluster in mem_clusters: + cluster_similarity.append(calc_cluster_similarity(pe, mem_cluster)) + return cluster_similarity.index(max(cluster_similarity)) + + # ------------------------------ + # Functionality: + # used in pruning an bus(ic) from the blocks that are connected (through the fork process) to the new bus. + # Variables: + # original_ic: ic to prune + # ex_dp: example design (current design) + # num_clusters: number of clusters for task migration + # new_ic: new ic that has inherited some of the old ic neighbours (pc, mem) + # ------------------------------ + def prune_smartly(self, original_ic, ex_dp, num_clusters, new_ic): + + # this is used for balancing the clusters (in case one + # of the pe clusters is left out empty + def reshuffle_clusters_if_necessary(mem_clusters, pe_clusters): + if len(pe_clusters) > 2: + raise Exception("more than two cluster not supported yet") + # if any of the clusters are empty, reshuffle + if any([True for pe_cluster in pe_clusters if len(pe_cluster) <= 0]): + pes = pe_clusters[0] + pe_clusters[1] + pe_clusters[0] = pes[0: int(len(pes)/2)] + pe_clusters[1] = pes[int(len(pes)/2): len(pes)] + + pe_neighs = [neigh for neigh in original_ic.neighs if neigh.type == "pe"] + mem_neighs = [neigh for neigh in original_ic.neighs if neigh.type == "mem"] + ic_neighs = [neigh for neigh in original_ic.neighs if neigh.type == "ic"] + + # disconnect the original ic from all of its neighbours + original_ic.disconnect_all() + + # connect back some of the neighbours + for ic_neigh in ic_neighs: + original_ic.connect(ic_neigh) + + # cluster memory and pe neighbours (each cluster would be assigned to one ic and then get connected to it) + mem_clusters = [mem_neighs[0: int(len(mem_neighs)/2)], mem_neighs[int(len(mem_neighs)/2): len(mem_neighs)]] + pe_clusters = [[] for _ in mem_clusters] + for pe in pe_neighs: + cluster_idx = self.find_closest_mem_cluster(mem_clusters, pe) + pe_clusters[cluster_idx].append(pe) + reshuffle_clusters_if_necessary(mem_clusters, pe_clusters) + pe_mem_clusters = [pe_clusters[0]+mem_clusters[0], pe_clusters[1] + mem_clusters[1]] + + # connect/disconnect the clusters to the ics + for block in pe_mem_clusters[0]: + block.disconnect_all() + original_ic.connect(block) \ No newline at end of file diff --git a/Project_FARSI/design_utils/design.py b/Project_FARSI/design_utils/design.py new file mode 100644 index 00000000..cca6b8f7 --- /dev/null +++ b/Project_FARSI/design_utils/design.py @@ -0,0 +1,3014 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. +import _pickle as cPickle +from design_utils.components.hardware import * +from design_utils.components.workload import * +from design_utils.components.mapping import * +from design_utils.components.scheduling import * +from design_utils.components.krnel import * +from design_utils.common_design_utils import * +import collections +import datetime +from datetime import datetime +from error_handling.custom_error import * +import gc +import statistics as st +if config.use_cacti: + from misc.cacti_hndlr import cact_handlr + +if config.simulation_method == "power_knobs": + from specs import database_input_powerKnobs as database_input +elif config.simulation_method == "performance": + from specs import database_input +else: + raise NameError("Simulation method unavailable") + + +# This class logs the insanity (opposite of sanity (check), so the flaw) with the design +class Insanity: + def __init__(self, task, block, name): + self.name = name + self.task = task + self.block = block + + # the problematic task + def set_task(self, task_): + self.task = task_ + + # the problematic block + def set_block(self, block_): + self.block = block_ + + # name of the insanity + def set_name(self, name_): + self.name= name_ + + def gen_msg(self): + output = "sanity check failed with: " + output += "insanity name:" + self.name + if not(self.task== "_"): + output += self.task.name + if not(self.block == "_"): + output += self.block.instance_name + return output + + +# This class emulates a design point containing +# hardware/software, their mapping and scheduling +class ExDesignPoint: + def __init__(self, hardware_graph:HardwareGraph): + self.hardware_graph = hardware_graph # hardware graph contains the hardware blocks + # and their connections + self.PA_prop_dict = {} # PA prop is used for PA design generation + self.id = str(-1) # this means it hasn't been set + self.valid = True + self.FARSI_ex_id = str(-1) + self.PA_knob_ctr_id = str(-1) + self.check_pointed_population_generation_cnt = 0 # only for check pointing purposes, and only work if the design has been checkpointed + self.check_pointed_total_iteration_cnt = 0 + + def set_check_pointed_population_generation_cnt(self, generation_cnt): + self.check_pointed_population_generation_cnt = generation_cnt + + def set_check_pointed_total_iteration_cnt(self, total_iteration): + self.check_pointed_total_iteration_cnt = total_iteration + + + def get_check_pointed_population_generation_cnt(self): + return self.check_pointed_population_generation_cnt + + def get_check_pointed_total_iteration_cnt(self): + self.check_pointed_total_iteration_cnt + + def eliminate_system_bus(self): + all_drams = [el for el in self.get_hardware_graph().get_blocks() if el.subtype == "dram"] + ics_with_dram = [] + for dram in all_drams: + for neigh in dram.get_neighs(): + if neigh.type == "ic": + ics_with_dram.append(neigh) + + # can only have one ic with dram hanging from it + return len(ics_with_dram) == 1 + + + def has_system_bus(self): + return False + # find all the drams and their ics + all_drams = [el for el in self.get_hardware_graph().get_blocks() if el.subtype == "dram"] + ics_with_dram = [] + for dram in all_drams: + for neigh in dram.get_neighs(): + if neigh.type == "ic": + ics_with_dram.append(neigh) + + # can only have one ic with dram hanging from it + return len(ics_with_dram) == 1 + + def get_system_bus(self): + if not self.has_system_bus(): + return None + else: + all_drams = [el for el in self.get_hardware_graph().get_blocks() if el.subtype == "dram"] + ics_with_dram = [] + for dram in all_drams: + for neigh in dram.get_neighs(): + if neigh.type == "ic": + neigh.set_as_system_bus() + return neigh + + # get hardware blocks of a design + def get_blocks(self): + return self.hardware_graph.blocks + + # get hardware blocks within a specific SOC of the design + def get_blocks_of_SOC(self,SOC_type, SOC_id): + return [block for block in self.hardware_graph.blocks if block.SOC_type == SOC_type and SOC_id == SOC_id] + + # get tasks (software tasks) of the design + def get_tasks(self): + return self.hardware_graph.get_all_tasks() + + def get_tasks_of_SOC(self, SOC_type, SOC_id): + return [task for task in self.get_tasks() if task.SOC_type == SOC_type and SOC_id == SOC_id] + + # samples the task distribution within the hardware graph. + # used for jitter modeling. + def sample_hardware_graph(self, hw_sampling): + self.hardware_graph.sample(hw_sampling) + + # get blocks that a task uses (host the task) + def get_blocks_of_task(self, task): + blocks = [] + for block in self.get_blocks(): + if task in block.get_tasks_of_block(): + blocks.append(block) + return blocks + + # if set, the design is complete and valid + def set_validity(self, validity): + self.valid = validity + + def get_validity(self): + return self.valid + + # delete this later. Used for debugging + def check_mem_fronts_sanity(self): + fronts_1 = sum([len(block.get_fronts("task_name_dir")) for block in self.get_blocks() if block.type == "mem"]) + fronts_2 = sum( + [len(block.get_fronts("task_dir_work_ratio")) for block in self.get_blocks() if block.type == "mem"]) + + + def check_system_ic_exist(self, block): + assert (block.type == "ic"), "should be checking this with non ic input" + system_ic_exist = False + connectd_ics = [block_ for block_ in block.get_neighs() if block_.type == "ic"] + + # iterate though the connected ics, get their neighbouring ics + # and make sure there is a ic with only dram + system_ic_list = [] + for neigh_ic in connectd_ics: + has_dram = len([neigh for neigh in neigh_ic.get_neighs() if neigh.subtype == "dram"]) >= 1 + has_pe = len([neigh for neigh in neigh_ic.get_neighs() if neigh.type == "pe"]) >= 1 + if has_dram: + if has_pe: + pass + #return False + #print(" system ic can not have a pe") + #exit(0) + else: + system_ic_list.append(neigh_ic) + + if self.block_is_system_ic(block): + system_ic_list.append(block) + + if len(set(system_ic_list)) > 1: + print("can only have one system ic") + exit(0) + + return len(system_ic_list) == 1 + + + def block_is_system_ic(self, block): + assert (block.type == "ic"), "should be checking this with non ic input" + # iterate though the connected ics, get their neighbouring ics + # and make sure there is a ic with only dram + system_ic_list = [] + has_dram = len([neigh for neigh in block.get_neighs() if neigh.subtype == "dram"]) >= 1 + has_pe = len([neigh for neigh in block.get_neighs() if neigh.type == "pe"]) >= 1 + if has_dram: + if has_pe: + pass + #print(" system ic can not have a pe") + #exit(0) + else: + return True + else: + return False + return False + + # sanity check the design + def sanity_check(self): + insanity_list = [] # list of Inanities + + # fronts check + fronts_1 = sum([len(block.get_fronts("task_name_dir")) for block in self.get_blocks() if block.type == "mem"]) + fronts_2 = sum([len(block.get_fronts("task_dir_work_ratio")) for block in self.get_blocks() if block.type== "mem"]) + if not fronts_1 == fronts_2: + pre_mvd_fronts_1 = [block.get_fronts("task_name_dir") for block in self.get_blocks() if block.type == "mem"] + pre_mvd_fronts_2 = [block.get_fronts("task_dir_work_ratio") for block in self.get_blocks() if block.type == "mem"] + raise UnEqualFrontsError + + # all the tasks have pe and mem + for task in self.get_tasks(): + pe_blocks = self.get_blocks_of_task_by_block_type(task, "pe") + mem_blocks = self.get_blocks_of_task_by_block_type(task, "mem") + if len(pe_blocks) == 0: + print("task:" + task.name + " does not have any pes") + insanity = Insanity("_", "_", "none") + insanity.set_block("_") + insanity.set_name("no_pe") + insanity_list.append(insanity) + pe_blocks = self.get_blocks_of_task_by_block_type(task, "pe") + print(insanity.gen_msg()) + raise NoPEError + #break + elif (len(mem_blocks) == 0 and not("siink" in task.name)): + print("task:" + task.name + " does not have any mems") + insanity = Insanity("_", "_", "none") + insanity.set_block("_") + insanity.set_name("no_mem") + insanity_list.append(insanity) + print(insanity.gen_msg()) + mem_blocks = self.get_blocks_of_task_by_block_type(task, "mem") + raise NoMemError + #break + + # every pe or memory needs to be connected to a bus + for block in self.get_blocks(): + if block.type in ["pe", "mem"]: + connectd_ics = [True for block_ in block.get_neighs() if block_.type =="ic" ] + if len(connectd_ics) > 1: + print("block: " + block.instance_name + " is connected to more than one ic") + insanity = Insanity("_", "_", "multi_bus") + insanity.set_name("multi_bus") + insanity_list.append(insanity) + print(insanity.gen_msg()) + raise MultiBusBlockError + #break + elif len(connectd_ics) < 1: + print("block: " + block.instance_name + " is not connected any ic") + insanity = Insanity("_", "_", "none") + insanity.set_block(block) + insanity.set_name("no_bus") + insanity_list.append(insanity) + print(insanity.gen_msg()) + raise NoBusError + #break + + # every bus needs to have at least one pe and mem + for block in self.get_blocks(): + if block.type in ["ic"]: + connectd_pes = [True for block_ in block.get_neighs() if block_.type =="pe" ] + connectd_mems = [True for block_ in block.get_neighs() if block_.type =="mem" ] + connectd_ics = [True for block_ in block.get_neighs() if block_.type =="ic" ] + + system_ic_exist = self.check_system_ic_exist(block) + + if len(connectd_mems) == 0 and not system_ic_exist: + insanity = Insanity("_",block, "bus_with_no_mem") + print(insanity.gen_msg()) + if self.hardware_graph.generation_mode == "user_generated": + print("deactivated Bus with No memory error, since hardware graph was directly user generated/parsed ") + else: + raise BusWithNoMemError + """ + elif len(connectd_pes) > 0 and self.block_is_system_ic(block): + insanity = Insanity("_", block, "system_ic_with_pe") + insanity_list.append(insanity) + print(insanity.gen_msg()) + if self.hardware_graph.generation_mode == "user_generated": + print( + "deactivated Bus with No Bus error, since hardware graph was directly user generated/parsed ") + else: + raise SystemICWithPEException + """ + elif len(connectd_pes) == 0 and not self.block_is_system_ic(block): + insanity = Insanity("_", block, "bus_with_no_pes") + insanity_list.append(insanity) + print(insanity.gen_msg()) + if self.hardware_graph.generation_mode == "user_generated": + print("deactivated Bus with No Bus error, since hardware graph was directly user generated/parsed ") + else: + raise BusWithNoPEError + # every design needs to have at least on pe, mem, and bus + block_type_count_dict = {} + block_type_count_dict["mem"] = 0 + block_type_count_dict["pe"] = 0 + block_type_count_dict["ic"] = 0 + for block in self.get_blocks(): + block_type_count_dict[block.type] +=1 + for type_, count in block_type_count_dict.items(): + if count < 1: + print("no block of type " + type_ + " found") + insanity = Insanity("_", "_", "none") + insanity.set_name("not_enough_ip_of_certain_type") + insanity_list.append(insanity) + print(insanity.gen_msg()) + raise NotEnoughIPOfCertainType + #break + + # every block should host at least one task + for block in self.get_blocks(): + if block.type == "ic": # since we unload + continue + if len(block.get_tasks_of_block()) == 0: + print( "block: " + block.instance_name + " does not host any tasks") + insanity = Insanity("_", "_", "none") + insanity.set_block(block) + insanity.set_name("no_task") + insanity_list.append(insanity) + print(insanity.gen_msg()) + raise BlockWithNoTaskError + + # get blocks within the design (filtered by type) + def get_blocks_by_type(self, block_type): + return [block for block in self.get_blocks() if block.type == block_type] + + # gets blocks for a task, and filter them based on hardware type (pe, mem, ic) + def get_blocks_of_task_by_block_type(self, task, block_type): + blocks_of_task = self.get_blocks_of_task(task) + blocks_by_type = [] + for block in blocks_of_task: + if block.type == block_type: + blocks_by_type.append(block) + return blocks_by_type + + + def get_write_mem_tasks(self, task, mem): + # get conncted ic + ics = [el for el in mem.get_neighs() if el.type =="ic"] + assert(len(ics) <= 1), "Each memory can be only connected to one bus master" + + # get the pipes + pipes = self.get_hardware_graph().get_pipes_between_two_blocks(ics[0], mem, "write") + assert(len(pipes) <= 1), "can only have one pipe (in a direction) between a memory and a ic" + + # traffic + traffics = pipes[0].get_traffic() + + return [trf.child for trf in traffics if trf.parent.name == task.name] + + + # for a specific task, find all the specific blocks of a type and their direction + def get_blocks_of_task_by_block_type_and_task_dir(self, task, block_type, task_dir=""): + assert ((block_type == "pe") != task_dir) # XORing the expression + blocks_of_task = self.get_blocks_of_task(task) + blocks_of_task_by_type = [block for block in blocks_of_task if block.type == block_type] + blocks_of_task_by_type_and_task_dir = [block for block in blocks_of_task_by_type if block.get_task_dir_by_task_name(task)[0][1] == task_dir] + return blocks_of_task_by_type_and_task_dir + + # get the properties of the design. This is used for the more accurate simulation + def filter_props_by_keyword(self, knob_order, knob_values, type_name): + prop_value_dict = collections.OrderedDict() + for knob_name, knob_value in zip(knob_order, knob_values): + if type_name+"_props" in knob_name: + knob_name_refined = knob_name.split("__")[-1] + prop_value_dict[knob_name_refined] = knob_value + return prop_value_dict + + def filter_auto_tune_props(self, type_name, auto_tune_props): + auto_tune_list = [] + for knob_name in auto_tune_props: + if type_name+"_props" in knob_name: + knob_name_refined = knob_name.split("__")[-1] + auto_tune_list.append(knob_name_refined) + return auto_tune_list + + # get id associated with a design. Each design has it's unique id. + def get_ex_id(self): + if self.id == str(-1): + print("experiments id is:" + str(self.id) + ". This means id has not been set") + exit(0) + return self.id + + def update_ex_id(self, id): + self.id = id + + def get_FARSI_ex_id(self): + if self.FARSI_ex_id == str(-1): + print("experiments id is:" + str(self.id) + ". This means id has not been set") + exit(0) + return self.FARSI_ex_id + + def get_PA_knob_ctr_id(self): + if self.PA_knob_ctr_id == str(-1): + print("experiments id is:" + str(self.PA_knob_ctr_id) + ". This means id has not been set") + exit(0) + return self.PA_knob_ctr_id + + def update_FARSI_ex_id(self, FARSI_ex_id): + self.FARSI_ex_id = FARSI_ex_id + + def update_PA_knob_ctr_id(self, knob_ctr): + self.PA_knob_ctr_id = knob_ctr + + def reset_PA_knobs(self, mode="batch"): + if mode == "batch": + # parse and set design props + self.PA_prop_dict = collections.OrderedDict() + # parse and set hw and update the props + for keyword in ["pe", "ic", "mem"]: + blocks_ = self.get_blocks_by_type(keyword) + for block in blocks_: + block.reset_PA_props() + # parse and set sw props + for keyword in ["sw"]: + tasks = self.get_tasks() + for task_ in tasks: + task_.reset_PA_props() + else: + print("mode:" + mode + " is not defind for apply_PA_knobs") + exit(0) + + def update_PA_knobs(self, knob_values, knob_order, all_auto_tunning_knobs, mode="batch"): + if mode == "batch": + # parse and set design props + prop_value_dict = {} + prop_value_dict["ex_id"] = self.get_ex_id() + prop_value_dict["FARSI_ex_id"] = self.get_FARSI_ex_id() + prop_value_dict["PA_knob_ctr_id"] = self.get_PA_knob_ctr_id() + self.PA_prop_dict.update(prop_value_dict) + # parse and set hw and update the props + for keyword in ["pe", "ic", "mem"]: + blocks_ = self.get_blocks_by_type(keyword) + prop_value_dict = self.filter_props_by_keyword(knob_order, knob_values, keyword) + prop_auto_tuning_list = self.filter_auto_tune_props(keyword, all_auto_tunning_knobs) + for block in blocks_: + block.update_PA_props(prop_value_dict) + block.update_PA_auto_tunning_knob_list(prop_auto_tuning_list) + # parse and set sw props + for keyword in ["sw"]: + tasks = self.get_tasks() + prop_value_dict = self.filter_props_by_keyword(knob_order, knob_values, keyword) + prop_auto_tuning_list = self.filter_auto_tune_props(keyword, all_auto_tunning_knobs) + for task_ in tasks: + task_.update_PA_props(prop_value_dict) + task_.update_PA_auto_tunning_knob_list(prop_auto_tuning_list) + else: + print("mode:"+ mode +" is not defined for apply_PA_knobs") + exit(0) + + # write all the props into the design file. + # design file is set in the config file (config.verification_result_file) + def dump_props(self, result_folder, mode="batch"): # batch means that all the blocks of similar type have similar props + file_name = config.verification_result_file + file_addr = os.path.join(result_folder, file_name) + if mode == "batch": + with open(file_addr, "a+") as output: + for prop_name, prop_value in self.PA_prop_dict.items(): + prop_name_modified = "\"design" + "__" + prop_name +"\"" + if not str(prop_value).isdigit(): + prop_value = "\"" + prop_value + "\"" + output.write(prop_name_modified + ": " + str(prop_value) + ",\n") + # writing the hardware props + for keyword in ["pe", "ic", "mem"]: + + block = self.get_blocks_by_type(keyword)[0] # since in batch mode, the first element shares the same prop values + # as all + for prop_name, prop_value in block.PA_prop_dict.items(): + prop_name_modified = "\""+ keyword+"__"+prop_name + "\"" + if "ic__/Buffer/enable" in prop_name_modified: # this is just because now parsing throws an error + continue + if not str(prop_value).isdigit(): + prop_value = "\"" + prop_value +"\"" + output.write(prop_name_modified+": " + str(prop_value) + ",\n") + # writing the software props + for keyword in ["sw"]: + task_ = self.get_tasks()[0] + for prop_name, prop_value in task_.PA_prop_dict.items(): + prop_name_modified = "\""+ keyword + "__" + prop_name +"\"" + if not str(prop_value).isdigit(): + prop_value = "\"" + prop_value +"\"" + output.write(prop_name_modified + ": " + str(prop_value) + ",\n") + + else: + print("mode:" + mode + " is not defind for apply_PA_knobs") + + def get_hardware_graph(self): + return self.hardware_graph + + def get_task_by_name(self, task_name): + return [task_ for task_ in self.get_tasks() if task_.name == task_name][0] + + +# collection of the simulated design points. +# not that you can query this container with the same functions as the +# SimDesignPoint (i.e, same modules are provide). However, this is not t +class SimDesignPointContainer: + def __init__(self, design_point_list, database, reduction_mode = "avg"): + self.design_point_list = design_point_list + self.reduction_mode = reduction_mode # how to reduce the results. + self.database = database # hw/sw database + self.dp_rep = self.design_point_list[0] # design representative + self.dp_stats = DPStatsContainer(self, self.dp_rep, self.database, reduction_mode) # design point stats + self.dp = self.dp_rep # design point is used to fill up some default values + # we use dp_rep, i.e., design point representative for this + + self.move_applied = None + self.dummy_tasks = [krnl.get_task() for krnl in self.dp.get_kernels() if (krnl.get_task()).is_task_dummy()] + self.exploration_and_simulation_approximate_time = 0 + self.neighbouring_design_space_size = 0 + + + def get_neighbouring_design_space_size(self): + return self.neighbouring_design_space_size + + def get_dummy_tasks(self): + return self.dummy_tasks + + # bootstrap the design from scratch + def reset_design(self, workload_to_hardware_map=[], workload_to_hardware_schedule=[]): + self.dp_rep.reset_design() + + def set_move_applied(self, move_applied): + self.move_applied = move_applied + + def get_move_applied(self): + return self.move_applied + + + def add_exploration_and_simulation_approximate_time(self, time): + # the reason that this is approximte is because we tak + # the entire generation time and divide it by the number of iterations per iteration + self.exploration_and_simulation_approximate_time += time + + def get_exploration_and_simulation_approximate_time(self): + return self.exploration_and_simulation_approximate_time + + + def get_phase_calculation_time(self): + return self.dp.simulation_time_phase_calculation_portion + + def get_phase_scheduling_time(self): + return self.dp.simulation_time_phase_scheduling_portion + + def get_task_update_time(self): + return self.dp.simulation_time_task_update_portion + + + + + def get_dp_stats(self): + return self.dp_stats + # ----------------- + # getters + # ----------------- + def get_task_graph(self): + return self.dp_rep.get_hardware_graph().get_task_graph() + + # Functionality: + # get the mapping + def get_workload_to_hardware_map(self): + return self.dp_rep.get_workload_to_hardware_map() + + # Functionality + # get the scheduling + def get_workload_to_hardware_schedule(self): + return self.dp_rp.get_workload_to_hardware_schedule() + + def get_kernels(self): + return self.dp_rp.get_kernels() + + def get_kernel_by_task_name(self, task: Task): + return self.dp_rep.get_kernel_by_task_name(task) + + # get the kernels of the design + def get_kernels(self): + return self.dp_rep.get_kernels() + + # get the sw to hw mapping + def get_workload_to_hardware_map(self): + return self.dp_rep.get_workload_to_hardware_map() + + # get the SOCs that the design resides in + def get_designs_SOCs(self): + return self.dp_rep.get_designs_SOCs() + + # get all the design points + def get_design_point_list(self): + return self.design_point_list + + # get the representative design point. + def get_dp_rep(self): + return self.dp_rep + + +# Container for all the design point stats. +# In order to collect profiling information, we reduce the statistics +# according to the desired reduction function. +# reduction semantically happens at two different levels depending on the question +# that we are asking. +# Level 1 Questions: Within/intra design questions to compare components of a +# single design. Example: finding the hottest kernel? +# To answer, reduce the results across at the task/kernel +# level 2 Questions: Across/inter design question to compare different designs? +# To answer, reduce the results from the end-to-end perspective, i.e., +# reduce(end-to-end latency), reduce(end-to-end energy), ... +# PS: at the moment, a design here is defined as a sw/hw tuple with only sw +# characteristic changing. +class DPStatsContainer(): + def __init__(self, sim_dp_container, dp_rep, database, reduction_mode): + self.comparison_mode = "latency" # metric to compare different design points + self.sim_dp_container = sim_dp_container + self.design_point_list = self.sim_dp_container.design_point_list # design point container (containing list of designs) + self.dp_rep = dp_rep #self.dp_container[0] # which design to use as representative (for plotting and so on + self.__kernels = self.sim_dp_container.design_point_list[0].get_kernels() + self.SOC_area_dict = defaultdict(lambda: defaultdict(dict)) # area of all blocks within each SOC + self.SOC_area_subtype_dict = defaultdict(lambda: defaultdict(dict)) # area of all blocks within each SOC + self.system_complex_area_dict = defaultdict() + self.SOC_metric_dict = defaultdict(lambda: defaultdict(dict)) + self.system_complex_metric_dict = defaultdict(lambda: defaultdict(dict)) + self.system_complex_area_dram_non_dram = defaultdict(lambda: defaultdict(dict)) + self.database = database # hw/sw database + self.latency_list =[] # list of latency values associated with each design point + self.power_list =[] # list of power values associated with each design point + self.energy_list =[] # list of energy values associated with each design point + self.reduction_mode = reduction_mode # how to statistically reduce the values + # collect the data + self.collect_stats() + self.dp = self.sim_dp_container # container that has all the designs + self.parallel_kernels = dp_rep.parallel_kernels + + + def get_parallel_kernels(self): + return self.parallel_kernels + + # helper function to apply an operator across two dictionaries + def operate_on_two_dic_values(self,dict1, dict2, operator): + dict_res = {} + for key in list(dict2.keys()) + list(dict1.keys()): + if key in dict1.keys() and dict2.keys(): + dict_res[key] = operator(dict2[key], dict1[key]) + else: + if key in dict1.keys(): + dict_res[key] = dict1[key] + elif key in dict2.keys(): + dict_res[key] = dict2[key] + return dict_res + + # operate on multiple dictionaries. The operation is determined by the operator input + def operate_on_dicionary_values(self, dictionaries, operator): + res = {} + for SOCs_latency in dictionaries: + #res = copy.deepcopy(self.operate_on_two_dic_values(res, SOCs_latency, operator)) + #gc.disable() + res = cPickle.loads(cPickle.dumps(self.operate_on_two_dic_values(res, SOCs_latency, operator), -1)) + #gc.enable() + return res + + # reduce the (list of) values based on a statistical parameter (such as average) + def reduce(self, list_): + if self.reduction_mode == 'avg': + if isinstance(list_[0],dict): + dict_added = self.operate_on_dicionary_values(list_, operator.add) + for key,val in dict_added.items(): + dict_added[key] = val/len(list_) + return dict_added + else: + return sum(list_)/len(list_) + elif self.reduction_mode == 'min': + return min(list_) + elif self.reduction_mode == 'max': + #if (len(list_) == 0): + # print("What") + return max(list_) + else: + print("reduction mode "+ self.reduction_mode + ' is not defined') + exit(0) + + + def get_number_blocks_of_all_sub_types(self): + subtype_cnt = [] + for block in self.dp_rep.get_blocks(): + if block.subtype not in subtype_cnt: + subtype_cnt[block.subtype] = 0 + subtype_cnt[block.subtype] += 1 + return subtype_cnt + + def get_compute_system_attr(self): + ips = [el for el in self.dp_rep.get_blocks() if el.subtype == "ip"] + gpps = [el for el in self.dp_rep.get_blocks() if el.subtype == "gpp"] + + + # get frequency data + ips_freqs = [mem.get_block_freq() for mem in ips] + gpp_freqs = [mem.get_block_freq() for mem in gpps] + if len(ips_freqs) == 0: + ips_avg_freq = 0 + else: + ips_avg_freq= sum(ips_freqs)/max(len(ips_freqs),1) + + loop_itr_ratio = [] + for ip in ips: + loop_itr_ratio.append(ip.get_loop_itr_cnt()/ip.get_loop_max_possible_itr_cnt()) + + if len(ips) == 0: + loop_itr_ratio_avg = 0 + else: + loop_itr_ratio_avg = st.mean(loop_itr_ratio) + + + if len(ips_freqs) in [0,1]: + ips_freq_std = 0 + ips_freq_coeff_var = 0 + loop_itr_ratio_std = 0 + loop_itr_ratio_var = 0 + else: + ips_freq_std = st.stdev(ips_freqs) + ips_freq_coeff_var = st.stdev(ips_freqs)/st.mean(ips_freqs) + loop_itr_ratio_std = st.stdev(loop_itr_ratio) + loop_itr_ratio_var = st.stdev(loop_itr_ratio)/st.mean(loop_itr_ratio) + + if len(gpp_freqs) == 0: + gpps_avg_freq = 0 + else: + gpps_avg_freq= sum(gpp_freqs)/max(len(gpp_freqs),1) + + if len(gpp_freqs + ips_freqs) in [0,1]: + pes_freq_std = 0 + pes_freq_coeff_var = 0 + else: + pes_freq_std = st.stdev(ips_freqs + gpp_freqs) + pes_freq_coeff_var = st.stdev(ips_freqs + gpp_freqs) / st.mean(ips_freqs + gpp_freqs) + + + # get area data + ips_area = [mem.get_area() for mem in ips] + gpp_area = [mem.get_area() for mem in gpps] + + if len(ips_area) == 0: + ips_total_area = 0 + + else: + ips_total_area = sum(ips_area) + + if len(ips_area) in [0,1]: + ips_area_std = 0 + ips_area_coeff_var = 0 + else: + ips_area_std = st.stdev(ips_area) + ips_area_coeff_var = st.stdev(ips_area) / st.mean(ips_area) + + if len(gpp_area) == 0: + gpps_total_area = 0 + else: + gpps_total_area = sum(gpp_area) + + + if len(ips_area + gpp_area) in [0,1]: + pes_area_std = 0 + pes_area_coeff_var = 0 + else: + pes_area_std = st.stdev(ips_area+gpp_area) + pes_area_coeff_var = st.stdev(ips_area+gpp_area)/st.mean(ips_area+gpp_area) + + phase_accelerator_parallelism = {} + for phase, krnls in self.dp_rep.phase_krnl_present.items(): + accelerators_in_parallel = [] + for krnl in krnls: + accelerators_in_parallel.extend([blk for blk in krnl.get_blocks() if blk.subtype == "ip"]) + if len(accelerators_in_parallel) == 0: + continue + phase_accelerator_parallelism[phase] = len(accelerators_in_parallel) + + if len(phase_accelerator_parallelism.keys()) == 0: + avg_accel_parallelism = 0 + max_accel_parallelism = 0 + else: + avg_accel_parallelism = sum(list(phase_accelerator_parallelism.values()))/len(list(phase_accelerator_parallelism.values())) + max_accel_parallelism = max(list(phase_accelerator_parallelism.values())) + + + phase_gpp_parallelism = {} + for phase, krnls in self.dp_rep.phase_krnl_present.items(): + gpps_in_parallel = [] + for krnl in krnls: + gpps_in_parallel.extend([blk for blk in krnl.get_blocks() if blk.subtype == "gpp"]) + if len(gpps_in_parallel) == 0: + continue + phase_gpp_parallelism[phase] = len(gpps_in_parallel) + + if len(phase_gpp_parallelism.keys()) == 0: + avg_gpp_parallelism = 0 + max_gpp_parallelism = 0 + else: + avg_gpp_parallelism = sum(list(phase_gpp_parallelism.values()))/len(list(phase_gpp_parallelism.values())) + max_gpp_parallelism = max(list(phase_gpp_parallelism.values())) + + + + buses = [el for el in self.dp_rep.get_blocks() if el.subtype == "ic"] + bus_neigh_count = [] + for bus in buses: + pe_neighs = [neigh for neigh in bus.get_neighs() if neigh.type == "pe"] + bus_neigh_count.append(len(pe_neighs)) + + cluster_pe_cnt_avg = st.mean(bus_neigh_count) + if len(bus_neigh_count) in [0,1]: + cluster_pe_cnt_std = 0 + cluster_pe_cnt_coeff_var = 0 + else: + cluster_pe_cnt_std = st.stdev(bus_neigh_count) + cluster_pe_cnt_coeff_var = st.stdev(bus_neigh_count)/st.mean(bus_neigh_count) + + return { + "avg_accel_parallelism": avg_accel_parallelism, "max_accel_parallelism":max_accel_parallelism, + "avg_gpp_parallelism": avg_gpp_parallelism, "max_gpp_parallelism": max_gpp_parallelism, + "ip_cnt":len(ips), "gpp_cnt": len(gpps), + "ips_avg_freq": ips_avg_freq, "gpps_avg_freq":gpps_avg_freq, + "ips_freq_std": ips_freq_std, "pes_freq_std": pes_freq_std, + "ips_freq_coeff_var": ips_freq_coeff_var, "pes_freq_coeff_var": pes_freq_coeff_var, + "ips_total_area": ips_total_area, "gpps_total_area":gpps_total_area, + "ips_area_std": ips_area_std, "pes_area_std": pes_area_std, + "ips_area_coeff_var": ips_area_coeff_var, "pes_area_coeff_var": pes_area_coeff_var, + "pe_total_area":ips_total_area+gpps_total_area, + "loop_itr_ratio_avg":loop_itr_ratio_avg, + "loop_itr_ratio_std":loop_itr_ratio_std, + "loop_itr_ratio_var":loop_itr_ratio_var, + "cluster_pe_cnt_avg":cluster_pe_cnt_avg, + "cluster_pe_cnt_std":cluster_pe_cnt_std, + "cluster_pe_cnt_coeff_var":cluster_pe_cnt_coeff_var + } + + + + def get_speedup_analysis(self,dse): + + # for now just fill it out. something goes wrong + speedup_avg = {"customization_first_speed_up_avg": 1, + "parallelism_second_speed_up_avg": 1, + "customization_second_speed_up_avg":1 , + "parallelism_first_speed_up_avg": 1, + "interference_degradation_avg":1, + "customization_speed_up_full_system":1, + "loop_unrolling_parallelism_speed_up_full_system": 1, + "task_level_parallelism_speed_up_full_system":1 + } + + workload_speed_up = {} + for workload in self.database.get_workloads_last_task().keys(): + workload_speed_up[workload] = {"customization_first_speed_up": 1, + "parallelism_second_speed_up": 1, + "customization_second_speed_up": 1, + "parallelism_first_speed_up": 1, + "interference_degradation": 1} + return workload_speed_up,speedup_avg + + + + + + + # lower the design + workload_speed_up = {} + customization_first_speed_up_list =[] + customization_second_speed_up_list = [] + parallelism_first_speed_up_list = [] + parallelism_second_speed_up_list = [] + interference_degradation_list = [] + + for workload in self.database.get_workloads_last_task().keys(): + # single out workload in the current best + cur_best_ex_singled_out_workload,database = dse.single_out_workload(dse.so_far_best_ex_dp, self.database, workload, self.database.db_input.workload_tasks[workload]) + cur_best_sim_dp_singled_out_workload = dse.eval_design(cur_best_ex_singled_out_workload, database) + + # lower the cur best with single out + most_infer_ex_dp = dse.transform_to_most_inferior_design(dse.so_far_best_ex_dp) + most_infer_ex_dp_singled_out_workload, database = dse.single_out_workload(most_infer_ex_dp, self.database, workload, self.database.db_input.workload_tasks[workload]) + most_infer_sim_dp_singled_out_workload = dse.eval_design(most_infer_ex_dp_singled_out_workload,database) + + # speed ups + customization_first_speed_up = most_infer_sim_dp_singled_out_workload.dp.get_serial_design_time()/cur_best_sim_dp_singled_out_workload.dp.get_serial_design_time() + parallelism_second_speed_up = cur_best_sim_dp_singled_out_workload.dp.get_par_speedup() + + parallelism_first_speed_up = most_infer_sim_dp_singled_out_workload.dp.get_par_speedup() + customization_second_speed_up = most_infer_sim_dp_singled_out_workload.dp_stats.get_system_complex_metric("latency")[workload]/cur_best_sim_dp_singled_out_workload.dp.get_serial_design_time() + + interference_degradation = dse.so_far_best_sim_dp.dp_stats.get_system_complex_metric("latency")[workload]/cur_best_sim_dp_singled_out_workload.dp_stats.get_system_complex_metric("latency")[workload] + + + workload_speed_up[workload] = {"customization_first_speed_up":customization_first_speed_up, + "parallelism_second_speed_up":parallelism_second_speed_up, + "customization_second_speed_up": customization_second_speed_up, + "parallelism_first_speed_up": parallelism_first_speed_up, + "interference_degradation":interference_degradation} + customization_first_speed_up_list.append(customization_first_speed_up) + customization_second_speed_up_list.append(customization_second_speed_up) + parallelism_first_speed_up_list.append(parallelism_first_speed_up) + parallelism_second_speed_up_list.append(parallelism_second_speed_up) + interference_degradation_list.append(interference_degradation) + + + # for the entire design + most_infer_ex_dp = dse.transform_to_most_inferior_design(dse.so_far_best_ex_dp) + most_infer_sim_dp = dse.eval_design(most_infer_ex_dp, self.database) + + most_infer_ex_before_unrolling_dp = dse.transform_to_most_inferior_design_before_loop_unrolling(dse.so_far_best_ex_dp) + most_infer_sim_before_unrolling_dp = dse.eval_design(most_infer_ex_before_unrolling_dp, self.database) + + #customization_first_speed_up_full_system = most_infer_sim_dp.dp.get_serial_design_time()/dse.so_far_best_sim_dp.dp.get_serial_design_time() + #parallelism_second_speed_up_full_system = dse.so_far_best_sim_dp.dp.get_par_speedup() + + #parallelism_first_speed_up_full_system = most_infer_sim_dp.dp.get_par_speedup() + #customization_second_speed_up_full_system = max(list((most_infer_sim_dp.dp_stats.get_system_complex_metric("latency")).values()))/max(list((dse.so_far_best_sim_dp.dp_stats.get_system_complex_metric("latency")).values())) + + customization_speed_up_full_system = most_infer_sim_dp.dp.get_serial_design_time()/most_infer_sim_before_unrolling_dp.dp.get_serial_design_time() + loop_unrolling_parallelism_speed_up_full_system = most_infer_sim_before_unrolling_dp.dp.get_serial_design_time()/dse.so_far_best_sim_dp.dp.get_serial_design_time() + task_level_parallelism_speed_up_full_system = dse.so_far_best_sim_dp.dp.get_serial_design_time()/max(list((dse.so_far_best_sim_dp.dp_stats.get_system_complex_metric("latency")).values())) + + # + + + speedup_avg = {"customization_first_speed_up_avg": st.mean(customization_first_speed_up_list), + "parallelism_second_speed_up_avg": st.mean(parallelism_second_speed_up_list), + "customization_second_speed_up_avg": st.mean(customization_second_speed_up_list), + "parallelism_first_speed_up_avg": st.mean(parallelism_first_speed_up_list), + "interference_degradation_avg": st.mean(interference_degradation_list), + "customization_speed_up_full_system": customization_speed_up_full_system, + "loop_unrolling_parallelism_speed_up_full_system": loop_unrolling_parallelism_speed_up_full_system, + "task_level_parallelism_speed_up_full_system": task_level_parallelism_speed_up_full_system + } + + return workload_speed_up,speedup_avg + + def get_memory_system_attr(self): + memory_system_attr = {} + local_memories = [el for el in self.dp_rep.get_blocks() if el.subtype == "sram"] + global_memories = [el for el in self.dp_rep.get_blocks() if el.subtype == "dram"] + buses = [el for el in self.dp_rep.get_blocks() if el.subtype == "ic"] + + # get frequency data + local_memory_freqs = [mem.get_block_freq() for mem in local_memories] + global_memory_freqs = [mem.get_block_freq() for mem in global_memories] + if len(local_memory_freqs) == 0: + local_memory_avg_freq = 0 + else: + local_memory_avg_freq= sum(local_memory_freqs)/max(len(local_memory_freqs),1) + + if len(local_memory_freqs) in [0, 1]: + local_memory_freq_std = 0 + local_memory_freq_coeff_var = 0 + else: + local_memory_freq_std = st.stdev(local_memory_freqs) + local_memory_freq_coeff_var = st.stdev(local_memory_freqs) / st.mean(local_memory_freqs) + + if len(global_memory_freqs) == 0: + global_memory_avg_freq = 0 + else: + global_memory_avg_freq= sum(global_memory_freqs)/max(len(global_memory_freqs),1) + + + # get bus width data + local_memory_bus_widths = [mem.get_block_bus_width() for mem in local_memories] + global_memory_bus_widths = [mem.get_block_bus_width() for mem in global_memories] + if len(local_memory_bus_widths) == 0: + local_memory_avg_bus_width = 0 + else: + local_memory_avg_bus_width= sum(local_memory_bus_widths)/max(len(local_memory_bus_widths),1) + + if len(local_memory_bus_widths) in [0, 1]: + local_memory_bus_width_std = 0 + local_memory_bus_width_coeff_var = 0 + else: + local_memory_bus_width_std = st.stdev(local_memory_bus_widths) + local_memory_bus_width_coeff_var = st.stdev(local_memory_bus_widths) / st.mean(local_memory_bus_widths) + + if len(global_memory_bus_widths) == 0: + global_memory_avg_bus_width = 0 + else: + global_memory_avg_bus_width= sum(global_memory_bus_widths)/max(len(global_memory_bus_widths),1) + + + #get bytes data + local_memory_bytes = [] + for mem in local_memories: + mem_bytes = max(mem.get_area_in_bytes(), config.cacti_min_memory_size_in_bytes) # to make sure we don't go smaller than cacti's minimum size + local_memory_bytes.append((math.ceil(mem_bytes / config.min_mem_size[mem.subtype])) * config.min_mem_size[mem.subtype]) # modulo calculation + if len(local_memory_bytes) == 0: + local_memory_total_bytes = 0 + local_memory_bytes_avg = 0 + else: + local_memory_total_bytes = sum(local_memory_bytes) + local_memory_bytes_avg = st.mean(local_memory_bytes) + + if len(local_memory_bytes) in [0,1]: + local_memory_bytes_std = 0 + local_memory_bytes_coeff_var = 0 + else: + local_memory_bytes_std = st.stdev(local_memory_bytes) + local_memory_bytes_coeff_var = st.stdev(local_memory_bytes)/max(st.mean(local_memory_bytes),.0000000001) + + global_memory_bytes = [] + for mem in global_memories: + mem_bytes = max(mem.get_area_in_bytes(), config.cacti_min_memory_size_in_bytes) # to make sure we don't go smaller than cacti's minimum size + global_memory_bytes.append((math.ceil(mem_bytes / config.min_mem_size[mem.subtype])) * config.min_mem_size[mem.subtype]) # modulo calculation + if len(global_memory_bytes) == 0: + global_memory_total_bytes = 0 + else: + global_memory_total_bytes = sum(global_memory_bytes) + + if len(global_memory_bytes) in [0,1]: + global_memory_bytes_std = 0 + global_memory_bytes_coeff_var = 0 + else: + global_memory_bytes_std = st.stdev(global_memory_bytes) + global_memory_bytes_coeff_var = st.stdev(global_memory_bytes) / max(st.mean(global_memory_bytes),.00000001) + + + # get area data + local_memory_area = [mem.get_area() for mem in local_memories] + global_memory_area = [mem.get_area() for mem in global_memories] + if len(local_memory_area) == 0: + local_memory_total_area = 0 + else: + local_memory_total_area = sum(local_memory_area) + + if len(local_memory_area) in [0,1]: + local_memory_area_std = 0 + local_memory_area_coeff_var = 0 + else: + local_memory_area_std = st.stdev(local_memory_area) + local_memory_area_coeff_var = st.stdev(local_memory_area) / st.mean(local_memory_area) + + if len(global_memory_area) == 0: + global_memory_total_area = 0 + else: + global_memory_total_area = sum(global_memory_area) + + # get traffic data + local_total_traffic = 0 + for mem in local_memories: + block_s_krnels = self.get_krnels_of_block(mem) + for krnl in block_s_krnels: + local_total_traffic += krnl.calc_traffic_per_block(mem) + + + local_traffic_per_mem = {} + for mem in local_memories: + local_traffic_per_mem[mem] =0 + block_s_krnels = self.get_krnels_of_block(mem) + for krnl in block_s_krnels: + local_traffic_per_mem[mem] += krnl.calc_traffic_per_block(mem) + + + global_total_traffic = 0 + for mem in global_memories: + block_s_krnels = self.get_krnels_of_block(mem) + for krnl in block_s_krnels: + global_total_traffic += krnl.calc_traffic_per_block(mem) + + local_bus_traffic = {} + for mem in local_memories: + local_traffic = 0 + block_s_krnels = self.get_krnels_of_block(mem) + for krnl in block_s_krnels: + local_traffic += krnl.calc_traffic_per_block(mem) + for bus in buses: + if mem in bus.get_neighs(): + if bus not in local_bus_traffic.keys(): + local_bus_traffic[bus] = 0 + local_bus_traffic[bus] += local_traffic + break + + # get traffic reuse + local_traffic_reuse_no_read_ratio = [] + local_traffic_reuse_no_read_in_bytes = [] + local_traffic_reuse_no_read_in_size = [] + local_traffic_reuse_with_read_ratio= [] + local_traffic_reuse_with_read_in_bytes = [] + local_traffic_reuse_with_read_in_size = [] + mem_local_traffic = {} + for mem in local_memories: + local_traffic = 0 + block_s_krnels = self.get_krnels_of_block(mem) + for krnl in block_s_krnels: + local_traffic += krnl.calc_traffic_per_block(mem) + mem_local_traffic[mem] = local_traffic + mem_bytes = max(mem.get_area_in_bytes(), config.cacti_min_memory_size_in_bytes) # to make sure we don't go smaller than cacti's minimum size + #mem_bytes_modulo = (math.ceil(mem_bytes/config.min_mem_size[mem.subtype]))*config.min_mem_size[mem.subtype] # modulo calculation + mem_size = mem.get_area() + reuse_ratio_no_read = max((local_traffic/mem_bytes)-2, 0) + local_traffic_reuse_no_read_ratio.append(reuse_ratio_no_read) + local_traffic_reuse_no_read_in_bytes.append(reuse_ratio_no_read*mem_bytes) + local_traffic_reuse_no_read_in_size.append(reuse_ratio_no_read*mem_size) + reuse_ratio_with_read = max((local_traffic/mem_bytes)-1, 0) + local_traffic_reuse_with_read_ratio.append(reuse_ratio_with_read) + local_traffic_reuse_with_read_in_bytes.append(reuse_ratio_with_read*mem_bytes) + local_traffic_reuse_with_read_in_size.append(reuse_ratio_with_read*mem_size) + + if len(local_memories) == 0: + local_total_traffic_reuse_no_read_ratio = 0 + local_total_traffic_reuse_no_read_in_bytes = 0 + local_total_traffic_reuse_no_read_in_size = 0 + local_total_traffic_reuse_with_read_ratio = 0 + local_total_traffic_reuse_with_read_in_bytes = 0 + local_total_traffic_reuse_with_read_in_size = 0 + local_traffic_per_mem_avg = 0 + else: + local_total_traffic_reuse_no_read_ratio = max((local_total_traffic/local_memory_total_bytes)-2, 0) + local_total_traffic_reuse_no_read_in_bytes = sum(local_traffic_reuse_no_read_in_bytes) + local_total_traffic_reuse_no_read_in_size = sum(local_traffic_reuse_no_read_in_size) + local_total_traffic_reuse_with_read_ratio = max((local_total_traffic/local_memory_total_bytes)-1, 0) + local_total_traffic_reuse_with_read_in_bytes = sum(local_traffic_reuse_with_read_in_bytes) + local_total_traffic_reuse_with_read_in_size = sum(local_traffic_reuse_with_read_in_size) + local_traffic_per_mem_avg = st.mean(list(local_traffic_per_mem.values())) + + + if len(local_bus_traffic) == 0: + local_bus_traffic_avg = 0 + else: + local_bus_traffic_avg = st.mean(list(local_bus_traffic.values())) + + if len(local_memories) in [0,1]: + local_traffic_per_mem_std = 0 + local_traffic_per_mem_coeff_var = 0 + + else: + local_traffic_per_mem_std = st.stdev(list(local_traffic_per_mem.values())) + local_traffic_per_mem_coeff_var = st.stdev(list(local_traffic_per_mem.values()))/st.mean(list(local_traffic_per_mem.values())) + + + if len(local_bus_traffic) in [0,1]: + local_bus_traffic_std = 0 + local_bus_traffic_coeff_var = 0 + else: + local_bus_traffic_std = st.stdev(list(local_bus_traffic.values())) + local_bus_traffic_coeff_var = st.stdev(list(local_bus_traffic.values()))/st.mean(list(local_bus_traffic.values())) + + # get traffic reuse + global_traffic_reuse_no_read_ratio= [] + global_traffic_reuse_no_read_in_bytes = [] + global_traffic_reuse_no_read_in_size = [] + global_traffic_reuse_with_read_ratio= [] + global_traffic_reuse_with_read_in_bytes = [] + global_traffic_reuse_with_read_in_size = [] + for mem in global_memories: + global_traffic = 0 + block_s_krnels = self.get_krnels_of_block(mem) + for krnl in block_s_krnels: + global_traffic += krnl.calc_traffic_per_block(mem) + mem_bytes = max(mem.get_area_in_bytes(), config.cacti_min_memory_size_in_bytes) # to make sure we don't go smaller than cacti's minimum size + #mem_bytes_modulo = (math.ceil(mem_bytes/config.min_mem_size[mem.subtype]))*config.min_mem_size[mem.subtype] # modulo calculation + mem_size = mem.get_area() + reuse_ratio_no_read = max((global_traffic/mem_bytes)-2, 0) + global_traffic_reuse_no_read_ratio.append(reuse_ratio_no_read) + global_traffic_reuse_no_read_in_bytes.append(reuse_ratio_no_read*mem_bytes) + global_traffic_reuse_no_read_in_size.append(reuse_ratio_no_read*mem_size) + reuse_ratio_with_read = max((global_traffic/mem_bytes)-1, 0) + global_traffic_reuse_with_read_ratio.append(reuse_ratio_with_read) + global_traffic_reuse_with_read_in_bytes.append(reuse_ratio_with_read*mem_bytes) + global_traffic_reuse_with_read_in_size.append(reuse_ratio_with_read*mem_size) + + if len(global_memories) == 0: + global_total_traffic_reuse_no_read_ratio = 0 + global_total_traffic_reuse_no_read_in_bytes = 0 + global_total_traffic_reuse_no_read_in_size = 0 + global_total_traffic_reuse_with_read_ratio = 0 + global_total_traffic_reuse_with_read_in_bytes = 0 + global_total_traffic_reuse_with_read_in_size = 0 + else: + global_total_traffic_reuse_no_read_ratio = max((global_total_traffic/global_memory_total_bytes)-2, 0) + global_total_traffic_reuse_no_read_in_bytes = sum(global_traffic_reuse_no_read_in_bytes) + global_total_traffic_reuse_no_read_in_size = sum(global_traffic_reuse_no_read_in_size) + global_total_traffic_reuse_with_read_ratio = max((global_total_traffic/global_memory_total_bytes)-1, 0) + global_total_traffic_reuse_with_read_in_bytes = sum(global_traffic_reuse_with_read_in_bytes) + global_total_traffic_reuse_with_read_in_size = sum(global_traffic_reuse_with_read_in_size) + + + + # per cluster start + # get traffic reuse + local_traffic_reuse_no_read_in_bytes_per_cluster = {} + local_traffic_reuse_no_read_in_size_per_cluster = {} + local_traffic_reuse_with_read_ratio_per_cluster = {} + local_traffic_reuse_with_read_in_bytes_per_cluster = {} + local_traffic_reuse_with_read_in_size_per_cluster = {} + + for bus in buses: + mems = [blk for blk in bus.get_neighs() if blk.subtype == "sram"] + local_traffic_reuse_no_read_in_bytes_per_cluster[bus] = 0 + local_traffic_reuse_no_read_in_size_per_cluster[bus] = 0 + local_traffic_reuse_with_read_in_bytes_per_cluster[bus] = 0 + local_traffic_reuse_with_read_in_size_per_cluster[bus] = 0 + for mem in mems: + local_traffic = 0 + block_s_krnels = self.get_krnels_of_block(mem) + for krnl in block_s_krnels: + local_traffic += krnl.calc_traffic_per_block(mem) + mem_bytes = max(mem.get_area_in_bytes(), config.cacti_min_memory_size_in_bytes) # to make sure we don't go smaller than cacti's minimum size + #mem_bytes_modulo = (math.ceil(mem_bytes/config.min_mem_size[mem.subtype]))*config.min_mem_size[mem.subtype] # modulo calculation + mem_size = mem.get_area() + reuse_ratio_no_read_per_cluster = max((local_traffic/mem_bytes)-2, 0) + local_traffic_reuse_no_read_in_bytes_per_cluster[bus]+= (reuse_ratio_no_read_per_cluster*mem_bytes) + local_traffic_reuse_no_read_in_size_per_cluster[bus]+=(reuse_ratio_no_read_per_cluster*mem_size) + reuse_ratio_with_read_per_cluster = max((local_traffic/mem_bytes)-1, 0) + local_traffic_reuse_with_read_in_bytes_per_cluster[bus] += (reuse_ratio_with_read_per_cluster*mem_bytes) + local_traffic_reuse_with_read_in_size_per_cluster[bus] += (reuse_ratio_with_read_per_cluster*mem_size) + + + local_total_traffic_reuse_no_read_in_size_per_cluster_avg = st.mean(list(local_traffic_reuse_no_read_in_size_per_cluster.values())) + local_total_traffic_reuse_with_read_in_size_per_cluster_avg = st.mean(list(local_traffic_reuse_with_read_in_size_per_cluster.values())) + local_total_traffic_reuse_no_read_in_bytes_per_cluster_avg = st.mean(list(local_traffic_reuse_no_read_in_bytes_per_cluster.values())) + local_total_traffic_reuse_with_read_in_bytes_per_cluster_avg = st.mean(list(local_traffic_reuse_with_read_in_bytes_per_cluster.values())) + + if len(buses) in [0,1]: + local_total_traffic_reuse_no_read_in_size_per_cluster_std = 0 + local_total_traffic_reuse_with_read_in_size_per_cluster_std = 0 + local_total_traffic_reuse_no_read_in_bytes_per_cluster_std = 0 + local_total_traffic_reuse_with_read_in_bytes_per_cluster_std = 0 + local_total_traffic_reuse_no_read_in_size_per_cluster_var = 0 + local_total_traffic_reuse_with_read_in_size_per_cluster_var = 0 + local_total_traffic_reuse_no_read_in_bytes_per_cluster_var = 0 + local_total_traffic_reuse_with_read_in_bytes_per_cluster_var = 0 + else: + local_total_traffic_reuse_no_read_in_size_per_cluster_std = st.stdev( + list(local_traffic_reuse_no_read_in_size_per_cluster.values())) + local_total_traffic_reuse_with_read_in_size_per_cluster_std = st.stdev( + list(local_traffic_reuse_with_read_in_size_per_cluster.values())) + local_total_traffic_reuse_no_read_in_bytes_per_cluster_std = st.stdev( + list(local_traffic_reuse_no_read_in_bytes_per_cluster.values())) + local_total_traffic_reuse_with_read_in_bytes_per_cluster_std = st.stdev( + list(local_traffic_reuse_with_read_in_bytes_per_cluster.values())) + local_total_traffic_reuse_no_read_in_size_per_cluster_var = st.stdev(list(local_traffic_reuse_no_read_in_size_per_cluster.values()))/max(st.mean(list(local_traffic_reuse_no_read_in_size_per_cluster.values())),.000001) + local_total_traffic_reuse_with_read_in_size_per_cluster_var = st.stdev(list(local_traffic_reuse_with_read_in_size_per_cluster.values()))/max(st.mean(list(local_traffic_reuse_with_read_in_size_per_cluster.values())),.0000001) + local_total_traffic_reuse_no_read_in_bytes_per_cluster_var = st.stdev(list(local_traffic_reuse_no_read_in_bytes_per_cluster.values()))/max(st.mean(list(local_traffic_reuse_no_read_in_bytes_per_cluster.values())),.000000001) + local_total_traffic_reuse_with_read_in_bytes_per_cluster_var = st.stdev(list(local_traffic_reuse_with_read_in_bytes_per_cluster.values()))/max(st.mean(list(local_traffic_reuse_with_read_in_bytes_per_cluster.values())),.00000001) + + + # per cluseter end + locality_in_bytes = 0 + for krnl in self.__kernels: + pe = [blk for blk in krnl.get_blocks() if blk.type == "pe"][0] + mems = [blk for blk in krnl.get_blocks() if blk.type == "mem"] + for mem in mems: + path_length = len(self.dp_rep.get_hardware_graph().get_path_between_two_vertecies(pe, mem)) + locality_in_bytes += krnl.calc_traffic_per_block(mem)/(path_length-2) + + """ + #parallelism data + for mem in local_memories: + bal_traffic = 0 + block_s_krnels = self.get_krnels_of_block(mem) + for krnl in blocks_krnels: + + krnl.block_phase_read_dict[mem][self.phase_num] += read_work + """ + + + return {"local_total_traffic":local_total_traffic, "global_total_traffic":global_total_traffic, + "local_total_traffic_reuse_no_read_ratio": local_total_traffic_reuse_no_read_ratio, "global_total_traffic_reuse_no_read_ratio": global_total_traffic_reuse_no_read_ratio, + "local_total_traffic_reuse_no_read_in_bytes": local_total_traffic_reuse_no_read_in_bytes, "global_total_traffic_reuse_no_read_in_bytes": global_total_traffic_reuse_no_read_in_bytes, + "local_total_traffic_reuse_no_read_in_size": local_total_traffic_reuse_no_read_in_size, "global_total_traffic_reuse_no_read_in_size": global_total_traffic_reuse_no_read_in_size, + "local_total_traffic_reuse_with_read_ratio": local_total_traffic_reuse_with_read_ratio, + "global_total_traffic_reuse_with_read_ratio": global_total_traffic_reuse_with_read_ratio, + "local_total_traffic_reuse_with_read_in_bytes": local_total_traffic_reuse_with_read_in_bytes, + "global_total_traffic_reuse_with_read_in_bytes": global_total_traffic_reuse_with_read_in_bytes, + "local_total_traffic_reuse_with_read_in_size": local_total_traffic_reuse_with_read_in_size, + "global_total_traffic_reuse_with_read_in_size": global_total_traffic_reuse_with_read_in_size, + "local_total_traffic_reuse_no_read_in_bytes_per_cluster_avg": local_total_traffic_reuse_no_read_in_bytes_per_cluster_avg, + "local_total_traffic_reuse_no_read_in_bytes_per_cluster_std": local_total_traffic_reuse_no_read_in_bytes_per_cluster_std, + "local_total_traffic_reuse_no_read_in_bytes_per_cluster_var": local_total_traffic_reuse_no_read_in_bytes_per_cluster_var, + "local_total_traffic_reuse_no_read_in_size_per_cluster_avg": local_total_traffic_reuse_no_read_in_size_per_cluster_avg, + "local_total_traffic_reuse_no_read_in_size_per_cluster_std": local_total_traffic_reuse_no_read_in_size_per_cluster_std, + "local_total_traffic_reuse_no_read_in_size_per_cluster_var": local_total_traffic_reuse_no_read_in_size_per_cluster_var, + "local_total_traffic_reuse_with_read_in_bytes_per_cluster_avg": local_total_traffic_reuse_with_read_in_bytes_per_cluster_avg, + "local_total_traffic_reuse_with_read_in_bytes_per_cluster_std": local_total_traffic_reuse_with_read_in_bytes_per_cluster_std, + "local_total_traffic_reuse_with_read_in_bytes_per_cluster_var": local_total_traffic_reuse_with_read_in_bytes_per_cluster_var, + "local_total_traffic_reuse_with_read_in_size_per_cluster_avg": local_total_traffic_reuse_with_read_in_size_per_cluster_avg, + "local_total_traffic_reuse_with_read_in_size_per_cluster_std": local_total_traffic_reuse_with_read_in_size_per_cluster_std, + "local_total_traffic_reuse_with_read_in_size_per_cluster_var": local_total_traffic_reuse_with_read_in_size_per_cluster_var, + "global_memory_avg_freq": global_memory_avg_freq, + "local_memory_avg_freq": local_memory_avg_freq, + "local_memory_freq_coeff_var": local_memory_freq_coeff_var, "local_memory_freq_std": local_memory_freq_std, + "global_memory_avg_bus_width": global_memory_avg_bus_width, + "local_memory_avg_bus_width": local_memory_avg_bus_width, + "local_memory_bus_width_coeff_var": local_memory_bus_width_coeff_var, + "local_memory_bus_width_std": local_memory_bus_width_std, + "global_memory_total_area": global_memory_total_area, + "local_memory_total_area":local_memory_total_area, + "local_memory_area_coeff_var": local_memory_area_coeff_var, "local_memory_area_std": local_memory_area_std, + "global_memory_total_bytes": global_memory_total_bytes, + "local_memory_total_bytes": local_memory_total_bytes, + "local_memory_bytes_avg": local_memory_bytes_avg, + "local_memory_bytes_coeff_var": local_memory_bytes_coeff_var, "local_memory_bytes_std": local_memory_bytes_std, + "memory_total_area":global_memory_total_area+local_memory_total_area, + "local_mem_cnt":len(local_memory_freqs), + "local_memory_traffic_per_mem_avg": local_traffic_per_mem_avg, + "local_memory_traffic_per_mem_std": local_traffic_per_mem_coeff_var, + "local_memory_traffic_per_mem_coeff_var": local_traffic_per_mem_coeff_var, + "local_bus_traffic_avg": local_bus_traffic_avg, + "local_bus_traffic_std": local_bus_traffic_std, + "local_bus_traffic_coeff_var": local_bus_traffic_coeff_var, + "locality_in_bytes": locality_in_bytes + } + + + + def get_krnels_of_block(self, block): + block_s_tasks = block.get_tasks_of_block() + block_s_krnels = [] + # get krnels of block + for task in block_s_tasks: + for krnl in self.__kernels: + if krnl.get_task_name() == task.get_name(): + block_s_krnels.append(krnl) + return block_s_krnels + + def get_krnels_of_block_clustered_by_workload(self, block): + workload_kernels = {} + block_s_tasks = block.get_tasks_of_block() + block_s_krnels = [] + # get krnels of block + for task in block_s_tasks: + for krnl in self.__kernels: + if krnl.get_task_name() == task.get_name(): + workload = self.database.db_input.task_workload[krnl.get_task_name()], + if workload not in workload_kernels.keys(): + workload_kernels[workload] =[] + workload_kernels[workload].append(krnl) + return workload_kernels + + + def get_bus_system_attr(self): + bus_system_attr = {} + # in reality there is only one system bus + for el, val in self.infer_system_bus_attr().items(): + bus_system_attr[el] = val + for el, val in self.infer_local_buses_attr().items(): + bus_system_attr[el] = val + + return bus_system_attr + + + def infer_system_bus_attr(self): + # has to get the max, as for now, system bus is infered and not imposed + highest_freq = 0 + highest_width = 0 + system_mems = [] + for block in self.dp_rep.get_blocks(): + if block.subtype == "dram": + system_mems.append(block) + + + system_mems_avg_work_rates = [] + system_mems_max_work_rates = [] + for mem in system_mems: + block_work_phase = {} + phase_write_work_rate = {} + phase_read_work_rate = {} + krnls_of_block = self.get_krnels_of_block(mem) + for krnl in krnls_of_block: + for phase, work in krnl.block_phase_write_dict[mem].items(): + if phase not in phase_write_work_rate.keys(): + phase_write_work_rate[phase] = 0 + + if krnl.stats.phase_latency_dict[phase] == 0: + phase_write_work_rate[phase] += 0 + else: + phase_write_work_rate[phase] += (work/krnl.stats.phase_latency_dict[phase]) + + for krnl in krnls_of_block: + for phase, work in krnl.block_phase_read_dict[mem].items(): + if phase not in phase_read_work_rate.keys(): + phase_read_work_rate[phase] = 0 + + if krnl.stats.phase_latency_dict[phase] == 0: + phase_read_work_rate[phase] += 0 + else: + phase_read_work_rate[phase] += (work/krnl.stats.phase_latency_dict[phase]) + + avg_write_work_rate = sum(list(phase_write_work_rate.values()))/len(list(phase_write_work_rate.values())) + avg_read_work_rate = sum(list(phase_read_work_rate.values()))/len(list(phase_read_work_rate.values())) + max_write_work_rate = max(list(phase_write_work_rate.values())) + max_read_work_rate = max(list(phase_read_work_rate.values())) + system_mems_avg_work_rates.append(max(avg_read_work_rate, avg_write_work_rate)) + system_mems_max_work_rates.append(max(max_write_work_rate, max_read_work_rate)) + + + # there might be no system bus at the moment + if len(system_mems) == 0: + count = 0 + system_mem_theoretical_bandwidth = 0 + highest_width = 0 + highest_freq= 0 + system_mem_avg_work_rate = system_mem_max_work_rate = 0 + else: + count = 1 + highest_width= max([system_mem.get_block_bus_width() for system_mem in system_mems]) + highest_freq= max([system_mem.get_block_freq() for system_mem in system_mems]) + system_mem_theoretical_bandwidth = highest_width*highest_freq + system_mem_avg_work_rate = sum(system_mems_avg_work_rates)/len(system_mems_avg_work_rates) + # averate of max + system_mem_max_work_rate = sum(system_mems_max_work_rates)/len(system_mems_max_work_rates) + + return {"system_bus_count":count, "system_bus_avg_freq":highest_freq, "system_bus_avg_bus_width":highest_width, + "system_bus_avg_theoretical_bandwidth":system_mem_theoretical_bandwidth, + "system_bus_avg_actual_bandwidth": system_mem_avg_work_rate, + "system_bus_max_actual_bandwidth": system_mem_max_work_rate + } + + + def infer_if_is_a_local_bus(self, block): + if block.type == "ic": + block_ic_mem_neighs = [el for el in block.get_neighs() if el.type == "mem"] + block_ic_dram_mem_neighs = [el for el in block.get_neighs() if el.subtype == "dram"] + if not len(block_ic_mem_neighs) == len(block_ic_dram_mem_neighs): + return True + return False + + # find the number buses that do not have dram connected to them. + # Note that it will be better if we have already set the system bus and not infereing it. + # TODO for later + def infer_local_buses_attr(self): + attr_val = {} + # get all the local buses + local_buses = [] + for block in self.dp_rep.get_blocks(): + if self.infer_if_is_a_local_bus(block): + local_buses.append(block) + + # get all the frequenies + freq_list = [] + for bus in local_buses: + freq_list.append(bus.get_block_freq()) + + # get all the bus widths + bus_width_list = [] + for bus in local_buses: + bus_width_list.append(bus.get_block_bus_width()) + + bus_bandwidth_list = [] + for bus in local_buses: + bus_bandwidth_list.append(bus.get_block_bus_width()*bus.get_block_freq()) + + local_buses_avg_work_rate_list = [] + local_buses_max_work_rate_list = [] + for bus in local_buses: + work_rate = [] + for pipe_cluster in bus.get_pipe_clusters(): + pathlet_phase_work_rate = pipe_cluster.get_pathlet_phase_work_rate() + for pathlet, phase_work_rate in pathlet_phase_work_rate.items(): + if not pathlet.get_out_pipe().get_slave().subtype == "dram": + work_rate.extend(list(phase_work_rate.values())) + local_buses_avg_work_rate_list.append(sum(work_rate)/len(work_rate)) + local_buses_max_work_rate_list.append(max(work_rate)) + + + local_channels_avg_work_rate_list = [] + local_channels_max_work_rate_list = [] + for bus in local_buses: + for pipe_cluster in bus.get_pipe_clusters(): + work_rate = [] + pathlet_phase_work_rate = pipe_cluster.get_pathlet_phase_work_rate() + for pathlet, phase_work_rate in pathlet_phase_work_rate.items(): + if not pathlet.get_out_pipe().get_slave().subtype == "dram": + work_rate.extend(list(phase_work_rate.values())) + if len(work_rate) == 0: + continue + local_channels_avg_work_rate_list.append(sum(work_rate)/max(len(work_rate),1)) + local_channels_max_work_rate_list.append(max(work_rate)) + + + local_channels_cnt_per_bus = {} + for bus in local_buses: + local_channels_cnt_per_bus[bus] =0 + work_rate = [] + for pipe_cluster in bus.get_pipe_clusters(): + pathlet_phase_work_rate = pipe_cluster.get_pathlet_phase_work_rate() + for pathlet, phase_work_rate in pathlet_phase_work_rate.items(): + if not pathlet.get_out_pipe().get_slave().subtype == "dram": + work_rate.extend(list(phase_work_rate.values())) + if len(work_rate) == 0: + continue + local_channels_cnt_per_bus[bus] +=1 + + attr_val["local_bus_count"] = len(local_buses) + if len(local_buses) == 0: + attr_val["avg_freq"] = 0 + attr_val["local_bus_avg_freq"] = 0 + attr_val["local_bus_avg_bus_width"] = 0 + attr_val["local_bus_avg_theoretical_bandwidth"] = 0 + attr_val["local_bus_avg_actual_bandwidth"] = 0 + attr_val["local_bus_max_actual_bandwidth"] = 0 + attr_val["local_bus_cnt"] = 0 + attr_val["local_channel_avg_actual_bandwidth"] = 0 + attr_val["local_channel_max_actual_bandwidth"] = 0 + attr_val["local_channel_count_per_bus_avg"] = 0 + else: + attr_val["avg_freq"] = sum(freq_list) / len(freq_list) + attr_val["local_bus_avg_freq"] = sum(freq_list) / len(freq_list) + attr_val["local_bus_avg_bus_width"] = sum(bus_width_list)/len(freq_list) + attr_val["local_bus_avg_theoretical_bandwidth"] = sum(bus_bandwidth_list)/len(bus_bandwidth_list) + attr_val["local_bus_avg_actual_bandwidth"] = sum(local_buses_avg_work_rate_list)/len(local_buses_avg_work_rate_list) + # getting average of max + attr_val["local_bus_max_actual_bandwidth"] = sum(local_buses_max_work_rate_list)/len(local_buses_max_work_rate_list) + attr_val["local_bus_cnt"] = len(bus_width_list) + attr_val["local_channel_avg_actual_bandwidth"] = st.mean(local_channels_avg_work_rate_list) + attr_val["local_channel_max_actual_bandwidth"] = st.mean(local_channels_max_work_rate_list) + attr_val["local_channel_count_per_bus_avg"] = st.mean(list(local_channels_cnt_per_bus.values())) + + + if len(local_buses) in [0,1]: + attr_val["local_bus_freq_std"] = 0 + attr_val["local_bus_freq_coeff_var"] = 0 + attr_val["local_bus_bus_width_std"] = 0 + attr_val["local_bus_bus_width_coeff_var"] = 0 + attr_val["local_bus_actual_bandwidth_std"] = 0 + attr_val["local_bus_actual_bandwidth_coeff_var"] = 0 + attr_val["local_channel_actual_bandwidth_std"] = 0 + attr_val["local_channel_actual_bandwidth_coeff_var"] = 0 + attr_val["local_channel_count_per_bus_std"] = 0 + attr_val["local_channel_count_per_bus_coeff_var"] = 0 + else: + attr_val["local_bus_freq_std"] = st.stdev(freq_list) + attr_val["local_bus_freq_coeff_var"] = st.stdev(freq_list)/st.mean(freq_list) + attr_val["local_bus_bus_width_std"] = st.stdev(bus_width_list) + attr_val["local_bus_bus_width_coeff_var"] = st.stdev(bus_width_list)/st.mean(bus_width_list) + attr_val["local_bus_actual_bandwidth_std"] = st.stdev(local_buses_avg_work_rate_list) + attr_val["local_bus_actual_bandwidth_coeff_var"] = st.stdev(local_buses_avg_work_rate_list)/st.mean(local_buses_avg_work_rate_list) + attr_val["local_channel_actual_bandwidth_std"] = st.stdev(local_channels_avg_work_rate_list) + attr_val["local_channel_actual_bandwidth_coeff_var"] = st.stdev(local_channels_avg_work_rate_list)/st.mean(local_channels_avg_work_rate_list) + attr_val["local_channel_count_per_bus_std"] = st.stdev(list(local_channels_cnt_per_bus.values())) + attr_val["local_channel_count_per_bus_coeff_var"] = st.stdev(list(local_channels_cnt_per_bus.values()))/st.mean(list(local_channels_cnt_per_bus.values())) + + return attr_val + + + # iterate through all the design points and + # collect their stats + def collect_stats(self): + for type, id in self.dp_rep.get_designs_SOCs(): + # level 1 reduction for intra design questions + self.intra_design_reduction(type, id) + # level 2 questions for across/inter design questions + self.inter_design_reduction(type, id) + + # level 1 reduction for intra design questions + def intra_design_reduction(self, SOC_type, SOC_id): + kernel_latency_dict = {} + latency_list = [] + kernel_metric_values = defaultdict(lambda: defaultdict(list)) + for dp in self.sim_dp_container.design_point_list: + for kernel_ in dp.get_kernels(): + for metric in config.all_metrics: + kernel_metric_values[kernel_.get_task_name()][metric].append\ + (kernel_.stats.get_metric(metric)) + + for kernel in self.__kernels: + for metric in config.all_metrics: + kernel.stats.set_stats_directly(metric, + self.reduce(kernel_metric_values[kernel.get_task_name()][metric])) + + def get_kernels(self): + return self.__kernels + + # Functionality: level 2 questions for across/inter design questions + def inter_design_reduction(self, SOC_type, SOC_id): + for metric_name in config.all_metrics: + self.set_SOC_metric_value(metric_name, SOC_type, SOC_id) + self.set_system_complex_metric(metric_name) # data per System + + # hot = longest latency + def get_hot_kernel_SOC(self, SOC_type, SOC_id, metric="latency", krnel_rank=0): + kernels_on_SOC = [kernel for kernel in self.__kernels if kernel.SOC_type == SOC_type and kernel.SOC_id == SOC_id] + for k in kernels_on_SOC: + if (k.stats.get_metric(metric) is None): + print("metric is " + metric) + sorted_kernels_hot_to_cold = sorted(kernels_on_SOC, key=lambda kernel: kernel.stats.get_metric(metric), reverse=True) + return sorted_kernels_hot_to_cold[krnel_rank] + + # get the hot kernels if the system. Hot means the bottleneck or rather the + # most power/energy/area/performance consuming of the system. This is determined + # by the input argument metric. + # Variables: + # knrle_rank: rank of the kernel to pick once we have already sorted the kernels + # based on hot ness. 0, means the hottest and higher values mean colder ones. + # We use this value to sometimes target less hot kernels if the hot kernel + # can not be improved any mot less hot kernels if the hot kernel + # can not be improved any more. + def get_hot_kernel_system_complex(self, metric="latency", krnel_rank=0): + hot_krnel_list = [] + for SOC_type, SOC_id in self.get_designs_SOCs(): + hot_krnel_list.append(self.get_hot_kernel_SOC(SOC_type, SOC_id, metric, krnel_rank)) + + return sorted(hot_krnel_list, key=lambda kernel: kernel.stats.get_metric(metric), reverse=True)[0] + + # sort the blocks for a kernel based how much impact they have on a metric + def get_hot_block_of_krnel_sorted(self, krnl_task_name, metric="latency"): + # find the hottest kernel + #hot_krnel = self.get_hot_kernel_SOC(SOC_type, SOC_id, metric, krnel_rank) + krnel_of_interest = [krnel for krnel in self.__kernels if krnel.get_task_name() == krnl_task_name] + assert(len(krnel_of_interest) == 1), "can't have no krnel with this name or more than one" + krnl = krnel_of_interest[0] + + # find the hot block accordingly + # TODO: this is not quit right since + # hot kernel of different designs might have different + # block bottlenecks, but here we just use the + # the block bottleneck of the representative design + # since self.__kernels are set to this designs kernels + kernel_blck_sorted : Block = krnl.stats.get_block_sorted(metric) + return kernel_blck_sorted + + # ------------------------------------------- + # Functionality: + # get the hot block among the blocks that a kernel resides in based how much impact they have on a metric. + # Hot means the bottleneck or rather the + # most power/energy/area/performance consuming of the system. This is determined + # by the input argument metric. + # ------------------------------------------- + def get_hot_block_of_krnel(self, krnl_task_name, metric="latency"): + # find the hottest kernel + #hot_krnel = self.get_hot_kernel_SOC(SOC_type, SOC_id, metric, krnel_rank) + krnel_of_interest = [krnel for krnel in self.__kernels if krnel.get_task_name() == krnl_task_name] + assert(len(krnel_of_interest) == 1), "can't have no krnel with this name or more than one" + krnl = krnel_of_interest[0] + + + # find the hot block accordingly + # TODO: this is not quit right since + # hot kernel of different designs might have different + # block bottlenecks, but here we just use the + # the block bottleneck of the representative design + # since self.__kernels are set to this designs kernels + kernel_blck_bottleneck: Block = krnl.stats.get_block_bottleneck(metric) + return kernel_blck_bottleneck + + # ------------------------------------------- + # Functionality: + # get the hot block among the blocks of the entire SOC based on the metric and kernel rank. + # Hot means the bottleneck or rather the + # most power/energy/area/performance consuming of the system. This is determined + # by the input argument metric. + # Variables: + # krnel_rank: rank of the kernel to pick once we have already sorted the kernels + # based on hot ness. 0, means the hottest and higher values mean colder ones. + # We use this value to sometimes target less hot kernels if the hot kernel + # can not be improved any mot less hot kernels if the hot kernel + # can not be improved any more. + # ------------------------------------------- + def get_hot_block_SOC(self, SOC_type, SOC_id, metric="latency", krnel_rank=0): + # find the hottest kernel + hot_krnel = self.get_hot_kernel_SOC(SOC_type, SOC_id, metric, krnel_rank) + + # find the hot block accordingly + # TODO: this is not quit right since + # hot kernel of different designs might have different + # block bottlenecks, but here we just use the + # the block bottleneck of the representative design + # since self.__kernels are set to this designs kernels + hot_kernel_blck_bottleneck:Block = hot_krnel.stats.get_block_bottleneck(metric) + return hot_kernel_blck_bottleneck + # corresponding block bottleneck. We need this since we make a copy of the the sim_dp, + # and hence, sim_dp and ex_dp won't be synced any more + #coress_hot_kernel_blck_bottleneck = self.find_cores_hot_kernel_blck_bottleneck(ex_dp, hot_kernel_blck_bottleneck) + #return cores_hot_kernel_blck_bottleneck + + # ------------------------------------------- + # Functionality: + # get the hot block among the blocks of the entire system complex based on the metric and kernel rank. + # Hot means the bottleneck or rather the + # most power/energy/area/performance consuming of the system. This is determined + # by the input argument metric. + # Variables: + # krnel_rank: rank of the kernel to pick once we have already sorted the kernels + # based on hot ness. 0, means the hottest and higher values mean colder ones. + # We use this value to sometimes target less hot kernels if the hot kernel + # can not be improved any mot less hot kernels if the hot kernel + # can not be improved any more. + # ------------------------------------------- + def get_hot_block_system_complex(self, metric="latency", krnel_rank=0): + hot_blck_list = [] + for SOC_type, SOC_id in self.get_designs_SOCs(): + hot_blck_list.append(self.get_hot_block_SOC(SOC_type, SOC_id, metric, krnel_rank)) + + return sorted(hot_blck_list, key=lambda blck: blck.get_metric(metric), reverse=True)[0] + + # ------------------------------------------- + # Functionality: + # calculating the metric (power,performance,area) value + # Variables: + # metric_type: which metric to calculate for + # SOC_type: type of the SOC, since we can accept multi SOC designs + # SOC_id: id of the SOC to target + # ------------------------------------------- + def calc_SOC_metric_value(self, metric_type, SOC_type, SOC_id): + self.unreduced_results = [] + # call dp_stats of each design and then reduce + for dp in self.sim_dp_container.design_point_list: + self.unreduced_results.append(dp.dp_stats.get_SOC_metric_value(metric_type, SOC_type, SOC_id)) + return self.reduce(self.unreduced_results) + + # ------------------------------------------- + # Functionality: + # calculating the area value + # Variables: + # type: mem, ic, pe + # SOC_type: type of the SOC, since we can accept multi SOC designs + # SOC_id: id of the SOC to target + # ------------------------------------------- + def calc_SOC_area_base_on_type(self, type_, SOC_type, SOC_id): + area_list = [] + for dp in self.sim_dp_container.design_point_list: + area_list.append(dp.dp_stats.get_SOC_area_base_on_type(type_, SOC_type, SOC_id)) + return self.reduce(area_list) + + def calc_SOC_area_base_on_subtype(self, subtype_, SOC_type, SOC_id): + area_list = [] + for dp in self.sim_dp_container.design_point_list: + area_list.append(dp.dp_stats.get_SOC_area_base_on_subtype(subtype_, SOC_type, SOC_id)) + return self.reduce(area_list) + + + def set_SOC_metric_value(self,metric_type, SOC_type, SOC_id): + assert(metric_type in config.all_metrics), metric_type + " is not supported" + if metric_type == "area": + for block_type in ["mem", "ic", "pe"]: + self.SOC_area_dict[block_type][SOC_type][SOC_id] = self.calc_SOC_area_base_on_type(block_type, SOC_type, SOC_id) + for block_subtype in ["dram", "sram", "ic", "ip", "gpp"]: + self.SOC_area_subtype_dict[block_subtype][SOC_type][SOC_id] = self.calc_SOC_area_base_on_subtype(block_subtype, SOC_type, SOC_id) + self.SOC_metric_dict[metric_type][SOC_type][SOC_id] = self.calc_SOC_metric_value(metric_type, SOC_type, SOC_id) + + + def set_system_complex_metric(self, metric_type): + type_id_list = self.dp_rep.get_designs_SOCs() + # only corner case is for area as + # we want area specific even to the blocks + if (metric_type == "area"): + for block_type in ["pe", "mem", "ic"]: + for type_, id_ in type_id_list: + self.system_complex_area_dict[block_type] = sum([self.get_SOC_area_base_on_type(block_type, type_, id_) + for type_, id_ in type_id_list]) + + self.system_complex_area_dram_non_dram["non_dram"] = 0 + for block_subtype in ["sram", "ic", "gpp", "ip"]: + for type_, id_ in type_id_list: + self.system_complex_area_dram_non_dram["non_dram"] += sum([self.get_SOC_area_base_on_subtype(block_subtype, type_, id_) + for type_, id_ in type_id_list]) + for block_subtype in ["dram"]: + for type_, id_ in type_id_list: + self.system_complex_area_dram_non_dram["dram"] = sum([self.get_SOC_area_base_on_subtype(block_subtype, type_, id_) + for type_, id_ in type_id_list]) + + if metric_type in ["area", "energy", "cost"]: + self.system_complex_metric_dict[metric_type] = sum([self.get_SOC_metric_value(metric_type, type_, id_) + for type_, id_ in type_id_list]) + elif metric_type in ["latency"]: + self.system_complex_metric_dict[metric_type] = self.operate_on_dicionary_values([self.get_SOC_metric_value(metric_type, type_, id_) + for type_, id_ in type_id_list], operator.add) + elif metric_type in ["power"]: + self.system_complex_metric_dict[metric_type] = max([self.get_SOC_metric_value(metric_type, type_, id_) + for type_, id_ in type_id_list]) + else: + raise Exception("metric_type:" + metric_type + " is not supported") + + # ------------------------ + # getters + # ------------------------ + # sort kernels. At the moment we just sort based on latency. + def get_kernels_sort(self): + def get_kernels_sort(self): + sorted_kernels_hot_to_cold = sorted(self.__kernels, key=lambda kernel: kernel.stats.latency, reverse=True) + return sorted_kernels_hot_to_cold + + # return the metric of interest for the SOC. metric_type is the metric you are interested in + def get_SOC_metric_value(self, metric_type, SOC_type, SOC_id): + return self.SOC_metric_dict[metric_type][SOC_type][SOC_id] + + def get_SOC_area_base_on_type(self, block_type, SOC_type, SOC_id): + assert(block_type in ["pe", "ic", "mem"]), "block_type" + block_type + " is not supported" + return self.SOC_area_dict[block_type][SOC_type][SOC_id] + + def get_SOC_area_base_on_subtype(self, block_subtype, SOC_type, SOC_id): + assert(block_subtype in ["dram", "sram", "gpp", "ip", "ic"]), "block_subtype" + block_subtype + " is not supported" + if block_subtype not in self.SOC_area_subtype_dict.keys(): # this element does not exist + return 0 + return self.SOC_area_subtype_dict[block_subtype][SOC_type][SOC_id] + + # return the metric of interest for the system complex. metric_type is the metric you are interested in. + # Note that system complex can contain multiple SOCs. + def get_system_complex_metric(self, metric_type): + return self.system_complex_metric_dict[metric_type] + + def get_system_complex_area_stacked_dram(self): + return self.system_complex_area_dram_non_dram + + + # get system_complex area. type_ is selected from ("pe", "mem", "ic") + def get_system_complex_area_base_on_type(self, type_): + return self.system_complex_area_type[type_] + + def get_designs_SOCs(self): + return self.dp_rep.get_designs_SOCs() + + # check if dp_rep is meeting the budget + def workload_fits_budget_for_metric(self, workload, metric_name, budget_coeff): + for type, id in self.dp_rep.get_designs_SOCs(): + if not all(self.fits_budget_for_metric_and_workload(type, id, metric_name, workload, 1)): + return False + return True + + + + # check if dp_rep is meeting the budget + def workload_fits_budget(self, workload, budget_coeff): + for type, id in self.dp_rep.get_designs_SOCs(): + for metric_name in self.database.get_budgetted_metric_names(type): + if not all(self.fits_budget_for_metric_and_workload(type, id, metric_name, workload, 1)): + return False + return True + + + # check if dp_rep is meeting the budget + def fits_budget(self, budget_coeff): + for type, id in self.dp_rep.get_designs_SOCs(): + for metric_name in self.database.get_budgetted_metric_names(type): + if not all(self.fits_budget_for_metric(type, id, metric_name, 1)): + return False + return True + + def fits_budget_for_metric_for_SOC(self, metric_name, budget_coeff): + for type, id in self.dp_rep.get_designs_SOCs(): + if not all(self.fits_budget_for_metric(type, id, metric_name, 1)): + return False + return True + + + # returns a list of values + def fits_budget_for_metric_and_workload(self, type, id, metric_name, workload, budget_coeff): + dist = self.normalized_distance_for_workload(type, id, metric_name, workload) + if not isinstance(dist, list): + dist = [dist] + return [dist_el<.001 for dist_el in dist] + + + # returns a list of values + def fits_budget_for_metric(self, type, id, metric_name, budget_coeff): + dist = self.normalized_distance(type, id, metric_name) + if not isinstance(dist, list): + dist = [dist] + return [dist_el<.001 for dist_el in dist] + + def normalized_distance_for_workload(self, type, id, metric_name, dampening_coeff=1): + if config.dram_stacked: + return self.normalized_distance_for_workload_for_stacked_dram(type, id, metric_name, dampening_coeff) + else: + return self.normalized_distance_for_workload_for_non_stacked_dram(type, id, metric_name, dampening_coeff) + + # normalized the metric to the budget + def normalized_distance_for_workload_for_non_stacked_dram(self, type, id, metric_name, workload, dampening_coeff=1): + metric_val = self.get_SOC_metric_value(metric_name, type, id) + if isinstance(metric_val, dict): + value_list= [] + for workload_, val in metric_val.items(): + if not (workload == workload_): + continue + dict_ = self.database.get_ideal_metric_value(metric_name, type) + value_list.append((val - dict_[workload])/(dampening_coeff*dict_[workload])) + return value_list + else: + return [(metric_val - self.database.get_ideal_metric_value(metric_name, type))/ (dampening_coeff*self.database.get_ideal_metric_value(metric_name, type))] + + def normalized_distance_for_workload_for_stacked_dram(self, type, id, metric_name, workload, dampening_coeff=1): + metric_val = self.get_SOC_metric_value(metric_name, type, id) + if metric_name == 'latency': + value_list= [] + for workload_, val in metric_val.items(): + if not (workload == workload_): + continue + dict_ = self.database.get_ideal_metric_value(metric_name, type) + value_list.append((val - dict_[workload_])/(dampening_coeff*dict_[workload_])) + return value_list + elif metric_name == "area": + # get area aggregation of all the SOC minus dram and normalize it + subtypes_no_dram = ["gpp", "ip", "ic", "sram"] + area_no_dram = 0 + for el in subtypes_no_dram: + area_no_dram += self.get_SOC_area_base_on_subtype(el, type, id) + area_no_dram_norm = (area_no_dram - self.database.get_ideal_metric_value(metric_name, type))/ (dampening_coeff*self.database.get_ideal_metric_value(metric_name, type)) + # get dram area and normalize it + area_dram = self.get_SOC_area_base_on_subtype("dram", type, id) + area_dram_norm = [(area_dram - self.database.get_ideal_metric_value(metric_name, type))/ (dampening_coeff*self.database.get_ideal_metric_value(metric_name, type))] + return [area_no_dram_norm, area_dram] + else: + return [(metric_val - self.database.get_ideal_metric_value(metric_name, type))/ (dampening_coeff*self.database.get_ideal_metric_value(metric_name, type))] + + def normalized_distance(self, type, id, metric_name, dampening_coeff=1): + if config.dram_stacked: + return self.normalized_distance_for_stacked_dram(type, id, metric_name, dampening_coeff) + else: + return self.normalized_distance_for_non_stacked_dram(type, id, metric_name, dampening_coeff) + + # normalized the metric to the budget + def normalized_distance_for_non_stacked_dram(self, type, id, metric_name, dampening_coeff=1): + metric_val = self.get_SOC_metric_value(metric_name, type, id) + if isinstance(metric_val, dict): + value_list= [] + for workload, val in metric_val.items(): + dict_ = self.database.get_ideal_metric_value(metric_name, type) + value_list.append((val - dict_[workload])/(dampening_coeff*dict_[workload])) + return value_list + else: + return [(metric_val - self.database.get_ideal_metric_value(metric_name, type))/ (dampening_coeff*self.database.get_ideal_metric_value(metric_name, type))] + + def normalized_distance_for_stacked_dram(self, type, id, metric_name, dampening_coeff=1): + metric_val = self.get_SOC_metric_value(metric_name, type, id) + if metric_name == 'latency': + value_list= [] + for workload, val in metric_val.items(): + dict_ = self.database.get_ideal_metric_value(metric_name, type) + value_list.append((val - dict_[workload])/(dampening_coeff*dict_[workload])) + return value_list + elif metric_name == "area": + # get area aggregation of all the SOC minus dram and normalize it + subtypes_no_dram = ["gpp", "ip", "ic", "sram"] + area_no_dram = 0 + for el in subtypes_no_dram: + area_no_dram += self.get_SOC_area_base_on_subtype(el, type, id) + area_no_dram_norm = (area_no_dram - self.database.get_ideal_metric_value(metric_name, type))/ (dampening_coeff*self.database.get_ideal_metric_value(metric_name, type)) + # get dram area and normalize it + area_dram = self.get_SOC_area_base_on_subtype("dram", type, id) + area_dram_norm = (area_dram - self.database.get_ideal_metric_value(metric_name, type))/ (dampening_coeff*self.database.get_ideal_metric_value(metric_name, type)) + return [area_no_dram_norm, area_dram_norm] + else: + return [(metric_val - self.database.get_ideal_metric_value(metric_name, type))/ (dampening_coeff*self.database.get_ideal_metric_value(metric_name, type))] + + + + + # normalized to the budget + def dist_to_goal_per_metric(self, metric_name, mode): + dist_list = [] + for type, id in self.dp_rep.get_designs_SOCs(): + meet_the_budgets = self.fits_budget_for_metric(type, id, metric_name, 1) + for idx, meet_the_budget in enumerate(meet_the_budgets): + if meet_the_budget: + if mode == "eliminate": + dist_list.append(0.000000001) + elif mode == "dampen" and meet_the_budget: + norm_dist = [math.fabs(el) for el in + self.normalized_distance(type, id, metric_name, config.annealing_dampening_coef)] + dist_list.append(math.fabs(norm_dist[idx])) + elif mode == "simple": + norm_dist = [math.fabs(el) for el in + self.normalized_distance(type, id, metric_name, 1)] + dist_list.append(math.fabs(norm_dist[idx])) + else: + print("mode: " + mode + " is not defined for dist_to_goal_per_metric") + exit(0) + else: + norm_dist = [math.fabs(el) for el in + self.normalized_distance(type, id, metric_name, 1)] + dist_list.append(math.fabs(norm_dist[idx])) + city_dist = sum(dist_list) + return city_dist + + + # check if dp_rep is meeting the budget + # modes: {"simple", "eliminate", "dampen"}. + # Simple: just calculates the city distance + # eliminates: eliminate the metric that has already met the budget + # dampen: dampens the impact of the metric that has already met the budget + def dist_to_goal(self, metrics_to_look_into=["all"], mode="simple"): # mode simple, just calculate + if metrics_to_look_into == ["all"]: + metrics_to_look_into = self.database.get_budgetted_metric_names_all_SOCs() + self.database.get_other_metric_names_all_SOCs() # which metrics to use for distance calculation + + dist_list = [] + for metric_name in metrics_to_look_into: + dist_list.append(self.dist_to_goal_per_metric(metric_name, mode)) + + city_dist = sum(dist_list) # we use city distance to allow for probability prioritizing + return city_dist + + # todo: change, right now it only uses the reduce value + def __lt__(self, other): + comp_list = [] + for metric in config.objectives: + comp_list.append(self.get_system_complex_metric(metric) < other.get_system_complex_metric(metric)) + return all(comp_list) + + # todo: change, right now it only uses the reduce value + def __gt__(self, other): + comp_list = [] + for metric in config.objectives: + comp_list.append(self.get_system_complex_metric(metric) > other.get_system_complex_metric(metric)) + return all(comp_list) + + +# This module emulates the simulated design point. +# It contains the information for the simulation of a design point +class SimDesignPoint(ExDesignPoint): + def __init__(self, hardware_graph, workload_to_hardware_map=[], workload_to_hardware_schedule=[]): + # primitive variables + self.__workload_to_hardware_map:WorkloadToHardwareMap = None + self.__workload_to_hardware_schedule:WorkloadToPEBlockSchedule = None + self.hardware_graph = hardware_graph # contains all the hardware blocks + their topology (how they are connected) + + self.__hardware, self.__workload, self.__kernels = [[]]*3 + # bootstrap the design and it's stats + self.reset_design(workload_to_hardware_map, workload_to_hardware_schedule) + + self.SOC_phase_energy_dict = defaultdict(dict) # energy associated with each phase + self.phase_latency_dict = {} # duration (time) for each phase. + self.dp_stats = None # design point statistics + self.block_phase_work_dict = {} # work done by the block as the system goes through different phases + self.block_phase_utilization_dict = {} # utilization done by the block as the system goes through different phases + self.pipe_cluster_path_phase_work_rate_dict = {} + self.parallel_kernels = {} + self.krnl_phase_present = {} + self.krnl_phase_present_operating_state = {} + self.phase_krnl_present = {} + self.iteration_number = 0 # the iteration which the simulation is done + self.population_observed_number = 0 + self.population_generated_number = 0 + self.depth_number = 0 # the depth (within on iteration) which the simulation is done + self.simulation_time = 0 # how long did it take to do the simulation + self.serial_design_time = 0 + self.par_speedup_time = 0 + if config.use_cacti: + self.cacti_hndlr = cact_handlr.CactiHndlr(config.cact_bin_addr, config.cacti_param_addr, + config.cacti_data_log_file, config.cacti_input_col_order, + config.cacti_output_col_order) + + for block in self.get_blocks(): + self.block_phase_work_dict[block] = {} + self.block_phase_utilization_dict[block] = {} + + + + def set_serial_design_time(self, serial_design_time): + self.serial_design_time = serial_design_time + + def get_serial_design_time(self): + return self.serial_design_time + + def set_par_speedup(self, speedup): + self.par_speedup_time = speedup + + + def set_simulation_time_phase_calculation_portion(self, time): + self.simulation_time_phase_calculation_portion = time + + def set_simulation_time_task_update_portion(self, time): + self.simulation_time_task_update_portion = time + + def set_simulation_time_phase_scheduling_portion(self, time): + self.simulation_time_phase_scheduling_portion = time + + + def set_simulation_time_analytical_portion(self, time): + self.simulation_time_analytical_portion = time + + def set_simulation_time_phase_driven_portion(self, time): + self.simulation_time_phase_driven_portion = time + + + def get_simulation_time_analytical_portion(self): + return self.simulation_time_analytical_portion + + def get_par_speedup(self): + return self.par_speedup_time + + def set_simulation_time(self, simulation_time): + self.simulation_time= simulation_time + + def get_simulation_time(self): + return self.simulation_time + + def set_population_generation_cnt(self, generation_cnt): + self.population_generation_cnt = generation_cnt + + def set_total_iteration_cnt(self, total_iteration): + self.total_iteration_cnt = total_iteration + + def set_population_observed_number(self, population_observed_number): + self.population_observed_number = population_observed_number + + def set_population_generated_number(self, population_generated_number): + self.population_generated_number = population_generated_number + + def set_depth_number(self, depth_number): + self.depth_number = depth_number + + def get_depth_number(self): + return self.depth_number + + def get_population_generation_cnt(self): + return self.population_generation_cnt + + def get_total_iteration_cnt(self): + return self.total_iteration_cnt + + + def get_population_observed_number(self): + return self.population_observed_number + + def get_population_generated_number(self): + return self.population_generated_number + + def get_tasks_parallel_task_dynamically(self, task): + if task.is_task_dummy(): + return [] + krnl = self.get_kernel_by_task_name(task) + + phases_present = self.krnl_phase_present[krnl] + parallel_krnls = [] + for phase_ in phases_present: + parallel_krnls.extend(self.phase_krnl_present[phase_]) + + # get_rid_of duplicates + parallel_tasks = set([el.get_task_name() for el in set(parallel_krnls) if not(task.get_name() == el.get_task_name())]) + + return list(parallel_tasks) + + + def get_tasks_using_the_different_pipe_cluster(self, task, block): + task_pipe_clusters = block.get_pipe_clusters_of_task(task) + tasks_of_block = block.get_tasks_of_block() + results = [] + for task_ in tasks_of_block: + if task == task_: + continue + task__pipe_clusters = block.get_pipe_clusters_of_task(task_) + if len(list(set(task_pipe_clusters) - set(task__pipe_clusters))) == len(task_pipe_clusters): + results.append(task_.get_name()) + return results + + + # Log the BW data about all the connections it the system + def dump_mem_bus_connection_bw(self, result_folder): # batch means that all the blocks of similar type have similar props + file_name = "bus_mem_connection_max_bw.txt" + file_addr = os.path.join(result_folder, file_name) + buses = self.get_blocks_by_type("ic") + + with open(file_addr, "a+") as output: + output.write("MasterInstance" +"," + "SlaveInstance" + ","+ "bus_bandwidth" + "," + "mode" + "\n") + for bus in buses: + connectd_pes = [block_ for block_ in bus.get_neighs() if block_.type =="pe" ] # PEs connected to bus + connectd_mems = [block_ for block_ in bus.get_neighs() if block_.type =="mem" ] # memories connected to the bus + connectd_ics = [block_ for block_ in bus.get_neighs() if block_.type =="ic"] + for ic in connectd_ics: + for mode in ["read", "write"]: + output.write(ic.instance_name + "," + bus.instance_name + "," + + str(ic.peak_work_rate) + "," + mode + "\n") + for pe in connectd_pes: + for mode in ["read", "write"]: + output.write(pe.instance_name + "," + bus.instance_name + "," + + str(bus.peak_work_rate) + "," + mode + "\n") + for mem in connectd_mems: + for mode in ["read", "write"]: + output.write(bus.instance_name + "," + mem.instance_name + ","+ + str(mem.peak_work_rate) + "," + mode + "\n") + + # ----------------------------------------- + # ----------------------------------------- + # CACTI handling functions + # ----------------------------------------- + # ----------------------------------------- + + + # Conversion of memory type (naming) from FARSI to CACTI + def FARSI_to_cacti_mem_type_converter(self, mem_subtype): + if mem_subtype == "dram": + return "main memory" + elif mem_subtype == "sram": + return "ram" + + # Conversion of memory type (naming) from FARSI to CACTI + def FARSI_to_cacti_cell_type_converter(self, mem_subtype): + if mem_subtype == "dram": + #return "lp-dram" + return "comm-dram" + elif mem_subtype == "sram": + return "itrs-lop" + + # run cacti to get results + def run_and_collect_cacti_data(self, blk, database): + tech_node = {} + tech_node["energy"] = 1 + tech_node["area"] = 1 + sw_hw_database_population = database.db_input.sw_hw_database_population + if "misc_knobs" in sw_hw_database_population.keys(): + misc_knobs = sw_hw_database_population["misc_knobs"] + if "tech_node_SF" in misc_knobs.keys(): + tech_node = misc_knobs["tech_node_SF"] + + if not blk.type == "mem": + print("Only memory blocks supported in CACTI") + exit(0) + + # prime cacti + mem_bytes = max(blk.get_area_in_bytes(), config.cacti_min_memory_size_in_bytes) + subtype = blk.subtype + mem_bytes = (math.ceil(mem_bytes/config.min_mem_size[subtype]))*config.min_mem_size[subtype] # modulo calculation + + #subtype = "sram" # TODO: change later to sram/dram + mem_subtype = self.FARSI_to_cacti_mem_type_converter(subtype) + cell_type = self.FARSI_to_cacti_cell_type_converter(subtype) + self.cacti_hndlr.set_cur_mem_type(mem_subtype) + self.cacti_hndlr.set_cur_mem_size(mem_bytes) + self.cacti_hndlr.set_cur_cell_type(cell_type) + + # run cacti + try: + cacti_area_energy_results = self.cacti_hndlr.collect_cati_data() + except Exception as e: + print("Using cacti, the following memory config tried and failed") + print(self.cacti_hndlr.get_config()) + raise e + + read_energy_per_byte = float(cacti_area_energy_results['Dynamic read energy (nJ)']) * (10 ** -9) / 16 + write_energy_per_byte = float(cacti_area_energy_results['Dynamic write energy (nJ)']) * (10 ** -9) / 16 + area = float(cacti_area_energy_results['Area (mm2)']) * (10 ** -6) + + read_energy_per_byte *= tech_node["energy"]["non_gpp"] + write_energy_per_byte *= tech_node["energy"]["non_gpp"] + area *= tech_node["area"]["mem"] + + # log values + self.cacti_hndlr.cacti_data_container.insert(list(zip(config.cacti_input_col_order + + config.cacti_output_col_order, + [mem_subtype, mem_bytes, read_energy_per_byte, write_energy_per_byte, area]))) + + return read_energy_per_byte, write_energy_per_byte, area + + # either run or look into the cached data (from CACTI) to get energy/area data + def collect_cacti_data(self, blk, database): + + if blk.type == "ic" : + return 0,0,0,1 + elif blk.type == "mem": + mem_bytes = max(blk.get_area_in_bytes(), config.cacti_min_memory_size_in_bytes) # to make sure we don't go smaller than cacti's minimum size + mem_subtype = self.FARSI_to_cacti_mem_type_converter(blk.subtype) + mem_bytes = (math.ceil(mem_bytes / config.min_mem_size[blk.subtype])) * config.min_mem_size[blk.subtype] # modulo calculation + #mem_subtype = "ram" #choose from ["main memory", "ram"] + found_results, read_energy_per_byte, write_energy_per_byte, area = \ + self.cacti_hndlr.cacti_data_container.find(list(zip(config.cacti_input_col_order,[mem_subtype, mem_bytes]))) + if not found_results: + read_energy_per_byte, write_energy_per_byte, area = self.run_and_collect_cacti_data(blk, database) + #read_energy_per_byte *= tech_node["energy"] + #write_energy_per_byte *= tech_node["energy"] + #area *= tech_node["area"] + area_per_byte = area/mem_bytes + return read_energy_per_byte, write_energy_per_byte, area, area_per_byte + + # For each kernel, update the energy and power using cacti + def cacti_update_energy_area_of_kernel(self, krnl, database): + # iterate through block/phases, collect data and insert them up + blk_area_dict = {} + for blk, phase_metric in krnl.block_phase_energy_dict.items(): + # only for memory and ic + if blk.type not in ["mem", "ic"]: + blk_area_dict[blk] = krnl.stats.get_block_area()[blk] + continue + read_energy_per_byte, write_energy_per_byte, area, area_per_byte = self.collect_cacti_data(blk, database) + for phase, metric in phase_metric.items(): + krnl.block_phase_energy_dict[blk][phase] = krnl.block_phase_read_dict[blk][ + phase] * read_energy_per_byte + krnl.block_phase_energy_dict[blk][phase] += krnl.block_phase_write_dict[blk][ + phase] * write_energy_per_byte + krnl.block_phase_area_dict[blk][phase] = area + + blk_area_dict[blk] = area + + # apply aggregates, which is iterate through every phase, scratch their values, and aggregates all the block energies + # areas. + krnl.stats.phase_energy_dict = krnl.aggregate_energy_of_for_every_phase() + krnl.stats.phase_area_dict = krnl.aggregate_area_of_for_every_phase() + + """ + # for debugging; delete later + for el in krnl.stats.get_block_area().keys(): + if el not in blk_area_dict.keys(): + print(" for debugging now delete later") + exit(0) + """ + krnl.stats.set_block_area(blk_area_dict) + krnl.stats.set_stats() # do not call it on set_stats directly, as it repopoluates without cacti + + return "_" + + # For each block, get energy area + # at the moment, only setting up area. TODO: check whether energy matters + def cacti_update_area_of_block(self, block, database): + if block.type not in ["mem", "ic"]: + return + read_energy_per_byte, write_energy_per_byte, area, area_per_byte = self.collect_cacti_data(block, database) + block.set_area_directly(area) + #block.update_area_energy_power_rate(energy_per_byte, area_per_byte) + + # update the design energy (after you have already updated the kernels energy) + def cacti_update_energy_area_of_design(self): + # resetting all first + for soc, phase_value in self.SOC_phase_energy_dict.items(): + for phase, value in self.SOC_phase_energy_dict[soc].items(): + self.SOC_phase_energy_dict[soc][phase] = 0 + + # iterate through SOCs and update + for soc, phase_value in self.SOC_phase_energy_dict.items(): + for phase, value in self.SOC_phase_energy_dict[soc].items(): + SOC_type = soc[0] + SOC_id = soc[1] + for kernel in self.get_kernels(): + if kernel.SOC_type == SOC_type and kernel.SOC_id == SOC_id: + if phase in kernel.stats.phase_energy_dict.keys(): + self.SOC_phase_energy_dict[(SOC_type, SOC_id)][phase] += kernel.stats.phase_energy_dict[phase] + + def correct_power_area_with_cacti(self, database): + # bellow dictionaries used for debugging purposes. You can delete them later + krnl_ratio_phase = {} # for debugging delete later + + # update in 3 stages + # (1) fix kernel energy first + for krnl in self.__kernels: + krnl_ratio_phase[krnl] = self.cacti_update_energy_area_of_kernel(krnl, database) + + # (2) fix the block's area + for block in self.get_blocks(): + self.cacti_update_area_of_block(block, database) + + # (3) update/fix the entire design accordingly + self.cacti_update_energy_area_of_design() + + def get_hardware_graph(self): + return self.hardware_graph + + # collect (profile) design points stats. + def collect_dp_stats(self, database): + self.dp_stats = DPStats(self, database) + + def get_designs_SOCs(self): + blocks = self.get_workload_to_hardware_map().get_blocks() + designs_SOCs = [] + for block in blocks: + if (block.SOC_type, block.SOC_id) not in designs_SOCs: + designs_SOCs.append((block.SOC_type, block.SOC_id)) + return designs_SOCs + + # This is used for idle power calculations + def get_blocks(self): + blocks = self.get_workload_to_hardware_map().get_blocks() + return blocks + + # It is a wrapper around reset_design that includes all the necessary work to clear the stats + # and start the simulation again (used for changing power-knobs) + def reset_design_wrapper(self): + # We do not want to lose this information! Since it is going to be the same + # and we do not have any other way to retain them + self.SOC_phase_energy_dict = defaultdict(dict) + self.phase_latency_dict = {} + self.dp_stats = None + + # bootstrap the design from scratch + def reset_design(self, workload_to_hardware_map=[], workload_to_hardware_schedule=[]): + def update_kernels(self_): + self_.__kernels = [] + for task_to_blocks_map in self.__workload_to_hardware_map.tasks_to_blocks_map_list: + task = task_to_blocks_map.task + self_.__kernels.append(Kernel(self_.__workload_to_hardware_map.get_by_task(task)))#, +# self_.__workload_to_hardware_schedule.get_by_task(task))) + if workload_to_hardware_map: + self.__workload_to_hardware_map = workload_to_hardware_map + if workload_to_hardware_schedule: + self.__workload_to_hardware_schedule = workload_to_hardware_schedule + update_kernels(self) + + def get_workload_to_hardware_map(self): + return self.__workload_to_hardware_map + + def get_workload_to_hardware_schedule(self): + return self.__workload_to_hardware_schedule + + def get_kernels(self): + return self.__kernels + + def get_kernel_by_task_name(self, task:Task): + return list(filter(lambda kernel: task.name == kernel.task_name, self.get_kernels()))[0] + + def get_kernels(self): + return self.__kernels + + def get_workload_to_hardware_map(self): + return self.__workload_to_hardware_map + + +# design point statistics (stats). This class contains the profiling information for a simulated design. +# Note that the difference between system complex and SOC is that a system complex can contain multiple SOCs. +class DPStats: + def __init__(self, sim_dp: SimDesignPoint, database): + self.comparison_mode = "latency" # metric to compare designs against one another + self.dp = sim_dp # simulated design point object + self.__kernels = self.dp.get_kernels() # design kernels + self.SOC_area_dict = defaultdict(lambda: defaultdict(dict)) # area of pes + self.SOC_area_subtype_dict = defaultdict(lambda: defaultdict(dict)) # area of pes + self.system_complex_area_dict = defaultdict() # system complex area values (for memory, PEs, buses) + self.power_duration_list = defaultdict(lambda: defaultdict(dict)) # power and duration of the power list + self.SOC_metric_dict = defaultdict(lambda: defaultdict(dict)) # dictionary containing various metrics for the SOC + self.system_complex_metric_dict = defaultdict(lambda: defaultdict(dict)) # dictionary containing the system complex metrics + self.database = database + self.pipe_cluster_pathlet_phase_work_rate_dict = {} + for pipe_cluster in self.dp.get_hardware_graph().get_pipe_clusters(): + self.pipe_cluster_pathlet_phase_work_rate_dict[pipe_cluster] = pipe_cluster.get_pathlet_phase_work_rate() + + self.pipe_cluster_pathlet_phase_latency_dict = {} + for pipe_cluster in self.dp.get_hardware_graph().get_pipe_clusters(): + self.pipe_cluster_pathlet_phase_latency_dict[pipe_cluster] = pipe_cluster.get_pathlet_phase_latency() + + use_slack_management_estimation = config.use_slack_management_estimation + # collect the data + self.collect_stats(use_slack_management_estimation) + + # write the results into a file + def dump_stats(self, des_folder, mode="light_weight"): + file_name = config.verification_result_file + file_addr = os.path.join(des_folder, file_name) + + for type, id in self.dp.get_designs_SOCs(): + ic_count = len(self.dp.get_workload_to_hardware_map().get_blocks_by_type("ic")) + mem_count = len(self.dp.get_workload_to_hardware_map().get_blocks_by_type("mem")) + pe_count = len(self.dp.get_workload_to_hardware_map().get_blocks_by_type("pe")) + with open(file_addr, "w+") as output: + routing_complexity = self.dp.get_hardware_graph().get_routing_complexity() + simple_topology = self.dp.get_hardware_graph().get_simplified_topology_code() + blk_cnt = sum([int(el) for el in simple_topology.split("_")]) + bus_cnt = [int(el) for el in simple_topology.split("_")][0] + mem_cnt = [int(el) for el in simple_topology.split("_")][1] + pe_cnt = [int(el) for el in simple_topology.split("_")][2] + task_cnt = len(list(self.dp.krnl_phase_present.keys())) + channel_cnt = self.dp.get_hardware_graph().get_number_of_channels() + + output.write("{\n") + output.write("\"FARSI_predicted_latency\": "+ str(max(list(self.get_system_complex_metric("latency").values()))) +",\n") + output.write("\"FARSI_predicted_energy\": "+ str(self.get_system_complex_metric("energy")) +",\n") + output.write("\"FARSI_predicted_power\": "+ str(self.get_system_complex_metric("power")) +",\n") + output.write("\"FARSI_predicted_area\": "+ str(self.get_system_complex_metric("area")) +",\n") + output.write("\"parallel_task_cnt\": "+ str(self.get_parallel_task_count_analytically()) +",\n") + output.write("\"parallel_task_cnt_experimentally\": "+ str(self.get_parallel_task_count_experimentally()) +",\n") + output.write("\"serial_task_count\": "+ str(self.get_serial_task_count()) +",\n") + output.write("\"parallel_task_type\": "+ "\""+str(self.get_parallel_task_type()) +"\",\n") + output.write("\"memory_boundedness_ratio_analytically\": "+ str(self.get_memory_boundedness_ratio_analytically()) +",\n") + output.write("\"memory_boundedness_ratio_experimentally\": "+ str(self.get_memory_boundedness_ratio_experimentally()) +",\n") + output.write("\"data_movement_scaling_ratio\": "+ str(self.get_datamovement_scaling_ratio()) +",\n") + output.write("\"num_of_hops_experimentally\": "+ str(self.get_num_of_hops_experimentally()) +",\n") + output.write("\"num_of_hops_theoretically\": "+ str(self.get_num_of_hops_theoretically()) +",\n") + #output.write("\"config_code\": "+ str(ic_count) + str(mem_count) + str(pe_count)+",\n") + #output.write("\"config_code\": "+ self.dp.get_hardware_graph().get_config_code() +",\n") + output.write("\"simplified_topology_code\": "+ self.dp.get_hardware_graph().get_simplified_topology_code() +",\n") + output.write("\"blk_cnt\": "+ str(blk_cnt) +",\n") + output.write("\"pe_cnt\": "+ str(pe_cnt) +",\n") + output.write("\"mem_cnt\": "+ str(mem_cnt) +",\n") + output.write("\"bus_cnt\": "+ str(bus_cnt) +",\n") + output.write("\"task_cnt\": "+ str(task_cnt) +",\n") + output.write("\"routing_complexity\": "+ str(routing_complexity) +",\n") + output.write("\"channel_cnt\": "+ str(channel_cnt) +",\n") + output.write("\"simulation_time_analytical_portion\": "+ str(self.get_simulation_time_analytical_portion()) +",\n") + output.write("\"FARSI simulation time\": " + str(self.dp.get_simulation_time()) + ",\n") + + # Function: profile the simulated design, collecting information about + # latency, power, area, and phasal behavior + # This is called within the constructor + def collect_stats(self, use_slack_management_estimation=False): + for type, id in self.dp.get_designs_SOCs(): + for metric_name in config.all_metrics: + self.set_SOC_metric_value(metric_name, type, id) # data per SoC + self.set_system_complex_metric(metric_name) # data per System + + # estimate the behavior if slack management applied + for type, id in self.dp.get_designs_SOCs(): + if use_slack_management_estimation: + values_changed = self.apply_slack_management_estimation_improved(type, id) + if values_changed: + for type, id in self.dp.get_designs_SOCs(): + for metric_name in config.all_metrics: + self.set_system_complex_metric(metric_name) # data per System + + # Functionality: + # Hot means the bottleneck or rather the + # most power/energy/area/performance consuming of the system. This is determined + # by the input argument metric. + def get_hot_kernel_SOC(self, SOC_type, SOC_id, metric="latency", krnel_rank=0): + kernels_on_SOC = [kernel for kernel in self.__kernels if kernel.SOC_type == SOC_type and kernel.SOC_id == SOC_id] + sorted_kernels_hot_to_cold = sorted(kernels_on_SOC, key=lambda kernel: kernel.stats.get_metric(metric), reverse=True) + return sorted_kernels_hot_to_cold[krnel_rank] + + # Hot means the bottleneck or rather the + # most power/energy/area/performance consuming of the system. This is determined + # by the input argument metric. + def get_hot_kernel_system_complex(self, metric="latency", krnel_rank=0): + hot_krnel_list = [] + for SOC_type, SOC_id in self.get_designs_SOCs(): + hot_krnel_list.append(self.get_hot_kernel_SOC(SOC_type, SOC_id, metric, krnel_rank)) + return sorted(hot_krnel_list, key=lambda kernel: kernel.stats.get_metric(metric), reverse=True)[0] + + # Hot means the bottleneck or rather the + # most power/energy/area/performance consuming of the system. This is determined + # by the input argument metric. + def get_hot_block_SOC(self, SOC_type, SOC_id, metric="latency", krnel_rank=0): + # find the hottest kernel + hot_krnel = self.get_hot_kernel_SOC(SOC_type, SOC_id, metric, krnel_rank) + hot_kernel_blck_bottleneck:Block = hot_krnel.stats.get_block_bottleneck(metric) + # corresponding block bottleneck. We need this since we make a copy of the the sim_dp, + # and hence, sim_dp and ex_dp won't be synced any more + return hot_kernel_blck_bottleneck + + # get hot blocks of the system + # Hot means the bottleneck or rather the + # most power/energy/area/performance consuming of the system. This is determined + # by the input argument metric. + # krnel_rank is the rank of the kernel to pick from once the kernels are sorted. 0 means the highest. + # This variable is used to unstuck the heuristic when necessary (e.g., when for example the hottest kernel + # modification is not helping the design, we move on to the second hottest) + def get_hot_block_system_complex(self, metric="latency", krnel_rank=0): + hot_blck_list = [] + for SOC_type, SOC_id in self.get_designs_SOCs(): + hot_blck_list.append(self.get_hot_block_SOC(SOC_type, SOC_id, metric, krnel_rank)) + + return sorted(hot_blck_list, key=lambda kernel: kernel.stats.get_metric(metric), reverse=True)[0] + + # get kernels sorted based on latency + def get_kernels_sort(self): + sorted_kernels_hot_to_cold = sorted(self.__kernels, key=lambda kernel: kernel.stats.latency, reverse=True) + return sorted_kernels_hot_to_cold + + # ----------------------------------------- + # Calculate profiling information per SOC + # ----------------------------------------- + def calc_SOC_latency(self, SOC_type, SOC_id): + kernels_on_SOC = [kernel for kernel in self.__kernels if kernel.SOC_type == SOC_type and kernel.SOC_id == SOC_id] + workload_latency_dict = {} + for workload, last_task in self.database.get_workloads_last_task().items(): + kernel = self.dp.get_kernel_by_task_name(self.dp.get_task_by_name(last_task)) + workload_latency_dict[workload] = kernel.get_completion_time() #kernel.stats.latency + kernel.starting_time + return workload_latency_dict + + # calculate SOC energy + def calc_SOC_energy(self, SOC_type, SOC_id): + phase_total_energy = {} + return sum(list(self.dp.SOC_phase_energy_dict[(SOC_type, SOC_id)].values())) + + # if estimate_slack_management_effect is set to true, + # we estimate what will happen if we can introduce slacks in order to reduce power + def calc_SOC_power(self, SOC_type, SOC_id, estimate_slack_management_effect=False): + self.power_duration_list[SOC_type][SOC_id] = [] + sorted_listified_phase_latency_dict = sorted(self.dp.phase_latency_dict.items(), key=operator.itemgetter(0)) + sorted_latencys = [latency for phase,latency in sorted_listified_phase_latency_dict] + sorted_phase_latency_dict = collections.OrderedDict(sorted_listified_phase_latency_dict) + + # get the energy first + SOC_phase_energy_dict = self.dp.SOC_phase_energy_dict[(SOC_type, SOC_id)] + sorted_listified_phase_energy_dict = sorted(SOC_phase_energy_dict.items(), key=operator.itemgetter(0)) + sorted_phase_energy_dict = collections.OrderedDict(sorted_listified_phase_energy_dict) + + # convert to power by slicing the time with the smallest duration within which power should be + # calculated with (PWP) + phase_bounds_lists = slice_phases_with_PWP(sorted_phase_latency_dict) + power_list = [] # list of power values collected based on the power collection freq + power_duration_list = [] + for lower_bnd, upper_bnd in phase_bounds_lists: + if sum(sorted_latencys[lower_bnd:upper_bnd])>0: + power_this_phase = sum(list(sorted_phase_energy_dict.values())[lower_bnd:upper_bnd])/sum(sorted_latencys[lower_bnd:upper_bnd]) + power_list.append(power_this_phase) + self.power_duration_list[SOC_type][SOC_id].append((power_this_phase, sum(sorted_latencys[lower_bnd:upper_bnd]))) + else: + power_list.append(0) + power_duration_list.append((0,0)) + + power = max(power_list) + + return power + + # estimate what happens if we can manage slack by optimizing kernel scheduling. + # note that this is just a estimation. Actually scheduling needs to be applied to get exact numbers. + # note that if we apply slack, the comparison with the higher fidelity simulation + # will be considerably hurt (since the higher fidelity simulation doesn't typically apply the slack + # management) + def apply_slack_management_estimation_improved(self, SOC_type, SOC_id): + power_duration_list = self.power_duration_list[SOC_type][SOC_id] + # relax power if possible + total_latency = sum([duration for power, duration in power_duration_list]) + slack = self.database.get_budget("latency", "glass") - total_latency + power_duration_recalculated = copy.deepcopy(power_duration_list) + values_changed = False # indicating whether slack was used to modify any value + while slack > 0 and (self.fits_budget_per_metric(SOC_type, SOC_id, "latency", 1) and + not self.fits_budget_per_metric(SOC_type, SOC_id, "power", 1)): + power_duration_sorted = sorted(power_duration_recalculated, key=lambda x: x[0]) + idx = power_duration_recalculated.index(power_duration_sorted[-1]) + power, duration = power_duration_recalculated[idx] + slack_used = min(.0005, slack) + slack = slack - slack_used + duration_with_slack = duration + slack_used + power_duration_recalculated[idx] = ((power * duration) / duration_with_slack, duration_with_slack) + values_changed = True + power = max([power for power, duration in power_duration_recalculated]) + self.SOC_metric_dict["power"][SOC_type][SOC_id] = power + self.SOC_metric_dict["latency"][SOC_type][SOC_id] = sum([duration for power, duration in power_duration_recalculated]) + return values_changed + + # get total area of an soc (type is not supported yet) + def calc_SOC_area_base_on_type(self, type_, SOC_type, SOC_id): + blocks = self.dp.get_workload_to_hardware_map().get_blocks() + total_area= sum([block.get_area() for block in blocks if block.SOC_type == SOC_type + and block.SOC_id == SOC_id and block.type == type_]) + return total_area + + + # get total area of an soc (type is not supported yet) + def calc_SOC_area_base_on_subtype(self, subtype_, SOC_type, SOC_id): + blocks = self.dp.get_workload_to_hardware_map().get_blocks() + total_area = 0 + for block in blocks: + if block.SOC_type == SOC_type and block.SOC_id == SOC_id and block.subtype == subtype_: + total_area += block.get_area() + return total_area + + # get total area of an soc + # Variables: + # SOC_type:the type of SOC you need information for + # SOC_id: id of the SOC you are interested in + def calc_SOC_area(self, SOC_type, SOC_id): + blocks = self.dp.get_workload_to_hardware_map().get_blocks() + # note: we can't use phase_area_dict for this, since: + # 1. we double count the statically 2. if a memory is shared, we would be double counting it + total_area= sum([block.get_area() for block in blocks if block.SOC_type == SOC_type and block.SOC_id == SOC_id]) + return total_area + + # the cost model associated with a PE. + # This will help us calculate the financial cost + # of using a specific PE + def PE_cost_model(self, task_normalized_work, block_type, model_type="linear"): + if model_type == "linear": + return task_normalized_work*self.database.db_input.porting_effort[block_type] + else: + print("this cost model is not defined") + exit(0) + + # the cost model associated with a MEM. + # This will help us calculate the financial cost + # of using a specific MEM + def MEM_cost_model(self, task_normalized_work, block, block_type, model_type="linear"): + if model_type == "linear": + return task_normalized_work*self.database.db_input.porting_effort[block_type]*block.get_num_of_banks() + else: + print("this cost model is not defined") + exit(0) + + # the cost model associated with a IC. + # This will help us calculate the financial cost + # of using a specific IC. + def IC_cost_model(self, task_normalized_work, block, block_type, model_type="linear"): + if model_type == "linear": + return task_normalized_work*self.database.db_input.porting_effort[block_type] + else: + print("this cost model is not defined") + exit(0) + + # calculate the development cost of an SOC + def calc_SOC_dev_cost(self, SOC_type, SOC_id): + blocks = self.dp.get_workload_to_hardware_map().get_blocks() + all_kernels = self.get_kernels_sort() + + # find the simplest task's work (simple = task with the least amount of work) + krnl_work_list = [] #contains the list of works associated with different kernels (excluding dummy tasks) + for krnl in all_kernels: + krnl_task = krnl.get_task() + if not krnl_task.is_task_dummy(): + krnl_work_list.append(krnl_task.get_self_task_work()) + simplest_task_work = min(krnl_work_list) + + num_of_tasks = len(all_kernels) + dev_cost = 0 + # iterate through each block and add the cost + for block in blocks: + if block.type == "pe" : + # for IPs + if block.subtype == "ip": + tasks = block.get_tasks_of_block() + task_work = max([task.get_self_task_work() for task in tasks]) # use max incase multiple task are mapped + task_normalized_work = task_work/simplest_task_work + dev_cost += self.PE_cost_model(task_normalized_work, "ip") + # for GPPS + elif block.subtype == "gpp": + # for DSPs + if "G3" in block.instance_name: + for task in block.get_tasks_of_block(): + task_work = task.get_self_task_work() + task_normalized_work = task_work/simplest_task_work + dev_cost += self.PE_cost_model(task_normalized_work, "dsp") + # for ARM + elif "A53" in block.instance_name or "ARM" in block.instance_name: + for task in block.get_tasks_of_block(): + task_work = task.get_self_task_work() + task_normalized_work = task_work/simplest_task_work + dev_cost += self.PE_cost_model(task_normalized_work, "arm") + else: + print("cost model for this GPP is not defined") + exit(0) + elif block.type == "mem": + task_normalized_work = 1 # treat it as the simplest task work + dev_cost += self.MEM_cost_model(task_normalized_work, block, "mem") + elif block.type == "ic": + task_normalized_work = 1 # treat it as the simplest task work + dev_cost += self.IC_cost_model(task_normalized_work, block, "mem") + else: + print("cost model for ip" + block.instance_name + " is not defined") + exit(0) + + pes = [blk for blk in blocks if blk.type == "pe"] + mems = [blk for blk in blocks if blk.type == "mem"] + for pe in pes: + pe_tasks = [el.get_name() for el in pe.get_tasks_of_block()] + for mem in mems: + mem_tasks = [el.get_name() for el in mem.get_tasks_of_block()] + task_share_cnt = len(pe_tasks) - len(list(set(pe_tasks) - set(mem_tasks))) + if task_share_cnt == 0: # this condition to avoid finding paths between vertecies, which is pretty comp intensive + continue + path_length = len(self.dp.get_hardware_graph().get_path_between_two_vertecies(pe, mem)) + #path_length = len(self.dp.get_hardware_graph().get_shortest_path(pe, mem, [], [])) + effort = self.database.db_input.porting_effort["ic"]/10 + dev_cost += (path_length*task_share_cnt)*.1 + + + return dev_cost + + # pb_type: processing block type + def get_SOC_s_specific_area(self, SOC_type, SOC_id, pb_type): + assert(pb_type in ["pe", "ic", "mem"]) , "block type " + pb_type + " is not supported" + return self.SOC_area_dict[pb_type][SOC_type][SOC_id] + + # -------- + # setters + # -------- + def set_SOC_metric_value(self,metric_type, SOC_type, SOC_id): + if metric_type == "area": + self.SOC_metric_dict[metric_type][SOC_type][SOC_id] = self.calc_SOC_area(SOC_type, SOC_id) + for block_type in ["pe", "mem", "ic"]: + self.SOC_area_dict[block_type][SOC_type][SOC_id] = self.calc_SOC_area_base_on_type(block_type, SOC_type, SOC_id) + for block_subtype in ["sram", "dram", "ic", "gpp", "ip"]: + self.SOC_area_subtype_dict[block_subtype][SOC_type][SOC_id] = self.calc_SOC_area_base_on_subtype(block_subtype, SOC_type, SOC_id) + + elif metric_type == "cost": + #self.SOC_metric_dict[metric_type][SOC_type][SOC_id] = self.calc_SOC_area(SOC_type, SOC_id) + self.SOC_metric_dict[metric_type][SOC_type][SOC_id] = self.calc_SOC_dev_cost(SOC_type, SOC_id) + elif metric_type == "energy": + self.SOC_metric_dict[metric_type][SOC_type][SOC_id] = self.calc_SOC_energy(SOC_type, SOC_id) + elif metric_type == "power" : + self.SOC_metric_dict[metric_type][SOC_type][SOC_id] = self.calc_SOC_power(SOC_type, SOC_id) + elif metric_type == "latency": + self.SOC_metric_dict[metric_type][SOC_type][SOC_id] = self.calc_SOC_latency(SOC_type, SOC_id) + else: + raise Exception("metric_type:" + metric_type + " is not supported") + + # helper function to apply an operator across two dictionaries + def operate_on_two_dic_values(self,dict1, dict2, operator): + dict_res = {} + for key in list(dict2.keys()) + list(dict1.keys()): + if key in dict1.keys() and dict2.keys(): + dict_res[key] = operator(dict2[key], dict1[key]) + else: + if key in dict1.keys(): + dict_res[key] = dict1[key] + elif key in dict2.keys(): + dict_res[key] = dict2[key] + return dict_res + + def operate_on_dicionary_values(self, dictionaries, operator): + res = {} + for SOCs_latency in dictionaries: + #res = copy.deepcopy(self.operate_on_two_dic_values(res, SOCs_latency, operator)) + #gc.disable() + res = cPickle.loads(cPickle.dumps(self.operate_on_two_dic_values(res, SOCs_latency, operator), -1)) + #gc.enable() + return res + + # set the metric (power, area, ...) for the entire system complex + def set_system_complex_metric(self, metric_type): + type_id_list = self.dp.get_designs_SOCs() + # the only spatial scenario is area + if metric_type == "area": + for block_type in ["pe", "mem", "ic"]: + for type_, id_ in type_id_list: + self.system_complex_area_dict[block_type] = sum([self.get_SOC_area_base_on_type(block_type, type_, id_) + for type_, id_ in type_id_list]) + if metric_type in ["area", "energy", "cost"]: + self.system_complex_metric_dict[metric_type] = sum([self.get_SOC_metric_value(metric_type, type_, id_) + for type_, id_ in type_id_list]) + elif metric_type in ["latency"]: + self.system_complex_metric_dict[metric_type] = self.operate_on_dicionary_values([self.get_SOC_metric_value(metric_type, type_, id_) + for type_, id_ in type_id_list], operator.add) + + #return res + #self.system_complex_metric_dict[metric_type] = sum([self.get_SOC_metric_value(metric_type, type_, id_) + # for type_, id_ in type_id_list]) + elif metric_type in ["power"]: + self.system_complex_metric_dict[metric_type] = max([self.get_SOC_metric_value(metric_type, type_, id_) + for type_, id_ in type_id_list]) + else: + raise Exception("metric_type:" + metric_type + " is not supported") + + # -------- + # getters + # -------- + def get_SOC_metric_value(self, metric_type, SOC_type, SOC_id): + assert(metric_type in config.all_metrics), metric_type + " not supported" + return self.SOC_metric_dict[metric_type][SOC_type][SOC_id] + + def get_SOC_area_base_on_type(self, block_type, SOC_type, SOC_id): + assert(block_type in ["pe", "ic", "mem"]), "block_type" + block_type + " is not supported" + return self.SOC_area_dict[block_type][SOC_type][SOC_id] + + def get_SOC_area_base_on_subtype(self, block_subtype, SOC_type, SOC_id): + assert(block_subtype in ["dram", "sram", "ic", "gpp", "ip"]), "block_subtype" + block_subtype + " is not supported" + return self.SOC_area_subtype_dict[block_subtype][SOC_type][SOC_id] + + + # get the simulation progress + def get_SOC_s_latency_sim_progress(self, SOC_type, SOC_id, progress_metrics): + kernels_on_SOC = [kernel for kernel in self.__kernels if kernel.SOC_type == SOC_type and kernel.SOC_id == SOC_id] + kernel_metric_value = {} + for kernel in kernels_on_SOC: + kernel_metric_value[kernel] = [] + for metric in progress_metrics: + for kernel in kernels_on_SOC: + if metric == "latency": + kernel_metric_value[kernel].append((kernel.starting_time*10**3, kernel.stats.latency*10**3, kernel.stats.latency*10**3, "ms")) + return kernel_metric_value + + # get the simulation progress + def get_SOC_s_latency_sim_progress(self, SOC_type, SOC_id, metric): + kernels_on_SOC = [kernel for kernel in self.__kernels if kernel.SOC_type == SOC_type and kernel.SOC_id == SOC_id] + kernel_metric_value = {} + for kernel in kernels_on_SOC: + kernel_metric_value[kernel] = [] + for kernel in kernels_on_SOC: + if metric == "latency": + kernel_metric_value[kernel].append((kernel.starting_time*10**3, kernel.stats.latency*10**3, kernel.stats.latency*10**3, "ms")) + elif metric == "bytes": + kernel_metric_value[kernel].append((kernel.starting_time * 10 ** 3, kernel.stats.latency * 10 ** 3, + kernel.stats.latency* 10 ** 3, "bytes")) + return kernel_metric_value + + + def get_sim_progress(self, metric="latency"): + #for phase, krnls in self.dp.phase_krnl_present.items(): + # accelerators_in_parallel = [] + if metric == "latency": + return [self.get_SOC_s_latency_sim_progress(type, id, metric) for type, id in self.dp.get_designs_SOCs()] + if metric == "bytes": + pass + + + def get_num_of_hops_theoretically(self): + return self.database.db_input.num_of_hops + + def get_num_of_hops_experimentally(self): + total_time = 0 + hop_time = 0 + phase_seen = [] + + for krnl in self.dp.get_kernels(): + phase_seen = [] + for phase, block in krnl.stats.phase_block_duration_bottleneck.items(): + if phase in phase_seen: + continue + phase_seen.append(phase) + + + if block[0].type in ["mem","ic"]: + max_hop = 1 + mems = [blk for blk in krnl.get_blocks() if blk.type == "mem"] + pe = [blk for blk in krnl.get_blocks() if blk.type == "pe"][0] + for mem in mems: + max_hop = max(len(self.dp.get_hardware_graph().get_path_between_two_vertecies(pe, mem))-2, max_hop) + hop_time += max_hop*krnl.stats.phase_latency_dict[phase] + total_time += krnl.stats.phase_latency_dict[phase] + else: + total_time += krnl.stats.phase_latency_dict[phase] + ratio = hop_time/total_time + return ratio + + def get_parallel_task_count_experimentally(self): + total_time = 0 + serial_time = 0 + for krnl in self.dp.get_kernels(): + for phase, latency in krnl.stats.phase_latency_dict.items(): + serial_time += latency + + execution_latency = 0 + for type, id in self.dp.get_designs_SOCs(): + execution_latency += list(self.get_SOC_metric_value("latency", type, id).values())[0] # data per SoC + ratio = serial_time/execution_latency + return ratio + + def get_memory_boundedness_ratio_experimentally(self): + mem_bottleneck_time = 0 + cpu_bottleneck_time = 0 + phase_seen = [] + for krnl in self.dp.get_kernels(): + for phase, block in krnl.stats.phase_block_duration_bottleneck.items(): + if phase in phase_seen: + continue + phase_seen.append(phase) + if block[0].type in ["mem","ic"]: + mem_bottleneck_time += krnl.stats.phase_latency_dict[phase] + else: + cpu_bottleneck_time += krnl.stats.phase_latency_dict[phase] + ratio = mem_bottleneck_time/(mem_bottleneck_time+cpu_bottleneck_time) + return ratio + + + + + + + # returns the latency associated with the phases of the system execution + def get_phase_latency(self, SOC_type=1, SOC_id=1): + return self.dp.phase_latency_dict + + # get utilization associated with the phases of the execution + def get_SOC_s_sim_utilization(self, SOC_type, SOC_id): + return self.dp.block_phase_utilization_dict + + def get_SOC_s_pipe_cluster_pathlet_phase_work_rate(self, SOC_type, SOC_id): + return self.pipe_cluster_pathlet_phase_work_rate_dict + + def get_SOC_s_pipe_cluster_pathlet_phase_latency(self, SOC_type, SOC_id): + return self.pipe_cluster_pathlet_phase_latency_dict + + def get_SOC_s_pipe_cluster_path_phase_latency(self, SOC_type, SOC_id): + return self.pipe_cluster_pathlet_phase_latency_dict + + # get work associated with the phases of the execution + def get_SOC_s_sim_work(self, SOC_type, SOC_id): + return self.dp.block_phase_work_dict + + def get_parallel_task_count_analytically(self): + return self.database.db_input.parallel_task_count + + def get_simulation_time_analytical_portion(self): + return self.dp.get_simulation_time_analytical_portion() + + def get_serial_task_count(self): + return self.database.db_input.serial_task_count + + def get_parallel_task_type(self): + return self.database.db_input.parallel_task_type + + + def get_memory_boundedness_ratio_analytically(self): + return self.database.db_input.memory_boundedness_ratio + + def get_datamovement_scaling_ratio(self): + return self.database.db_input.datamovement_scaling_ratio + + + + # get total (consider all SoCs') system metrics + def get_system_complex_metric(self, metric_type): + assert(metric_type in config.all_metrics), metric_type + " not supported" + assert(not (self.system_complex_metric_dict[metric_type] == -1)), metric_type + "not calculated" + return self.system_complex_metric_dict[metric_type] + + # check if dp_rep is meeting the budget + def fits_budget(self, budget_coeff): + for type, id in self.dp.get_designs_SOCs(): + for metric_name in self.database.get_budgetted_metric_names(type): + if not self.fits_budget_for_metric(type, id, metric_name): + return False + return True + + def fits_budget_per_metric(self, metric_name, budget_coeff): + for type, id in self.dp.get_designs_SOCs(): + if not self.fits_budget_for_metric(type, id, metric_name): + return False + return True + + + # whether the design fits the budget for the metric argument specified + # type, and id specify the relevant parameters of the SOC + # ignore budget_coeff for now + def fits_budget_for_metric(self, type, id, metric_name, budget_coeff): + return self.normalized_distance(type, id, metric_name) < .001 + + def __lt__(self, other): + comp_list = [] + for metric in config.objectives: + comp_list.append(self.get_system_complex_metric(metric) < other.get_system_complex_metric(metric)) + return all(comp_list) + + def __gt__(self, other): + comp_list = [] + for metric in config.objectives: + comp_list.append(self.get_system_complex_metric(metric) > other.get_system_complex_metric(metric)) + return all(comp_list) \ No newline at end of file diff --git a/Project_FARSI/error_handling/custom_error.py b/Project_FARSI/error_handling/custom_error.py new file mode 100644 index 00000000..11ebf90b --- /dev/null +++ b/Project_FARSI/error_handling/custom_error.py @@ -0,0 +1,164 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +import sys, inspect +# we split the handling between exceptions and errors. +# Errors are defined at critical failures that needs to be debugged by the developerse. Error end with the word "Error" +# Exceptions are caused when the framework results in invalid designs due to careless modifications. These +# are inevitable as not every modification is valid. Exceptions end with the word "Exception" + +# --------------------------- +# Error List +# --------------------------- +# a task with no children +class TaskNoChildrenError(Exception): + pass + +# a task with no pe detected +class NoPEError(Exception): + pass + +# a task with no mem detected +class NoMemError(Exception): + pass + +# a task with no bus detected +class NoBusError(Exception): + pass + +class BlockCountDeviationError(Exception): + pass + +# swap transformation was not executed properly +class IncompleteSwapError(Exception): + pass + +# design doesn't have a block of a certain type ("pe", "mem" , "ic") +class NoBlockOfCertainType(Exception): + pass + +# each Block can only be connected to one bus +class MultiBusBlockError(Exception): + pass + +# bus with no memory was detected +class BusWithNoMemError(Exception): + pass + +# bus with no memory was detected +class SystemICWithPEException(Exception): + pass + + +# bus with no PE was detected +class BusWithNoPEError(Exception): + pass + +class NotEnoughIPOfCertainType(Exception): + pass + +# a block with no tasks mapped to it was detected +class BlockWithNoTaskError(Exception): + pass + +# IP (accelerators) can not be splitted because they only have on task on them +class IPSplitException(Exception): + pass + + +class NoAbException(Exception): + pass + + +class TransferException(Exception): + pass + +class RoutingException(Exception): + pass + + +# couldn't find two blocks of the same type to use for cleaning up +class CostPairingException(Exception): + pass + +# --------------------------- +# exception List +# --------------------------- +class MoveNoDesignException(Exception): + pass + + +class UnEqualFrontsError(Exception): + pass + + +class MoveNotValidException(Exception): + pass + +# ToDO: this exception is for the most part caused by the fact that +# moves (and their corresponding tasks) are determined before loading of memory. +# Later, we need to fix this, and get rid of this exception (basically make sure whoever calls it, is fixed) +class NoMigrantException(Exception): + pass + +# could not find a task that can run in parallel (for split and migration) +class NoParallelTaskException(Exception): + pass + +# This is a scenario where no migrant is detected to be moved, but it must have +# this is different than NoMigrantException (since the exception scenario is permissable) +class NoMigrantError(Exception): + pass + + +# ic migration not supported at the moment +class ICMigrationException(Exception): + pass + +# ic migration not supported at the moment +class NoValidTransformationException(Exception): + pass + +def get_error_classes(): + error_class_list = [] + for name, obj in inspect.getmembers(sys.modules[__name__]): + if inspect.isclass(obj): + error_class_list.append(obj) + return error_class_list + +def get_error_names(): + error_name_list = [] + for name, obj in inspect.getmembers(sys.modules[__name__]): + if inspect.isclass(obj) and "Error" in name: + error_name_list.append(name) + return error_name_list + +def get_exception_names(): + exception_name_list = [] + for name, obj in inspect.getmembers(sys.modules[__name__]): + if inspect.isclass(obj) and "Exception" in name: + exception_name_list.append(name) + return exception_name_list + + +errors_names = [el for el in get_error_names()] +exception_names = [el for el in get_exception_names()] + +""" +# simple unit test + +def foo(): + raise NoPEError + #raise TaskNoChildrenError +def foo_2(): + return 1/0 + +try: + foo_2() +except Exception as e: + if e.__class__.__name__ in errors_names: + print("have seend this error before") + else: + raise e +""" diff --git a/Project_FARSI/figures/DSE_on_the_map.pdf b/Project_FARSI/figures/DSE_on_the_map.pdf new file mode 100644 index 00000000..5e8e3aa4 Binary files /dev/null and b/Project_FARSI/figures/DSE_on_the_map.pdf differ diff --git a/Project_FARSI/figures/DSE_on_the_map.png b/Project_FARSI/figures/DSE_on_the_map.png new file mode 100644 index 00000000..21117595 Binary files /dev/null and b/Project_FARSI/figures/DSE_on_the_map.png differ diff --git a/Project_FARSI/figures/FARSI_methodology.pdf b/Project_FARSI/figures/FARSI_methodology.pdf new file mode 100644 index 00000000..e2192632 Binary files /dev/null and b/Project_FARSI/figures/FARSI_methodology.pdf differ diff --git a/Project_FARSI/figures/FARSI_methodology.png b/Project_FARSI/figures/FARSI_methodology.png new file mode 100644 index 00000000..707e7d35 Binary files /dev/null and b/Project_FARSI/figures/FARSI_methodology.png differ diff --git a/Project_FARSI/figures/FARSI_output.pdf b/Project_FARSI/figures/FARSI_output.pdf new file mode 100644 index 00000000..a0c72ec8 Binary files /dev/null and b/Project_FARSI/figures/FARSI_output.pdf differ diff --git a/Project_FARSI/figures/FARSI_output.png b/Project_FARSI/figures/FARSI_output.png new file mode 100644 index 00000000..106916f4 Binary files /dev/null and b/Project_FARSI/figures/FARSI_output.png differ diff --git a/Project_FARSI/misc/__init__.py b/Project_FARSI/misc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Project_FARSI/misc/cacti_hndlr/cact_handlr.py b/Project_FARSI/misc/cacti_hndlr/cact_handlr.py new file mode 100644 index 00000000..63a46a03 --- /dev/null +++ b/Project_FARSI/misc/cacti_hndlr/cact_handlr.py @@ -0,0 +1,199 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +import os +import csv +import pandas as pd +import math +import numpy as np +import time +import shutil +import subprocess +#from settings import config + +# This class at the moment only handls very specific cases, +# concretely, we can provide the size of memory and get the power/area results back. +class CactiHndlr(): + def __init__(self, bin_addr, param_file , cacti_data_log_file, input_col_order, output_col_order): + self.bin_addr = bin_addr + self.param_file = param_file + self.cur_mem_size = 0 + self.input_cfg = "" + self.output_cfg = "" + self.cur_mem_type = "" + self.cur_cell_type = "" + self.input_col_order = input_col_order + self.cacti_data_log_file = cacti_data_log_file + self.output_col_order = output_col_order + self.cacti_data_container = CactiDataContainer(cacti_data_log_file, input_col_order, output_col_order) + + def set_cur_cell_type(self, cell_type): + self.cur_cell_type = cell_type + + def set_cur_mem_size(self, cur_mem_size): + self.cur_mem_size = math.ceil(cur_mem_size) + + def set_cur_mem_type(self, cur_mem_type): + self.cur_mem_type = cur_mem_type + + def set_params(self): + param_file_copy_name= "/".join(self.param_file.split("/")[:-1]) + "/" + self.param_file.split("/")[-1] +"_cp" + #os.system("cp " + self.param_file + " " + param_file_copy_name) + shutil.copy(self.param_file, param_file_copy_name) + time.sleep(.05) + file1 = open(param_file_copy_name, "a") # append mode + size_cmd = "-size (bytes) " + str(self.cur_mem_size) + cur_mem_type_cmd = "-cache type \"" + self.cur_mem_type + "\"" + cell_type_cmd = "-Data array cell type - \"" + self.cur_cell_type+ "\"" + file1.write(size_cmd + "\n") + file1.write(cur_mem_type_cmd + "\n") + file1.write(cell_type_cmd + "\n") + file1.close() + + self.input_cfg = param_file_copy_name + self.output_cfg = self.input_cfg +".out" + + def get_config(self): + return {"mem_size":self.cur_mem_size, "mem_type":self.cur_mem_type, "cell_type:":self.cur_cell_type} + + def run_bin(self): + bin_dir = "/".join(self.bin_addr.split("/")[:-1]) + os.chdir(bin_dir) + #cmd = self.bin_addr + " " + "-infile " + self.input_cfg + subprocess.call([self.bin_addr, "-infile", self.input_cfg]) + #os.system(cmd) + + def run_cacti(self): + self.set_params() + self.run_bin() + + def reset_cfgs(self): + self.input_cfg = "" + self.output_cfg = "" + + def parse_and_find(self, kwords): + results_dict = {} + ctr =0 + while not os.path.isfile(self.output_cfg) and ctr < 60: + time.sleep(1) + ctr +=1 + + f = open(self.output_cfg) + reader = csv.DictReader(f) + dict_list = [] + for line in reader: + dict_list.append(line) + + for kw in kwords: + results_dict [kw] = [] + + for dict_ in dict_list: + for kw in results_dict.keys(): + for key in dict_.keys(): + if key == " " +kw: + results_dict[kw] = dict_[key] + + f.close() + return results_dict + + def collect_cati_data(self): + self.run_cacti() + results = self.parse_and_find(["Dynamic read energy (nJ)", "Dynamic write energy (nJ)", "Area (mm2)"]) + os.system("rm " + self.output_cfg) + return results + + + +class CactiDataContainer(): + def __init__(self, cached_data_file_addr, input_col_order, output_col_order): + self.cached_data_file_addr = cached_data_file_addr + self.input_col_order = input_col_order + self.output_col_order = output_col_order + self.prase_cached_data() + + def prase_cached_data(self): + # create the file if doesn't exist + if not os.path.exists(self.cached_data_file_addr): + file = open(self.cached_data_file_addr, "w") + for col_val in (self.input_col_order + self.output_col_order)[:-1]: + file.write(str(col_val)+ ",") + file.write(str((self.input_col_order + self.output_col_order)[-1])+"\n") + file.close() + + # populate the pand data frames with it + try: + self.df = pd.read_csv(self.cached_data_file_addr) + except Exception as e: + if e.__class__.__name__ in "pandas.errors.EmptyDataError": + self.df = pd.DataFrame(columns=self.input_col_order + self.output_col_order) + #self.df = + + def find(self, KVs): + df_ = self.df + for k,v in KVs: + df_temp = self.find_one_kv(df_, (k,v)) + if isinstance(df_temp, bool) and df_temp == False: # if can't be found + return False, "_", "_", "_" + elif df_temp.empty: + return False, "_", "_", "_" + df_ = df_temp + + if len(df_.index) > 1: # number of rows >1 means more than one equal value + print("can not have duplicated values ") + exit(0) + + output = [True] + [df_.iloc[0][col_name] for col_name in self.output_col_order] + return output + + def find_one_kv(self, df, kv): + if df.empty: + return False + + k = kv[0] + v = kv[1] + result = df.loc[(df[k] == v)] + return result + + def insert(self, key_values_): + # if data exists, return + if not self.df.empty: + if self.find(key_values_)[0]: + return + + # append the output file + # make the output file if doesn't exist + if not os.path.exists(self.cached_data_file_addr): + file = open(self.cached_data_file_addr, "w") + for col_val in self.df.columns[:-1]: + file.write(self.df.columns + ",") + file.write(self.df.columns[-1]+ "\n") + file.close() + + values_ = [kv[1] for kv in key_values_] + # add it to the pandas + df2 = pd.DataFrame([values_], columns=self.input_col_order + self.output_col_order) + self.df = self.df.append(df2, ignore_index=True) + + # append results to the file + with open(self.cached_data_file_addr, "a") as output: + for key, value in key_values_[:-1]: + output.write(str(value) +",") + output.write(str(values_[-1]) + "\n") + + +# just a test case +if __name__ == "__main__": + cact_bin_addr = "/Users/behzadboro/Downloads/cacti/cacti" + cacti_param_addr = "/Users/behzadboro/Downloads/cacti/farsi_gen.cfg" + cacti_data_log_file= "/Users/behzadboro/Downloads/cacti/data_log.csv" + + cur_mem_size = 320000000 + cur_mem_type = "main memory" # ["main memory", "ram"] + input_col_order = ("mem_subtype", "mem_size") + output_col_order = ("energy_per_byte", "area") + cacti_hndlr = CactiHndlr(cact_bin_addr, cacti_param_addr, cacti_data_log_file, input_col_order, output_col_order) + cacti_hndlr.set_cur_mem_size(cur_mem_size) + cacti_hndlr.set_cur_mem_type(cur_mem_type) + area_power_results = cacti_hndlr.collect_cati_data() + print(area_power_results) diff --git a/Project_FARSI/misc/cacti_hndlr/ddr3_.cfg b/Project_FARSI/misc/cacti_hndlr/ddr3_.cfg new file mode 100644 index 00000000..fc855cfb --- /dev/null +++ b/Project_FARSI/misc/cacti_hndlr/ddr3_.cfg @@ -0,0 +1,254 @@ +# Cache size +//-size (bytes) 2048 +//-size (bytes) 4096 +//-size (bytes) 32768 +//-size (bytes) 131072 +//-size (bytes) 262144 +//-size (bytes) 1048576 +//-size (bytes) 2097152 +//-size (bytes) 4194304 +//-size (bytes) 8388608 +//-size (bytes) 16777216 +//-size (bytes) 33554432 +//-size (bytes) 134217728 +//-size (bytes) 67108864 +//-size (bytes) 1073741824 + +# power gating +-Array Power Gating - "false" +-WL Power Gating - "false" +-CL Power Gating - "false" +-Bitline floating - "false" +-Interconnect Power Gating - "false" +-Power Gating Performance Loss 0.01 + +# Line size +-block size (bytes) 32 +//-block size (bytes) 64 + +# To model Fully Associative cache, set associativity to zero +//-associativity 0 +//-associativity 2 +//-associativity 4 +//-associativity 8 +-associativity 8 + +-read-write port 1 +-exclusive read port 0 +-exclusive write port 0 +-single ended read ports 0 + +# Multiple banks connected using a bus +-UCA bank count 1 +//-technology (u) 0.022 +-technology (u) 0.040 +//-technology (u) 0.032 +//-technology (u) 0.090 + +# following three parameters are meaningful only for main memories + +-page size (bits) 8192 +-burst length 8 +-internal prefetch width 8 + +# following parameter can have one of five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +-Data array cell type - "itrs-hp" +//-Data array cell type - "itrs-lstp" +//-Data array cell type - "itrs-lop" + +# following parameter can have one of three values -- (itrs-hp, itrs-lstp, itrs-lop) +-Data array peripheral type - "itrs-hp" +//-Data array peripheral type - "itrs-lstp" +//-Data array peripheral type - "itrs-lop" + +# following parameter can have one of five values -- (itrs-hp, itrs-lstp, itrs-lop, lp-dram, comm-dram) +-Tag array cell type - "itrs-hp" +//-Tag array cell type - "itrs-lstp" +//-Tag array cell type - "itrs-lop" + +# following parameter can have one of three values -- (itrs-hp, itrs-lstp, itrs-lop) +-Tag array peripheral type - "itrs-hp" +//-Tag array peripheral type - "itrs-lstp" +//-Tag array peripheral type - "itrs-lop + +# Bus width include data bits and address bits required by the decoder +//-output/input bus width 16 +-output/input bus width 128 + +// 300-400 in steps of 10 +-operating temperature (K) 360 + +# Type of memory - cache (with a tag array) or ram (scratch ram similar to a register file) +# or main memory (no tag array and every access will happen at a page granularity Ref: CACTI 5.3 report) +//-cache type "cache" +//-cache type "ram" +//-cache type "main memory" + +# to model special structure like branch target buffers, directory, etc. +# change the tag size parameter +# if you want cacti to calculate the tagbits, set the tag size to "default" +-tag size (b) "default" +//-tag size (b) 22 + +# fast - data and tag access happen in parallel +# sequential - data array is accessed after accessing the tag array +# normal - data array lookup and tag access happen in parallel +# final data block is broadcasted in data array h-tree +# after getting the signal from the tag array +//-access mode (normal, sequential, fast) - "fast" +-access mode (normal, sequential, fast) - "normal" +//-access mode (normal, sequential, fast) - "sequential" + + +# DESIGN OBJECTIVE for UCA (or banks in NUCA) +-design objective (weight delay, dynamic power, leakage power, cycle time, area) 0:0:0:100:0 + +# Percentage deviation from the minimum value +# Ex: A deviation value of 10:1000:1000:1000:1000 will try to find an organization +# that compromises at most 10% delay. +# NOTE: Try reasonable values for % deviation. Inconsistent deviation +# percentage values will not produce any valid organizations. For example, +# 0:0:100:100:100 will try to identify an organization that has both +# least delay and dynamic power. Since such an organization is not possible, CACTI will +# throw an error. Refer CACTI-6 Technical report for more details +-deviate (delay, dynamic power, leakage power, cycle time, area) 20:100000:100000:100000:100000 + +# Objective for NUCA +-NUCAdesign objective (weight delay, dynamic power, leakage power, cycle time, area) 100:100:0:0:100 +-NUCAdeviate (delay, dynamic power, leakage power, cycle time, area) 10:10000:10000:10000:10000 + +# Set optimize tag to ED or ED^2 to obtain a cache configuration optimized for +# energy-delay or energy-delay sq. product +# Note: Optimize tag will disable weight or deviate values mentioned above +# Set it to NONE to let weight and deviate values determine the +# appropriate cache configuration +//-Optimize ED or ED^2 (ED, ED^2, NONE): "ED" +-Optimize ED or ED^2 (ED, ED^2, NONE): "ED^2" +//-Optimize ED or ED^2 (ED, ED^2, NONE): "NONE" + +-Cache model (NUCA, UCA) - "UCA" +//-Cache model (NUCA, UCA) - "NUCA" + +# In order for CACTI to find the optimal NUCA bank value the following +# variable should be assigned 0. +-NUCA bank count 0 + +# NOTE: for nuca network frequency is set to a default value of +# 5GHz in time.c. CACTI automatically +# calculates the maximum possible frequency and downgrades this value if necessary + +# By default CACTI considers both full-swing and low-swing +# wires to find an optimal configuration. However, it is possible to +# restrict the search space by changing the signaling from "default" to +# "fullswing" or "lowswing" type. +-Wire signaling (fullswing, lowswing, default) - "Global_30" +//-Wire signaling (fullswing, lowswing, default) - "default" +//-Wire signaling (fullswing, lowswing, default) - "lowswing" + +//-Wire inside mat - "global" +-Wire inside mat - "semi-global" +//-Wire outside mat - "global" +-Wire outside mat - "semi-global" + +-Interconnect projection - "conservative" +//-Interconnect projection - "aggressive" + +# Contention in network (which is a function of core count and cache level) is one of +# the critical factor used for deciding the optimal bank count value +# core count can be 4, 8, or 16 +//-Core count 4 +-Core count 8 +//-Core count 16 +-Cache level (L2/L3) - "L3" + +-Add ECC - "true" + +//-Print level (DETAILED, CONCISE) - "CONCISE" +-Print level (DETAILED, CONCISE) - "DETAILED" + +# for debugging +//-Print input parameters - "true" +-Print input parameters - "false" +# force CACTI to model the cache with the +# following Ndbl, Ndwl, Nspd, Ndsam, +# and Ndcm values +//-Force cache config - "true" +-Force cache config - "false" +-Ndwl 1 +-Ndbl 1 +-Nspd 0 +-Ndcm 1 +-Ndsam1 0 +-Ndsam2 0 + + + +#### Default CONFIGURATION values for baseline external IO parameters to DRAM. More details can be found in the CACTI-IO technical report (), especially Chapters 2 and 3. + +# Memory Type (D=DDR3, L=LPDDR2, W=WideIO). Additional memory types can be defined by the user in extio_technology.cc, along with their technology and configuration parameters. + +-dram_type "D" +//-dram_type "L" +//-dram_type "W" +//-dram_type "S" + +# Memory State (R=Read, W=Write, I=Idle or S=Sleep) + +//-iostate "R" +-iostate "W" +//-iostate "I" +//-iostate "S" + +#Address bus timing. To alleviate the timing on the command and address bus due to high loading (shared across all memories on the channel), the interface allows for multi-cycle timing options. + +-addr_timing 0.5 //DDR +//-addr_timing 1.0 //SDR (half of DQ rate) +//-addr_timing 2.0 //2T timing (One fourth of DQ rate) +//-addr_timing 3.0 // 3T timing (One sixth of DQ rate) + +# Memory Density (Gbit per memory/DRAM die) + +-mem_density 8 Gb //Valid values 2^n Gb + +# IO frequency (MHz) (frequency of the external memory interface). + +-bus_freq 1000 MHz //As of current memory standards (2013), valid range 0 to 1.5 GHz for DDR3, 0 to 533 MHz for LPDDR2, 0 - 800 MHz for WideIO and 0 - 3 GHz for Low-swing differential. However this can change, and the user is free to define valid ranges based on new memory types or extending beyond existing standards for existing dram types. + +# Duty Cycle (fraction of time in the Memory State defined above) + +-duty_cycle 1.0 //Valid range 0 to 1.0 + +# Activity factor for Data (0->1 transitions) per cycle (for DDR, need to account for the higher activity in this parameter. E.g. max. activity factor for DDR is 1.0, for SDR is 0.5) + +-activity_dq 1.0 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR +#-activity_dq .50 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR + +# Activity factor for Control/Address (0->1 transitions) per cycle (for DDR, need to account for the higher activity in this parameter. E.g. max. activity factor for DDR is 1.0, for SDR is 0.5) + +-activity_ca 1.0 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR, 0 to 0.25 for 2T, and 0 to 0.17 for 3T +#-activity_ca 0.25 //Valid range 0 to 1.0 for DDR, 0 to 0.5 for SDR, 0 to 0.25 for 2T, and 0 to 0.17 for 3T + +# Number of DQ pins + +-num_dq 72 //Number of DQ pins. Includes ECC pins. + +# Number of DQS pins. DQS is a data strobe that is sent along with a small number of data-lanes so the source synchronous timing is local to these DQ bits. Typically, 1 DQS per byte (8 DQ bits) is used. The DQS is also typucally differential, just like the CLK pin. + +-num_dqs 36 //2 x differential pairs. Include ECC pins as well. Valid range 0 to 18. For x4 memories, could have 36 DQS pins. + +# Number of CA pins + +-num_ca 35 //Valid range 0 to 35 pins. +#-num_ca 25 //Valid range 0 to 35 pins. + +# Number of CLK pins. CLK is typically a differential pair. In some cases additional CLK pairs may be used to limit the loading on the CLK pin. + +-num_clk 2 //2 x differential pair. Valid values: 0/2/4. + +# Number of Physical Ranks + +-num_mem_dq 2 //Number of ranks (loads on DQ and DQS) per buffer/register. If multiple LRDIMMs or buffer chips exist, the analysis for capacity and power is reported per buffer/register. + +# Width of the Memory Data Bus + +-mem_data_width 32 //x4 or x8 or x16 or x32 memories. For WideIO upto x128. diff --git a/Project_FARSI/misc/converters/dot_to_png.py b/Project_FARSI/misc/converters/dot_to_png.py new file mode 100644 index 00000000..e80d42b5 --- /dev/null +++ b/Project_FARSI/misc/converters/dot_to_png.py @@ -0,0 +1,23 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +import os +import sys +import pygraphviz as pgv +from sys import argv + + +def converter(): + file = sys.argv[1] + hardware_dot_graph = pgv.AGraph(file) + hardware_dot_graph.layout() + hardware_dot_graph.layout(prog='circo') + output_file_name_1 = file.split(".dot")[0] + ".png" + output_file_1 = os.path.join("./", output_file_name_1) + # output_file_real_time_vis = os.path.join(".", output_file_name) # this is used for realtime visualization + hardware_dot_graph.draw(output_file_1, format = "png", prog='circo') + + +if __name__ == "__main__": + converter() diff --git a/Project_FARSI/misc/gotchas b/Project_FARSI/misc/gotchas new file mode 100644 index 00000000..cfcb7416 --- /dev/null +++ b/Project_FARSI/misc/gotchas @@ -0,0 +1 @@ +for blocks that we stitically set their area, by convention, work_over_area values are 1/(their fixed area) diff --git a/Project_FARSI/misc/scratch/__init__.py b/Project_FARSI/misc/scratch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Project_FARSI/readme b/Project_FARSI/readme new file mode 100644 index 00000000..de8a96a9 --- /dev/null +++ b/Project_FARSI/readme @@ -0,0 +1,21 @@ +This is the repository for FARSI code base + + +** to run and collect data for exploration: + (1) cd data_collection/collection_utils/what_ifs/ # go to the executable folder + (2) set the workload name properly in FARSI_what_ifs.py + (3) python FARSI_what_ifs.py # run FARSI. + + +** to run and collect one simulation data: + (1) cd data_collection/collection_utils/sim_run/ # go to the executable folder + (2) set the workload name properly in simple_sim_run.py + (3) python simple_sim_run.py # run FARSI + + + +** to modify the settings +Modify the settings/config.py file. This file contains many knobs that will determine the exploration heuristic and simulation +features. Please refer to the in file documentations for more details + + diff --git a/Project_FARSI/settings/__init__.py b/Project_FARSI/settings/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Project_FARSI/settings/config.py b/Project_FARSI/settings/config.py new file mode 100644 index 00000000..4b046af4 --- /dev/null +++ b/Project_FARSI/settings/config.py @@ -0,0 +1,320 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +# simulation mode +from collections import defaultdict +import home_settings +from specs.LW_cl import * +import os +import settings.config_cacti as CC +termination_mode = "workload_completion" # when to terminate the exploration +assert termination_mode in ["workload_completion", "time_budget_reached"], "termination_mode:" +\ + termination_mode + " not defined" +# json inputs +design_input_folder = "model_utils/specs" + +simulation_method = "performance" # whether to performance simulator or power simulator +#simulation_method = "power_knobs" + +# -------------------- +# DSE params +# -------------------- +# algorithm +dse_type = "hill_climbing" # [exhaustive, hill_climbing] # type of the design space exploration +dse_collection_mode = "serial" # [serial, parallel] # only relevant for exhaustive for now + # python parallel: python spawns processes (used when running within a server) + # bash_parallel: you need to manually run the processes (used when you want to + # run across multiple servers) + +# exhaustive analysis output folder names +exhaustive_output_file_prefix = stage_1_input_exhaustive_data = 'exhaustive_for_pid' +stage_2_input_exhaustive_data = 'stage_2_for_pid' +stage_2_output_exhaustive_data = 'stage_2_consolidated' +exhaustive_target_data = "target_budget" +exhaustive_result_dir = "exhaustive_search" +FARSI_result_dir = "FARSI_search" +FARSI_outputfile_prefix_verbose = "FARSI_search" +FARSI_outputfile_prefix_minimal = "FARSI_search_minimal" +FARSI_outputfile_prefix_error_integration = "FARSI_error_integration_study" # compile the data for error_integration vs exact IP library search +FARSI_cost_correlation_study_prefix = "FARSI_cost_correlation_study" +FARSI_input_error_output_cost_sensitivity_study_prefix = "FARSI_input_error_output_cost_sensitivity_study" +FARSI_input_error_input_cost_sensitivity_study_prefix = "FARSI_input_error_input_cost_sensitivity_study" +FARSI_input_error_output_PPA_sensitivity_study_prefix = "FARSI_input_error_output_PPA_sensitivity_study" +FARSI_simple_run_prefix = "FARSI_simple_run" +FARSI_simple_sim_run_study_prefix = "simple_sim_run" + + +parallel_processes_count = 1 # number of processes to spawn to search the design space. Only used for exhaustive search +process_id = 0 # used for parallel execution. Each process is assigned an id, through main + +warning_mode = "always" +num_clusters = 2 # how many clusters to create everytime we split +TOTAL_RUN_THRESHOLD = 5000 # acceptable iterations count without improvement before termination +DES_STAG_THRESHOLD = 50 # acceptable iterations count without improvement before termination + +neigh_gen_mode = "some" # neighbouring design points generation mode ("all" "random_one", ...) +num_neighs_to_try = 3 # how many neighs to try around a dp +neigh_sel_mode = "best" # neighbouring design selection mode (best, sometimes, best ...) +dp_rank_obj = "latency" # design point ranking object function(best, sometimes, best ...) + +sel_next_dp = "all_metrics" # how to select the next desigh, ["all_metrics", "one_"metric"] + + +# selection algorithm (picking the best neighbour) +neigh_sel_algorithm = "annealing" +SA_breadth = 1 # breath of the neighbour search +SA_depth = 15 # depth of the neighbour search +annealing_max_temp = 500 +annealing_temp_dec = 50 +annealing_dampening_coef = 10 # how much to dampen the metric that has met the design objectives + # in annealing_energy calculation + +metric_sel_dis_mode = "eliminate" # {"eliminate", "dampen", "simple"} eliminate, eliminates the metric metric from + # distance calculation and dampen , dampens it (using annealing_dampening_coef) + +# scheduling policy +scheduling_policy = "FRFS" # first read, first serve + +# migration policy +#migrant_clustering_policy = "tasks_dependency" # the policy that is used for clustering the migrants ["random", "tasks_dependency"] +migrant_clustering_policy = "selected_kernel" # the policy that is used for clustering the migrants ["random", "tasks_dependency"] +migration_policy = "random" # the policy to use in picking the clustered tasks to move + +# objectives +objectives = ["latency"] # [power, area, latency] are the options +#objective_function_type = "pareto" # ["pareto", "weighted"] if weighted, we need to provide the weights + +sorting_SOC_metric = "power" +all_metrics = ["latency", "power", "area", "energy", "cost"] # all the metrics for the evaluation of SOC +budgetted_metrics = ["latency", "power", "area"] +other_metrics = ["cost"] + +budget_dict = {} +budget_dict["glass"] = {} +budget_dict["glass"]["power"] = .05 +budget_dict["glass"]["area"] = .000005 + +home_dir = home_settings.home_dir +#home_dir = os.getcwd()+"/../../" + +# ------------------------ +# verify parameters +# ------------------------ +verify_home_dir = home_dir + +# ------------------------ +# move parameters +# ------------------------ +# negative means smaller the better. We only support minimization +# problems (i.e, the smaller the better) +metric_improvement_dir = {} +metric_improvement_dir["latency"] = -1 # direction of improvement is reduction, and hence -1 +metric_improvement_dir["power"] = -1 # direction of improvement is reduction, and hence -1 +metric_improvement_dir["energy"] = -1 # direction of improvement is reduction, and hence -1 +metric_improvement_dir["area"] = -1 # direction of improvement is reduction, and hence -1 +metric_improvement_dir["cost"] = -1 # direction of improvement is reduction, and hence -1 +move_s_krnel_selection = ["bottleneck"] # options are :bottleneck, improvement_ease + + +for metric in all_metrics: + if (metric not in metric_improvement_dir.keys()) or\ + not(metric_improvement_dir[metric] == -1): # is not a minimization problem + print("---Error:can only support metrics that require minimization. You need to change the metric selection in" + "navigation heuristic if you want otherwise in ") + exit(0) + +heuristic_type = "FARSI" # {moos, FARSI, SA} +moos_greedy_mode = "phv" +MOOS_GREEDY_CTR_RUN = 10 +DESIGN_COLLECTED_PER_GREEDY = 20 + +#objective_function = 0 # +#objective_budget = .000000001 +metric_trans_dict = {"latency": ["split", "swap", "migrate", "split_swap"], "power": ["split", "swap", "migrate", "split_swap"], + "area": ["split", "swap", "migrate", "split_swap"]} + +cleaning_threshold = 2000000000000000000000 # how often to activate cleaning +cleaning_consecutive_iterations = 1 # how many consecutive iterations to clean + +move_metric_ranking_mode = "exact" # exact, prob. If exact, metrics are ranked (and hence selected) based on + # their distance to the goal. If prob, we sample probabilistically based on the + # distance + +move_krnel_ranking_mode = "exact" # exact, prob. If exact, kernels are ranked (and hence selected) based on + # their distance to the goal. If prob, we sample probabilistically based on the + # distance + +move_blck_ranking_mode = "exact" # exact, prob. If exact, blocks are ranked (and hence selected) based on + # their distance to the goal. If prob, we sample probabilistically based on the + # distance + +max_krnel_stagnation_ctr = 2 +fitted_budget_ctr_threshold = 1 # how many times fitting the budget before terminating + +recently_cached_designs_queue_size = 20 +max_recently_seen_design_ctr = 2 +assert(recently_cached_designs_queue_size > max_recently_seen_design_ctr) +# -------------------- +# DEBUGGING +# -------------------- +NO_VIS = False # if set, no visualization is used. This speeds up everything +DEBUG_SANITY = True # run sanity check on the design +DEBUG_FIX = False # non randomize the flow (by not touching the seed) +VIS_GR_PER_GEN = False and not NO_VIS # visualize the graph per design point generation +VIS_SIM_PER_GEN = False and not NO_VIS # if true, we visualize the simulation progression +VIS_GR_PER_ITR = True and not NO_VIS # visualize the graph exploration per iteration +VIS_PROFILE = True and not NO_VIS # visualize the profiling data +VIS_FINAL_RES = False and not NO_VIS # see the final results +VIS_ALL = False and not NO_VIS # visualize everything +REPORT = True # report the stats (to the screen); draw plots. +DATA_DELIVEYRY = "absolute" #[obfuscate, absolute]" +DEBUG_MISC = False # scenarios haven't covered above +pa_conversion_file = "pa_conversion.txt" +WARN = False +data_folder = "data" +PA_output_folder = data_folder+"/"+"PA_output" +sim_progress_folder = data_folder+"/"+"sim_progress" +RUN_VERIFICATION_PER_GEN = False # every new desi, generate the verification data +RUN_VERIFICATION_PER_NEW_CONFIG = False +RUN_VERIFICATION_PER_IMPROVMENT = False and not (RUN_VERIFICATION_PER_GEN or RUN_VERIFICATION_PER_NEW_CONFIG) # every improvement, generate verification + # don't want to double generate, hence the second + # predicate clause +RUN_VERIFICATION_AT_ALL = RUN_VERIFICATION_PER_IMPROVMENT or RUN_VERIFICATION_PER_NEW_CONFIG or RUN_VERIFICATION_PER_GEN + +VIS_SIM_PROG = RUN_VERIFICATION_PER_GEN or RUN_VERIFICATION_PER_IMPROVMENT or RUN_VERIFICATION_PER_NEW_CONFIG # visualize the simulation progression + +verification_result_file = "verification_result_file.csv" +# MOVES + +FARSI_memory_consumption = "high" # [low, high] if low is selected, we deactivate certain knobs to avoid using memory excessively +FARSI_performance = "fast" # ["slow", "fast"] # if set to fast, we don't visualize as often and use certain fast versions + # of functions to accomplish the tasks +if FARSI_performance == "fast": + vis_reg_ctr_threshold = 60 +else: + vis_reg_ctr_threshold = 1 + +DEBUG_MOVE = True and not NO_VIS # if true, we print/collect relevant info about moves +regulate_move_tracking = (FARSI_memory_consumption == "low") # if true, we don't track and hence graph every move. This helps preventing memory pressure (and avoid getting killed by the OS) +#vis_move_trail_ctr_threshold = 20 # how often sample the moves (only applies if regulat_move_tracking enabled) + +cache_seen_designs = False and not(FARSI_memory_consumption == "low") # if True, we cache the designs that we have seen. This way we wont simulate them unnecessarily. + # This should be set to false if memory is an issue + +VIS_MOVE_TRAIL = DEBUG_MOVE and not NO_VIS and False +eval_mode ="statistical" # not statistical evaluation ["singular, statistical]. Note that singular is deprecated now +statistical_reduction_mode = "avg" +hw_sampling = {"mode":"exact", "population_size":1, "reduction":"avg"} # mode:["error_integration", "exact"] # error integration means that our IP library has an error and needs to be taken into account + # exact, means that (even if IP library has an error), treat the (most likely) value as accurate value + +check_pointing_allowed = True +check_point_list = ["ex", "db", "counters"] #choose from ["sim","ex", "db", "counters"] + + +use_slack_management_estimation = False and not (RUN_VERIFICATION_PER_GEN or RUN_VERIFICATION_PER_IMPROVMENT or RUN_VERIFICATION_PER_NEW_CONFIG)# if run verification, we can apply slack, otherwise we get the wrong numbers +jitter_population_size= 1 # not statistical evaluation +if hw_sampling["mode"] == "exact": + hw_sampling["population_size"] = 1 + +#dice_factor_list = range(1, 150, 50) +#dice_factor_list = [1] +sw_model = "gables_inspired_exact" # [gables_inspired_exact, gables_inspired, sequential] the diff is that exact replicates the PEs to solve the PA DRVR preemption issue +#sw_model = "sequential" # read, execute write done in this order instead of simultenously +#if not sw_model == "gables_inspired": +# dice_factor_list = [1] + +if VIS_GR_PER_GEN: VIS_GR_PER_ITR = True + +if VIS_ALL: + DEBUG = True; DEBUG_FIX = True; VIS_GR_PER_GEN = True; VIS_GR_PER_ITR = True; VIS_PROFILE = True + VIS_FINAL_RES = True; + +# visualization +hw_graphing_mode = "block_task" # block, block_extra, block_task +stats_output = "stats" + +# clustering +ic_mig_clustering = "data_sharing" # how to pick the pe, mem tuples choose between ["data_sharing", random] +tasks_clustering_data_sharing_method = "task_dep" + +# area +statically_sized_blocks = ["gpp"] # these blocks size is predeteremined (as opposed to memory, buses and ips that + # require dynamic size based on the mapping +zero_sized_blocks = ["ic"] # blocks to ignore for now # TODO: figure ic size later + + + +DMA_mode = "serialized_read_write" # [serialized_read_write, parallel_read_write] +#DMA_mode = "parallelized_read_write" # [serialized_read_write, parallelized_read_write] + +# power collection period (how often to divide energy). it's measured in seconds +#PCP = .0001 +PCP = .01 + +budget_fitting_coeff = .9999 # used to make sure the slack values are uses in a way to bring the power and latency just beneath budget + +# soource, sink +#souurce_memory_work = 81920 # in bytes +proj_name = "simple_multiple_hops" +proj_name = "SLAM" + +# PA (platform) conversion files +hw_yaml = "hw.yml" +pa_des_top_name = "my_top" +pa_space_distance = 1 +pa_push_distance = 4 +max_budget_coeff = 50 + +# check pointing and reading from checkpoins +latest_visualization = os.path.join(home_dir, 'data_collection/data/latest_visualization') +check_point_folder = os.path.join(home_dir, 'data_collection/data/check_points') +replay_folder_base = os.path.join(home_dir, 'data_collection/data/replayer') +database_csv_folder = os.path.join(home_dir, 'specs/database_csvs/') # where all the library input are located + +axis_unit = {"area": "mm2", "power": "mW", "latency": "s"} +database_data_dir = os.path.join(home_dir, "specs", "database_data") +transaction_base_simulation = False # do not set to true. It doesn't work + + +# CACTI +use_cacti = True and not RUN_VERIFICATION_AT_ALL # if True, use cacti. You have to have cacti installed.j +#use_cacti = True +cact_bin_addr = CC.cact_bin_addr +cacti_param_addr = CC.cacti_param_addr +cacti_data_log_file = CC.cacti_data_log_file +cacti_input_col_order = ["mem_subtype", "mem_size"] +cacti_output_col_order = ["read_energy_per_byte", "write_energy_per_byte", "area"] +cacti_min_memory_size_in_bytes = 2048 # bellow this value cacti errors out. We can play with burst size and page size to fix this though + +#ACC_coeff = 128 # comparing to what we have parsed, how much to modify. This is just for some exploration purposes + # It should almost always set to 1 + + +transformation_selection_mode = "random" # choose from {random, arch-aware} + +all_available_transformations = ["migrate", "swap", "split", "split_swap"]#, "transfer", "routing"] +if RUN_VERIFICATION_AT_ALL: + all_available_transformations = ["migrate", "swap", "split", "split_swap"] + +min_mem_size = {"sram": 256000, "dram":256000} + +dram_stacked = True +parallelism_analysis = "dynamic" # choose from ["dynamic", "static"] # at the moment static is not working, something to do with the task and task sync + # and read to being present after unloading + + +heuristic_scaling_study = True +print_info_regularly = False + +out_of_memory_percentage = 93 +default_cmd_queue_size = 16 +default_data_queue_size = 16 +#default_burst_size = 128 +default_burst_size = 256 +#default_cmd_queue_size = 16 +#default_data_queue_size = 16 +#default_burst_size = 64 + + +memory_conscious = True diff --git a/Project_FARSI/settings/config_cacti.py b/Project_FARSI/settings/config_cacti.py new file mode 100644 index 00000000..91cf36d9 --- /dev/null +++ b/Project_FARSI/settings/config_cacti.py @@ -0,0 +1,18 @@ +import os + +# get the base path of arch-gym +base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) + + +cact_bin_addr = os.path.join(base_path, "Project_FARSI/cacti_for_FARSI/cacti") + +print(cact_bin_addr, os.path.exists(cact_bin_addr)) + +cacti_param_addr = os.path.join(base_path, "Project_FARSI/cacti_for_FARSI/farsi_gen.cfg") + +print(cacti_param_addr, os.path.exists(cacti_param_addr)) + +cacti_data_log_file = os.path.join(base_path, "Project_FARSI/cacti_for_FARSI/data_log.csv") + +print(cacti_data_log_file, os.path.exists(cacti_data_log_file)) + diff --git a/Project_FARSI/settings/config_plotting.py b/Project_FARSI/settings/config_plotting.py new file mode 100644 index 00000000..233ddd1d --- /dev/null +++ b/Project_FARSI/settings/config_plotting.py @@ -0,0 +1,34 @@ +zoneNum = 4 +run_folder_name= "/media/reddi-rtx/KINGSTON/FARSI_results/1_1_1_for_paper_07-31" +run_folder_name="/home/reddi-rtx/FARSI_related_stuff/Project_FARSI_4/data_collection/data/simple_run/08-01_13-22_12" +run_folder_name="/media/reddi-rtx/KINGSTON/FARSI_results/scaling_of_1_2_4_across_all_budgets_07-31" +run_folder_name="/media/reddi-rtx/KINGSTON/FARSI_results/blind_study/blind_version/blind_combined/" +run_folder_name="/media/reddi-rtx/KINGSTON/FARSI_results/blind_study/blind_version/blind_vs_arch_ware" +run_folder_name="/media/reddi-rtx/KINGSTON/FARSI_results/blind_study_smart_krnel_selection/blind_vs_arch_ware" +run_folder_name ="/home/reddi-rtx/FARSI_related_stuff/Project_FARSI_4/data_collection/data/simple_run/aggregate_data" +#run_folder_name ="/home/reddi-rtx/FARSI_related_stuff/Project_FARSI_4/data_collection/data/simple_run/08-02_01-38_26/" +#run_folder_name ="/media/reddi-rtx/KINGSTON/FARSI_results/optimal_budgetting_problem_08_1" +run_folder_name = "/Users/behzadboro/FARSI_related_stuff/FARSI_results/1_1_1_for_paper_07-31" +run_folder_name = "/home/reddi-rtx/FARSI_related_stuff/Project_FARSI_TECS/Project_FARSI_6/data_collection/data/simple_run/final_03-04_19-58_52" +heuristic_comparison_folder = "/home/reddi-rtx/FARSI_related_stuff/Project_FARSI_TECS/Project_FARSI_6/data_collection/data/simple_run/blah_" +#heuristic_comparison_folder = "/media/reddi-rtx/Samsung USB/FARSI_results_5000" +heuristic_comparison_folder = "/media/reddi-rtx/Samsung USB/moos_with_ctr_greedy_collection" +#heuristic_comparison_folder = "/media/reddi-rtx/Samsung USB/FARSI_results" +#heuristic_comparison_folder = "/media/reddi-rtx/Samsung USB/SA_metropolis_criteria" +#heuristic_comparison_folder = "/media/reddi-rtx/Samsung USB/SA_modified_randomness" +#heuristic_comparison_folder = "/media/reddi-rtx/Samsung USB/FARSI_5000_all_the_way" +#run_folder_name = "/home/yingj4/Desktop/FARSI_results/arch-aware-vs-random/all_workloads_together" +#top_result_folder = "/home/yingj4/Desktop/FARSI_results/arch-aware-vs-random" +heuristic_comparison_folder = "/media/reddi-rtx/Samsung USB/compare_heuristics" # for comparing FARSI vs SA vs moos +heuristic_comparison_folder = "/media/reddi-rtx/Samsung USB/heuristic_scaling" # just FARSI for different task count/and budget +heuristic_comparison_folder = "/media/reddi-rtx/Samsung USB/compare_heuristic_worse_version_1" # for comparing FARSI vs SA vs greedy-ctr of moos +#heuristic_comparison_folder = "/media/reddi-rtx/Samsung USB/compare_heuristic_worse_version_2" # for comparing FARSI vs SA vs closest to SA of moos + +top_result_folder = "/home/reddi-rtx/FARSI_related_stuff/Project_FARSI_4/data_collection/data/simple_run" +ignore_file_names = ["README.md", "figures", "cross_workloads", "single_workload", "3D", "budget_optimality", "panda_study", "pie_chart", "verification"] +#plot_list = ["cross_workloads", "single_workload","plot_3d", "pandas_plots", "budget_optimality"] +plot_list = ["cross_workloads"]#"cross_workloads", "single_workload"] +plot_list = ["heuristic_comparison"]#"cross_workloads", "single_workload"] +#plot_list = ["heuristic_scaling"]#"cross_workloads", "single_workload"] +#plot_list = ["pareto_studies"]#"cross_workloads", "single_workload"] +draw_for_paper = False \ No newline at end of file diff --git a/Project_FARSI/specs/LW_cl.py b/Project_FARSI/specs/LW_cl.py new file mode 100644 index 00000000..a59441d6 --- /dev/null +++ b/Project_FARSI/specs/LW_cl.py @@ -0,0 +1,285 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +# this file contains light weight version of the design +# components. These classes are used in the database +from typing import List, Tuple +from design_utils.components.hardware import * +from design_utils.components.workload import * +#from design_utils.components.mapping import * +from design_utils.components.scheduling import * +from typing import List +from collections import defaultdict + +# This class emulates a hardware block. However, BlockL is a light weight class that directly talks to the database. +# This is later used by the Block class which is a much more involved +class BlockL: # block light weight + def __init__(self, block_instance_name: str, block_type: str, block_subtype, peak_work_rate_distribution, + work_over_energy_distribution, work_over_area_distribution, one_over_area_distribution, + clock_freq, bus_width, loop_itr_cnt=1, loop_max_possible_itr_cnt=1, + hop_latency=1, pipe_line_depth=1, + leakage_power = "", power_knobs = ""): + self.block_instance_name = block_instance_name+"_"+block_type # block instance name + self.block_instance_name_without_type = block_instance_name # without type + self.block_type = block_type # type of the block (pe, mem, ic) + self.block_subtype = block_subtype # sub type of each block, e.g, for pe: ip or gpp + self.peak_work_rate_distribution = peak_work_rate_distribution # peak_work_rate: the fastest that a hardware block can do it's work + # note that work definition varies depending on the hardware type + # e.g., pe work = instructions, mem/ic work = bytes + self.work_over_energy_distribution = work_over_energy_distribution # how much energy is consume per the amount of work done + self.work_over_area_distribution = work_over_area_distribution # how much area is occupied per the amount of work done + self.leakage_power = leakage_power # leakage power + self.power_knobs = power_knobs # all knobs used for power modulation + self.one_over_area_distribution = one_over_area_distribution + self.bus_width = bus_width + self.clock_freq = clock_freq + self.loop_itr_cnt = loop_itr_cnt + self.loop_max_possible_itr_cnt = loop_max_possible_itr_cnt + self.hop_latency = hop_latency + self.pipe_line_depth = pipe_line_depth + + +# This class emulates the software tasks (e.g., glint detection) within an entire workload. Note that TaskL is a +# light weight class that directory talks to the database. This is later used by the Task class which is much more +# involved. +class TaskL: # task light weight + def __init__(self, task_name: str, work: float, iteration=1, type = "latency_based", throughput_info = {}): + self.task_name = task_name + self.work = work # the amount of work associated with task (at the mement, this is expressed for PEs (as the + # reference block). so work = number of instructions. + self.__task_children = [] # dependent task. + self.__self_to_children_work = {} # amount of byte that will be passed from this task to its children. + self.__self_task_work_distribution = [] # amount of work for this task as a distribution (for jitter modeling). + self.__self_to_child_task_work_distribution = {} # amount of bytes passed from this task to its children (as a distribution). + self.__children_nature_dict = dict() + self.burst_size = 256 + self.iteration = iteration + self.throughput_info = throughput_info # can be empty if the task is latency based + self.type = type + # hardcoding for testing + #self.iteration = 2 + """ + if "Smoothing" in self.task_name: + self.type = "throughput_based" + self.throughput_info = {"read": 100000000, "write":100000000, "clock_period": 10000} # clock period in ns + else: + self.type = "latency_based" + """ + + def get_throughput_info(self): + return self.throughput_info + + def get_type(self): + return self.type + + # ------------------------------ + # Functionality: + # adding a child task (i.e., a dependent task) + # Variables: + # taskL: child task + # work: amount of work (currently expressed in bytes) imposed on the child task + # ------------------------------ + def add_child(self, taskL, work, child_nature): + self.__task_children.append(taskL) + self.__self_to_children_work[taskL] = work + self.__children_nature_dict[taskL] = child_nature + + def get_child_nature(self, taskL): + return self.__children_nature_dict[taskL] + + def set_burst_size(self, burst_size): + self.burst_size = burst_size + + def get_burst_size(self): + return self.burst_size + + def get_children_nature(self): + return self.__children_nature_dict + + def set_children_nature(self, children_nature): + self.__children_nature_dict = children_nature + + # ------------------------------ + # Functionality: + # add work distribution (as opposed to a single value) for the task (for jitter modeling) + # Variables: + # work_dis: work distribution to add + # ------------------------------ + def add_task_work_distribution(self, work_dis): + self.__self_task_work_distribution = work_dis + + # ------------------------------ + # Functionality: + # add work distribution (as opposed to a single value) for the task's child (for jitter modeling) + # Variables: + # childL: child task + # work_dis: work distribution to add + # ------------------------------ + def add_task_to_child_work_distribution(self, childL, work_dist): + self.__self_to_child_task_work_distribution[childL] = work_dist + + + def set_self_to_children_work_distribution(self, task_to_child_work_distribution): + self.__self_to_child_task_work_distribution = task_to_child_work_distribution + + def get_self_to_children_work_distribution(self): + return self.__self_to_child_task_work_distribution + + # ------------------------------ + # Functionality: + # get the task's work distribution + # ------------------------------ + def get_task_work_distribution(self): + return self.__self_task_work_distribution + + def set_task_work_distribution(self, work_dis): + self.__self_task_work_distribution = work_dis + + + # ------------------------------ + # Functionality: + # get the task to child work distribution + # ------------------------------ + def get_task_to_child_task_work_distribution(self, childL): + return self.__self_to_child_task_work_distribution[childL] + + # get taskL's dependencies + def get_children(self): + return self.__task_children + + # set taskL's dependencies + def set_children(self, children): + self.__task_children = children + + # ------------------------------ + # Functionality: + # convert the light weight task (i.e., TaskL) to Task class + # ------------------------------ + def toTask(self): + return Task(self.task_name, self.work) + + # ------------------------------ + # Functionality: + # get work imposed on the child + # Variables: + # child: task's child + # ------------------------------ + def get_self_to_child_work(self, child): + assert(child in self.__self_to_children_work.keys()) + return self.__self_to_children_work[child] + + # ------------------------------ + # Functionality: + # get work imposed on the children + # ------------------------------ + def get_self_to_children_work(self): + return self.__self_to_children_work + + def set_self_to_children_work(self, self_to_children_work): + self.__self_to_children_work = self_to_children_work + + +# This class emulates an SOC +class SOCL: + def __init__(self, type, budget_dict, other_metrics_dict): + self.budget_dict = budget_dict + self.other_metrics_dict = other_metrics_dict + self.type = type + assert (sorted(list(budget_dict.keys())) == sorted(config.budgetted_metrics)), "budgetted metrics need to be the same" + + # ------------------------------ + # Functionality: + # get the SOC budget + # Variables: + # metric_name: metric (energy, power, area, latency) to get buget for + # ------------------------------ + def get_budget(self, metric_name): + for metric_name_, budget_value in self.budget_dict.items(): + if metric_name_ == metric_name: + return budget_value + raise Exception("meteric:" + metric_name + " is not budgetted in the design") + + + def set_budget(self, metric_name, metric_value): + for metric_name_, _ in self.budget_dict.items(): + if metric_name_ == metric_name: + self.budget_dict[metric_name_] = metric_value + raise Exception("meteric:" + metric_name + " is not budgetted in the design") + + def set_other_metrics_ideal_values(self, metric, value): + for metric_name_, _ in self.other_metrics_dict.items(): + if metric_name_ == metric: + self.other_metrics_dict[metric_name_] = value + return + raise Exception("meteric:" + metric + " is not in other values in the design") + + def get_other_metrics_ideal_values(self, metric_name): + for metric_name_, ideal_value in self.other_metrics_dict.items(): + if metric_name_ == metric_name: + return ideal_value + raise Exception("meteric:" + metric_name + " is not in other values in the design") + + # ------------------------------ + # Functionality: + # get the name of all the metrics that have budgets. + # ------------------------------ + def get_budgetted_metric_names(self): + return list(self.budget_dict.keys()) + + + def get_other_metric_names(self): + return list(self.other_metrics_dict.keys()) + + # get type of the SOC + def get_type(self): + return self.type + + def get_budget_dict(self): + return self.budget_dict + + +# This class models a mapping from a task to a block +class TaskToPEBlockMapL: + def __init__(self, task_name: str, pe_block_instance_name): + self.task_name = task_name # task name + self.pe_block_instance_name = pe_block_instance_name+"_"+"pe" # block name + self.child_facing_work_ratio_dict = {} # This is work ratio assiociated with a specific child (task). Note that + # since this is facing a child, it's a write (so write work ratio) + self.parent_facing_work_ratio_dict = {} # This is the write work ratio assiociated with a specific child (task) + # since this is facing a parent, it's a read (so read work ratio) + + self.family_work_ratio = defaultdict(dict) # work ratio associated with the task and it's family (parents/children) members. + # Work_ratio: + # self: (PE) = 1 + # parent to child: (mem) = bytes/insts + self.family_work_ratio['self'][self.task_name] = 1 # this is for PEs, so for every family member, we'll add + # a work_ratio of 1 which is associated with the tasks' work itself + + # ------------------------------ + # Functionality: + # adding a family member for the task + # Variable: + # list of family members and their work ratio + # ------------------------------ + def add_family(self, family_list: List[Tuple[str, str, float]]): + for family in family_list: + family_member_name = family[1] + family_member_work_ratio = family[2] + relationship = family[0] + # this is "mem" and "bus" work_ratio + self.family_work_ratio[relationship][family_member_name] = family_member_work_ratio + + # ------------------------------ + # Functionality: + # getting the work ratio + # ------------------------------ + def get_work_ratio_new(self): + return self.family_work_ratio + raise Exception("could not find a task with name:" + task.name + " in this task_to_block_map") + +# This class contains scheduling information for task +class TaskScheduleL: + def __init__(self, task_name: str, starting_time: float): + self.task_name = task_name + self.starting_time = starting_time \ No newline at end of file diff --git a/Project_FARSI/specs/__init__.py b/Project_FARSI/specs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Project_FARSI/specs/data_base.py b/Project_FARSI/specs/data_base.py new file mode 100644 index 00000000..baf7abb9 --- /dev/null +++ b/Project_FARSI/specs/data_base.py @@ -0,0 +1,925 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +from specs.LW_cl import * +import time +import random +from operator import itemgetter, attrgetter +from settings import config +from datetime import datetime +#from specs import database_input +if config.simulation_method == "power_knobs": + from specs import database_input_powerKnobs as database_input +elif config.simulation_method == "performance": + from specs import database_input +else: + raise NameError("Simulation method unavailable") + +from design_utils.components.hardware import * + + +# this class is used for populating th library. +# It is a light weight class to help library population +class IPLibraryElement(): + def __init__(self): + self.blockL = "_" + self.task = "_" + self.PPAC_dict = {} + + def set_task(self, task_): + self.task = task_ + + def set_blockL(self, blockL_): + self.blockL = blockL_ + + # generate the PPAC (performance, power, area and cost) dictionary + def generate(self): + self.PPAC_dict["latency"] = self.task.get_self_task_work()/self.blockL.peak_work_rate + self.PPAC_dict["energy"] = self.task.get_self_task_work()/self.blockL.work_over_energy + self.PPAC_dict["area"] = self.task.get_self_task_work()/self.blockL.work_over_area + self.PPAC_dict["power"] = self.PPAC_dict["energy"]/self.PPAC_dict["latency"] + + def get_blockL(self): + return self.blockL + + def get_task(self): + return self.task + + def get_PPAC(self): + return self.PPAC_dict + + +# this class handles the input data base data. +# any class that wants to read the database data needs to talk to this class +class DataBase: + def __init__(self, db_input, hw_sampling): + self.cached_blockL_to_block = {} + self.cached_block_to_blockL = {} + self.db_input = db_input + self.tasksL = self.db_input.tasksL # task + self.pe_mapsL = self.db_input.pe_mapsL # pe maps + self.blocksL = self.db_input.blocksL # blocks + self.pe_schedulesL = db_input.pe_schedeulesL # schedules + self.workloads_last_task = db_input.workloads_last_task + self.SOCsL = db_input.SOCsL # SOC + self.SOC_id = 0 # soc id. + self.hw_sampling = hw_sampling # how to sample hardware + + # cluster blocks + self.ic_block_list = self.get_blocksL_by_type(block_type="ic") # list of ICs + self.mem_block_list = self.get_blocksL_by_type(block_type="mem") # list of memories + self.pe_block_list = self.get_blocksL_by_type(block_type="pe") # list of PEs + + # generate tasks (TODO: this might need to move out of data_base if we consider a data_base for different workloads) + self.__tasks = self.parse_and_gen_tasks() # generate all the tasks (from tasks in the database_input.py) + self.__blocks = [self.cast(blockL) for blockL in self.get_all_BlocksL()] + + # find mappable blocks for all the tasks + self.mappable_blocksL_to_tasks_s_name_dict:Dict[str:TaskL] = {} + self.populate_mappable_blocksL_to_tasks_s_name_dict() + + def set_workloads_last_task(self, workloads_last_task): + self.workloads_last_task = workloads_last_task + + # used to determine when a workload is done + def get_workloads_last_task(self): + return self.workloads_last_task + + def set_hw_sampling(self, hw_sampling): + self.hw_sampling = hw_sampling + + # ------------------------------ + # Functionality: + # input_database to database conversion + # ------------------------------ + def cast(self, obj, *argv ): + if len(argv) == 0 and isinstance(obj, BlockL): + return Block(self.db_input, self.hw_sampling, obj.block_instance_name, obj.block_type, obj.block_subtype, + self.get_block_peak_work_rate_distribution(obj), self.get_block_work_over_energy_distribution(obj), + self.get_block_work_over_area_distribution(obj), self.get_block_one_over_area_distribution(obj), + obj.clock_freq, + obj.bus_width, + obj.loop_itr_cnt, + obj.loop_max_possible_itr_cnt, + obj.hop_latency, + obj.pipe_line_depth, + self.get_block_leakage_power(obj), + self.get_block_power_knobs(obj),) + elif len(argv) == 0 and isinstance(obj, TaskL): + return Task(obj.task_name, self.get_task_work(obj), self.get_task_iteration(obj), self.get_task_type(obj), self.get_task_throughput_info(obj)) + elif len(argv) == 3 and isinstance(obj, Task) and isinstance(argv[0], Block): + raise Exception("this is case is deprecated") + task = obj + pe_block = argv[0] + ic = argv[1] + mem = argv[2] + + task_to_blocks_map = TaskToBlocksMap(task, {}) + work_ratio_read = self.get_block_work_ratio_by_task(task, pe_block, "read") + # TODO: this needs to be fixed to react to different referencing and also different interconnects, and mem + task_to_blocks_map.block_workRatio_dict[pe_block] = 1 + task_to_blocks_map.block_workRatio_dict[ic] = work_ratio_read + task_to_blocks_map.block_workRatio_dict[mem] = work_ratio_read + return task_to_blocks_map + elif len(argv) == 1 and isinstance(obj, Task) and isinstance(argv[0], str): + return TaskToPEBlockSchedule(obj, self.get_task_starting_time(obj)) + else: + raise Exception("this casting is not acceptable" + str(type(obj))) + + # get the name of the budgeted metrics + def get_budgetted_metric_names_all_SOCs(self): + result = [] + for SOCL in self.SOCsL: + result.extend(SOCL.get_budgetted_metric_names()) + + return np.unique(result).tolist() + + def get_other_metric_names_all_SOCs(self): + result = [] + for SOCL in self.SOCsL: + result.extend(SOCL.get_other_metric_names()) + + return np.unique(result).tolist() + + def get_budgetted_metric_names(self, type): + for SOCL in self.SOCsL: + if SOCL.type == type: + return SOCL.get_other_mertic_names() + + # ------------------------------ + # Functionality: + # get the name of budgeted metrics + # Variables: + # type: SOC type_ to querry + # ------------------------------------------- + def get_budgetted_metric_names(self, type): + for SOCL in self.SOCsL: + if SOCL.type == type: + return SOCL.get_budgetted_metric_names() + + # ------------------------------ + # Functionality: + # get the budget value for a specific metric. + # Variables: + # type: SOC type_ + # metric: metric to query. + # ------------------------------------------- + def get_budget(self, metric, type): + for SOCL in self.SOCsL: + if SOCL.type == type: + if not metric in SOCL.get_budgetted_metric_names(): + print("this metric is not budgget") + exit(0) + return SOCL.get_budget(metric) + + # ideal or desired value is the value that we are targetting + def get_other_metrics_ideal_value(self, metric, type): + for SOCL in self.SOCsL: + if SOCL.type == type: + if not metric in SOCL.get_other_metrics_names(): + print("this metric is not included in the other metrics") + exit(0) + return SOCL.get_other_metrics_ideal_value(metric) + + def set_ideal_metric_value(self, metric, type, value): + for SOCL in self.SOCsL: + if SOCL.type == type: + if metric in SOCL.get_budgetted_metric_names(): + return SOCL.set_budget(metric, value) + elif metric in SOCL.get_other_metric_names(): + return SOCL.set_other_metrics_ideal_values(metric, value) + + # get the desired (basically budget) value for various metrics + def get_ideal_metric_value(self, metric, type): + for SOCL in self.SOCsL: + if SOCL.type == type: + if metric in SOCL.get_budgetted_metric_names(): + return SOCL.get_budget(metric) + elif metric in SOCL.get_other_metric_names(): + return SOCL.get_other_metrics_ideal_values(metric) + + # ------------------------------ + # Functionality: + # get all the tasks for the database + # ------------------------------------------- + def get_tasks(self): + return self.__tasks + + def get_task_by_name(self, name): + for task in self.get_tasks(): + if task.name == name: + return task + return None + + def get_blocks(self): + return self.__blocks + + def get_block_by_name(self, name): + for block in self.get_blocks(): + if block.get_generic_instance_name() == name: + return block + return None + + # ------------------------------ + # Functionality: + # get the list of blocks that a task can map to. + # Variables: + # task_name: name of the task to query. + # ------------------------------------------- + def get_task_s_mappable_blocksL_by_task_s_name(self, task_name): + return self.mappable_blocksL_to_tasks_s_name_dict[task_name] + + # ------------------------------ + # Functionality: + # get the task (task class) version of taskL, if one has already be generated. + # Variables: + # taskL: taskL object to compare against. + # tasks: list of already generated tasks. + # ------------------------------------------- + def get_tasks_from_taskL(self, taskL, tasks): + for task in tasks: + if taskL.task_name == task.name: + return task + return None + + # ------------------------------ + # Functionality: + # find all the blocks that input tasks can map to simultaneously. This means the all the tasks + # should be mappable to each one of the blocks in the output + # Variables: + # block_type: filter based on block type + # tasks: list of tasks to find mappable blocks for + # ------------------------------------------- + def find_all_mappable_blocksL_for_tasks(self, block_type, tasks=[]): + # this function works by having two buckets (shared_so_far and so_far_temp + # and swapping them + if block_type in ["ic"]: + return self.ic_block_list + elif block_type == "mem": + return self.mem_block_list + elif block_type == "pe": + shared_so_far_temp = self.get_task_s_mappable_blocksL_by_task_s_name(tasks[0].name) + found_common_block = False + for task in tasks[1:]: + blockL_list = self.get_task_s_mappable_blocksL_by_task_s_name(task.name) + shared_so_far = shared_so_far_temp # swap + shared_so_far_temp = [] + for blockL in blockL_list: + if blockL in shared_so_far: + shared_so_far_temp.append(blockL) + found_common_block = True + + if len(tasks) == 1 or found_common_block: + return shared_so_far_temp + else: + raise Exception("no common block found") + else: + raise ValueError("block_type:" + block_type + " not supported") + + # ------------------------------ + # Functionality: + # get all the blocksL + # Variables: + # block_type: filter based on block_type + # ------------------------------------------- + def get_blocksL_by_type(self, block_type): + return list(filter(lambda blockL: blockL.block_type == block_type, self.blocksL)) + + def get_all_BlocksL(self): + return self.blocksL + + # ------------------------------ + # Functionality: + # get the blocks work/energy (how much work done per energy consumed) + # Variables: + # blockL: blockL under query. + # ------------------------------------------- + def get_block_work_over_energy_distribution(self, blockL): + blockL = list(filter(lambda block: block.block_instance_name == blockL.block_instance_name, self.blocksL)) + assert(len(blockL) == 1) + return blockL[0].work_over_energy_distribution + + def get_block_one_over_area_distribution(self, blockL): + blockL = list(filter(lambda block: block.block_instance_name == blockL.block_instance_name, self.blocksL)) + assert(len(blockL) == 1) + return blockL[0].one_over_area_distribution + + # ------------------------------ + # Functionality: + # get the blocks leakage power + # Variables: + # blockL: blockL under query. + # ------------------------------------------- + def get_block_leakage_power(self, blockL): + blockL = list(filter(lambda block: block.block_instance_name == blockL.block_instance_name, self.blocksL)) + assert(len(blockL) == 1) + return blockL[0].leakage_power + + # ------------------------------ + # Functionality: + # get the blocks power knobs + # Variables: + # blockL: blockL under query. + # ------------------------------------------- + def get_block_power_knobs(self, blockL): + return 0 + blockL = list(filter(lambda block: block.block_instance_name == blockL.block_instance_name, self.blocksL)) + assert(len(blockL) == 1) + return blockL[0].power_knobs + + #def introduce_iteration(self, tasks): + + # ------------------------------ + # Functionality: + # parses taskL and generates Tasks objects (including their dependencies) + # ------------------------------------------- + def parse_and_gen_tasks(self): + tasks = [self.cast(taskL) for taskL in self.tasksL] + for task in tasks: + corresponding_taskL = self.get_taskL_from_task_name(task.name) + if config.eval_mode == "statistical": + task.add_task_work_distribution(corresponding_taskL.get_task_work_distribution()) + taskL_children = corresponding_taskL.get_children() + for taskL_child in taskL_children: + child_task = self.get_tasks_from_taskL(taskL_child, tasks) + taskL__ = self.get_taskL_from_task_name(task.name) + task.add_child(child_task, taskL__.get_self_to_child_work(taskL_child), corresponding_taskL.get_child_nature(taskL_child)) + task.set_burst_size(taskL__.get_burst_size()) + if config.eval_mode == "statistical": + task.add_task_to_child_work_distribution(child_task, taskL__.get_task_to_child_task_work_distribution(taskL_child)) + + return tasks + + # ------------------------------ + # Functionality: + # find all the compatible blocks for a specific task. This version is the fast version (using caching), however + # it's a bit less intuitive + # Variables: + # block_type: filter based on the type (pe, mem, ic) + # tasks: all the tasks that the blocks need to compatible for. + # ------------------------------------------- + def find_all_compatible_blocks_fast(self, block_type, tasks): + mappable_blocksL = self.find_all_mappable_blocksL_for_tasks(block_type, tasks) + assert(len(mappable_blocksL) > 0), "there should be at least one block for all the tasks to map to" + result = [] + for mappable_blockL in mappable_blocksL: + if mappable_blockL in self.cached_blockL_to_block.keys(): # to improve performance, we look into caches + result.append(self.cached_blockL_to_block[mappable_blockL]) + else: + casted = self.cast(mappable_blockL) + self.cached_blockL_to_block[mappable_blockL] = casted + self.cached_block_to_blockL[casted] = mappable_blockL + result.append(casted) + return result + + # ------------------------------ + # Functionality: + # find all the compatible blocks for a specific task + # Variables: + # block_type: filter based on the type (pe, mem, ic) + # tasks: all the tasks that the blocks need to compatible for. + # ------------------------------------------- + def find_all_compatible_blocks(self, block_type, tasks): + mappable_blocksL = self.find_all_mappable_blocksL_for_tasks(block_type, tasks) + assert(len(mappable_blocksL) > 0), "there should be at least one block for all the tasks to map to" + return [self.cast(mappable_blockL) for mappable_blockL in mappable_blocksL] + + # get the block name (without type in name) generate a block object with it, and return it + # Note that without type means that the name does not contains the type, and we add the type when we generate + # BlockL so we can make certain deduction about the block for Platform architect purposes + def gen_one_block_by_name_without_type(self, blk_name_without_type): + # generate the block + blockL = [el for el in self.blocksL if el.block_instance_name_without_type == blk_name_without_type][0] + block = self.cast(blockL) + # set the SOC type: + # TODO: for now we just set it to the most superior SOC. Later, get an input for this + ordered_SOCsL = sorted(self.SOCsL, key=lambda SOC: SOC.get_budget("latency")) + block.set_SOC(ordered_SOCsL[0].type, self.SOC_id) + return block + + # get the block name and generate a block object with it, and return it + def gen_one_block_by_name(self, blk_name): + blk_name_refined = '_'.join( + blk_name.split("_")[:-1]) # fix the name by getting rid of the instance part of the name + suffix = blk_name.split("_")[-1] + # This is becuase BlockL do not have an instance name + # generate the block + blockL = [el for el in self.blocksL if el.block_instance_name_without_type == blk_name_refined][0] + block = self.cast(blockL) + block_name = block.instance_name + "_" + suffix + block.set_instance_name(block_name) + # set the SOC type: + # TODO: for now we just set it to the most superior SOC. Later, get an input for this + ordered_SOCsL = sorted(self.SOCsL, key=lambda SOC: SOC.get_budget("latency")) + block.set_SOC(ordered_SOCsL[0].type, self.SOC_id) + return block + + + # ------------------------------ + # Functionality: + # get taskL from the task name + # Variables: + # task_name: name of the task + # ------------------------------------------- + def get_taskL_from_task_name(self, task_name): + for taskL in self.tasksL: + if taskL.task_name == task_name: + return taskL + raise Exception("taskL:" + task_name + " does not exist") + + # ------------------------------ + # Functionality: + # get pe for a taskL + # Variables: + # taskL: taskL + # ------------------------------------------- + def get_corresponding_pe_blocksL(self, taskL): + task_to_pe_blocks_mapL_list = list(filter(lambda pe_map: pe_map.task_name == taskL.task_name, self.pe_mapsL)) + assert(len(task_to_pe_blocks_mapL_list) >= 1), taskL.task_name + pe_blocksL_list = [] + for el in task_to_pe_blocks_mapL_list: + blockL_list = list(filter(lambda blockL: blockL.block_instance_name == el.pe_block_instance_name, self.blocksL)) + assert(len(blockL_list) == 1) + pe_blocksL_list.append(blockL_list[0]) + return pe_blocksL_list + + def get_task_starting_time(self, task): + return list(filter(lambda pe_sch: pe_sch.task_name == task.name, self.pe_schedulesL))[0].starting_time + + def get_block_peak_work_rate_distribution(self, blockL): + blockL = list(filter(lambda block: block.block_instance_name == blockL.block_instance_name, self.blocksL)) + assert(len(blockL) == 1) + return blockL[0].peak_work_rate_distribution + + def get_block_work_over_area_distribution(self, blockL): + blockL = list(filter(lambda block: block.block_instance_name == blockL.block_instance_name, self.blocksL)) + assert(len(blockL) == 1) + return blockL[0].work_over_area_distribution + + def get_blocks_immediate_superior(self, block): + return None + + # ---------------------------------- + # samplers + # ---------------------------------- + # sample the database from a set of compatible blocks + def sample_blocks(self, all_compatible_blocks, mode="random"): + if (config.DEBUG_FIX): + random.seed(0) + else: + time.sleep(.00001) + random.seed(datetime.now().microsecond) + + if not (all_compatible_blocks): + raise Exception("din't find any compatible blocks") + + if (mode == "random"): + block = random.choice(all_compatible_blocks) + elif (mode == "immediate_superior"): + block = sorted(all_compatible_blocks)[0] + else: + print("mode: " + mode + " for block sampling is not defined") + exit(0) + return block + + def sample_all_blocks(self, tasks, block, mode="random"): + all_compatible_blocks = self.find_all_compatible_blocks(block.type, tasks) + new_block = self.sample_blocks(all_compatible_blocks, mode) + return new_block + + def sample_similar_block(self, block): + for blocksL in self.get_all_BlocksL(): + if blocksL.block_instance_name == block.get_generic_instance_name(): + return self.cast(blocksL) + + print("there should be at least one block that is similar ") + exit(0) + + def sample_DMA(self): + return list(filter(lambda blockL: blockL.block_type == block_type, self.blocksL)) + + def get_tasksL(self): + return self.tasksL + + # ------------------------------ + # Functionality: + # get all the PE's that the task can run on + # ------------------------------ + def get_corresponding_pe_blocksL(self, taskL): + task_to_pe_blocks_mapL_list = list(filter(lambda pe_map: pe_map.task_name == taskL.task_name, self.pe_mapsL)) + assert(len(task_to_pe_blocks_mapL_list) >= 1), taskL.task_name + pe_blocksL_list = [] + for el in task_to_pe_blocks_mapL_list: + blockL_list = list(filter(lambda blockL: blockL.block_instance_name == el.pe_block_instance_name, self.blocksL)) + if not(len(blockL_list) ==1): + print("weird") + assert(len(blockL_list) == 1) + pe_blocksL_list.append(blockL_list[0]) + return pe_blocksL_list + + def get_task_starting_time(self, task): + return list(filter(lambda pe_sch: pe_sch.task_name == task.name, self.pe_schedulesL))[0].starting_time + + # ------------------------------ + # Functionality: + # get the block work ratio given the task and it's direction (read/write) + # Variables: + # task: task under query + # pe_block: pe block to get work ratio for + # dir: direction for work ratio (read/write) + # ------------------------------------------- + def get_block_work_ratio_by_task_dir(self, task, pe_block, dir): + work_ratio = [pe_map_.get_work_ratio_new() for pe_map_ in self.pe_mapsL if pe_map_.task_name == task.name and + pe_map_.pe_block_instance_name == pe_block.get_generic_instance_name()] + if dir == 'read': relationship = "parent" + elif dir == "write": relationship = "child" + elif dir == "loop_back": relationship = "self" + if "DMA" in task.name: + if relationship == "parent": return {task.get_parents()[0].name:1} + elif relationship == "child": return {task.get_children()[0].name:1} + elif relationship == "self": return {task.name:1} + else: + return work_ratio[0][relationship] + + # task work is reported in number of instructions + def get_task_work(self, taskL): + taskL = list(filter(lambda taskL_: taskL_.task_name == taskL.task_name, self.tasksL)) + assert(len(taskL) == 1) + return taskL[0].work + + def get_task_iteration(self, taskL): + taskL = list(filter(lambda taskL_: taskL_.task_name == taskL.task_name, self.tasksL)) + assert(len(taskL) == 1) + return taskL[0].iteration + + def get_task_throughput_info(self, taskL): + taskL = list(filter(lambda taskL_: taskL_.task_name == taskL.task_name, self.tasksL)) + assert(len(taskL) == 1) + return taskL[0].get_throughput_info() + + def get_task_type(self, taskL): + taskL = list(filter(lambda taskL_: taskL_.task_name == taskL.task_name, self.tasksL)) + assert(len(taskL) == 1) + return taskL[0].get_type() + + + + + + + """ + # find the block that is better than the current block that the tasks are running on + # superior means more power, area efficient or better latency + def get_blocks_immediate_superior_for_tasks(self, block, tasks): + mappable_blocksL = self.find_mappable_blocks_among_tasks(block.block_type, tasks) + superior_blockL = self.get_blocks_immediate_superior(block.block_subtype) + while(superior_blockL): + if superior_blockL in mappable_blocksL: + return self.cast(mappable_blocksL) + return block + """ + + # ------------------- + # generators (finder and convert) + # ------------------ + def populate_mappable_blocksL_to_tasks_s_name_dict(self): + for taskL in self.tasksL: + mappable_blocksL_to_task_list = self.get_corresponding_pe_blocksL(taskL) + self.mappable_blocksL_to_tasks_s_name_dict[taskL.task_name] = mappable_blocksL_to_task_list + + # get all the tasks to blocks mappings + def get_mappable_blocksL_to_tasks(self): + return self.mappable_blocksL_to_tasks_s_name_dict + + def sample_DMA_blocks(self): + DMA_blocks = [blockL for blockL in self.get_blocksL_by_type("pe") if "DMA" in blockL.block_instance_name] + random_blockL = random.choice(DMA_blocks) + return self.cast(random_blockL) + + # block type chosend from ["pe", "mem", "ic"] + def sample_all_blocks_by_type(self, mode="random", tasks=[], block_type="pe"): + all_compatible_blocks = self.find_all_compatible_blocks(block_type, tasks) + return self.sample_blocks(all_compatible_blocks, mode) + + def sample_most_inferior_blocks_by_type(self, mode="random", tasks=[], block_type="pe"): + all_compatible_blocks = self.find_all_compatible_blocks(block_type, tasks) + return sorted(all_compatible_blocks)[0] + + def sample_most_inferior_blocks_before_unrolling_by_type(self, mode="random", tasks=[], block_type="pe", block=""): + if not block.subtype == "ip": + return self.sample_similar_block(block) + else: + all_compatible_blocks = self.find_all_compatible_blocks(block_type, tasks) + sorted_blocks = sorted(all_compatible_blocks) + for block_ in sorted_blocks: + if block_.subtype == "ip" and block_.get_block_freq() == block.get_block_freq(): + return block_ + return sorted(all_compatible_blocks)[0] + + + # superior = better performant wise + # Variables: + # cur_blck: current block to find a superior for + # all_comtible_blcks: list of blocks to pick from + def find_superior_blocks(self, cur_blck, all_comptble_blcks:List[Block]): + srtd_comptble_blcks = sorted(all_comptble_blcks) + cur_blck_idx = 0 + for blck in srtd_comptble_blcks: + if cur_blck.get_generic_instance_name() == blck.get_generic_instance_name(): + break + cur_blck_idx +=1 + if (cur_blck_idx == len(srtd_comptble_blcks)-1): + # TODO: not good coding, should really fold this in the previous hierarchy + return [srtd_comptble_blcks[-1]] + else: + return srtd_comptble_blcks[cur_blck_idx+1:] + + # ------------------------------ + # Functionality: + # find the better SOC for the block under query + # Variables: + # metric_name: the metric name to pick a better SOC based off of. + # ------------------------------ + def find_superior_SOC(self, block, metric_name): + cur_SOC_type = block.SOC_type + ordered_SOCsL = sorted(self.SOCsL, key=lambda SOC: SOC.get_budget(metric_name)) + cur_SOC_idx = ordered_SOCsL.index(cur_SOC_type) + if cur_SOC_idx == len(ordered_SOCsL): return ordered_SOCsL[cur_SOC_idx] + else: return ordered_SOCsL[cur_SOC_idx +1] + + # ------------------------------ + # Functionality: + # find a superior (based on the metric_name) SOC. + # Variables: + # block: block to find a better SOC for. + # metric_name: name of the metric to choose a better SOC based off of. + # ------------------------------ + def up_sample_SOC(self, block, metric_name): + superior_SOC_type = self.find_superior_SOC(block, metric_name) + block.set_SOC(superior_SOC_type, self.SOC_id) + return block + + def copy_SOC(self, block_copy_to, block_copy_from): + block_copy_to.SOC_type = block_copy_from.SOC_type + block_copy_to.SOC_id = block_copy_from.SOC_id + return block_copy_to + + # ------------------------------ + # Functionality: + # get the worse SOC (based on the metric name) + # ------------------------------ + def sample_most_inferior_SOC(self, block, metric_name): + ordered_SOCsL = sorted(self.SOCsL, key=lambda SOCL: SOCL.get_budget(metric_name)) + block.set_SOC(ordered_SOCsL[0].type, self.SOC_id) + return block + + # ------------------------------ + # Functionality: + # find a more superior block compatible with all the input tasks + # ------------------------------ + def up_sample_blocks(self, block, mode="random", tasks=[]): + all_compatible_blocks = self.find_all_compatible_blocks(block.type, tasks) + superior_blocks = self.find_superior_blocks(block, all_compatible_blocks) + return self.sample_blocks(superior_blocks, mode) + + # ------------------------------ + # Functionality: + # check if a block is superior comparing to another block + # ------------------------------ + def check_superiority(self, block_1, block_2, tasks=[]): + all_compatible_blocks = self.find_all_compatible_blocks(block_1.type, tasks) + superior_blocks_names = [block.get_generic_instance_name() for block in self.find_superior_blocks(block_1, all_compatible_blocks)] + if block_2.get_generic_instance_name() in superior_blocks_names: + return True + else: + return False + + + # ------------------------------ + # Functionality: + # find a block that is superior (from a specific metric perspective) comparing to the current block. This version + # is the fast version, where we use caching, however, it's a bit less intuitive + # Variables: + # blck_to_imprv: block to improve upon + # metric: metric to consider while choosing a block superiority. Chosen from power, area, performance. + # ------------------------------ + def up_sample_down_sample_block_fast(self, blck_to_imprv, metric, sampling_dir, tasks=[]): + all_compatible_blocks = self.find_all_compatible_blocks_fast(blck_to_imprv.type, tasks) + if metric == "latency": + metric_to_sort = 'peak_work_rate' + elif metric == "power": + #metric_to_sort = 'work_over_energy' + metric_to_sort = 'one_over_power' + elif metric == "area": + metric_to_sort = 'one_over_area' + else: + print("metric: " + metric + " is not defined") + + if sampling_dir > 0: + reversed = True + else: + reversed = False + srtd_comptble_blcks = sorted(all_compatible_blocks, key=attrgetter(metric_to_sort), reverse=reversed) # + idx = 0 + + # find the block + results = [] + for blck in srtd_comptble_blcks: + #if (getattr(blck, metric_to_sort) == getattr(blck_to_imprv, metric_to_sort)): + if sampling_dir < 0: # need to reduce + if (getattr(blck, metric_to_sort) > getattr(blck_to_imprv, metric_to_sort)): + results.append(blck) + elif sampling_dir > 0: # need to reduce + if (getattr(blck, metric_to_sort) < getattr(blck_to_imprv, metric_to_sort)): + results.append(blck) + + if len(results) == 0: + for el in srtd_comptble_blcks: + if el.get_generic_instance_name() == blck_to_imprv.get_generic_instance_name(): + results = [el] + break + + return results + + # ------------------------------ + # Functionality: + # find a block that is superior (from a specific metric perspective) comparing to the current block + # Variables: + # blck_to_imprv: block to improve upon + # metric: metric to consider while choosing a block superiority. Chosen from power, area, performance. + # ------------------------------ + def up_sample_down_sample_block(self, blck_to_imprv, metric, sampling_dir, tasks=[]): + all_compatible_blocks = self.find_all_compatible_blocks(blck_to_imprv.type, tasks) + if metric == "latency": + metric_to_sort = 'peak_work_rate' + elif metric == "power": + #metric_to_sort = 'work_over_energy' + metric_to_sort = 'one_over_power' + elif metric == "area": + metric_to_sort = 'one_over_area' + else: + print("metric: " + metric + " is not defined") + + if sampling_dir > 0: + reversed = True + else: + reversed = False + srtd_comptble_blcks = sorted(all_compatible_blocks, key=attrgetter(metric_to_sort), reverse=reversed) # + idx = 0 + + # find the block + results = [] + for blck in srtd_comptble_blcks: + #if (getattr(blck, metric_to_sort) == getattr(blck_to_imprv, metric_to_sort)): + if sampling_dir < 0: # need to reduce + if (getattr(blck, metric_to_sort) > getattr(blck_to_imprv, metric_to_sort)): + results.append(blck) + elif sampling_dir > 0: # need to reduce + if (getattr(blck, metric_to_sort) < getattr(blck_to_imprv, metric_to_sort)): + results.append(blck) + + if len(results) == 0: + results = [blck_to_imprv] + return results + + def equal_sample_up_sample_down_sample_block(self, blck_to_imprv, metric, sampling_dir, tasks=[]): + all_compatible_blocks = self.find_all_compatible_blocks(blck_to_imprv.type, tasks) + if metric == "latency": + metric_to_sort = 'peak_work_rate' + elif metric == "power": + # metric_to_sort = 'work_over_energy' + metric_to_sort = 'one_over_power' + elif metric == "area": + metric_to_sort = 'one_over_area' + else: + print("metric: " + metric + " is not defined") + + if sampling_dir > 0: + reversed = True + else: + reversed = False + srtd_comptble_blcks = sorted(all_compatible_blocks, key=attrgetter(metric_to_sort), reverse=reversed) # + idx = 0 + + # find the block + results = [] + for blck in srtd_comptble_blcks: + if sampling_dir < 0: # need to reduce + if (getattr(blck, metric_to_sort) >= getattr(blck_to_imprv, metric_to_sort)): + results.append(blck) + elif sampling_dir > 0: # need to reduce + if (getattr(blck, metric_to_sort) <= getattr(blck_to_imprv, metric_to_sort)): + results.append(blck) + + return results + + def up_sample_down_sample_block_multi_metric_fast(self, blck_to_imprv, sorted_metric_dir, tasks=[]): + all_compatible_blocks = self.find_all_compatible_blocks_fast(blck_to_imprv.type, tasks) + + metrics_to_sort_reversed = [] + for metric,dir in sorted_metric_dir.items(): + if metric == "latency": + metric_to_sort = 'peak_work_rate' + elif metric == "power": + #metric_to_sort = 'work_over_energy' + metric_to_sort = 'one_over_power' + elif metric == "area": + metric_to_sort = 'one_over_area' + else: + print("metric: " + metric + " is not defined") + metrics_to_sort_reversed.append((metric_to_sort, -1*dir)) + + most_important_metric = list(sorted_metric_dir.keys())[-1] + sampling_dir = sorted_metric_dir[most_important_metric] + + #srtd_comptble_blcks = sorted(all_compatible_blocks, key=attrgetter(metric_to_sort), reverse=reversed) # + srtd_comptble_blcks = sorted(all_compatible_blocks, key=lambda blk: (metrics_to_sort_reversed[2][1]*getattr(blk, metrics_to_sort_reversed[2][0]), + metrics_to_sort_reversed[1][1]*getattr(blk, metrics_to_sort_reversed[1][0]), + metrics_to_sort_reversed[0][1]*getattr(blk, metrics_to_sort_reversed[0][0]))) + idx = 0 + + # find the block + results = [] + """ + # first make sure it can meet across all metrics + for blck in srtd_comptble_blcks: + if metrics_to_sort_reversed[2][1]*getattr(blck, metrics_to_sort_reversed[2][0]) > \ + metrics_to_sort_reversed[2][1]*getattr(blck_to_imprv, metrics_to_sort_reversed[2][0]): + if metrics_to_sort_reversed[1][1] * getattr(blck, metrics_to_sort_reversed[1][0]) >= \ + metrics_to_sort_reversed[1][1] * getattr(blck_to_imprv,metrics_to_sort_reversed[1][0]): + if metrics_to_sort_reversed[0][1]*getattr(blck, metrics_to_sort_reversed[0][0]) >= \ + metrics_to_sort_reversed[0][1] * getattr(blck_to_imprv, metrics_to_sort_reversed[0][0]): + results.append(blck) + + # meet across two metrics + if len(results) == 0: + for blck in srtd_comptble_blcks: + if metrics_to_sort_reversed[2][1] * getattr(blck, metrics_to_sort_reversed[2][0]) > \ + metrics_to_sort_reversed[2][1] * getattr(blck_to_imprv, metrics_to_sort_reversed[2][0]): + if metrics_to_sort_reversed[1][1] * getattr(blck, metrics_to_sort_reversed[1][0]) >= \ + metrics_to_sort_reversed[1][1] * getattr(blck_to_imprv, metrics_to_sort_reversed[1][0]): + results.append(blck) + """ + # meet across at least one meteric + if len(results) == 0: + for blck in srtd_comptble_blcks: + if metrics_to_sort_reversed[2][1] * getattr(blck, metrics_to_sort_reversed[2][0]) > \ + metrics_to_sort_reversed[2][1] * getattr(blck_to_imprv, metrics_to_sort_reversed[2][0]): + results.append(blck) + + # we need pareto front calculation here, but we are doing something simple at the moment instead + if len(results) > 1: + first_el = results[0] + second_el = results[1] + if metrics_to_sort_reversed[1][1] * getattr(first_el, metrics_to_sort_reversed[1][0]) >= \ + metrics_to_sort_reversed[1][1] * getattr(second_el, metrics_to_sort_reversed[1][0]): + results = [results[0]] + else: + results = [results[1]] + +# if len(results) > 0: +# self.check_weird_nests(results, blck_to_imprv, metrics_to_sort_reversed, srtd_comptble_blcks) + if len(results) == 0: + for el in srtd_comptble_blcks: + if el.get_generic_instance_name() == blck_to_imprv.get_generic_instance_name(): + results = [el] + break + #if len(results) == 0: + # results = [srtd_comptble_blcks[-1]] + + return results + + def equal_sample_up_sample_down_sample_block_fast(self, blck_to_imprv, metric, sampling_dir, tasks=[]): + all_compatible_blocks = self.find_all_compatible_blocks_fast(blck_to_imprv.type, tasks) + if metric == "latency": + metric_to_sort = 'peak_work_rate' + elif metric == "power": + #metric_to_sort = 'work_over_energy' + metric_to_sort = 'one_over_power' + elif metric == "area": + metric_to_sort = 'one_over_area' + else: + print("metric: " + metric + " is not defined") + + if sampling_dir > 0: + reversed = True + else: + reversed = False + + srtd_comptble_blcks = sorted(all_compatible_blocks, key=attrgetter(metric_to_sort), reverse=reversed) # + #srtd_comptble_blcks = sorted(all_compatible_blocks, key=lambda blk: (getattr(blk, metrics_to_sort[0]), getattr(blk, metrics_to_sort[1]), getattr(blk, metrics_to_sort[2])), reverse=reversed) # + idx = 0 + + # find the block + results = [] + for blck in srtd_comptble_blcks: + #if (getattr(blck, metric_to_sort) == getattr(blck_to_imprv, metric_to_sort)): + if sampling_dir < 0: # need to reduce + if (getattr(blck, metric_to_sort) >= getattr(blck_to_imprv, metric_to_sort)): + results.append(blck) + elif sampling_dir > 0: # need to reduce + if (getattr(blck, metric_to_sort) <= getattr(blck_to_imprv, metric_to_sort)): + results.append(blck) + + return results diff --git a/Project_FARSI/specs/database_data/generate/input.py b/Project_FARSI/specs/database_data/generate/input.py new file mode 100644 index 00000000..61f4238d --- /dev/null +++ b/Project_FARSI/specs/database_data/generate/input.py @@ -0,0 +1,55 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +from specs.LW_cl import * +gen_config = {} + +# --------------------------- +# global variables +# --------------------------- +# HW database +sim_time_per_design = .25 / (3600) +gen_config["DS_output_mode"] = "DB_DS_time" # ["DS_size", "DS_time"] +gen_config["DB_MAX_PE_CNT_range"] = range(1, 4) +gen_config["DB_MAX_TASK_CNT_range"] = range(5, 10) +gen_config["DB_MAX_PAR_TASK_CNT_range"] = range(5, 7) +gen_config["DB_DS_type"] = "exhaustive_naive" +gen_config["DB_DS_type"] = "exhaustive_reduction_DB_semantics" + +# SW database +gen_config["DB_MAX_TASK_CNT"] = 10 # does include souurce and siink +gen_config["DB_MAX_PAR_TASK_CNT"] = 8 # souurce and siink would be automatically serialized + +gen_config["DB_MAX_PE_CNT"] = 4 +gen_config["DB_MAX_BUS_CNT"] = 4 +gen_config["DB_MAX_MEM_CNT"] = 4 + +gen_config["DB_PE_list"] = ["A53_pe"] +gen_config["DB_BUS_list"] = ["LMEM_ic_0_ic"] +gen_config["DB_MEM_list"] = ["LMEM_0_mem"] + +gen_config["DB_MIN_PE_CNT"] = 4 +gen_config["DB_MIN_MEM_CNT"] = 4 +gen_config["DB_MIN_BUS_CNT"] = 4 + +gen_config["DB_MAX_SYSTEM_to_investigate"] = 1 +assert (gen_config["DB_MIN_PE_CNT"] <= gen_config["DB_MAX_PE_CNT"]) +assert (gen_config["DB_MIN_BUS_CNT"] <= gen_config["DB_MAX_BUS_CNT"]) +assert (gen_config["DB_MIN_MEM_CNT"] <= gen_config["DB_MAX_MEM_CNT"]) +assert (gen_config["DB_MAX_PE_CNT"] >= gen_config["DB_MAX_MEM_CNT"]) +assert (gen_config["DB_MAX_PE_CNT"] >= gen_config["DB_MAX_BUS_CNT"]) +assert(gen_config["DB_MAX_PE_CNT"] <= gen_config["DB_MAX_TASK_CNT"] -1) # have to make sure that is not isolated +assert(gen_config["DB_MAX_MEM_CNT"] <= gen_config["DB_MAX_TASK_CNT"] -1) # have to make sure that is not isolated +assert(gen_config["DB_MAX_BUS_CNT"] <= gen_config["DB_MAX_TASK_CNT"] -1) # have to make sure that is not isolated + + + +budgets_dict = defaultdict(dict) +# some numbers for now. This doesn't matter at the momoent. +budgets_dict["glass"] = {} +budgets_dict["glass"]["latency"] = {"synthetic": .030} +budgets_dict["glass"]["power"] = 20*10**-3 +budgets_dict["glass"]["area"] = 15*10**-6 +other_values_dict = defaultdict(dict) +other_values_dict["glass"]["cost"] = 10**-9 # something really small \ No newline at end of file diff --git a/Project_FARSI/specs/database_data/hardcoded_test.txt b/Project_FARSI/specs/database_data/hardcoded_test.txt new file mode 100644 index 00000000..9daeafb9 --- /dev/null +++ b/Project_FARSI/specs/database_data/hardcoded_test.txt @@ -0,0 +1 @@ +test diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Budget.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Budget.csv new file mode 100755 index 00000000..f9a0895f --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Budget.csv @@ -0,0 +1,3 @@ +Workload,latency,power,area,cost +participants,0.021,,, +all,,0.01,0.000003,0.000000001 diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Hardware Graph.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Hardware Graph.csv new file mode 100755 index 00000000..c46ad46d --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Hardware Graph.csv @@ -0,0 +1,5 @@ +Block Name,A53_0,LMEM_ic_2_1_sr1,LMEM_2_1_ta10,LMEM_ic_2_1_sr2 +A53_0,,1,, +LMEM_ic_2_1_sr1,1,,,1 +LMEM_2_1_ta10,,,,1 +LMEM_ic_2_1_sr2,,1,1, \ No newline at end of file diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Task Data Movement.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Task Data Movement.csv new file mode 100755 index 00000000..b61c0190 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Task Data Movement.csv @@ -0,0 +1,4 @@ +Task Name,source,participant0,siink +souurce,,6400, +participant0,,,6400 +siink,,, diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Task Itr Count.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Task Itr Count.csv new file mode 100755 index 00000000..cbfb5221 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Task Itr Count.csv @@ -0,0 +1,4 @@ +Task Name,number of iterations +souurce,1 +participant0,1000000000 +siink,1 diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Task PE Area.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Task PE Area.csv new file mode 100755 index 00000000..f0baba35 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Task PE Area.csv @@ -0,0 +1,4 @@ +Task Name,A53 +souurce,0 +participant0,10 +siink,0 diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Task PE Energy.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Task PE Energy.csv new file mode 100755 index 00000000..81e4c06d --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Task PE Energy.csv @@ -0,0 +1,4 @@ +Task Name,A53 +souurce,0 +participant0, +siink,0 diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Task PE Performance.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Task PE Performance.csv new file mode 100755 index 00000000..8c11a285 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Task PE Performance.csv @@ -0,0 +1,4 @@ +Task Name,A53 +souurce,0 +participant0,50 +siink,0 diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Task to Hardware Mapping.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Task to Hardware Mapping.csv new file mode 100755 index 00000000..61e14d68 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_2r_database - Task to Hardware Mapping.csv @@ -0,0 +1,4 @@ +A53_0,LMEM_ic_2_1_sr1,LMEM_2_1_ta10,LMEM_ic_2_2_sr2 +souurce,souurce -> participant0,souurce -> participant0,souurce -> participant0 +siink,participant0 -> siink,participant0 -> siink,participant0 -> siink +participant0,,, diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Budget.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Budget.csv new file mode 100755 index 00000000..f9a0895f --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Budget.csv @@ -0,0 +1,3 @@ +Workload,latency,power,area,cost +participants,0.021,,, +all,,0.01,0.000003,0.000000001 diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Hardware Graph.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Hardware Graph.csv new file mode 100755 index 00000000..50dad6e6 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Hardware Graph.csv @@ -0,0 +1,4 @@ +Block Name,A53_0,LMEM_ic_2_1_sr1,LMEM_2_1_ta10 +A53_0,,1, +LMEM_ic_2_1_sr1,1,,1 +LMEM_2_1_ta10,,1, diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Task Data Movement.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Task Data Movement.csv new file mode 100755 index 00000000..b61c0190 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Task Data Movement.csv @@ -0,0 +1,4 @@ +Task Name,source,participant0,siink +souurce,,6400, +participant0,,,6400 +siink,,, diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Task Itr Count.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Task Itr Count.csv new file mode 100755 index 00000000..cbfb5221 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Task Itr Count.csv @@ -0,0 +1,4 @@ +Task Name,number of iterations +souurce,1 +participant0,1000000000 +siink,1 diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Task PE Area.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Task PE Area.csv new file mode 100755 index 00000000..f0baba35 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Task PE Area.csv @@ -0,0 +1,4 @@ +Task Name,A53 +souurce,0 +participant0,10 +siink,0 diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Task PE Energy.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Task PE Energy.csv new file mode 100755 index 00000000..81e4c06d --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Task PE Energy.csv @@ -0,0 +1,4 @@ +Task Name,A53 +souurce,0 +participant0, +siink,0 diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Task PE Performance.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Task PE Performance.csv new file mode 100755 index 00000000..daedcdb4 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Task PE Performance.csv @@ -0,0 +1,4 @@ +Task Name,A53 +souurce,0 +participant0,1046 +siink,0 diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Task To Hardware Mapping.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Task To Hardware Mapping.csv new file mode 100755 index 00000000..870a2725 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_1p_database - Task To Hardware Mapping.csv @@ -0,0 +1,4 @@ +A53_0,LMEM_ic_2_1_sr1,LMEM_2_1_ta10 +souurce,souurce -> participant0,souurce -> participant0 +siink,participant0 -> siink,participant0 -> siink +participant0,, diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Budget.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Budget.csv new file mode 100755 index 00000000..f9a0895f --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Budget.csv @@ -0,0 +1,3 @@ +Workload,latency,power,area,cost +participants,0.021,,, +all,,0.01,0.000003,0.000000001 diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Hardware Graph.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Hardware Graph.csv new file mode 100755 index 00000000..bd18a08a --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Hardware Graph.csv @@ -0,0 +1,11 @@ +Block Name,A53_0,A53_1,A53_2,A53_3,A53_4,A53_5,A53_6,A53_7,LMEM_ic_0_1_sr1,LMEM_0_1_ta10 +A53_0,,,,,,,,,1, +A53_1,,,,,,,,,1, +A53_2,,,,,,,,,1, +A53_3,,,,,,,,,1, +A53_4,,,,,,,,,1, +A53_5,,,,,,,,,1, +A53_6,,,,,,,,,1, +A53_7,,,,,,,,,1, +LMEM_ic_0_1_sr1,1,1,1,1,1,1,1,1,,1 +LMEM_0_1_ta10,,,,,,,,,1, diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Task Data Movement.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Task Data Movement.csv new file mode 100755 index 00000000..25c1bafb --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Task Data Movement.csv @@ -0,0 +1,11 @@ +Task Name,source,participant0,participant1,participant2,participant3,participant4,participant5,participant6,participant7,siink +souurce,,67000,67000,67000,67000,67000,67000,67000,67000, +participant0,,,,,,,,,,67000 +participant1,,,,,,,,,,67000 +participant2,,,,,,,,,,67000 +participant3,,,,,,,,,,67000 +participant4,,,,,,,,,,67000 +participant5,,,,,,,,,,67000 +participant6,,,,,,,,,,67000 +participant7,,,,,,,,,,67000 +siink,,,,,,,,,, diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Task Itr Count.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Task Itr Count.csv new file mode 100755 index 00000000..4f297db3 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Task Itr Count.csv @@ -0,0 +1,11 @@ +Task Name,number of iterations +souurce,1 +participant0,1000000 +participant1,1000000 +participant2,1000000 +participant3,1000000 +participant4,1000000 +participant5,1000000 +participant6,1000000 +participant7,1000000 +siink,1 diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Task PE Area.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Task PE Area.csv new file mode 100755 index 00000000..c97eca15 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Task PE Area.csv @@ -0,0 +1,11 @@ +Task Name,A53 +souurce,0 +participant0,10 +participant1,10 +participant2,10 +participant3,10 +participant4,10 +participant5,10 +participant6,10 +participant7,10 +siink,0 diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Task PE Energy.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Task PE Energy.csv new file mode 100755 index 00000000..c05a3863 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Task PE Energy.csv @@ -0,0 +1,11 @@ +Task Name,A53 +souurce,0 +participant0, +participant1, +participant2, +participant3, +participant4, +participant5, +participant6, +participant7, +siink,0 diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Task PE Performance.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Task PE Performance.csv new file mode 100755 index 00000000..5ca7d6b1 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Task PE Performance.csv @@ -0,0 +1,11 @@ +Task Name,A53 +souurce,0 +participant0,1046025 +participant1,1046025 +participant2,1046025 +participant3,1046025 +participant4,1046025 +participant5,1046025 +participant6,1046025 +participant7,1046025 +siink,0 diff --git a/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Task To Hardware Mapping.csv b/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Task To Hardware Mapping.csv new file mode 100755 index 00000000..393fd3d5 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/SOC_example_8p_database - Task To Hardware Mapping.csv @@ -0,0 +1,17 @@ +A53_0,A53_1,A53_2,A53_3,A53_4,A53_5,A53_6,A53_7,LMEM_ic_0_1_sr1,LMEM_0_1_ta10 +souurce,participant1,participant2,participant3,participant4,participant5,participant6,participant7,souurce -> participant0,souurce -> participant0 +siink,,,,,,,,souurce -> participant1,souurce -> participant1 +participant0,,,,,,,,souurce -> participant2,souurce -> participant2 +,,,,,,,,souurce -> participant3,souurce -> participant3 +,,,,,,,,souurce -> participant4,souurce -> participant4 +,,,,,,,,souurce -> participant5,souurce -> participant5 +,,,,,,,,souurce -> participant6,souurce -> participant6 +,,,,,,,,souurce -> participant7,souurce -> participant7 +,,,,,,,,participant0 -> siink,participant0 -> siink +,,,,,,,,participant1 -> siink,participant1 -> siink +,,,,,,,,participant2 -> siink,participant2 -> siink +,,,,,,,,participant3 -> siink,participant3 -> siink +,,,,,,,,participant4 -> siink,participant4 -> siink +,,,,,,,,participant5 -> siink,participant5 -> siink +,,,,,,,,participant6 -> siink,participant6 -> siink +,,,,,,,,participant7 -> siink,participant7 -> siink diff --git a/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Budget.csv b/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Budget.csv new file mode 100644 index 00000000..47459114 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Budget.csv @@ -0,0 +1,3 @@ +Workload,latency,power,area,cost +audio_decoder,0.021,,, +all,,0.01,0.000003,0.000000001 \ No newline at end of file diff --git a/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Hardware Graph.csv b/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Hardware Graph.csv new file mode 100644 index 00000000..883ea0ee --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Hardware Graph.csv @@ -0,0 +1,4 @@ +Task Name,A53_0,LMEM_ic_0_0,LMEM_0_0 +A53_0,,1, +LMEM_ic_0_0,1,,1 +LMEM_0_0,,1, \ No newline at end of file diff --git a/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Task Data Movement.csv b/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Task Data Movement.csv new file mode 100644 index 00000000..211149af --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Task Data Movement.csv @@ -0,0 +1,19 @@ +Task Name,souurce,zoomProcess_fxp_cloned,psychoFilter_fxp_cloned,rotateOrder3_fxp_cloned,FIR_right_fxp_cloned,FIR_left_fxp_cloned,IFFT_right_fxp_cloned,IFFT_left_fxp_cloned,rotateOrder2_fxp_cloned,rotateOrder1_fxp_cloned,overlap_right_fxp_cloned,overlap_left_fxp_cloned,setAndFFT_right_fxp_cloned,setAndFFT_left_fxp_cloned,rotatorSet_fxp_cloned,zoomSet_fxp_cloned,dummy_last,siink +souurce,,,193576,,,,,,,,,,,,55856,248,, +zoomProcess_fxp_cloned,,248,,,,,,,,,,,131136,131136,,,, +psychoFilter_fxp_cloned,,131136,,84528,,,,,76336,68144,,,,,,,, +rotateOrder3_fxp_cloned,,28672,,,,,,,,,,,,,,,, +FIR_right_fxp_cloned,,,,,,,329352,,,,,,,,,,, +FIR_left_fxp_cloned,,,,,,,,329352,,,,,,,,,, +IFFT_right_fxp_cloned,,,,,,,,,,,329352,,,,,,, +IFFT_left_fxp_cloned,,,,,,,,,,,,329352,,,,,, +rotateOrder2_fxp_cloned,,20480,,,,,,,,,,,,,,,, +rotateOrder1_fxp_cloned,,12288,,,,,,,,,,,,,,,, +overlap_right_fxp_cloned,,,,,,,,,,,,,,,,,4096, +overlap_left_fxp_cloned,,,,,,,,,,,,,,,,,4096, +setAndFFT_right_fxp_cloned,,,,,329352,,,,,,,,,,,,, +setAndFFT_left_fxp_cloned,,,,,,329352,,,,,,,,,,,, +rotatorSet_fxp_cloned,,,55856,,,,,,,,,,,,,,, +zoomSet_fxp_cloned,,248,,,,,,,,,,,,,,,, +dummy_last,,,,,,,,,,,,,,,,,,1 +siink,,,,,,,,,,,,,,,,,, \ No newline at end of file diff --git a/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Task Itr Count.csv b/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Task Itr Count.csv new file mode 100644 index 00000000..e3324ad9 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Task Itr Count.csv @@ -0,0 +1,18 @@ +Task Name,number of iterations +souurce,1 +zoomProcess_fxp_cloned,1 +psychoFilter_fxp_cloned,1 +rotateOrder3_fxp_cloned,1024 +FIR_right_fxp_cloned,16400 +FIR_left_fxp_cloned,16400 +IFFT_right_fxp_cloned,1 +IFFT_left_fxp_cloned,1 +rotateOrder2_fxp_cloned,1024 +rotateOrder1_fxp_cloned,1024 +overlap_right_fxp_cloned,1 +overlap_left_fxp_cloned,1 +setAndFFT_right_fxp_cloned,1 +setAndFFT_left_fxp_cloned,1 +rotatorSet_fxp_cloned,1 +zoomSet_fxp_cloned,1 +siink,1 \ No newline at end of file diff --git a/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Task PE Area.csv b/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Task PE Area.csv new file mode 100644 index 00000000..2ceb2486 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Task PE Area.csv @@ -0,0 +1,19 @@ +Task Name,A53,IP0 +souurce,0, +rotatorSet_fxp_cloned,,51361 +psychoFilter_fxp_cloned,,78814 +rotateOrder1_fxp_cloned,,56493 +rotateOrder2_fxp_cloned,,191946 +rotateOrder3_fxp_cloned,,476887 +zoomSet_fxp_cloned,,4874 +zoomProcess_fxp_cloned,,54359 +setAndFFT_left_fxp_cloned,,458 +setAndFFT_right_fxp_cloned,,458 +FIR_left_fxp_cloned,,19496 +FIR_right_fxp_cloned,,19496 +IFFT_left_fxp_cloned,,5496 +IFFT_right_fxp_cloned,,5496 +overlap_left_fxp_cloned,,59160 +overlap_right_fxp_cloned,,59110 +siink,0, +dummy_last,0, \ No newline at end of file diff --git a/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Task PE Energy.csv b/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Task PE Energy.csv new file mode 100644 index 00000000..998500e6 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Task PE Energy.csv @@ -0,0 +1,19 @@ +Task Name,A53,IP0 +souurce,0, +rotatorSet_fxp_cloned,,103.026 +psychoFilter_fxp_cloned,,142.002 +rotateOrder1_fxp_cloned,,112.713 +rotateOrder2_fxp_cloned,,386.718 +rotateOrder3_fxp_cloned,,945.308 +zoomSet_fxp_cloned,,8.986 +zoomProcess_fxp_cloned,,100.54 +setAndFFT_left_fxp_cloned,,0.206 +setAndFFT_right_fxp_cloned,,0.206 +FIR_left_fxp_cloned,,38.597 +FIR_right_fxp_cloned,,38.597 +IFFT_left_fxp_cloned,,4.371 +IFFT_right_fxp_cloned,,4.371 +overlap_left_fxp_cloned,,103.253 +overlap_right_fxp_cloned,,103.235 +siink,0, +dummy_last,0, \ No newline at end of file diff --git a/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Task PE Performance for 1000 blocks.csv b/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Task PE Performance for 1000 blocks.csv new file mode 100644 index 00000000..30c69f64 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Task PE Performance for 1000 blocks.csv @@ -0,0 +1,19 @@ +Task Name,A53,IP0,IP1,IP2,IP3 +souurce,0,,,, +rotatorSet_fxp_cloned,108000,3323000,,, +psychoFilter_fxp_cloned,1322969000,38005000,,, +rotateOrder1_fxp_cloned,109565000,6712000,3356000,1678000,839000 +rotateOrder2_fxp_cloned,331748000,6712000,3356000,1678000,839000 +rotateOrder3_fxp_cloned,805804000,7736000,3868000,1934000,967000 +zoomSet_fxp_cloned,27000,1661000,,, +zoomProcess_fxp_cloned,1459196000,104855000,,, +setAndFFT_left_fxp_cloned,329000,2417000,,, +setAndFFT_right_fxp_cloned,329000,2417000,,, +FIR_left_fxp_cloned,679206000,19261000,9630500,4815250,2407625 +FIR_right_fxp_cloned,679206000,19261000,9630500,4815250,2407625 +IFFT_left_fxp_cloned,416985000,17503000,,, +IFFT_right_fxp_cloned,416985000,17503000,,, +overlap_left_fxp_cloned,40958000,3060000,,, +overlap_right_fxp_cloned,40958000,3060000,,, +dummy_last,1000,,,, +siink,0,,,, \ No newline at end of file diff --git a/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Task PE Performance.csv b/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Task PE Performance.csv new file mode 100644 index 00000000..3986a939 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/audio_decoder_database - Task PE Performance.csv @@ -0,0 +1,19 @@ +Task Name,A53,IP0 +souurce,0, +rotatorSet_fxp_cloned,108,3323 +psychoFilter_fxp_cloned,1322969,38005 +rotateOrder1_fxp_cloned,109565,6712 +rotateOrder2_fxp_cloned,331748,6712 +rotateOrder3_fxp_cloned,805804,7736 +zoomSet_fxp_cloned,27,1661 +zoomProcess_fxp_cloned,1459196,104855 +setAndFFT_left_fxp_cloned,329,2417 +setAndFFT_right_fxp_cloned,329,2417 +FIR_left_fxp_cloned,679206,19261 +FIR_right_fxp_cloned,679206,19261 +IFFT_left_fxp_cloned,416985,17503 +IFFT_right_fxp_cloned,416985,17503 +overlap_left_fxp_cloned,40958,3060 +overlap_right_fxp_cloned,40958,3060 +dummy_last,0, +siink,0, \ No newline at end of file diff --git a/Project_FARSI/specs/database_data/parsing/audio_decoder_database - misc.csv b/Project_FARSI/specs/database_data/parsing/audio_decoder_database - misc.csv new file mode 100644 index 00000000..b760f119 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/audio_decoder_database - misc.csv @@ -0,0 +1,2 @@ +workload,last_task +audio_decoder,dummy_last \ No newline at end of file diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_1_database - Task Data Movement.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_1_database - Task Data Movement.csv new file mode 100644 index 00000000..ca8f095f --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_1_database - Task Data Movement.csv @@ -0,0 +1,9 @@ +Task Name,souurce,gaussianSmoothing_1,laplacianEstimate_1,computeZeroCrossings_1,computeGradient_1,computeMaxGradientLeaf_1,rejectZeroCrossings_1,siink +souurce,,13107396,,,,,, +gaussianSmoothing_1,,,6553600,,6553600,,, +laplacianEstimate_1,,,,6553600,,,, +computeZeroCrossings_1,,,,,,,6553600, +computeGradient_1,,,,,,6553600,, +computeMaxGradientLeaf_1,,,,,,,4, +rejectZeroCrossings_1,,,,,,,,6553600 +siink,,,,,,,, diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_1_database - Task Itr Count.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_1_database - Task Itr Count.csv new file mode 100644 index 00000000..6df41c5c --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_1_database - Task Itr Count.csv @@ -0,0 +1,9 @@ +Task Name,number of iterations +souurce,1 +gaussianSmoothing_1,1638400 +laplacianEstimate_1,1638400 +computeZeroCrossings_1,1638400 +computeGradient_1,1638400 +computeMaxGradientLeaf_1,256 +rejectZeroCrossings_1,1638400 +siink,1 diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_1_database - Task PE Area.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_1_database - Task PE Area.csv new file mode 100644 index 00000000..6511a9e5 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_1_database - Task PE Area.csv @@ -0,0 +1,9 @@ +Task Name,A53,IP0 +souurce,0, +gaussianSmoothing_1,,137341 +laplacianEstimate_1,,122700 +computeZeroCrossings_1,,118826 +computeGradient_1,,241114 +computeMaxGradientLeaf_1,,7148 +rejectZeroCrossings_1,,9669 +siink,0, diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_1_database - Task PE Energy.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_1_database - Task PE Energy.csv new file mode 100644 index 00000000..e6959942 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_1_database - Task PE Energy.csv @@ -0,0 +1,9 @@ +Task Name,A53,IP0,IP1,IP2,IP3,IP4,IP5,IP6,IP7,IP8 +souurce,0,,,,,,,,, +gaussianSmoothing_1,,233984,467968,935936,1871872,3743744,7487488,14974976,29949952,59899904 +laplacianEstimate_1,,224815,449630,899260,1798520,3597040,7194080,14388160,28776320,57552640 +computeZeroCrossings_1,,218329,436658,873316,1746632,3493264,6986528,13973056,27946112,55892224 +computeGradient_1,,447176,894352,1788704,3577408,7154816,14309632,28619264,57238528,114477056 +computeMaxGradientLeaf_1,,9859,19718,39436,78872,157744,315488,630976,1261952,2523904 +rejectZeroCrossings_1,,17818,35636,71272,142544,285088,570176,1140352,2280704,4561408 +siink,0,,,,,,,,, diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_1_database - Task PE Performance.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_1_database - Task PE Performance.csv new file mode 100644 index 00000000..23060c25 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_1_database - Task PE Performance.csv @@ -0,0 +1,9 @@ +Task Name,A53,IP0 +souurce,0, +gaussianSmoothing_1,1617100800,331284480 +laplacianEstimate_1,421068800,198574080 +computeZeroCrossings_1,437452800,198574080 +computeGradient_1,427622400,189071360 +computeMaxGradientLeaf_1,14749184,3301888 +rejectZeroCrossings_1,376832000,177602560 +siink,0, diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_2_database - Task Data Movement.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_2_database - Task Data Movement.csv new file mode 100644 index 00000000..c8d8d63a --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_2_database - Task Data Movement.csv @@ -0,0 +1,9 @@ +Task Name,souurce,gaussianSmoothing_2,laplacianEstimate_2,computeZeroCrossings_2,computeGradient_2,computeMaxGradientLeaf_2,rejectZeroCrossings_2,siink +souurce,,13107396,,,,,, +gaussianSmoothing_2,,,6553600,,6553600,,, +laplacianEstimate_2,,,,6553600,,,, +computeZeroCrossings_2,,,,,,,6553600, +computeGradient_2,,,,,,6553600,, +computeMaxGradientLeaf_2,,,,,,,4, +rejectZeroCrossings_2,,,,,,,,6553600 +siink,,,,,,,, diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_2_database - Task Itr Count.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_2_database - Task Itr Count.csv new file mode 100644 index 00000000..e09a0adc --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_2_database - Task Itr Count.csv @@ -0,0 +1,9 @@ +Task Name,number of iterations +souurce,1 +gaussianSmoothing_2,1638400 +laplacianEstimate_2,1638400 +computeZeroCrossings_2,1638400 +computeGradient_2,1638400 +computeMaxGradientLeaf_2,256 +rejectZeroCrossings_2,1638400 +siink,1 diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_2_database - Task PE Area.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_2_database - Task PE Area.csv new file mode 100644 index 00000000..c4861c18 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_2_database - Task PE Area.csv @@ -0,0 +1,9 @@ +Task Name,A53,IP0 +souurce,0, +gaussianSmoothing_2,,137341 +laplacianEstimate_2,,122700 +computeZeroCrossings_2,,118826 +computeGradient_2,,241114 +computeMaxGradientLeaf_2,,7148 +rejectZeroCrossings_2,,9669 +siink,0, diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_2_database - Task PE Energy.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_2_database - Task PE Energy.csv new file mode 100644 index 00000000..9b2b8b59 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_2_database - Task PE Energy.csv @@ -0,0 +1,9 @@ +Task Name,A53,IP0,IP1,IP2,IP3,IP4,IP5,IP6,IP7,IP8 +souurce,0,,,,,,,,, +gaussianSmoothing_2,,233984,467968,935936,1871872,3743744,7487488,14974976,29949952,59899904 +laplacianEstimate_2,,224815,449630,899260,1798520,3597040,7194080,14388160,28776320,57552640 +computeZeroCrossings_2,,218329,436658,873316,1746632,3493264,6986528,13973056,27946112,55892224 +computeGradient_2,,447176,894352,1788704,3577408,7154816,14309632,28619264,57238528,114477056 +computeMaxGradientLeaf_2,,9859,19718,39436,78872,157744,315488,630976,1261952,2523904 +rejectZeroCrossings_2,,17818,35636,71272,142544,285088,570176,1140352,2280704,4561408 +siink,0,,,,,,,,, diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_2_database - Task PE Performance.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_2_database - Task PE Performance.csv new file mode 100644 index 00000000..9828913c --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_2_database - Task PE Performance.csv @@ -0,0 +1,9 @@ +Task Name,A53,IP0 +souurce,0, +gaussianSmoothing_2,1617100800,331284480 +laplacianEstimate_2,421068800,198574080 +computeZeroCrossings_2,437452800,198574080 +computeGradient_2,427622400,189071360 +computeMaxGradientLeaf_2,14749184,3301888 +rejectZeroCrossings_2,376832000,177602560 +siink,0, diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_3_database - Task Data Movement.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_3_database - Task Data Movement.csv new file mode 100644 index 00000000..dacd1e95 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_3_database - Task Data Movement.csv @@ -0,0 +1,9 @@ +Task Name,souurce,gaussianSmoothing_3,laplacianEstimate_3,computeZeroCrossings_3,computeGradient_3,computeMaxGradientLeaf_3,rejectZeroCrossings_3,siink +souurce,,13107396,,,,,, +gaussianSmoothing_3,,,6553600,,6553600,,, +laplacianEstimate_3,,,,6553600,,,, +computeZeroCrossings_3,,,,,,,6553600, +computeGradient_3,,,,,,6553600,, +computeMaxGradientLeaf_3,,,,,,,4, +rejectZeroCrossings_3,,,,,,,,6553600 +siink,,,,,,,, diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_3_database - Task Itr Count.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_3_database - Task Itr Count.csv new file mode 100644 index 00000000..28ab97fc --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_3_database - Task Itr Count.csv @@ -0,0 +1,9 @@ +Task Name,number of iterations +souurce,1 +gaussianSmoothing_3,1638400 +laplacianEstimate_3,1638400 +computeZeroCrossings_3,1638400 +computeGradient_3,1638400 +computeMaxGradientLeaf_3,256 +rejectZeroCrossings_3,1638400 +siink,1 diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_3_database - Task PE Area.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_3_database - Task PE Area.csv new file mode 100644 index 00000000..d0ffa047 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_3_database - Task PE Area.csv @@ -0,0 +1,9 @@ +Task Name,A53,IP0 +souurce,0, +gaussianSmoothing_3,,137341 +laplacianEstimate_3,,122700 +computeZeroCrossings_3,,118826 +computeGradient_3,,241114 +computeMaxGradientLeaf_3,,7148 +rejectZeroCrossings_3,,9669 +siink,0, diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_3_database - Task PE Energy.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_3_database - Task PE Energy.csv new file mode 100644 index 00000000..324adc1b --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_3_database - Task PE Energy.csv @@ -0,0 +1,9 @@ +Task Name,A53,IP0,IP1,IP2,IP3,IP4,IP5,IP6,IP7,IP8 +souurce,0,,,,,,,,, +gaussianSmoothing_3,,233984,467968,935936,1871872,3743744,7487488,14974976,29949952,59899904 +laplacianEstimate_3,,224815,449630,899260,1798520,3597040,7194080,14388160,28776320,57552640 +computeZeroCrossings_3,,218329,436658,873316,1746632,3493264,6986528,13973056,27946112,55892224 +computeGradient_3,,447176,894352,1788704,3577408,7154816,14309632,28619264,57238528,114477056 +computeMaxGradientLeaf_3,,9859,19718,39436,78872,157744,315488,630976,1261952,2523904 +rejectZeroCrossings_3,,17818,35636,71272,142544,285088,570176,1140352,2280704,4561408 +siink,0,,,,,,,,, diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_3_database - Task PE Performance.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_3_database - Task PE Performance.csv new file mode 100644 index 00000000..df976908 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_3_database - Task PE Performance.csv @@ -0,0 +1,9 @@ +Task Name,A53,IP0 +souurce,0, +gaussianSmoothing_3,1617100800,331284480 +laplacianEstimate_3,421068800,198574080 +computeZeroCrossings_3,437452800,198574080 +computeGradient_3,427622400,189071360 +computeMaxGradientLeaf_3,14749184,3301888 +rejectZeroCrossings_3,376832000,177602560 +siink,0, diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_4_database - Task Data Movement.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_4_database - Task Data Movement.csv new file mode 100644 index 00000000..86ef6e02 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_4_database - Task Data Movement.csv @@ -0,0 +1,9 @@ +Task Name,souurce,gaussianSmoothing_4,laplacianEstimate_4,computeZeroCrossings_4,computeGradient_4,computeMaxGradientLeaf_4,rejectZeroCrossings_4,siink +souurce,,13107396,,,,,, +gaussianSmoothing_4,,,6553600,,6553600,,, +laplacianEstimate_4,,,,6553600,,,, +computeZeroCrossings_4,,,,,,,6553600, +computeGradient_4,,,,,,6553600,, +computeMaxGradientLeaf_4,,,,,,,4, +rejectZeroCrossings_4,,,,,,,,6553600 +siink,,,,,,,, diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_4_database - Task Itr Count.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_4_database - Task Itr Count.csv new file mode 100644 index 00000000..72fecf50 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_4_database - Task Itr Count.csv @@ -0,0 +1,9 @@ +Task Name,number of iterations +souurce,1 +gaussianSmoothing_4,1638400 +laplacianEstimate_4,1638400 +computeZeroCrossings_4,1638400 +computeGradient_4,1638400 +computeMaxGradientLeaf_4,256 +rejectZeroCrossings_4,1638400 +siink,1 diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_4_database - Task PE Area.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_4_database - Task PE Area.csv new file mode 100644 index 00000000..249aa759 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_4_database - Task PE Area.csv @@ -0,0 +1,9 @@ +Task Name,A53,IP0 +souurce,0, +gaussianSmoothing_4,,137341 +laplacianEstimate_4,,122700 +computeZeroCrossings_4,,118826 +computeGradient_4,,241114 +computeMaxGradientLeaf_4,,7148 +rejectZeroCrossings_4,,9669 +siink,0, diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_4_database - Task PE Energy.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_4_database - Task PE Energy.csv new file mode 100644 index 00000000..e1f33aef --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_4_database - Task PE Energy.csv @@ -0,0 +1,9 @@ +Task Name,A53,IP0,IP1,IP2,IP3,IP4,IP5,IP6,IP7,IP8 +souurce,0,,,,,,,,, +gaussianSmoothing_4,,233984,467968,935936,1871872,3743744,7487488,14974976,29949952,59899904 +laplacianEstimate_4,,224815,449630,899260,1798520,3597040,7194080,14388160,28776320,57552640 +computeZeroCrossings_4,,218329,436658,873316,1746632,3493264,6986528,13973056,27946112,55892224 +computeGradient_4,,447176,894352,1788704,3577408,7154816,14309632,28619264,57238528,114477056 +computeMaxGradientLeaf_4,,9859,19718,39436,78872,157744,315488,630976,1261952,2523904 +rejectZeroCrossings_4,,17818,35636,71272,142544,285088,570176,1140352,2280704,4561408 +siink,0,,,,,,,,, diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_4_database - Task PE Performance.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_4_database - Task PE Performance.csv new file mode 100644 index 00000000..a2e52f87 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_4_database - Task PE Performance.csv @@ -0,0 +1,9 @@ +Task Name,A53,IP0 +souurce,0, +gaussianSmoothing_4,1617100800,331284480 +laplacianEstimate_4,421068800,198574080 +computeZeroCrossings_4,437452800,198574080 +computeGradient_4,427622400,189071360 +computeMaxGradientLeaf_4,14749184,3301888 +rejectZeroCrossings_4,376832000,177602560 +siink,0, diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_database - Task Data Movement.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_database - Task Data Movement.csv new file mode 100644 index 00000000..87a4e518 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_database - Task Data Movement.csv @@ -0,0 +1,9 @@ +Task Name,souurce,gaussianSmoothing,laplacianEstimate,computeZeroCrossings,computeGradient,computeMaxGradientLeaf,rejectZeroCrossings,siink +souurce,,13107396,,,,,, +gaussianSmoothing,,,6553600,,6553600,,, +laplacianEstimate,,,,6553600,,,, +computeZeroCrossings,,,,,,,6553600, +computeGradient,,,,,,6553600,, +computeMaxGradientLeaf,,,,,,,4, +rejectZeroCrossings,,,,,,,,6553600 +siink,,,,,,,, diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_database - Task Itr Count.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_database - Task Itr Count.csv new file mode 100644 index 00000000..82a86732 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_database - Task Itr Count.csv @@ -0,0 +1,9 @@ +Task Name,number of iterations +souurce,1 +gaussianSmoothing,1638400 +laplacianEstimate,1638400 +computeZeroCrossings,1638400 +computeGradient,1638400 +computeMaxGradientLeaf,256 +rejectZeroCrossings,1638400 +siink,1 \ No newline at end of file diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_database - Task PE Area.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_database - Task PE Area.csv new file mode 100644 index 00000000..467cdd5a --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_database - Task PE Area.csv @@ -0,0 +1,9 @@ +Task Name,A53,IP0 +souurce,0, +gaussianSmoothing,,137341 +laplacianEstimate,,122700 +computeZeroCrossings,,118826 +computeGradient,,241114 +computeMaxGradientLeaf,,7148 +rejectZeroCrossings,,9669 +siink,0, \ No newline at end of file diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_database - Task PE Energy.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_database - Task PE Energy.csv new file mode 100644 index 00000000..2b421d43 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_database - Task PE Energy.csv @@ -0,0 +1,9 @@ +Task Name,A53,IP0,IP1,IP2,IP3,IP4,IP5,IP6,IP7,IP8 +souurce,0,,,,,,,,, +gaussianSmoothing,,233984,467968,935936,1871872,3743744,7487488,14974976,29949952,59899904 +laplacianEstimate,,224815,449630,899260,1798520,3597040,7194080,14388160,28776320,57552640 +computeZeroCrossings,,218329,436658,873316,1746632,3493264,6986528,13973056,27946112,55892224 +computeGradient,,447176,894352,1788704,3577408,7154816,14309632,28619264,57238528,114477056 +computeMaxGradientLeaf,,9859,19718,39436,78872,157744,315488,630976,1261952,2523904 +rejectZeroCrossings,,17818,35636,71272,142544,285088,570176,1140352,2280704,4561408 +siink,0,,,,,,,,, \ No newline at end of file diff --git a/Project_FARSI/specs/database_data/parsing/edge_detection_database - Task PE Performance.csv b/Project_FARSI/specs/database_data/parsing/edge_detection_database - Task PE Performance.csv new file mode 100644 index 00000000..b24aa034 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/edge_detection_database - Task PE Performance.csv @@ -0,0 +1,9 @@ +Task Name,A53,IP0 +souurce,0, +gaussianSmoothing,1617100800,331284480 +laplacianEstimate,421068800,198574080 +computeZeroCrossings,437452800,198574080 +computeGradient,427622400,189071360 +computeMaxGradientLeaf,14749184,3301888 +rejectZeroCrossings,376832000,177602560 +siink,0, diff --git a/Project_FARSI/specs/database_data/parsing/hpvm_cava_database - Hardware Graph.csv b/Project_FARSI/specs/database_data/parsing/hpvm_cava_database - Hardware Graph.csv new file mode 100644 index 00000000..2bdff02e --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/hpvm_cava_database - Hardware Graph.csv @@ -0,0 +1,4 @@ +Block Name,A53_0,LMEM_ic_0_0,LMEM_0_0 +A53_0,,1, +LMEM_ic_0_0,1,,1 +LMEM_0_0,,1, \ No newline at end of file diff --git a/Project_FARSI/specs/database_data/parsing/hpvm_cava_database - Task Data Movement.csv b/Project_FARSI/specs/database_data/parsing/hpvm_cava_database - Task Data Movement.csv new file mode 100644 index 00000000..a9022f96 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/hpvm_cava_database - Task Data Movement.csv @@ -0,0 +1,10 @@ +Task Name,souurce,scale_fxp_cloned,demosaic_fxp_cloned,denoise_fxp_cloned,transform_fxp_cloned,gamut_map_fxp_cloned,tone_map_fxp_cloned,descale_fxp_cloned,siink +souurce,,91053,,,,,,, +scale_fxp_cloned,,,364212,,,,,, +demosaic_fxp_cloned,,,,364212,,,,, +denoise_fxp_cloned,,,,,364212,,,, +transform_fxp_cloned,,,,,,364212,,, +gamut_map_fxp_cloned,,,,,,,364212,, +tone_map_fxp_cloned,,,,,,,,364212, +descale_fxp_cloned,,,,,,,,,91053 +siink,,,,,,,,, \ No newline at end of file diff --git a/Project_FARSI/specs/database_data/parsing/hpvm_cava_database - Task Itr Count.csv b/Project_FARSI/specs/database_data/parsing/hpvm_cava_database - Task Itr Count.csv new file mode 100644 index 00000000..d35514e1 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/hpvm_cava_database - Task Itr Count.csv @@ -0,0 +1,10 @@ +Task Name,number of iterations +source,1 +scale_fxp_cloned,151 +demosaic_fxp_cloned,151 +denoise_fxp_cloned,151 +transform_fxp_cloned,151 +gamut_map_fxp_cloned,151 +tone_map_fxp_cloned,151 +descale_fxp_cloned,151 +siink,1 \ No newline at end of file diff --git a/Project_FARSI/specs/database_data/parsing/hpvm_cava_database - Task PE Area.csv b/Project_FARSI/specs/database_data/parsing/hpvm_cava_database - Task PE Area.csv new file mode 100644 index 00000000..108ebc02 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/hpvm_cava_database - Task PE Area.csv @@ -0,0 +1,10 @@ +Task Name,A53,IP0 +souurce,0, +scale_fxp_cloned,,14601 +demosaic_fxp_cloned,,73580 +denoise_fxp_cloned,,44998 +transform_fxp_cloned,,39287 +gamut_map_fxp_cloned,,107726 +tone_map_fxp_cloned,,19425 +descale_fxp_cloned,,15030 +siink,0, \ No newline at end of file diff --git a/Project_FARSI/specs/database_data/parsing/hpvm_cava_database - Task PE Energy.csv b/Project_FARSI/specs/database_data/parsing/hpvm_cava_database - Task PE Energy.csv new file mode 100644 index 00000000..6089d4fb --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/hpvm_cava_database - Task PE Energy.csv @@ -0,0 +1,10 @@ +Task Name,A53,IP0 +souurce,0, +scale_fxp_cloned,,26887 +demosaic_fxp_cloned,,131069 +denoise_fxp_cloned,,80897 +transform_fxp_cloned,,72148 +gamut_map_fxp_cloned,,208042 +tone_map_fxp_cloned,,35786 +descale_fxp_cloned,,27098 +siink,0, \ No newline at end of file diff --git a/Project_FARSI/specs/database_data/parsing/hpvm_cava_database - Task PE Performance.csv b/Project_FARSI/specs/database_data/parsing/hpvm_cava_database - Task PE Performance.csv new file mode 100644 index 00000000..cf19d282 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/hpvm_cava_database - Task PE Performance.csv @@ -0,0 +1,10 @@ +Task Name,A53,IP0 +souurce,0, +scale_fxp_cloned,23110550,1390257 +demosaic_fxp_cloned,16764020,1585500 +denoise_fxp_cloned,348287144,9876910 +transform_fxp_cloned,47395880,2606562 +gamut_map_fxp_cloned,81304050420,3387311728 +tone_map_fxp_cloned,20683980,2606562 +descale_fxp_cloned,3122039760,186117617 +siink,0, \ No newline at end of file diff --git a/Project_FARSI/specs/database_data/parsing/hpvm_cava_database - Task To Hardware Mapping.csv b/Project_FARSI/specs/database_data/parsing/hpvm_cava_database - Task To Hardware Mapping.csv new file mode 100644 index 00000000..d4c80aca --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/hpvm_cava_database - Task To Hardware Mapping.csv @@ -0,0 +1,10 @@ +A53_0,LMEM_0_0,LMEM_ic_0_0 +souurce,souurce -> scale_fxp_cloned,souurce -> scale_fxp_cloned +scale_fxp_cloned,scale_fxp_cloned -> demosaic_fxp_cloned,scale_fxp_cloned -> demosaic_fxp_cloned +demosaic_fxp_cloned,demosaic_fxp_cloned -> denoise_fxp_cloned,demosaic_fxp_cloned -> denoise_fxp_cloned +denoise_fxp_cloned,denoise_fxp_cloned->transform_fxp_cloned,denoise_fxp_cloned->transform_fxp_cloned +transform_fxp_cloned,transform_fxp_cloned->gamut_map_fxp_cloned,transform_fxp_cloned->gamut_map_fxp_cloned +gamut_map_fxp_cloned,gamut_map_fxp_cloned->tone_map_fxp_cloned,gamut_map_fxp_cloned->tone_map_fxp_cloned +tone_map_fxp_cloned,tone_map_fxp_cloned->descale_fxp_cloned,tone_map_fxp_cloned->descale_fxp_cloned +descale_fxp_cloned,descale_fxp_cloned->siink,descale_fxp_cloned->siink +siink,, \ No newline at end of file diff --git a/Project_FARSI/specs/database_data/parsing/misc_database - Block Characteristics.csv b/Project_FARSI/specs/database_data/parsing/misc_database - Block Characteristics.csv new file mode 100644 index 00000000..33bd63af --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/misc_database - Block Characteristics.csv @@ -0,0 +1,21 @@ +Name,dhrystone_IPC,Freq,Ref,Type,Subtype,Inst_per_joul,Gpp_area,BitWidth,Byte_per_joul,Byte_per_m,hop_latency,pipe_line_depth +A53,2,1000000000,yes,pe,gpp,47019607843,6.5E-07,NA,NA,NA,1,8 +IP,NA,100000000,yes,pe,ip,,NA,NA,NA,NA,1,4 +LMEM_0,NA,100000000,yes,mem,sram,,NA,4,200,1,1,4 +LMEM_1,NA,100000000,no,mem,sram,,NA,8,200,1,1,8 +LMEM_2,NA,100000000,no,mem,sram,,NA,16,200,1,1,4 +LMEM_3,NA,100000000,no,mem,sram,,NA,32,200,1,1,8 +LMEM_4,NA,100000000,no,mem,sram,,NA,64,200,1,1,8 +LMEM_5,NA,100000000,no,mem,sram,,NA,128,200,1,1,8 +GMEM_0,NA,100000000,yes,mem,dram,,NA,4,1,2,1,4 +GMEM_1,NA,100000000,no,mem,dram,,NA,8,1,2,1,8 +GMEM_2,NA,100000000,no,mem,dram,,NA,16,1,2,1,4 +GMEM_3,NA,100000000,no,mem,dram,,NA,32,1,2,1,8 +GMEM_4,NA,100000000,no,mem,dram,,NA,64,1,2,1,8 +GMEM_5,NA,100000000,no,mem,dram,,NA,128,1,2,1,8 +LMEM_ic_0,NA,100000000,yes,ic,ic,,NA,4,1,1,1,8 +LMEM_ic_1,NA,100000000,no,ic,ic,,NA,8,1,1,1,8 +LMEM_ic_2,NA,100000000,no,ic,ic,,NA,16,1,1,1,4 +LMEM_ic_3,NA,100000000,no,ic,ic,,NA,32,1,1,1,8 +LMEM_ic_4,NA,100000000,no,ic,ic,,NA,64,1,1,1,8 +LMEM_ic_5,NA,100000000,no,ic,ic,,NA,128,1,1,1,8 diff --git a/Project_FARSI/specs/database_data/parsing/misc_database - Block Characteristics_for_PA.csv b/Project_FARSI/specs/database_data/parsing/misc_database - Block Characteristics_for_PA.csv new file mode 100644 index 00000000..5db6aa96 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/misc_database - Block Characteristics_for_PA.csv @@ -0,0 +1,21 @@ +Name,dhrystone_IPC,Freq,Ref,Type,Subtype,Inst_per_joul,Gpp_area,BitWidth,Byte_per_joul,Byte_per_m,hop_latency,pipe_line_depth +A53,2,500000000,yes,pe,gpp,47019607843,6.5E-07,NA,NA,NA,1,8 +IP,NA,50000000,yes,pe,ip,,NA,NA,NA,NA,1,4 +LMEM_0,NA,50000000,yes,mem,sram,,NA,4,200,1,1,4 +LMEM_1,NA,50000000,no,mem,sram,,NA,8,200,1,1,8 +LMEM_2,NA,50000000,no,mem,sram,,NA,16,200,1,1,4 +LMEM_3,NA,50000000,no,mem,sram,,NA,32,200,1,1,8 +LMEM_4,NA,50000000,no,mem,sram,,NA,64,200,1,1,8 +LMEM_5,NA,50000000,no,mem,sram,,NA,128,200,1,1,8 +GMEM_0,NA,50000000,yes,mem,dram,,NA,4,1,2,1,4 +GMEM_1,NA,50000000,no,mem,dram,,NA,8,1,2,1,8 +GMEM_2,NA,50000000,no,mem,dram,,NA,16,1,2,1,4 +GMEM_3,NA,50000000,no,mem,dram,,NA,32,1,2,1,8 +GMEM_4,NA,50000000,no,mem,dram,,NA,64,1,2,1,8 +GMEM_5,NA,50000000,no,mem,dram,,NA,128,1,2,1,8 +LMEM_ic_0,NA,50000000,yes,ic,ic,,NA,4,1,1,1,8 +LMEM_ic_1,NA,50000000,no,ic,ic,,NA,8,1,1,1,8 +LMEM_ic_2,NA,50000000,no,ic,ic,,NA,16,1,1,1,4 +LMEM_ic_3,NA,50000000,no,ic,ic,,NA,32,1,1,1,8 +LMEM_ic_4,NA,50000000,no,ic,ic,,NA,64,1,1,1,8 +LMEM_ic_5,NA,50000000,no,ic,ic,,NA,128,1,1,1,8 diff --git a/Project_FARSI/specs/database_data/parsing/misc_database - Block Characteristics_for_real.csv b/Project_FARSI/specs/database_data/parsing/misc_database - Block Characteristics_for_real.csv new file mode 100644 index 00000000..33bd63af --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/misc_database - Block Characteristics_for_real.csv @@ -0,0 +1,21 @@ +Name,dhrystone_IPC,Freq,Ref,Type,Subtype,Inst_per_joul,Gpp_area,BitWidth,Byte_per_joul,Byte_per_m,hop_latency,pipe_line_depth +A53,2,1000000000,yes,pe,gpp,47019607843,6.5E-07,NA,NA,NA,1,8 +IP,NA,100000000,yes,pe,ip,,NA,NA,NA,NA,1,4 +LMEM_0,NA,100000000,yes,mem,sram,,NA,4,200,1,1,4 +LMEM_1,NA,100000000,no,mem,sram,,NA,8,200,1,1,8 +LMEM_2,NA,100000000,no,mem,sram,,NA,16,200,1,1,4 +LMEM_3,NA,100000000,no,mem,sram,,NA,32,200,1,1,8 +LMEM_4,NA,100000000,no,mem,sram,,NA,64,200,1,1,8 +LMEM_5,NA,100000000,no,mem,sram,,NA,128,200,1,1,8 +GMEM_0,NA,100000000,yes,mem,dram,,NA,4,1,2,1,4 +GMEM_1,NA,100000000,no,mem,dram,,NA,8,1,2,1,8 +GMEM_2,NA,100000000,no,mem,dram,,NA,16,1,2,1,4 +GMEM_3,NA,100000000,no,mem,dram,,NA,32,1,2,1,8 +GMEM_4,NA,100000000,no,mem,dram,,NA,64,1,2,1,8 +GMEM_5,NA,100000000,no,mem,dram,,NA,128,1,2,1,8 +LMEM_ic_0,NA,100000000,yes,ic,ic,,NA,4,1,1,1,8 +LMEM_ic_1,NA,100000000,no,ic,ic,,NA,8,1,1,1,8 +LMEM_ic_2,NA,100000000,no,ic,ic,,NA,16,1,1,1,4 +LMEM_ic_3,NA,100000000,no,ic,ic,,NA,32,1,1,1,8 +LMEM_ic_4,NA,100000000,no,ic,ic,,NA,64,1,1,1,8 +LMEM_ic_5,NA,100000000,no,ic,ic,,NA,128,1,1,1,8 diff --git a/Project_FARSI/specs/database_data/parsing/misc_database - Budget.csv b/Project_FARSI/specs/database_data/parsing/misc_database - Budget.csv new file mode 100644 index 00000000..ad5b4245 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/misc_database - Budget.csv @@ -0,0 +1,14 @@ +Workload,latency,power,area,cost +audio_decoder,0.021,0.008737,0.000017475,0.000000001 +edge_detection,.034,0.008737,0.000017475,0.000000001 +hpvm_cava,.034,0.008737,0.000017475,0.000000001 +synthetic,.034,0.008737,0.000017475,0.000000001 +edge_detection_1,.034,0.008737,0.000017475,0.000000001 +edge_detection_2,.034,0.008737,0.000017475,0.000000001 +edge_detection_3,.034,0.008737,0.000017475,0.000000001 +edge_detection_4,.034,0.008737,0.000017475,0.000000001 +simple_all_parallel, .100,1,1, +SOC_example, .1,1,1, +partial_SOC_example_hard, .1,1,1, +SOC_example_1p_2r,.1,1,1, +all,,0.008737,0.000017475,0.000000001 diff --git a/Project_FARSI/specs/database_data/parsing/misc_database - Budget_for_PA.csv b/Project_FARSI/specs/database_data/parsing/misc_database - Budget_for_PA.csv new file mode 100644 index 00000000..74be71a6 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/misc_database - Budget_for_PA.csv @@ -0,0 +1,16 @@ +Workload,latency,power,area,cost +audio_decoder,0.002,1000000.008737,1000000.000017475,0.000000001 +edge_detection,0.017,1000000.008737,1000000.000017475,0.000000001 +hpvm_cava,0.034,1000000.008737,1000000.000017475,0.000000001 +synthetic,0.034,1000000.008737,1000000.000017475,0.000000001 +simple_all_parallel,0.1,1,1, +SOC_example,0.1,1,1, +SOC_example_2p,0.1,1,1, +partial_SOC_example_hard,0.1,1,1, +SOC_example_4p,0.1,1,1, +SOC_example_8p,0.1,1,1, +SOC_example_1p,0.1,1,1, +SOC_example_1p_2r,0.1,1,1, +SOC_example_1p_4r,0.1,1,1, +SOC_example_1p_8r,0.1,1,1, +all,,0.008737,0.000017475,0.000000001 diff --git a/Project_FARSI/specs/database_data/parsing/misc_database - Budget_for_real.csv b/Project_FARSI/specs/database_data/parsing/misc_database - Budget_for_real.csv new file mode 100644 index 00000000..552f51b7 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/misc_database - Budget_for_real.csv @@ -0,0 +1,9 @@ +Workload,latency,power,area,cost +audio_decoder,0.021,0.008737,0.000017475,0.000000001 +edge_detection,.034,0.008737,0.000017475,0.000000001 +hpvm_cava,.034,0.008737,0.000017475,0.000000001 +synthetic,.034,0.008737,0.000017475,0.000000001 +simple_all_parallel, .100,1,1, +SOC_example, .1,1,1, +partial_SOC_example_hard, .1,1,1, +all,,0.008737,0.000017475,0.000000001 diff --git a/Project_FARSI/specs/database_data/parsing/misc_database - Budget_individiaul.csv b/Project_FARSI/specs/database_data/parsing/misc_database - Budget_individiaul.csv new file mode 100644 index 00000000..30af28a3 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/misc_database - Budget_individiaul.csv @@ -0,0 +1,8 @@ +Workload,latency,power,area,cost +audio_decoder,0.021,.003097,.000006195, +edge_detection,.034,.003630,.000007260, +hpvm_cava,.034,.002010,.000004020, +simple_all_parallel, .100,1,1, +SOC_example, .1,1,1, +partial_SOC_example_hard, .1,1,1, +all,,0.008737,0.000017475,0.000000001 diff --git a/Project_FARSI/specs/database_data/parsing/misc_database - Budget_individually.csv b/Project_FARSI/specs/database_data/parsing/misc_database - Budget_individually.csv new file mode 100644 index 00000000..3d171208 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/misc_database - Budget_individually.csv @@ -0,0 +1,15 @@ +Workload,latency,power,area,cost +audio_decoder,0.021,0.003097,0.000006195, +edge_detection,0.034,0.00363,0.00000726, +hpvm_cava,0.034,0.00201,0.00000402, +simple_all_parallel,0.1,1,1, +SOC_example,0.1,1,1, +SOC_example_2p,0.1,1,1, +SOC_example_p1,0.1,1,1, +SOC_example_p1_r2,0.1,1,1, +SOC_example_p1_r4,0.1,1,1, +SOC_example_p1_r8,0.1,1,1, +SOC_example_4p,0.1,1,1, +SOC_example_8p,0.1,1,1, +partial_SOC_example_hard,0.1,1,1, +all,,0.008737,0.000017475,0.000000001 diff --git a/Project_FARSI/specs/database_data/parsing/misc_database - Common Hardware.csv b/Project_FARSI/specs/database_data/parsing/misc_database - Common Hardware.csv new file mode 100644 index 00000000..5b2ec712 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/misc_database - Common Hardware.csv @@ -0,0 +1,15 @@ +variable_name,value +ref_gpp_dhrystone_value,2.3 +arm_power_over_clock,=40.8*10^-12 +arm_fix_area,=.71*10^-6 +dsp_speed_up_coef,5 +dsp_power_over_clock,60*10^-12 +dsp_fix_area,=.44*10^-6 +mac_area,=.0000892*(10^-6) +mac_allocated_per_mac_operations,=(10^-3) +ip_speed_up_coef,20 +ip_power_reduction_coef,15 +memory_block_size,=1048576 +ref_ic_width,=4 +ref_mem_width,=4 + diff --git a/Project_FARSI/specs/database_data/parsing/misc_database - Last Tasks.csv b/Project_FARSI/specs/database_data/parsing/misc_database - Last Tasks.csv new file mode 100644 index 00000000..b15edfd3 --- /dev/null +++ b/Project_FARSI/specs/database_data/parsing/misc_database - Last Tasks.csv @@ -0,0 +1,16 @@ +workload,last_task +audio_decoder,dummy_last +edge_detection,rejectZeroCrossings +hpvm_cava,descale_fxp_cloned +simple_all_parallel,siink +edge_detection_1,rejectZeroCrossings_1 +edge_detection_2,rejectZeroCrossings_2 +edge_detection_3,rejectZeroCrossings_3 +edge_detection_4,rejectZeroCrossings_4 +SOC_example,siink +SOC_example_2p,siink +SOC_example_1p,siink +SOC_example_8p,siink +SOC_example_4p,siink +SOC_example_1p_2r,siink +partial_SOC_example_hard,siink diff --git a/Project_FARSI/specs/database_input.py b/Project_FARSI/specs/database_input.py new file mode 100644 index 00000000..9130debe --- /dev/null +++ b/Project_FARSI/specs/database_input.py @@ -0,0 +1,317 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +from specs.LW_cl import * +from specs.parse_libraries.parse_library import * +import importlib +from specs.gen_synthetic_data import * + + +# ----------------- +# This class helps populating the database. It uses the sw_hw_database_population mode to do so. +# the mode can be hardcoded, parse (parse a csv) or generate (synthetic generation) +# ----------------- +class database_input_class(): + def append_tasksL(self, tasksL_list): + # one tasks absorbs another one + def absorb(ref_task, task): + ref_task.set_children_nature({**ref_task.get_children_nature(), **task.get_children_nature()}) + ref_task.set_self_to_children_work_distribution({**ref_task.get_self_to_children_work_distribution(), + **task.get_self_to_children_work_distribution()}) + ref_task.set_self_to_children_work({**ref_task.get_self_to_children_work(), + **task.get_self_to_children_work()}) + ref_task.set_children(ref_task.get_children() + task.get_children()) + + for el in tasksL_list: + refine = False + for el_ in self.tasksL: + if el_.task_name == el.task_name: + refine = True + break + if refine: + absorb(el_, el) + else: + self.tasksL.append(el) + + def append_blocksL(self, blocksL_list): + for el in blocksL_list: + append = True + for el_ in self.blocksL: + if el_.block_instance_name == el.block_instance_name: + append = False + break + if append: + self.blocksL.append(el) + + def append_pe_mapsL(self, pe_mapsL): + for el in pe_mapsL: + append = True + for el_ in self.pe_mapsL: + if el_.pe_block_instance_name == el.pe_block_instance_name and el_.task_name == el.task_name: + append = False + break + if append: + self.pe_mapsL.append(el) + + def append_pe_scheduels(self, pe_scheduels): + for el in pe_scheduels: + append = True + for el_ in self.pe_schedeulesL: + if el_.task_name == el.task_name: + append = False + break + if append: + self.pe_schedeulesL.append(el) + + + + + #glass_constraints = {"power": config.budget_dict["glass"]["power"]} #keys:power, area, latency example: {"area": area_budget} + def __init__(self, + sw_hw_database_population={"db_mode":"hardcoded", "hw_graph_mode":"generated_from_scratch", + "workloads":{},"misc_knobs":{}}): + # some sanity checks first + assert(sw_hw_database_population["db_mode"] in ["hardcoded", "generate", "parse"]) + assert(sw_hw_database_population["hw_graph_mode"] in ["generated_from_scratch", "generated_from_check_point", "parse", "hardcoded", "hop_mode", "star_mode"]) + if sw_hw_database_population["hw_graph_mode"] == "parse": + assert(sw_hw_database_population["db_mode"] == "parse") + + # start parsing/generating + if sw_hw_database_population["db_mode"] == "hardcoded": + lib_relative_addr = config.database_data_dir.replace(config.home_dir, "") + lib_relative_addr_pythony_fied = lib_relative_addr.replace("/",".") + files_to_import = [lib_relative_addr_pythony_fied+".hardcoded."+workload+".input" for workload in sw_hw_database_population["workloads"]] + imported_databases = [importlib.import_module(el) for el in files_to_import] + elif sw_hw_database_population["db_mode"] == "generate": + lib_relative_addr = config.database_data_dir.replace(config.home_dir, "") + lib_relative_addr_pythony_fied = lib_relative_addr.replace("/",".") + files_to_import = [lib_relative_addr_pythony_fied+".generate."+"input" for workload in sw_hw_database_population["workloads"]] + imported_databases = [importlib.import_module(el) for el in files_to_import] + + self.blocksL:List[BlockL] = [] # collection of all the blocks + self.pe_schedeulesL:List[TaskScheduleL] = [] + self.tasksL: List[TaskL] = [] + self.pe_mapsL: List[TaskToPEBlockMapL] = [] + self.souurce_memory_work = {} + self.workloads_last_task = [] + self.workload_tasks = {} + self.task_workload = {} + self.misc_data = {} + self.parallel_task_names = {} + self.hoppy_task_names = [] + self.hardware_graph = "" + self.task_to_hardware_mapping = "" + self.parallel_task_count = "NA" + self.serial_task_count = "NA" + self.memory_boundedness_ratio = "NA" + self.datamovement_scaling_ratio = "NA" + self.parallel_task_type = "NA" + self.num_of_hops = "NA" + self.num_of_NoCs = "NA" + + # using the input files, populate the task graph and possible blocks and the mapping of tasks to blocks + if sw_hw_database_population["db_mode"] == "hardcoded": + if len(imported_databases) > 1: + print("we have to fix the budets_dict collection here. support it and run") + exit(0) + #self.workload_tasks[sw_hw_database_population["workloads"][0]] = [] + for imported_database in imported_databases: + self.blocksL.extend(imported_database.blocksL) + self.tasksL.extend(imported_database.tasksL) + self.pe_mapsL.extend(imported_database.pe_mapsL) + self.pe_schedeulesL.extend(imported_database.pe_schedeulesL) + self.workloads_last_task = imported_database.workloads_last_task + self.budgets_dict = imported_database.budgets_dict + self.other_values_dict = imported_database.other_values_dict + self.souurce_memory_work = imported_database.souurce_memory_work + self.misc_data["same_ip_tasks_list"] = imported_database.same_ip_tasks_list + self.workload_tasks[list(sw_hw_database_population["workloads"])[0]] = [el.task_name for el in self.tasksL] + for el in self.tasksL: + self.task_workload[el.task_name] = list(sw_hw_database_population["workloads"])[0] + + self.sw_hw_database_population = sw_hw_database_population + + elif sw_hw_database_population["db_mode"] == "parse": + for workload in sw_hw_database_population["workloads"]: + tasksL_, data_movement = gen_task_graph(os.path.join(config.database_data_dir, "parsing"), workload+"_database - ", sw_hw_database_population["misc_knobs"]) + blocksL_, pe_mapsL_, pe_schedulesL_ = gen_hardware_library(os.path.join(config.database_data_dir, "parsing"), workload+"_database - ", workload, sw_hw_database_population["misc_knobs"]) + self.sw_hw_database_population = sw_hw_database_population + self.append_tasksL(copy.deepcopy(tasksL_)) + self.append_blocksL(copy.deepcopy(blocksL_)) + self.append_pe_mapsL(copy.deepcopy(pe_mapsL_)) + self.append_pe_scheduels(copy.deepcopy(pe_schedulesL_)) + blah = data_movement + self.souurce_memory_work.update(data_movement['souurce']) + self.workload_tasks[workload] = [el.task_name for el in tasksL_] + for el in tasksL_: + self.task_workload[el.task_name] = workload + + #self.souurce_memory_work += sum([sum(list(data_movement[task].values())) for task in data_movement.keys() if task == "souurce"]) + + self.workloads_last_task = collect_last_task(sw_hw_database_population["workloads"], os.path.join(config.database_data_dir, "parsing"), "misc_database - ") + self.budgets_dict, self.other_values_dict = collect_budgets(sw_hw_database_population["workloads"], sw_hw_database_population["misc_knobs"], os.path.join(config.database_data_dir, "parsing"), "misc_database - ") + if config.heuristic_scaling_study: + for metric in self.budgets_dict['glass'].keys(): + if metric == "latency": + continue + self.budgets_dict['glass'][metric] *= len(sw_hw_database_population["workloads"]) + + self.misc_data["same_ip_tasks_list"] = [] + elif sw_hw_database_population["db_mode"] == "generate": + if len(imported_databases) > 1: + print("we have to fix the budets_dict collection here. support it and run") + exit(0) + self.gen_config = imported_databases[0].gen_config + self.gen_config['parallel_task_cnt'] = sw_hw_database_population["misc_knobs"]['task_spawn']["parallel_task_cnt"] + self.gen_config['serial_task_cnt'] = sw_hw_database_population["misc_knobs"]['task_spawn']["serial_task_cnt"] + self.parallel_task_type = sw_hw_database_population["misc_knobs"]['task_spawn']["parallel_task_type"] + self.parallel_task_count =self.gen_config['parallel_task_cnt'] + self.serial_task_count =self.gen_config['serial_task_cnt'] + self.datamovement_scaling_ratio = sw_hw_database_population["misc_knobs"]['task_spawn']["boundedness"][1] + self.memory_boundedness_ratio = sw_hw_database_population["misc_knobs"]['task_spawn']["boundedness"][2] + self.num_of_hops = sw_hw_database_population["misc_knobs"]['num_of_hops'] + self.budgets_dict = imported_databases[0].budgets_dict + self.other_values_dict= imported_databases[0].other_values_dict + self.num_of_NoCs = sw_hw_database_population["misc_knobs"]["num_of_NoCs"] + + other_task_count = 7 + + total_task_cnt = other_task_count + max(self.gen_config["parallel_task_cnt"]-1, 0) + self.gen_config["serial_task_cnt"] + max(self.num_of_NoCs -2, 0) + + #intensity_params = ["memory_intensive", 1] + intensity_params = sw_hw_database_population["misc_knobs"]['task_spawn']["boundedness"] + + + tasksL_, data_movement, task_work_dict, parallel_task_names, hoppy_task_names = generate_synthetic_task_graphs_for_asymetric_graphs(total_task_cnt, other_task_count, self.gen_config["parallel_task_cnt"], self.gen_config["serial_task_cnt"], self.parallel_task_type, intensity_params, self.num_of_NoCs) # memory_intensive, comp_intensive + blocksL_, pe_mapsL_, pe_schedulesL_ = generate_synthetic_hardware_library(task_work_dict, os.path.join(config.database_data_dir, "parsing"), "misc_database - Block Characteristics.csv") + self.tasksL.extend(tasksL_) + self.blocksL.extend(copy.deepcopy(blocksL_)) + self.pe_mapsL.extend(pe_mapsL_) + self.pe_schedeulesL.extend(pe_schedulesL_) + for task in data_movement.keys(): + if task == "synthetic_souurce": + #self.souurce_memory_work = ["synthetic_souurce"] = sum(list(data_movement[task].values())) + self.souurce_memory_work = data_movement[task] + + self.workloads_last_task = {"synthetic" : [taskL.task_name for taskL in tasksL_ if len(taskL.get_children()) == 0][0]} + self.gen_config["full_potential_tasks_list"] = list(task_work_dict.keys()) + self.misc_data["same_ip_tasks_list"] = [] + self.parallel_task_names = parallel_task_names + self.hoppy_task_names = hoppy_task_names + self.workload_tasks[list(sw_hw_database_population["workloads"])[0]] = [el.task_name for el in self.tasksL] + self.sw_hw_database_population = sw_hw_database_population + + pass + else: + print("db_mode:" + sw_hw_database_population["db_mode"] + " is not supported" ) + exit(0) + + # get the hardware graph if need be + if sw_hw_database_population["hw_graph_mode"] == "parse": + self.hardware_graph = gen_hardware_graph(os.path.join(config.database_data_dir, "parsing"), + workload + "_database - ") + self.task_to_hardware_mapping = gen_task_to_hw_mapping(os.path.join(config.database_data_dir, "parsing"), + workload + "_database - ") + else: + self.hardware_graph = "" + self.task_to_hardware_mapping = "" + + # set the budget values + config.souurce_memory_work = self.souurce_memory_work + self.SOCsL = [] + self.SOCL0_budget_dict = {"latency": self.budgets_dict["glass"]["latency"], "area":self.budgets_dict["glass"]["area"], + "power": self.budgets_dict["glass"]["power"]} #keys:power, area, latency example: {"area": area_budget} + + self.SOCL0_other_metrics_dict = {"cost": self.other_values_dict["glass"]["cost"]} + self.SOCL0 = SOCL("glass", self.SOCL0_budget_dict, self.SOCL0_other_metrics_dict) + self.SOCsL.append(self.SOCL0) + + # ----------------- + # HARDWARE DATABASE + # ----------------- + self.porting_effort = {} + self.porting_effort["arm"] = 1 + self.porting_effort["dsp"] = 10 + self.porting_effort["ip"] = 100 + self.porting_effort["mem"] = .1 + self.porting_effort["ic"] = .1 + + df = pd.read_csv(os.path.join(config.database_data_dir, "parsing", "misc_database - Common Hardware.csv")) + + # eval the expression + def evaluate(value): + replaced_value_1 = value.replace("^", "**") + replaced_value_2 = replaced_value_1.replace("=","") + return eval(replaced_value_2) + + for index, row in df.iterrows(): + temp_dict = row.to_dict() + self.misc_data[list(temp_dict.values())[0]] = evaluate(list(temp_dict.values())[1]) + + self.proj_name = config.proj_name + # simple models to FARSI_fy the database + self.misc_data["byte_error_margin"] = 100 # since we use work ratio, we might calculate the bytes wrong (this effect area calculation) + self.misc_data["area_error_margin"] = 2.1739130434782608e-10 + #self.misc_data["byte_error_margin"]/ self.misc_data["ref_mem_work_over_area"] # to tolerate the error caused by work_ratio + # # (use byte_error_margin for this calculation) + + arm_clock =[el.clock_freq for el in self.blocksL if el.block_subtype == "gpp"][0] + self.misc_data["arm_work_over_energy"] = self.misc_data["ref_gpp_dhrystone_value"]/self.misc_data["arm_power_over_clock"] + self.misc_data["ref_gpp_work_rate"] = self.misc_data["arm_work_rate"] = self.misc_data["ref_gpp_dhrystone_value"] * arm_clock + self.misc_data["dsp_work_rate"] = self.misc_data["dsp_speed_up_coef"] * self.misc_data["ref_gpp_work_rate"] + self.misc_data["ip_work_rate"] = self.misc_data["ip_speed_up_coef"]*self.misc_data["ref_gpp_work_rate"] + self.misc_data["dsp_work_over_energy"] = self.misc_data["dsp_speed_up_coef"] * self.misc_data["ref_gpp_dhrystone_value"] / self.misc_data["dsp_power_over_clock"] + self.misc_data["ip_work_over_energy"] = 8*self.misc_data["dsp_work_over_energy"] # multiply by 5 sine we assume only 1/5 th of the instructions are MACs + self.misc_data["ip_work_over_area"] = 5/(self.misc_data["mac_area"]*self.misc_data["mac_allocated_per_mac_operations"]) + #self.misc_data["ref_ic_work_rate"] = self.misc_data["ref_ic_width"] * self.misc_data["ref_ic_clock"] + #self.misc_data["ref_mem_work_rate"] = self.misc_data["ref_mem_width"] * self.misc_data["ref_mem_clock"] + + + def set_workloads_last_task(self, workloads_last_task): + self.workloads_last_task = workloads_last_task + + def get_parsed_hardware_graph(self): + return self.hardware_graph + + def get_parsed_task_to_hw_mapping(self): + return self.task_to_hardware_mapping + + # setting the budgets dict directly + # This needs to be done with caution + def set_budgets_dict_directly(self, budget_dicts): + self.budgets_dict = budget_dicts + self.SOCsL = [] + self.SOCL0_budget_dict = {"latency": self.budgets_dict["glass"]["latency"], "area":self.budgets_dict["glass"]["area"], + "power": self.budgets_dict["glass"]["power"]} #keys:power, area, latency example: {"area": area_budget} + + self.SOCL0_other_metrics_dict = {"cost": self.other_values_dict["glass"]["cost"]} + self.SOCL0 = SOCL("glass", self.SOCL0_budget_dict, self.SOCL0_other_metrics_dict) + self.SOCsL.append(self.SOCL0) + + # get the budget values for the SOC + def get_budget_dict(self, SOC_name): + for SOC in self.SOCsL: + if SOC.type == SOC_name: + return SOC.get_budget_dict() + + print("SOC_name:" + SOC_name + "not defined") + exit(0) + + # get non budget values (e.g., cost) + def get_other_values_dict(self): + return self.other_values_dict + + # update the budget values + def update_budget_dict(self, budgets_dict): + self.budgets_dict = budgets_dict + + def update_other_values_dict(self, other_values_dict): + self.other_values_dict = other_values_dict + + # set the porting effort (porting effort is a coefficient that determines how hard it is to port the + # task for a non general purpose processor) + def set_porting_effort_for_block(self, block_type, porting_effort): + self.porting_effort[block_type] = porting_effort + assert(self.porting_effort["arm"] < self.porting_effort["dsp"]), "porting effort for arm needs to be smaller than dsp" + assert(self.porting_effort["dsp"] < self.porting_effort["ip"]), "porting effort for dsp needs to be smaller than ip" diff --git a/Project_FARSI/specs/gen_synthetic_data.py b/Project_FARSI/specs/gen_synthetic_data.py new file mode 100644 index 00000000..59f9038f --- /dev/null +++ b/Project_FARSI/specs/gen_synthetic_data.py @@ -0,0 +1,564 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +from specs.LW_cl import * +from specs.parse_libraries.parse_library import * + + + +def gen_tg_with_hops(task_name_intensity_type, others_task_cnt, parallel_task_cnt, serial_task_cnt, parallel_task_type, num_of_NoCs): + task_names = [task_name for task_name, intensity in task_name_intensity_type] + task_name_position = [] + name = "synthetic_" + parents = [name + "souurce"] + task_name_position.append((task_names[0], parents)) + + parents = [name + "0"] + task_name_position.append((task_names[1], parents)) + last_task = 2 + + """ + parents = [name + "0"] + task_name_position.append((task_names[2], parents)) + last_task = 3 + + parents = [name + "1", name + "2"] + task_name_position.append((task_names[3], parents)) + last_task =4 + """ + + for i in range(0, others_task_cnt + serial_task_cnt - 3): + parents = [name + str(last_task - 1)] + task_name_position.append((task_names[last_task], parents)) + last_task +=1 + + for i in range(0, num_of_NoCs-2): + parents = [name + str(last_task - 1)] + task_name_position.append((task_names[last_task], parents)) + last_task += 1 + return task_name_position,"_" + + +def gen_tg_core_improved(task_name_intensity_type, others_task_cnt, parallel_task_cnt, serial_task_cnt, parallel_task_type, mode, num_of_NoCs): + task_names = [task_name for task_name, intensity in task_name_intensity_type] + task_name_position = [] + name = "synthetic_" + #parallel_task_type = "edge_detection" + #parallel_task_type = "audio" + hoppy_tasks = [] + + for i in range(0,2): + if i == 0: + parents = [name+"souurce"] + task_name_position.append((task_names[i],parents)) + if i in [1]: + parents = [name +str(0)] + task_name_position.append((task_names[i], parents)) + + + num_of_NoCs_added_tasks = max(num_of_NoCs-2, 0) + + # spacial case + if parallel_task_cnt == 0: + last_task = 2 + for i in range(0, others_task_cnt - 4 - num_of_NoCs_added_tasks): + parents = [name + str(last_task- 1)] + task_name_position.append((task_names[last_task], parents)) + last_task +=1 + + if mode == "hop": + for i in range(0, num_of_NoCs_added_tasks): + parents = [name + str(last_task-1)] + task_name_position.append((task_names[last_task], parents)) + hoppy_tasks.append(task_names[last_task]) + last_task += 1 + + return task_name_position,{},hoppy_tasks + + + + + parallel_task_cnt -=1 # by default has one + # set up parallel type of audio + parallel_task_names = {} + parallel_task_names[0] = [] + parallel_task_names[0].append(name + str(1)) + last_task = 2 + if parallel_task_type == "audio_style": + parallel_offset =1 + last_serial = serial_offset = 1 + for i in range(0, parallel_task_cnt): + parents = [name + str(parallel_offset-1)] + task_name_position.append((task_names[1 + i+1], parents)) + parallel_task_names[0].append(name + str(i+2)) + last_task +=1 + + # set up serial + if parallel_task_type == "edge_detection_style": + serial_task_cnt += int(parallel_task_cnt/2) + + + for i in range(0, serial_task_cnt): + if i ==0: + parents = [el for el in parallel_task_names[0]] + else: + parents = [name + str(last_task - 1)] + task_name_position.append((task_names[last_task], parents)) + last_task +=1 + + if serial_task_cnt == 0: + parents = [el for el in parallel_task_names[0]] + else: + parents = [name + str(last_task-1)] + task_name_position.append((task_names[last_task], parents)) + last_right = last_task + + # set up parallel type of edge + parents = [name+str(0)] + last_task +=1 + task_name_position.append((task_names[last_task], parents)) + left_begin = last_task + parallel_task_names[0].append(task_names[left_begin]) + + if parallel_task_type == "edge_detection_style": + for i in range(0, int(parallel_task_cnt/2)): + parents = [name + str(last_task)] + last_task += 1 + task_name_position.append((task_names[last_task], parents)) + last_left = last_task + else: + last_left = last_task + + # set up last node + parents = [name + str(last_left), name + str(last_right)] + last_task +=1 + task_name_position.append((task_names[last_task], parents)) + + + if parallel_task_type == "edge_detection_style": + idx =1 + for i in range(2, left_begin): + parallel_task_names[idx] = [] + parallel_task_names[idx].append(name+str(idx)) + idx +=1 + idx = 1 + for i in range(left_begin, last_left): + parallel_task_names[idx].append(name + str(i)) + idx+=1 + + + if mode == "hop": + for i in range(0, num_of_NoCs_added_tasks): + parents = [name + str(last_task)] + task_name_position.append((task_names[last_task+1], parents)) + hoppy_tasks.append(task_names[last_task+1]) + last_task += 1 + + return task_name_position,parallel_task_names, hoppy_tasks + + + + +def gen_tg_core(task_name_intensity_type, others_task_cnt, parallel_task_cnt, serial_task_cnt): + task_names = [task_name for task_name, intensity in task_name_intensity_type] + task_name_position = [] + name = "synthetic_" + for i in range(0,2): + if i == 0: + parents = [name+"souurce"] + task_name_position.append((task_names[i],parents)) + if i in [1]: + parents = [name +str(0)] + task_name_position.append((task_names[i], parents)) + + parallel_offset =1 + last_serial = serial_offset = 1 + parallel_task_names = [] + parallel_task_names.append(name + str(1)) + for i in range(0, parallel_task_cnt): + parents = [name + str(parallel_offset-1)] + task_name_position.append((task_names[1 + i+1], parents)) + parallel_task_names.append(name + str(i+2)) + + for i in range(0, serial_task_cnt): + if i == 0: + parents = [el for el in parallel_task_names] + else: + parents = [name + str(1 + i + parallel_task_cnt)] + task_name_position.append((task_names[1 + i + 1 + parallel_task_cnt], parents)) + last_serial = serial_offset + i+1 + + + parents = [name + str(1 + 1 + parallel_task_cnt + serial_task_cnt-1)] + task_name_position.append((task_names[1 + 1 + parallel_task_cnt + serial_task_cnt], parents)) + + + parents = [name+str(0)] + task_name_position.append((task_names[1 + 2 + parallel_task_cnt + serial_task_cnt], parents)) + + parents = [name + str(1 + 2 + parallel_task_cnt + serial_task_cnt), name + str(1 + 2 + parallel_task_cnt + serial_task_cnt-1)] + task_name_position.append((task_names[1 + 2 + parallel_task_cnt + serial_task_cnt+1], parents)) + + + + return task_name_position + + +# ----------- +# Functionality: +# split tasks to what can run in parallel and +# what should run in serial +# Variables: +# task_name_intensity_type: list of tuples : (task name, intensity (memory intensive, comp intensive)) +# ----------- +def cluster_tasks(task_name_intensity_type, avg_parallelism): + task_names = [task_name for task_name, intensity in task_name_intensity_type] + task_name_position = [] + state = "par" + pos_ctr = 0 + state_ctr = avg_parallelism + for idx, task in enumerate(task_names): + if state == "par" and state_ctr > 0: + state_ctr -= 1 + elif state == "par" and state_ctr == 0: + state = "ser" + pos_ctr += 1 + elif state == "ser": + state = "par" + state_ctr = avg_parallelism + state_ctr -= 1 + pos_ctr += 1 + task_name_position.append((task, pos_ctr)) + return task_name_position + + +# ----------- +# Functionality: +# generate synthetic work (instructions, bytes) for tasks +# Variables: +# task_name_intensity_type: list of tuples : (task name, intensity (memory intensive, comp intensive)) +# ----------- +def generate_synthetic_work(task_exec_intensity_type, general_task_type_char): + + work_dict = {} + for task, intensity in task_exec_intensity_type: + if "siink" in task or "souurce" in task: + work_dict[task] = 0 + else: + work_dict[task] = general_task_type_char[intensity]["exec"] + + return work_dict + + +def generate_synthetic_datamovement_asymetric_tg(task_exec_intensity_type, others_task_cnt, parallel_task_cnt, serial_task_cnt , parallel_task_type, exec_intensity_scaling_factor, num_of_NoCs): + # hardcoded data. + # TODO: move these into a config file later + general_task_type_char = {} + general_task_type_char["memory_intensive"] = {} + general_task_type_char["comp_intensive"] = {} + general_task_type_char["dummy_intensive"] = {} + general_task_type_char["memory_intensive"]["read_bytes"] = exec_intensity_scaling_factor*500000 + general_task_type_char["memory_intensive"]["write_bytes"] = exec_intensity_scaling_factor*500000 # hardcoded, delete later + general_task_type_char["comp_intensive"]["read_bytes"] = exec_intensity_scaling_factor*80000 + general_task_type_char["comp_intensive"]["write_bytes"] = exec_intensity_scaling_factor*80000 + + general_task_type_char["memory_intensive"]["exec"] = exec_intensity_scaling_factor*50000 + general_task_type_char["comp_intensive"]["exec"] = exec_intensity_scaling_factor*10*50000*3 + + general_task_type_char["dummy_intensive"]["read_bytes"] = 64 + general_task_type_char["dummy_intensive"]["write_bytes"] = 64 + general_task_type_char["dummy_intensive"]["exec"] = 64 + + #general_task_type_char["memory_intensive"]["read_bytes"] = math.floor((exec_intensity_scaling_factor * 500000)/256)*256 + #general_task_type_char["memory_intensive"]["write_bytes"] = math.floor((exec_intensity_scaling_factor * 500000)/256)*256 + + # find a family task (parent, child or sibiling) + def find_family(task_name_position, task_name, relationship): + task_position = [position for task, position in task_name_position if task == task_name][0] + if relationship == "parent": + parents = [task_name for task_name, position in task_name_position if position ==(task_position -1)] + return parents + elif relationship == "child": + children = [task_name for task_name, position in task_name_position if position ==(task_position +1)] + return children + else: + print('relationsihp ' + relationship + " is not defined") + exit(0) + + def find_family_asymetric_tg(task_name_position, task_name, relationship): + task_parents = [parents for task, parents in task_name_position if task == task_name] + if relationship == "parent": + return task_parents[0] + elif relationship == "child": + children = [] + for task_name_, parents in task_name_position: + for parent in parents: + if parent == task_name: + children.append(task_name_) + return children + else: + print('relationsihp ' + relationship + " is not defined") + exit(0) + + data_movement = {} + if num_of_NoCs == 1: + task_name_position, parallel_task_names, hoppy_tasks = gen_tg_core_improved(task_exec_intensity_type, others_task_cnt, parallel_task_cnt, serial_task_cnt, parallel_task_type, "non_hop", 1) + else: + task_name_position, parallel_task_names, hoppy_tasks = gen_tg_core_improved(task_exec_intensity_type, others_task_cnt, parallel_task_cnt, serial_task_cnt, parallel_task_type, "hop", num_of_NoCs) + #task_name_position, parallel_task_names = gen_tg_with_hops(task_exec_intensity_type, others_task_cnt, parallel_task_cnt, serial_task_cnt, parallel_task_type, "hop", num_of_NoCs) + + task_name_position.insert(0,("synthetic_souurce", [""])) + all_idx = [] + for el, pos in task_name_position: + if "souurce" in el: + continue + all_idx.append(int(el.split('_')[1])) + last_task_name = max(all_idx) + parent = ["synthetic_"+str(last_task_name)] + task_name_position.append(("synthetic_siink", parent)) + task_exec_intensity_type.insert(0, ("synthetic_souurce", (task_exec_intensity_type[0])[1])) + task_exec_intensity_type.append(("synthetic_siink", task_exec_intensity_type[0][1])) + + for task,_ in task_name_position: + data_movement[task] = {} + tasks_parents = find_family_asymetric_tg(task_name_position, task, "parent") + tasks_children = find_family_asymetric_tg(task_name_position, task, "child") + for child in tasks_children: + exec_intensity = [intensity for task_, intensity in task_exec_intensity_type if task_== task][0] + #if child == "synthetic_siink": # hardcoded delete later + # data_movement[task][child] = 1 + #else: + data_movement[task][child] = general_task_type_char[exec_intensity]["write_bytes"] + return data_movement,general_task_type_char, parallel_task_names,hoppy_tasks + + + +# ----------- +# Functionality: +# generate synthetic datamovement between tasks +# Variables: +# task_exec_intensity_type: memory or computational intensive +# average number of tasks that can run in parallel +# ----------- +def generate_synthetic_datamovement(task_exec_intensity_type, avg_parallelism): + # hardcoded data. + # TODO: move these into a config file later + general_task_type_char = {} + general_task_type_char["memory_intensive"] = {} + general_task_type_char["comp_intensive"] = {} + general_task_type_char["memory_intensive"]["read_bytes"] = 100000 + general_task_type_char["memory_intensive"]["write_bytes"] = 100000 # hardcoded, delete later + general_task_type_char["comp_intensive"]["read_bytes"] = 100 + general_task_type_char["comp_intensive"]["write_bytes"] = 100 + + general_task_type_char["memory_intensive"]["exec"] = 100 + general_task_type_char["comp_intensive"]["exec"] = 10000000 + + # find a family task (parent, child or sibiling) + def find_family(task_name_position, task_name, relationship): + task_position = [position for task, position in task_name_position if task == task_name][0] + if relationship == "parent": + parents = [task_name for task_name, position in task_name_position if position ==(task_position -1)] + return parents + elif relationship == "child": + children = [task_name for task_name, position in task_name_position if position ==(task_position +1)] + return children + else: + print('relationsihp ' + relationship + " is not defined") + exit(0) + + + + data_movement = {} + task_name_position = cluster_tasks(task_exec_intensity_type, avg_parallelism) + task_name_position.insert(0,("synthetic_souurce", -1)) + task_name_position.append(("synthetic_siink", max([pos for el, pos in task_name_position])+1)) + task_exec_intensity_type.insert(0, ("synthetic_souurce", (task_exec_intensity_type[0])[1])) + task_exec_intensity_type.append(("synthetic_siink", task_exec_intensity_type[0][1])) + + for task,_ in task_name_position: + data_movement[task] = {} + tasks_parents = find_family(task_name_position, task, "parent") + tasks_children = find_family(task_name_position, task, "child") + for child in tasks_children: + exec_intensity = [intensity for task_, intensity in task_exec_intensity_type if task_== task][0] + #if child == "synthetic_siink": # hardcoded delete later + # data_movement[task][child] = 1 + #else: + data_movement[task][child] = general_task_type_char[exec_intensity]["write_bytes"] + + return data_movement + + +# we generate very simple scenarios for now. +def generate_synthetic_task_graphs_for_asymetric_graphs(num_of_tasks, others_task_cnt, parallel_task_cnt, serial_task_cnt, parallel_task_type, intensity_params, num_of_NoCs): + + exec_intensity = intensity_params[0] + exec_intensity_scaling_factor = intensity_params[1] # scaling with respect to the referece (amount of data movement) + intensity_ratio = intensity_params[2] # what percentage of the tasks are memory intensive + + #---------------------- + # assigning memory bounded ness or compute boundedness to tasks + #---------------------- + num_of_NoCs_added_tasks = max(num_of_NoCs-2, 0) + opposite_intensity_task_cnt = int((num_of_tasks - 2 - num_of_NoCs_added_tasks)*(1-intensity_ratio)) + opposite_intensity = list({"memory_intensive", "comp_intensive"}.difference(set([exec_intensity])))[0] + + last_idx = 0 + task_exec_intensity = [] + for idx in range(0, num_of_tasks - 2 - opposite_intensity_task_cnt - num_of_NoCs_added_tasks): + task_exec_intensity.append(("synthetic_"+str(idx), exec_intensity)) + last_idx = idx+1 + + for idx in range(0, opposite_intensity_task_cnt): + task_exec_intensity.append(("synthetic_"+str(last_idx), opposite_intensity)) + last_idx +=1 + + # for dummy tasks taht are used for hops + for idx in range(0, num_of_NoCs_added_tasks): + task_exec_intensity.append(("synthetic_"+str(last_idx + idx), "dummy_intensive")) + + + # generate task graph and data movement + task_graph_dict, general_task_type_char, parallel_task_names, hoppy_task_names = generate_synthetic_datamovement_asymetric_tg(task_exec_intensity, num_of_tasks, parallel_task_cnt, serial_task_cnt, parallel_task_type, exec_intensity_scaling_factor, num_of_NoCs) + + # collect number of instructions for each tasks + work_dict = generate_synthetic_work(task_exec_intensity, general_task_type_char) + for task,work in work_dict.items(): + intensity_ = "none" + for task_, intensity__ in task_exec_intensity: + if task_ == task: + intensity_ = intensity__ + break + if intensity_ == "comp_intensive": + children_cnt = len(list(task_graph_dict[task].values())) + work_dict[task] = work_dict[task]*(children_cnt+1) + + + + tasksL: List[TaskL] = [] + for task_name_, values in task_graph_dict.items(): + task_ = TaskL(task_name=task_name_, work=work_dict[task_name_]) + task_.add_task_work_distribution([(work_dict[task_name_], 1)]) + tasksL.append(task_) + + for task_name_, values in task_graph_dict.items(): + task_ = [taskL for taskL in tasksL if taskL.task_name == task_name_][0] + for child_task_name, data_movement in values.items(): + if child_task_name in hoppy_task_names or child_task_name in ["synthetic_siink"]: + data_movement = 64 + child_task = [taskL for taskL in tasksL if taskL.task_name == child_task_name][0] + task_.add_child(child_task, data_movement, "real") # eye_tracking_soource t glint_mapping + task_.add_task_to_child_work_distribution(child_task, [(data_movement, 1)]) # eye_tracking_soource t glint_mapping + + return tasksL,task_graph_dict, work_dict, parallel_task_names,hoppy_task_names + + +#generate_synthetic_task_graphs_for_asymetric_graphs(10, 3, "memory_intensive") + + +# we generate very simple scenarios for now. +def generate_synthetic_task_graphs(num_of_tasks, avg_parallelism, exec_intensity): + assert(num_of_tasks > avg_parallelism) + + tasksL: List[TaskL] = [] + + # generate a list of task names and their exec intensity (i.e., compute or memory intensive) + task_exec_intensity = [] + for idx in range(0, num_of_tasks - 2): + task_exec_intensity.append(("synthetic_"+str(idx), exec_intensity)) + + # collect data movement data + task_graph_dict = generate_synthetic_datamovement(task_exec_intensity, avg_parallelism) + + # collect number of instructions for each tasks + work_dict = generate_synthetic_work(task_exec_intensity) + + for task_name_, values in task_graph_dict.items(): + task_ = TaskL(task_name=task_name_, work=work_dict[task_name_]) + task_.add_task_work_distribution([(work_dict[task_name_], 1)]) + tasksL.append(task_) + + for task_name_, values in task_graph_dict.items(): + task_ = [taskL for taskL in tasksL if taskL.task_name == task_name_][0] + for child_task_name, data_movement in values.items(): + child_task = [taskL for taskL in tasksL if taskL.task_name == child_task_name][0] + task_.add_child(child_task, data_movement, "real") # eye_tracking_soource t glint_mapping + task_.add_task_to_child_work_distribution(child_task, [(data_movement, 1)]) # eye_tracking_soource t glint_mapping + + return tasksL,task_graph_dict, work_dict + +# generate a synthetic hardware library to generate systems from +def generate_synthetic_hardware_library(task_work_dict, library_dir, Block_char_file_name): + + blocksL: List[BlockL] = [] # collection of all the blocks + pe_mapsL: List[TaskToPEBlockMapL] = [] + pe_schedulesL: List[TaskScheduleL] = [] + + gpps = parse_block_based_on_types(library_dir, Block_char_file_name, ("pe", "gpp")) + gpp_names = list(gpps.keys()) + mems = parse_block_based_on_types(library_dir, Block_char_file_name, ("mem", "sram")) + mems.update(parse_block_based_on_types(library_dir, Block_char_file_name, ("mem", "dram"))) + ics = parse_block_based_on_types(library_dir, Block_char_file_name, ("ic", "ic")) + + + hardware_library_dict = {} + for task_name in task_work_dict.keys(): + for IP_name in gpp_names: + if IP_name in hardware_library_dict.keys(): + hardware_library_dict[IP_name]["mappable_tasks"].append(task_name) + continue + hardware_library_dict[IP_name] = {} + if IP_name in gpps: + hardware_library_dict[IP_name]["work_rate"] = float(gpps[IP_name]['Freq'])*float(gpps[IP_name]["dhrystone_IPC"]) + hardware_library_dict[IP_name]["work_over_energy"] = float(gpps[IP_name]['Inst_per_joul']) + hardware_library_dict[IP_name]["work_over_area"] = 1/float(gpps[IP_name]['Gpp_area']) + hardware_library_dict[IP_name]["mappable_tasks"] = [task_name] + hardware_library_dict[IP_name]["type"] = "pe" + hardware_library_dict[IP_name]["sub_type"] = "gpp" + hardware_library_dict[IP_name]["clock_freq"] = gpps[IP_name]['Freq'] + hardware_library_dict[IP_name]["bus_width"] = "NA" + #print("taskname: " + str(task_name) + ", subtype: gpp, power is"+ str(hardware_library_dict[IP_name]["work_rate"]/hardware_library_dict[IP_name]["work_over_energy"] )) + + for blck_name, blck_value in mems.items(): + hardware_library_dict[blck_value['Name']] = {} + hardware_library_dict[blck_value['Name']]["work_rate"] = float(blck_value['BitWidth'])*float(blck_value['Freq']) + hardware_library_dict[blck_value['Name']]["work_over_energy"] = float(blck_value['Byte_per_joul']) + hardware_library_dict[blck_value['Name']]["work_over_area"] = float(blck_value['Byte_per_m']) + hardware_library_dict[blck_value['Name']]["mappable_tasks"] = 'all' + hardware_library_dict[blck_value['Name']]["type"] = "mem" + hardware_library_dict[blck_value['Name']]["sub_type"] = blck_value['Subtype'] + hardware_library_dict[blck_value['Name']]["clock_freq"] = blck_value['Freq'] + hardware_library_dict[blck_value['Name']]["bus_width"] = blck_value['BitWidth']*8 + + for blck_name, blck_value in ics.items(): + hardware_library_dict[blck_value['Name']] = {} + hardware_library_dict[blck_value['Name']]["work_rate"] = float(blck_value['BitWidth'])*float(blck_value['Freq']) + hardware_library_dict[blck_value['Name']]["work_over_energy"] = float(blck_value['Byte_per_joul']) + hardware_library_dict[blck_value['Name']]["work_over_area"] = float(blck_value['Byte_per_m']) + hardware_library_dict[blck_value['Name']]["mappable_tasks"] = 'all' + hardware_library_dict[blck_value['Name']]["type"] = "ic" + hardware_library_dict[blck_value['Name']]["sub_type"] = "ic" + hardware_library_dict[blck_value['Name']]["clock_freq"] = blck_value['Freq'] + hardware_library_dict[blck_value['Name']]["bus_width"] = blck_value['BitWidth']*8 + + + block_suptype = "gpp" # default. + for IP_name, values in hardware_library_dict.items(): + block_subtype = values['sub_type'] + block_type = values['type'] + blocksL.append( + BlockL(block_instance_name=IP_name, block_type=block_type, block_subtype=block_subtype, + peak_work_rate_distribution = {hardware_library_dict[IP_name]["work_rate"]:1}, + work_over_energy_distribution = {hardware_library_dict[IP_name]["work_over_energy"]:1}, + work_over_area_distribution = {hardware_library_dict[IP_name]["work_over_area"]:1}, + one_over_area_distribution = {1/hardware_library_dict[IP_name]["work_over_area"]:1}, + clock_freq=hardware_library_dict[IP_name]["clock_freq"], bus_width=hardware_library_dict[IP_name]['bus_width'])) + + if block_type == "pe": + for mappable_tasks in hardware_library_dict[IP_name]["mappable_tasks"]: + task_to_block_map_ = TaskToPEBlockMapL(task_name=mappable_tasks, pe_block_instance_name=IP_name) + pe_mapsL.append(task_to_block_map_) + + return blocksL, pe_mapsL, pe_schedulesL \ No newline at end of file diff --git a/Project_FARSI/specs/parse_libraries/__init__.py b/Project_FARSI/specs/parse_libraries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Project_FARSI/specs/parse_libraries/parse_library.py b/Project_FARSI/specs/parse_libraries/parse_library.py new file mode 100644 index 00000000..b5f8f2c0 --- /dev/null +++ b/Project_FARSI/specs/parse_libraries/parse_library.py @@ -0,0 +1,801 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +import csv +import os +from collections import defaultdict +from design_utils.components.hardware import * +from SIM_utils.SIM import * +from specs.LW_cl import * +import pandas as pd +import itertools + +# ------------------------------ +# Functionality: +# parse the csv file and return a dictionary containing the hardware graph +# ------------------------------ +def parse_hardware_graph(hardware_graph_file): + if not os.path.exists(hardware_graph_file): + return "" + + reader = csv.DictReader(open(hardware_graph_file, 'r')) + dict_list = [] + for line in reader: + dict_list.append(line) + + # generate the task graph dictionary, a 2d dictionary where + # the first coordinate is the parent and the second coordinate is the child + hardware_graph_dict = defaultdict(dict) + table_name = "Block Name" + for dict_ in dict_list: + block_name = [dict_[key] for key in dict_.keys() if key == table_name][0] + hardware_graph_dict[block_name] = {} + for child_block_name, data_movement in dict_.items(): + if child_block_name == table_name: # skip the table_name entry + continue + elif child_block_name == block_name: # for now skip the self movement (Requires scratch pad) + continue + elif data_movement == "": + continue + else: + hardware_graph_dict[block_name][child_block_name] = float(data_movement) + + return hardware_graph_dict + + +# ------------------------------ +# Functionality: +# parse the csv file and return a dictionary containing the task to hardware mapping +# ------------------------------ +def parse_task_to_hw_mapping(task_to_hw_mapping_file): + if not os.path.exists(task_to_hw_mapping_file): + return "" + reader = csv.DictReader(open(task_to_hw_mapping_file, 'r')) + dict_list = [] + for line in reader: + dict_list.append(line) + + # generate the task graph dictionary, a 2d dictionary where + # the first coordinate is the parent and the second coordinate is the child + hardware_graph_dict = defaultdict(dict) + table_name = "Block Name" + block_names = [] + for dict_ in dict_list: + block_names.extend([key for key in dict_.keys()]) + block_names = list(set(block_names)) + + block_mapping = {} + for dict_ in dict_list: + for blk, task in dict_.items(): + if task == "": + continue + elif ("->") in task: + task = [el.strip() for el in task.split("->")] # split the producer consumer and get rid of extra spaces + else: + task = [task, task] + + if blk not in block_mapping.keys(): + block_mapping[blk] = [task] + else: + block_mapping[blk].append(task) + + return block_mapping + +# ------------------------------ +# Functionality: +# parse the csv file and return a dictionary containing the task graph +# ------------------------------ +def parse_task_graph_data_movement(task_graph_file_addr): + reader = csv.DictReader(open(task_graph_file_addr, 'r')) + dict_list = [] + for line in reader: + dict_list.append(line) + + # generate the task graph dictionary, a 2d dictionary where + # the first coordinate is the parent and the second coordinate is the child + task_graph_data_movement_dict = defaultdict(dict) + table_name = "Task Name" + for dict_ in dict_list: + task_name = [dict_[key] for key in dict_.keys() if key == table_name][0] + task_graph_data_movement_dict[task_name] = {} + for child_task_name, data_movement in dict_.items(): + if child_task_name == table_name: # skip the table_name entry + continue + elif child_task_name == task_name: # for now skip the self movement (Requires scratch pad) + continue + elif data_movement == "": + continue + else: + task_graph_data_movement_dict[task_name][child_task_name] = float(data_movement) + + return task_graph_data_movement_dict + + +# ------------------------------ +# Functionality: +# file finding helper +# ------------------------------ +def get_full_file_name(partial_name, file_list): + for file_name in file_list: + if partial_name == file_name: + return file_name + print("file with the name of :" + partial_name + " doesnt exist") + + +def get_block_clock_freq(library_dir, input_file_name): + input_file_addr = os.path.join(library_dir, input_file_name) + + df = pd.read_csv(input_file_addr) + + # eval the expression + def evaluate(value): + replaced_value_1 = value.replace("^", "**") + replaced_value_2 = replaced_value_1.replace("=", "") + return eval(replaced_value_2) + + misc_data = {} + for index, row in df.iterrows(): + temp_dict = row.to_dict() + misc_data[list(temp_dict.values())[0]] = evaluate(list(temp_dict.values())[1]) + + hardware_sub_type = ["sram", "dram", "ip", "gpp", "ic"] + block_sub_type_clock_freq = {} + for key in misc_data.keys() : + if "clock" in key: + for type in hardware_sub_type: + if type in key: + block_sub_type_clock_freq[type] = misc_data[key] + + return block_sub_type_clock_freq + +# ------------------------------ +# Functionality: +# parse the task graph csv and generate FARSI digestible task graph +# ------------------------------ +def gen_task_graph(library_dir, prefix, misc_knobs): + tasksL: List[TaskL] = [] + + # get files + file_list = [f for f in os.listdir(library_dir) if os.path.isfile(os.path.join(library_dir, f))] + data_movement_file_name = get_full_file_name(prefix + "Task Data Movement.csv", file_list) + IP_perf_file_name = get_full_file_name(prefix + "Task PE Performance.csv", file_list) + Block_char_file_name = get_full_file_name("misc_database - "+ "Block Characteristics.csv", file_list) + + # collect data movement data + data_movement_file_addr = os.path.join(library_dir, data_movement_file_name) + task_graph_dict = parse_task_graph_data_movement(data_movement_file_addr) + + # collect number of instructions for each tasks + work_dict = gen_task_graph_work_dict(library_dir, IP_perf_file_name, Block_char_file_name, misc_knobs) + """ + for task_name, work in work_dict.items(): + print(task_name+","+str(work)) + exit(0) + """ + + if "burst_size_options" in misc_knobs: + universal_burst_size = misc_knobs["burst_size_options"][0] # this will be tuned by the DSE + else: + universal_burst_size = config.default_burst_size + + for task_name_, values in task_graph_dict.items(): + task_ = TaskL(task_name=task_name_, work=work_dict[task_name_]) + task_.set_burst_size(universal_burst_size) + task_.add_task_work_distribution([(work_dict[task_name_], 1)]) + tasksL.append(task_) + + for task_name_, values in task_graph_dict.items(): + task_ = [taskL for taskL in tasksL if taskL.task_name == task_name_][0] + for child_task_name, data_movement in values.items(): + child_task = [taskL for taskL in tasksL if taskL.task_name == child_task_name][0] + task_.add_child(child_task, data_movement, "real") # eye_tracking_soource t glint_mapping + task_.add_task_to_child_work_distribution(child_task, [(data_movement, 1)]) # eye_tracking_soource t glint_mapping + + return tasksL,task_graph_dict + + +# ------------------------------ +# Functionality: +# get teh reference gpp (general purpose processor) data (properties) +# ------------------------------ +def get_ref_gpp_data(dict_list): + for dict_ in dict_list: + # get all the properties + speedup = [float(value) for key, value in dict_.items() if key == "speed up"][0] + if speedup == 1: + return dict_ + print("couldn't find the reference gpp") + exit(0) + + +# ------------------------------ +# Functionality: +# get number of iterations for a task +# ------------------------------ +def parse_task_itr_cnt(library_dir, task_itr_cnt_file_name): + if task_itr_cnt_file_name == None : + print(" Task Itr Cnt for the workload is not provided") + return {} + elif not os.path.exists(os.path.join(library_dir, task_itr_cnt_file_name)): + print(" Task Itr Cnt for the workload is not provided") + return {} + + reader = csv.DictReader(open(os.path.join(library_dir, task_itr_cnt_file_name), 'r')) + dict_list = [] + task_itr_cnt = {} + for line in reader: + dict_list.append(line) + + task_metric_dict = {} + for dict_ in dict_list: + task_name = [dict_[key] for key in dict_.keys() if key == "Task Name"][0] + itr_cnt = int([dict_[key] for key in dict_.keys() if key == "number of iterations"][0]) + task_itr_cnt[task_name] = itr_cnt + + return task_itr_cnt + +# ------------------------------ +# Functionality: +# get each tasks PPA (perf, pow, area) data +# ------------------------------ +def parse_task_PPA(library_dir, IP_file_name, gpp_names): + reader = csv.DictReader(open(os.path.join(library_dir, IP_file_name), 'r')) + dict_list = [] + for line in reader: + dict_list.append(line) + + task_metric_dict = {} + for dict_ in dict_list: + task_name = [dict_[key] for key in dict_.keys() if key == "Task Name"][0] + task_metric_dict[task_name] = {} + for key, value in dict_.items(): + if key == "Task Name" or value == "": + continue + if key not in gpp_names: + ip_name_modified = task_name + "_" + key + else: + ip_name_modified = key + + task_metric_dict[task_name][ip_name_modified] = float(value) + return task_metric_dict + + +# ------------------------------ +# Functionality: +# get the performance of the task on the gpp +# ------------------------------ +def parse_task_perf_on_ref_gpp(library_dir, IP_perf_file_name, Block_char_file_name): + reader = csv.DictReader(open(os.path.join(library_dir, IP_perf_file_name), 'r')) + dict_list = [] + for line in reader: + dict_list.append(line) + + ref_gpp_dict = parse_ref_block_values(library_dir, Block_char_file_name, ("pe", "gpp")) + + task_perf_on_ref_dict = defaultdict(dict) + for dict_ in dict_list: + task_name = [dict_[key] for key in dict_.keys() if key == "Task Name"][0] + task_perf_on_ref_dict[task_name] = {} + time = [dict_[key] for key in dict_.keys() if ref_gpp_dict['Name'] in key][0] + task_perf_on_ref_dict[task_name] = float(time) + return task_perf_on_ref_dict + + +# ------------------------------ +# Functionality: +# parse and find the hardware blocks of a certain type +# ------------------------------ +def parse_block_based_on_types(library_dir, Block_char_file_name, type_sub_type): + type = type_sub_type[0] + sub_type = type_sub_type[1] + ref_gpp_dict = {} + reader = csv.DictReader(open(os.path.join(library_dir, Block_char_file_name), 'r')) + blck_dict_list = [] + for line in reader: + blck_dict_list.append(line) + + blck_dict = {} + for dict_ in blck_dict_list: + if dict_['Type'] == type and dict_['Subtype'] == sub_type: + blck_dict[dict_['Name']] = {} + for key, value in dict_.items(): + if value.isdigit(): + blck_dict[dict_['Name']][key] = float(value) + else: + blck_dict[dict_['Name']][key] = value + return blck_dict + + +# ------------------------------ +# Functionality: +# parse reference block blocks values +# ------------------------------ +def parse_ref_block_values(library_dir, Block_char_file_name, type_sub_type): + blck_dict = parse_block_based_on_types(library_dir, Block_char_file_name, type_sub_type) + for blck_name, blck_dict_ in blck_dict.items(): + for key_, value_ in blck_dict_.items(): + if key_ == "Ref" and value_ == "yes": + return blck_dict_ + + print("need to at least have one ref gpp") + exit(0) + + +# ------------------------------ +# Functionality: +# generate task graph and populate it with work values +# ------------------------------ +def gen_task_graph_work_dict(library_dir, IP_perf_file_name, Block_char_file_name, misc_knobs): + #correction_values = gen_correction_values(workmisc_knobs) + + gpp_file_addr = os.path.join(library_dir, Block_char_file_name) + IP_perf_file_addr = os.path.join(library_dir, IP_perf_file_name) + gpp_perf_file_addr = os.path.join(library_dir, Block_char_file_name) + + # parse the file and collect in a dictionary + reader = csv.DictReader(open(IP_perf_file_addr, 'r')) + dict_list = [] + for line in reader: + dict_list.append(line) + + # collect data on ref gpp (from ref_gpp perspective) + ref_gpp_dict = parse_ref_block_values(library_dir, Block_char_file_name, ("pe", "gpp")) + # calculate the work (basically number of instructions per task) + task_perf_on_ref_gpp_dict = parse_task_perf_on_ref_gpp(library_dir, IP_perf_file_name, Block_char_file_name) + + work_dict = {} # per task, what is the work (number of instruction processed) + + # right now, the data is reported in cycles + for task_name, time in task_perf_on_ref_gpp_dict.items(): + # the following two lines can be omitted once the time is reported in seconds + cycles = time # we have to do this, because at the moment, the data is reported in cycles + time = cycles/float(ref_gpp_dict["Freq"]) # we have to do this, because at the moment, the data is reported in cycles + + # don't need to use frequency correction values for work since frequency cancels out + work_dict[task_name] = time* float(ref_gpp_dict["dhrystone_IPC"])*float(ref_gpp_dict["Freq"]) + + return work_dict + + +# ------------------------------ +# Functionality: +# find all the ips (accelerators) for a task +# ------------------------------ +def deduce_IPs(task_PEs, gpp_names): + ip_dict = {} + for task_PE in task_PEs: + for PE, cycles in task_PE.items(): + if PE in gpp_names or PE in ip_dict.keys(): + continue + ip_dict[PE] = {} + ip_dict[PE]["Freq"] = 100000000 + + return ip_dict + + +def convert_energy_to_power(task_PPA_dict): + task_names = task_PPA_dict["perf"].keys() + task_PPA_dict["power"] = {} + for task_name in task_names: + task_PPA_dict["power"][task_name] = {} + blocks = task_PPA_dict["perf"][task_name].keys() + for block in blocks: + if task_PPA_dict["perf"][task_name][block] == 0 or (block not in task_PPA_dict["energy"][task_name].keys()): + task_PPA_dict["power"][task_name][block] = 0 + else: + task_PPA_dict["power"][task_name][block] = task_PPA_dict["energy"][task_name][block]/task_PPA_dict["perf"][task_name][block] + + +# based on various knobs correct for the parsed data. +# if the input doesn't require any correction, then no changes are applied +def gen_correction_values(workload, misc_knobs): + # instantiate the dictionary + correction_dict = {} + correction_dict["ip"] = {} + correction_dict["gpp"] = {} + correction_dict["dram"] = {} + correction_dict["sram"] = {} + correction_dict["ic"] = {} + + # initilize the correction values (in case misc knobs do not contain these values) + ip_freq_correction_ratio = 1 + gpp_freq_correction_ratio = 1 + dram_freq_correction_ratio = 1 + sram_freq_correction_ratio = 1 + ic_freq_correction_ratio = 1 + tech_node_SF = {} + tech_node_SF["perf"] =1 + tech_node_SF["energy"] = {"gpp":1, "non_gpp":1} + tech_node_SF["area"] = {"mem":1, "non_mem":1, "gpp":1} + + # if any of hte above values found in misc_knobs, over write + if "ip_freq_correction_ratio" in misc_knobs.keys(): + ip_freq_correction_ratio = misc_knobs["ip_freq_correction_ratio"] + + if "gpp_freq_correction_ratio" in misc_knobs.keys(): + gpp_freq_correction_ratio = misc_knobs["gpp_freq_correction_ratio"] + + if "dram_freq_correction_ratio" in misc_knobs.keys(): + dram_freq_correction_ratio = misc_knobs["dram_freq_correction_ratio"] + + if "sram_freq_correction_ratio" in misc_knobs.keys(): + sram_freq_correction_ratio = misc_knobs["sram_freq_correction_ratio"] + + if "sram_freq_correction_ratio" in misc_knobs.keys(): + ic_freq_correction_ratio = misc_knobs["ic_freq_correction_ratio"] + + if "tech_node_SF" in misc_knobs.keys(): + tech_node_SF = misc_knobs["tech_node_SF"] + + # populate the correction dictionary + correction_dict["ip"]["work_rate"] = (1/tech_node_SF["perf"])*ip_freq_correction_ratio + correction_dict["gpp"]["work_rate"] = (1/tech_node_SF["perf"])*gpp_freq_correction_ratio + correction_dict["dram"]["work_rate"] = (1/tech_node_SF["perf"])*dram_freq_correction_ratio + correction_dict["sram"]["work_rate"] = (1/tech_node_SF["perf"])*sram_freq_correction_ratio + correction_dict["ic"]["work_rate"] = (1/tech_node_SF["perf"])*ic_freq_correction_ratio + + correction_dict["ip"]["work_over_energy"] = (1/tech_node_SF["energy"]["non_gpp"])*1 + correction_dict["gpp"]["work_over_energy"] = (1/tech_node_SF["energy"]["gpp"])*1 + correction_dict["sram"]["work_over_energy"] = (1/tech_node_SF["energy"]["non_gpp"])*1 + correction_dict["dram"]["work_over_energy"] = (1/tech_node_SF["energy"]["non_gpp"])*1 + correction_dict["ic"]["work_over_energy"] = (1/tech_node_SF["energy"]["non_gpp"])*1 + + + correction_dict["ip"]["work_over_area"] = (1/tech_node_SF["area"]["non_mem"])*1 + correction_dict["gpp"]["work_over_area"] = (1/tech_node_SF["area"]["gpp"])*1 + correction_dict["sram"]["work_over_area"] = (1/tech_node_SF["area"]["mem"])*1 + correction_dict["dram"]["work_over_area"] = (1/tech_node_SF["area"]["mem"])*1 + correction_dict["ic"]["work_over_area"] = (1/tech_node_SF["area"]["non_mem"])*1 + + correction_dict["ip"]["one_over_area"] = (1 / tech_node_SF["area"]["non_mem"]) * 1 + correction_dict["gpp"]["one_over_area"] = (1 / tech_node_SF["area"]["gpp"]) * 1 + correction_dict["sram"]["one_over_area"] = (1 / tech_node_SF["area"]["mem"]) * 1 + correction_dict["dram"]["one_over_area"] = (1 / tech_node_SF["area"]["mem"]) * 1 + correction_dict["ic"]["one_over_area"] = (1 / tech_node_SF["area"]["non_mem"]) * 1 + + return correction_dict + +# ------------------------------ +# Functionality: +# parse the hardware library +# ------------------------------ +def parse_hardware_library(library_dir, IP_perf_file_name, + IP_energy_file_name, IP_area_file_name, + Block_char_file_name, task_itr_cnt_file_name, workload, misc_knobs): + + def gen_freq_range(misc_knobs, block_sub_type): + assert(block_sub_type in ["ip", "mem", "ic"]) + if block_sub_type+"_spawn" not in misc_knobs.keys(): + result = [1] + else: + spawn = misc_knobs[block_sub_type+"_spawn"] + result = spawn[block_sub_type+"_freq_range"] + #upper_bound = spawn[block_sub_type+"_freq_range"]["upper_bound"] + #incr = spawn[block_sub_type+"_freq_range"]["incr"] + #result = list(range(1, int(upper_bound), int(incr))) + return result + + + def gen_loop_itr_range(task_name, task_itr_cnt, misc_knobs): + max_num_itr = 1 + max_spawn_ip_by_loop_itr = 1 + loop_itr_incr = 1 + + # base cases + if task_name not in task_itr_cnt: + return range(1, 2) + else: + max_num_itr = task_itr_cnt[task_name] + if max_num_itr == 1: + return range(1, 2) + if "ip_spawn" not in misc_knobs.keys(): + return range(1,2) + + # sanity check + assert misc_knobs["ip_spawn"]["ip_loop_unrolling"]["spawn_mode"] in ["arithmetic", "geometric"] + if misc_knobs["ip_spawn"]["ip_loop_unrolling"]["spawn_mode"] == "geometric": + assert(misc_knobs["ip_spawn"]["ip_loop_unrolling"]["incr"] > 1) + + # get parameters + if "incr" in misc_knobs["ip_spawn"]["ip_loop_unrolling"].keys(): + loop_itr_incr = misc_knobs["ip_spawn"]["ip_loop_unrolling"]["incr"] + + if "max_spawn_ip" in misc_knobs["ip_spawn"]["ip_loop_unrolling"].keys(): + max_spawn_ip_by_loop_itr = misc_knobs["ip_spawn"]["ip_loop_unrolling"]["max_spawn_ip"] + else: + max_spawn_ip_by_loop_itr = max_num_itr + + # use arithmetic or geometric progression to spawn ips + if misc_knobs["ip_spawn"]["ip_loop_unrolling"]["spawn_mode"] == "arithmetic": + num_ips_perspective_2 = int(max_num_itr / loop_itr_incr) + result = list(range(1, int(max_num_itr), int(loop_itr_incr))) + elif misc_knobs["ip_spawn"]["ip_loop_unrolling"]["spawn_mode"] == "geometric": + num_ips_perspective_2 = int(math.log(max_num_itr, loop_itr_incr)) + result = [loop_itr_incr** (n) for n in range(0, num_ips_perspective_2+ 1)] + + # cap the result by het maximum_spawn_ip + if len(result) > max_spawn_ip_by_loop_itr: + result = copy.deepcopy(result[:max_spawn_ip_by_loop_itr-1]) + + # add the maximum as well + if max_num_itr not in result: + result.append(max_num_itr) + return result + # return the range + + # add corrections + correction_values = gen_correction_values(workload, misc_knobs) + + hardware_library_dict = {} + # parse IPs + gpps = parse_block_based_on_types(library_dir, Block_char_file_name, ("pe", "gpp")) + ip_template = parse_block_based_on_types(library_dir, Block_char_file_name, ("pe", "ip")) # this is just to collect clock freq for ips + srams = parse_block_based_on_types(library_dir, Block_char_file_name, ("mem", "sram")) + drams = parse_block_based_on_types(library_dir, Block_char_file_name, ("mem", "dram")) + mems = {**drams, **srams} + ics = parse_block_based_on_types(library_dir, Block_char_file_name, ("ic", "ic")) + task_work_dict = gen_task_graph_work_dict(library_dir, IP_perf_file_name, Block_char_file_name, misc_knobs) + + task_PPA_dict = {} + gpp_names = list(gpps.keys()) + task_PPA_dict["perf_in_cycles"] = parse_task_PPA(library_dir, IP_perf_file_name, gpp_names) # are provided in cycles at the moment, + + ips = deduce_IPs(list(task_PPA_dict["perf_in_cycles"].values()), gpp_names) + + task_PPA_dict["perf"] = copy.deepcopy(task_PPA_dict["perf_in_cycles"]) + for task, task_PE in task_PPA_dict["perf_in_cycles"].items(): + for PE, cycles in task_PE.items(): + if PE in ips: + block_freq = ips[PE]["Freq"] + elif PE in gpps: + block_freq = gpps[PE]["Freq"] + task_PPA_dict["perf"][task][PE] = float(cycles)/block_freq + + task_PPA_dict["energy"] = parse_task_PPA(library_dir, IP_energy_file_name, gpp_names) + + # generate power here + convert_energy_to_power(task_PPA_dict) + task_PPA_dict["area"] = parse_task_PPA(library_dir, IP_area_file_name, gpp_names) + task_itr_cnt = parse_task_itr_cnt(library_dir, task_itr_cnt_file_name) + + for task_name in task_work_dict.keys(): + IP_perfs = task_PPA_dict["perf"][task_name] + IP_energy = task_PPA_dict["energy"][task_name] # reported in miliwatt at the moment + IP_area = task_PPA_dict["area"][task_name] + IP_names = list(task_PPA_dict["perf"][task_name].keys()) + for IP_name in IP_names: + if IP_name in hardware_library_dict.keys(): + hardware_library_dict[IP_name]["mappable_tasks"].append(task_name) + continue + if IP_name in gpps: + hardware_library_dict[IP_name] = {} + hardware_library_dict[IP_name]["work_rate"] = correction_values["gpp"]["work_rate"]*float(gpps[IP_name]['Freq'])*float(gpps[IP_name]["dhrystone_IPC"]) + hardware_library_dict[IP_name]["work_over_energy"] = correction_values["gpp"]["work_over_energy"]*float(gpps[IP_name]['Inst_per_joul']) + hardware_library_dict[IP_name]["work_over_area"] = correction_values["gpp"]["work_over_area"]*(1.0)/float(gpps[IP_name]['Gpp_area']) + hardware_library_dict[IP_name]["one_over_area"] = correction_values["gpp"]["one_over_area"]*(1.0)/float(gpps[IP_name]['Gpp_area']) # convention is that workoverarea is 1/area for fix areas (like IPs and GPPs) + hardware_library_dict[IP_name]["mappable_tasks"] = [task_name] + hardware_library_dict[IP_name]["type"] = "pe" + hardware_library_dict[IP_name]["sub_type"] = "gpp" + hardware_library_dict[IP_name]["clock_freq"] = gpps[IP_name]["Freq"] + hardware_library_dict[IP_name]["BitWidth"] = gpps[IP_name]["BitWidth"] + hardware_library_dict[IP_name]["loop_itr_cnt"] = 0 + hardware_library_dict[IP_name]["loop_max_possible_itr_cnt"] = 0 + hardware_library_dict[IP_name]["hop_latency"] = gpps[IP_name]["hop_latency"] + hardware_library_dict[IP_name]["pipe_line_depth"] = gpps[IP_name]["pipe_line_depth"] + #print("taskname: " + str(task_name) + ", subtype: gpp, power is"+ str(hardware_library_dict[IP_name]["work_rate"]/hardware_library_dict[IP_name]["work_over_energy"] )) + else: + loop_itr_range_ = gen_loop_itr_range(task_name, task_itr_cnt, misc_knobs) + ip_freq_range = gen_freq_range(misc_knobs, "ip") + for loop_itr_cnt, ip_freq in itertools.product(loop_itr_range_, ip_freq_range): + IP_name_refined = IP_name +"_"+str(loop_itr_cnt) + "_" + str(ip_freq) + hardware_library_dict[IP_name_refined] = {} + hardware_library_dict[IP_name_refined]["work_rate"] = (ip_freq*loop_itr_cnt*correction_values["ip"]["work_rate"])*(task_work_dict[task_name]/(IP_perfs[IP_name])) + hardware_library_dict[IP_name_refined]["work_over_energy"] = (correction_values["ip"]["work_over_energy"]/loop_itr_cnt)*(task_work_dict[task_name]/(float(IP_energy[IP_name]*float((10**-15))))) + hardware_library_dict[IP_name_refined]["work_over_area"] = (correction_values["ip"]["work_over_area"]/loop_itr_cnt)*(task_work_dict[task_name])/(IP_area[IP_name]*(10**-12)) + hardware_library_dict[IP_name_refined]["one_over_area"] = (correction_values["ip"]["one_over_area"]/loop_itr_cnt)*(1.0)/(IP_area[IP_name]*(10**-12)) # convention is that workoverarea is 1/area for fix areas (like IPs and GPPs) + hardware_library_dict[IP_name_refined]["mappable_tasks"] = [task_name] + hardware_library_dict[IP_name_refined]["type"] = "pe" + hardware_library_dict[IP_name_refined]["sub_type"] = "ip" + hardware_library_dict[IP_name_refined]["clock_freq"] = ip_template["IP"]["Freq"]*ip_freq + hardware_library_dict[IP_name_refined]["BitWidth"] = ip_template["IP"]["BitWidth"] + hardware_library_dict[IP_name_refined]["loop_itr_cnt"] = loop_itr_cnt + hardware_library_dict[IP_name_refined]["loop_max_possible_itr_cnt"] = task_itr_cnt[task_name] + hardware_library_dict[IP_name_refined]["hop_latency"] = ip_template["IP"]["hop_latency"] + hardware_library_dict[IP_name_refined]["pipe_line_depth"] = ip_template["IP"]["pipe_line_depth"] + #print("taskname: " + str(task_name) + ", subtype: ip, power is"+ str(hardware_library_dict[IP_name]["work_rate"]/hardware_library_dict[IP_name]["work_over_energy"] )) + + for blck_name, blck_value in mems.items(): + mem_freq_range = gen_freq_range(misc_knobs, "mem") + for freq in mem_freq_range: + IP_name_refined = blck_value['Name']+ "_" + str(freq) + hardware_library_dict[IP_name_refined] = {} + #hardware_library_dict[blck_value['Name']] = {} + hardware_library_dict[IP_name_refined]["work_rate"] = freq*correction_values[blck_value["Subtype"]]["work_rate"]*float(blck_value['BitWidth'])*float(blck_value['Freq']) + hardware_library_dict[IP_name_refined]["work_over_energy"] = correction_values[blck_value["Subtype"]]["work_over_energy"]*float(blck_value['Byte_per_joul']) + hardware_library_dict[IP_name_refined]["work_over_area"] = correction_values[blck_value["Subtype"]]["work_over_area"]*float(blck_value['Byte_per_m']) + hardware_library_dict[IP_name_refined]["one_over_area"] = correction_values[blck_value["Subtype"]]["one_over_area"]*float(blck_value['Byte_per_m']) # not gonna be used so doesn't matter how to populate + hardware_library_dict[IP_name_refined]["mappable_tasks"] = 'all' + hardware_library_dict[IP_name_refined]["type"] = "mem" + hardware_library_dict[IP_name_refined]["sub_type"] = blck_value['Subtype'] + hardware_library_dict[IP_name_refined]["clock_freq"] = freq*blck_value["Freq"] + hardware_library_dict[IP_name_refined]["BitWidth"] = blck_value["BitWidth"] + hardware_library_dict[IP_name_refined]["loop_itr_cnt"] = 0 + hardware_library_dict[IP_name_refined]["loop_max_possible_itr_cnt"] = 0 + hardware_library_dict[IP_name_refined]["hop_latency"] = blck_value["hop_latency"] + hardware_library_dict[IP_name_refined]["pipe_line_depth"] = blck_value["pipe_line_depth"] + + for blck_name, blck_value in ics.items(): + ic_freq_range = gen_freq_range(misc_knobs, "ic") + for freq in ic_freq_range: + IP_name_refined = blck_value['Name']+ "_" + str(freq) + hardware_library_dict[IP_name_refined] = {} + hardware_library_dict[IP_name_refined]["work_rate"] = freq*correction_values[blck_value["Subtype"]]["work_rate"]*float(blck_value['BitWidth'])*float(blck_value['Freq']) + hardware_library_dict[IP_name_refined]["work_over_energy"] = correction_values[blck_value["Subtype"]]["work_over_energy"]*float(blck_value['Byte_per_joul']) + hardware_library_dict[IP_name_refined]["work_over_area"] = correction_values[blck_value["Subtype"]]["work_over_area"]*float(blck_value['Byte_per_m']) + hardware_library_dict[IP_name_refined]["one_over_area"] = correction_values[blck_value["Subtype"]]["one_over_area"]*float(blck_value['Byte_per_m']) # not gonna be used so doesn't matter how to populate + hardware_library_dict[IP_name_refined]["mappable_tasks"] = 'all' + hardware_library_dict[IP_name_refined]["type"] = "ic" + hardware_library_dict[IP_name_refined]["sub_type"] = "ic" + hardware_library_dict[IP_name_refined]["clock_freq"] = freq*blck_value["Freq"] + hardware_library_dict[IP_name_refined]["BitWidth"] = blck_value["BitWidth"] + hardware_library_dict[IP_name_refined]["loop_itr_cnt"] = 0 + hardware_library_dict[IP_name_refined]["loop_max_possible_itr_cnt"] = 0 + hardware_library_dict[IP_name_refined]["hop_latency"] = blck_value["hop_latency"] + hardware_library_dict[IP_name_refined]["pipe_line_depth"] = blck_value["pipe_line_depth"] + + return hardware_library_dict + +# collect budget values for each workload +def collect_budgets(workloads_to_consider, budget_misc_knobs, library_dir, prefix=""): + if "base_budget_scaling" not in budget_misc_knobs.keys(): + base_budget_scaling = {"latency":1, "power":1, "area":1} + else: + base_budget_scaling = budget_misc_knobs["base_budget_scaling"] + + # get files + file_list = [f for f in os.listdir(library_dir) if os.path.isfile(os.path.join(library_dir, f))] + misc_file_name = get_full_file_name(prefix + "Budget.csv", file_list) + + # get the time profile + df = pd.read_csv(os.path.join(library_dir, misc_file_name)) + workloads = df['Workload'] + workload_last_task = {} + + budgets_dict = {} + budgets_dict = defaultdict(dict) + other_values_dict = defaultdict(dict) + other_values_dict["glass"] = {} + budgets_dict["glass"]["latency"] = {} + + for metric in config.budgetted_metrics: + if metric in ["power", "area"] and not len(workloads_to_consider) == 1: + budgets_dict["glass"][metric] = (df.loc[df['Workload'] == "all"])[metric].values[0] + budgets_dict["glass"][metric] *= float(base_budget_scaling[metric]) + # this is a hack for now. change later. + # but used for budget sweep for now + #budgets_dict["glass"][metric] = config.budget_dict["glass"][metric] + elif metric in ["latency"] or len(workloads_to_consider)==1: + for idx in range(0, len(workloads)): + workload_name = workloads[idx] + if workload_name == "all" or workload_name not in workloads_to_consider: + continue + if metric == "latency": + budgets_dict["glass"][metric][workload_name] = (df.loc[df['Workload'] == workload_name])[metric].values[0] + budgets_dict["glass"][metric][workload_name] *= float(base_budget_scaling[metric]) + else: + budgets_dict["glass"][metric] = (df.loc[df['Workload'] == workload_name])[metric].values[0] + budgets_dict["glass"][metric] *= float(base_budget_scaling[metric]) + + for metric in config.other_metrics: + other_values_dict["glass"][metric] = (df.loc[df['Workload'] == "all"])[metric].values[0] + + return budgets_dict, other_values_dict + +# get the last task for each workload +def collect_last_task(workloads_to_consider, library_dir, prefix=""): + blocksL: List[BlockL] = [] # collection of all the blocks + pe_mapsL: List[TaskToPEBlockMapL] = [] + pe_schedulesL: List[TaskScheduleL] = [] + + + # get files + file_list = [f for f in os.listdir(library_dir) if os.path.isfile(os.path.join(library_dir, f))] + misc_file_name = get_full_file_name(prefix + "Last Tasks.csv", file_list) + + # get the time profile + df = pd.read_csv(os.path.join(library_dir, misc_file_name)) + workloads = df['workload'] + last_tasks = df['last_task'] + workload_last_task = {} + for idx in range(0, len(workloads)): + workload = workloads[idx] + if workload not in workloads_to_consider: + continue + workload_last_task[workloads[idx]] = last_tasks[idx] + + return workload_last_task + +# ------------------------------ +# Functionality: +# generate the hardware graph with light libraries +# ------------------------------ +def gen_hardware_graph(library_dir, prefix = "") : + # get files + file_list = [f for f in os.listdir(library_dir) if os.path.isfile(os.path.join(library_dir, f))] + hardware_graph_file_name = get_full_file_name(prefix + "Hardware Graph.csv", file_list) + task_to_hardware_mapping_file_name = get_full_file_name(prefix + "Task to Hardware Mapping.csv", file_list) + hardware_graph_file_addr = os.path.join(library_dir, hardware_graph_file_name) + hardware_graph_dict = parse_hardware_graph(hardware_graph_file_addr) + + return hardware_graph_dict + + +# ------------------------------ +# Functionality: +# generate the hardware graph with light libraries +# ------------------------------ +def gen_task_to_hw_mapping(library_dir, prefix = "") : + # get files + file_list = [f for f in os.listdir(library_dir) if os.path.isfile(os.path.join(library_dir, f))] + task_to_hardware_mapping_file_name = get_full_file_name(prefix + "Task To Hardware Mapping.csv", file_list) + task_to_hardware_mapping_file_addr = os.path.join(library_dir, task_to_hardware_mapping_file_name) + task_to_hardware_mapping = parse_task_to_hw_mapping(task_to_hardware_mapping_file_addr) + + return task_to_hardware_mapping + + +# ------------------------------ +# Functionality: +# generate the hardware library +# ------------------------------ +def gen_hardware_library(library_dir, prefix, workload, misc_knobs={}): + blocksL: List[BlockL] = [] # collection of all the blocks + pe_mapsL: List[TaskToPEBlockMapL] = [] + pe_schedulesL: List[TaskScheduleL] = [] + + + # get files + file_list = [f for f in os.listdir(library_dir) if os.path.isfile(os.path.join(library_dir, f))] + data_movement_file_name = get_full_file_name(prefix + "Task Data Movement.csv", file_list) + IP_perf_file_name = get_full_file_name(prefix + "Task PE Performance.csv", file_list) + IP_energy_file_name = get_full_file_name(prefix + "Task PE Energy.csv", file_list) + IP_area_file_name = get_full_file_name(prefix + "Task PE Area.csv", file_list) + task_itr_cnt_file_name = get_full_file_name(prefix+ "Task Itr Count.csv", file_list) + Block_char_file_name = get_full_file_name("misc_database - "+ "Block Characteristics.csv", file_list) + common_block_char_file_name = get_full_file_name("misc_database - "+ "Common Hardware.csv", file_list) + + # get the time profile + #task_perf_file_addr = os.path.join(library_dir, IP_perf_file_name) + #ref_gpp_dict = parse_ref_block_values(library_dir, Block_char_file_name, ("pe", "gpp")) + # calculate the work (basically number of instructions per task) + task_perf_on_ref_gpp_dict = parse_task_perf_on_ref_gpp(library_dir, IP_perf_file_name, Block_char_file_name) + # set up the schedules + for task_name, _ in task_perf_on_ref_gpp_dict.items(): + pe_schedulesL.append(TaskScheduleL(task_name=task_name, starting_time=0)) + + + # get the mapping and IP library + hardware_library_dict = parse_hardware_library(library_dir, IP_perf_file_name, + IP_energy_file_name, IP_area_file_name, + Block_char_file_name, task_itr_cnt_file_name, workload, misc_knobs) + block_suptype = "gpp" # default. + for IP_name, values in hardware_library_dict.items(): + block_subtype = values['sub_type'] + block_type = values['type'] + blocksL.append( + BlockL(block_instance_name=IP_name, block_type=block_type, block_subtype=block_subtype, + peak_work_rate_distribution = {hardware_library_dict[IP_name]["work_rate"]:1}, + work_over_energy_distribution = {hardware_library_dict[IP_name]["work_over_energy"]:1}, + work_over_area_distribution = {hardware_library_dict[IP_name]["work_over_area"]:1}, + one_over_area_distribution = {hardware_library_dict[IP_name]["one_over_area"]:1}, + clock_freq=hardware_library_dict[IP_name]["clock_freq"], bus_width=hardware_library_dict[IP_name]["BitWidth"], + loop_itr_cnt=hardware_library_dict[IP_name]["loop_itr_cnt"], loop_max_possible_itr_cnt=hardware_library_dict[IP_name]["loop_max_possible_itr_cnt"], + hop_latency=hardware_library_dict[IP_name]["hop_latency"], pipe_line_depth=hardware_library_dict[IP_name]["pipe_line_depth"],)) + + if block_type == "pe": + for mappable_tasks in hardware_library_dict[IP_name]["mappable_tasks"]: + task_to_block_map_ = TaskToPEBlockMapL(task_name=mappable_tasks, pe_block_instance_name=IP_name) + pe_mapsL.append(task_to_block_map_) + + return blocksL, pe_mapsL, pe_schedulesL diff --git a/Project_FARSI/top/main_FARSI.py b/Project_FARSI/top/main_FARSI.py new file mode 100644 index 00000000..108e8d6b --- /dev/null +++ b/Project_FARSI/top/main_FARSI.py @@ -0,0 +1,194 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +from SIM_utils.SIM import * +from DSE_utils.design_space_exploration_handler import * +from specs.database_input import * +import psutil + +# Run an instance of FARSI, the exploration framework +# Variables: +# unique_number: a unique number used for check_point +# db_input: input database to use for design generation +# hw_sampling: hardware sampling mode. This specifies: 1.what percentage of +# error does the data base has. What should be the population size for each design +# and the statistical reduction mode (avg, max, min) +# starting_exploration_mode: whether to start from scratch or from an existing check pointed design +def run_FARSI(result_folder, unique_number, case_study, db_input, hw_sampling, starting_exploration_mode ="generated_from_scratch"): + if config.use_cacti: + print("*****************************") + print("***** YOU ASKED TO USE CACTI FOR POWER/AREA MODELING OF MEMORY SUBSYSTEM. MAKE SURE YOU HAVE CACTI INSTALLED ****") + print("*****************************") + + with warnings.catch_warnings(): + warnings.simplefilter(config.warning_mode) + assert starting_exploration_mode in ["generated_from_scratch", "generated_from_check_point", "parse", "hardcoded"] + + config.dse_type = "hill_climbing" + if not (config.dse_type == "hill_climbing"): + print("this main is only suitable the hill_climbing search") + exit(0) + + # init relevant configs and objects + num_SOCs = 1 # how many SOCs to spread the design across + so_far_best_ex_dp = None + boost_SOC = False # specify whether to stick with the old SOC type or boost it + best_design_sim_this_itr = None + + # set up the design handler and the design explorer + dse_handler = DSEHandler(result_folder) + # First copies the DataBase information (SoC, Blocks(modules), tasks, mappings, scheduling) + # then chooses among DSE algorithms (hill climbing) and initializes it + dse_handler.setup_an_explorer(db_input, hw_sampling) + + # Use the check pointed design, parsed design or generate an simple design point + dse_handler.prepare_for_exploration(boost_SOC, starting_exploration_mode) + + # iterate until you find a design meeting the constraints or terminate if none found + while True: + # does the simulation for design points (performance, energy, and area core calculations) + dse_handler.explore() + dse_handler.check_point_best_design(unique_number) # check point + dse_handler.write_data(unique_number, result_folder, case_study, 0, 1, 1) + + # iterate if budget not met and we have seen improvements + best_design_sim_last_itr = best_design_sim_this_itr + best_design_sim_this_itr = dse_handler.dse.so_far_best_sim_dp + + if dse_handler.dse.reason_to_terminate == "out_of_memory" or dse_handler.dse.reason_to_terminate == "exploration (total itr_ctr) iteration threshold reached": + return dse_handler + elif not dse_handler.dse.found_any_improvement and config.heuristic_type == "FARSI": + return dse_handler + elif not dse_handler.dse.found_any_improvement and dse_handler.dse.reason_to_terminate == "met the budget": + return dse_handler + else: + dse_handler.dse.reset_ctrs() + dse_handler.dse.init_ex_dp = dse_handler.dse.so_far_best_ex_dp + """ + elif not best_design_sim_last_itr == None and \ + (best_design_sim_this_itr.dp_rep.get_hardware_graph().get_SOC_design_code() == + best_design_sim_last_itr.dp_rep.get_hardware_graph().get_SOC_design_code()): + return dse_handler + """ + + #if stat_result.fits_budget(1) get_SOC_design_code + #return dse_handler + + +def set_up_FARSI_with_arch_gym(result_folder, unique_number, case_study, db_input, hw_sampling, starting_exploration_mode ="generated_from_scratch"): + starting_exploration_mode = "FARSI_des_passed_in" + if config.use_cacti: + print("*****************************") + print("***** YOU ASKED TO USE CACTI FOR POWER/AREA MODELING OF MEMORY SUBSYSTEM. MAKE SURE YOU HAVE CACTI INSTALLED ****") + print("*****************************") + + with warnings.catch_warnings(): + warnings.simplefilter(config.warning_mode) + assert starting_exploration_mode in ["generated_from_scratch", "generated_from_check_point", "parse", "hardcoded", "FARSI_des_passed_in"] + + config.dse_type = "simple_greedy_one_sample" + + # init relevant configs and objects + num_SOCs = 1 # how many SOCs to spread the design across + so_far_best_ex_dp = None + boost_SOC = False # specify whether to stick with the old SOC type or boost it + best_design_sim_this_itr = None + + # set up the design handler and the design explorer + dse_handler = DSEHandler(result_folder) + # First copies the DataBase information (SoC, Blocks(modules), tasks, mappings, scheduling) + # then chooses among DSE algorithms (hill climbing) and initializes it + dse_handler.setup_an_explorer(db_input, hw_sampling) + + # generate an initial design point (right now, the simplest design) + dse_handler.dse.gen_init_ex_dp() + return dse_handler + +# Run FARSI only to simulate one design (parsed, generated or from check point) +# Variables: +# unique_number: a unique number used for check_point +# db_input: input database to use for design generation +# hw_sampling: hardware sampling mode. This specifies: 1.what percentage of +# error does the data base has. What should be the population size for each design +# and the statistical reduction mode (avg, max, min) +# starting_exploration_mode: whether to start from scratch or from an existing check pointed design +def run_FARSI_only_simulation(result_folder, unique_number, db_input, hw_sampling, starting_exploration_mode ="from_scratch"): + if config.use_cacti: + print("*****************************") + print("***** YOU ASKED TO USE CACTI FOR POWER/AREA MODELING OF MEMORY SUBSYSTEM. MAKE SURE YOU HAVE CACTI INSTALLED ****") + print("*****************************") + + with warnings.catch_warnings(): + warnings.simplefilter(config.warning_mode) + + config.dse_type = "hill_climbing" + if not (config.dse_type == "hill_climbing"): + print("this main is only suitable the hill_climbing search") + exit(0) + + # init relevant configs and objects + num_SOCs = 1 # how many SOCs to spread the design across + so_far_best_ex_dp = None + boost_SOC = False # specify whether to stick with the old SOC type or boost it + + # set up the design handler and the design explorer + dse_handler = DSEHandler(result_folder) + # First copies the DataBase information (SoC, Blocks(modules), tasks, mappings, scheduling) + # then chooses among DSE algorithms (hill climbing) and initializes it + dse_handler.setup_an_explorer(db_input, hw_sampling) + + # Use the check pointed design, parsed design or generate an simple design point + dse_handler.prepare_for_exploration(boost_SOC, starting_exploration_mode) + # does the simulation for design points (performance, energy, and area core calculations) + dse_handler.explore_one_design() + dse_handler.check_point_best_design(unique_number) # check point + return dse_handler + + +# main function. If this file is run in isolation, +# the the main function is called (which itself calls runFARSI mentioned above) +if __name__ == "__main__": + + run_ctr = 0 + case_study = "simple_run" + result_home_dir = os.path.join(os.getcwd(), "data_collection/data/" + case_study) + home_dir = os.getcwd() + date_time = datetime.now().strftime('%m-%d_%H-%M_%S') + result_folder = os.path.join(result_home_dir, + date_time) + + current_process_id = 0 + total_process_cnt = 0 + starting_exploration_mode = config.exploration_mode + print('case study:' + case_study) + + # ------------------------------------------- + # set parameters + # ------------------------------------------- + experiment_repetition_cnt = 1 + reduction = "most_likely" + workloads = {"SLAM"} + sw_hw_database_population = {"db_mode": "hardcoded", "hw_graph_mode": "generated_from_scratch", + "workloads": workloads} + + accuracy_percentage = {} + accuracy_percentage["mem"] = accuracy_percentage["ic"] = accuracy_percentage["gpp"] = accuracy_percentage["ip"] = \ + {"latency": 1, + "energy": 1, + "area": 1, + "one_over_area": 1} + hw_sampling = {"mode": "exact", "population_size": 1, "reduction": reduction, + "accuracy_percentage": accuracy_percentage} + db_input = database_input_class(sw_hw_database_population) + print("hw_sampling:" + str(hw_sampling)) + print("budget set to:" + str(db_input.get_budget_dict("glass"))) + unique_suffix = str(total_process_cnt) + "_" + str(current_process_id) + "_" + str(run_ctr) + dse_hndlr = run_FARSI(result_folder, unique_suffix, db_input, hw_sampling, sw_hw_database_population["hw_graph_mode"]) + + + exploration_start_time = time.time() + db_input = database_input_class(config.budgets_dict, config.other_values_dict) + hw_sampling = {"mode": "exact", "population_size": 1} + dse_handler = run_FARSI(db_input, hw_sampling) + if config.REPORT: dse_handler.dse.report(exploration_start_time); dse_handler.dse.plot_data() diff --git a/Project_FARSI/visualization_utils/Iulian_plots/error_analysis_per_app.py b/Project_FARSI/visualization_utils/Iulian_plots/error_analysis_per_app.py new file mode 100644 index 00000000..205d0dc5 --- /dev/null +++ b/Project_FARSI/visualization_utils/Iulian_plots/error_analysis_per_app.py @@ -0,0 +1,14 @@ +#Author: Iulian Brumar +#Script to analyze the error per application +import pandas as pd +import seaborn as sns +import sys +import matplotlib.pyplot as plt +import numpy as np +data = pd.read_csv(sys.argv[1]) + +print(data["error"]) +print(data["app"]) + +sns.barplot(data=data, x = "app", y = "error") +plt.savefig('error_analysis.png') diff --git a/Project_FARSI/visualization_utils/Iulian_plots/plot_err_vs_arch.py b/Project_FARSI/visualization_utils/Iulian_plots/plot_err_vs_arch.py new file mode 100644 index 00000000..15e7e958 --- /dev/null +++ b/Project_FARSI/visualization_utils/Iulian_plots/plot_err_vs_arch.py @@ -0,0 +1,75 @@ +#Author: Iulian Brumar +#Script that plots error with average per different arch parameters + +import pandas as pd +import seaborn as sns +import sys +import matplotlib.pyplot as plt +import numpy as np +sys.path.append("..") +from plot_validations import * +from sklearn.linear_model import LinearRegression + +def abline(slope, intercept, color): + """Plot a line from slope and intercept""" + axes = plt.gca() + x_vals = np.array(axes.get_xlim()) + y_vals = intercept + slope * x_vals + plt.plot(x_vals, y_vals, '--', color = color) + + +data = pd.read_csv(sys.argv[1]) +error = list(data["error"]) +blk_cnt = list(data["blk_cnt"]) +pe_cnt = list(data["pe_cnt"]) +mem_cnt = list(data["mem_cnt"]) +bus_cnt = list(data["bus_cnt"]) +channel_cnt = list(data["channel_cnt"]) +pa_sim_time = list(data["PA simulation time"]) +farsi_sim_time = list(data["FARSI simulation time"]) + +num_counts_cols = 5 +tmp_reformatted_df_data = [blk_cnt+pe_cnt+mem_cnt+bus_cnt+channel_cnt, ["Block Counts"]*len(blk_cnt)+["PE Counts"]*len(blk_cnt) + ["Mem Counts"]*len(blk_cnt) + ["Bus Counts"]*len(bus_cnt) + ["Channel Counts"]*len(bus_cnt), error*num_counts_cols] + +reformatted_df_data = [[tmp_reformatted_df_data[j][i] for j in range(len(tmp_reformatted_df_data))] for i in range(len(blk_cnt)*num_counts_cols) ] + + + +#print(reformatted_df_data[0:3]) +#exit() +#for col in reformatted_df_data: +# print("Len of col is {}".format(len(col))) +reformatted_df = pd.DataFrame(reformatted_df_data, columns = ["Counts", "ArchParam", "Error"]) +print(reformatted_df.tail()) + + +df_avg = get_df_as_avg_for_each_x_coord(reformatted_df, x_coord_name = "Counts", y_coord_name = "Error", hue_col = "ArchParam") + +color_per_hue = {"Bus Counts" : "green", "Mem Counts" : "orange", "PE Counts" : "blue", "Block Counts" : "red", "Channel Counts" : "pink"} +#df_avg = df_avg.loc[df_avg["ArchParam"] != "Bus Counts"] +splot = sns.scatterplot(data=df_avg, y = "Error", x = "Counts", hue = "ArchParam", palette = color_per_hue) +#splot.set(yscale = "log") + + +#sklearn.linear_model.LinearRegression() +hues = set(list(df_avg["ArchParam"])) +for hue in hues: + #x required to be in matrix format in sklearn + print(np.isnan(df_avg["Error"])) + xs_hue = [[x] for x in list(df_avg.loc[(df_avg["ArchParam"] == hue) & (df_avg["Error"].notnull())]["Counts"])] + ys_hue = np.array(list(df_avg.loc[(df_avg["ArchParam"] == hue) & (df_avg["Error"].notnull())]["Error"])) + print("xs_hue") + print(xs_hue) + + print("ys_hue") + print(ys_hue) + reg = LinearRegression().fit(xs_hue, ys_hue) + m = reg.coef_[0] + n = reg.intercept_ + abline(m, n, color_per_hue[hue]) +#plt.set_ylim(top = 10) +plt.savefig('error_vs_c_reg.png') + + + + diff --git a/Project_FARSI/visualization_utils/Iulian_plots/plot_sim_vs_lat.py b/Project_FARSI/visualization_utils/Iulian_plots/plot_sim_vs_lat.py new file mode 100644 index 00000000..c3809f0b --- /dev/null +++ b/Project_FARSI/visualization_utils/Iulian_plots/plot_sim_vs_lat.py @@ -0,0 +1,28 @@ +#Author: Iulian Brumar +#Script that plots Simulation time vs PA predicted latency + +import pandas as pd +import seaborn as sns +import sys +import matplotlib.pyplot as plt +import numpy as np +data = pd.read_csv(sys.argv[1]) +blk_cnt = list(data["blk_cnt"]) +pe_cnt = list(data["pe_cnt"]) +mem_cnt = list(data["mem_cnt"]) +bus_cnt = list(data["bus_cnt"]) +pa_sim_time = list(data["PA simulation time"]) +farsi_sim_time = list(data["FARSI simulation time"]) +pa_predicted_lat = list(data["PA_predicted_latency"]) +tmp_reformatted_df_data = [pa_predicted_lat*2, pa_sim_time+farsi_sim_time, ["PA"]*len(blk_cnt)+["FARSI"]*len(blk_cnt)] +reformatted_df_data = [[tmp_reformatted_df_data[j][i] for j in range(len(tmp_reformatted_df_data))] for i in range(len(blk_cnt)*2) ] +#print(reformatted_df_data[0:3]) +#exit() +#for col in reformatted_df_data: +# print("Len of col is {}".format(len(col))) +reformatted_df = pd.DataFrame(reformatted_df_data, columns = ["PA Predicted Latency", "Simulation Time", "FARSI or PA"]) +print(reformatted_df.head()) +splot = sns.scatterplot(data = reformatted_df, x = "PA Predicted Latency", y = "Simulation Time", hue = "FARSI or PA") +splot.set(yscale="log") +plt.savefig('pa_lat_vs_simtime_logy.png') + diff --git a/Project_FARSI/visualization_utils/Iulian_plots/plot_validations.py b/Project_FARSI/visualization_utils/Iulian_plots/plot_validations.py new file mode 100644 index 00000000..5e50b4ba --- /dev/null +++ b/Project_FARSI/visualization_utils/Iulian_plots/plot_validations.py @@ -0,0 +1,86 @@ +#Author: Iulian Brumar +#Script that plots average simulation time (PA/FARSI hue) vs different architectural parameters + +import pandas as pd +import seaborn as sns +import sys +import matplotlib.pyplot as plt +import numpy as np + + +def get_df_as_avg_for_each_x_coord(reformatted_df, x_coord_name, y_coord_name = "Simulation Time"): + avg_df_lst = [] + for x_coord in set(reformatted_df[x_coord_name]): + #print("hola") + #print(reformatted_df.loc[(reformatted_df[x_coord_name] == x_coord) & (reformatted_df["FARSI or PA"] == "FARSI")]) + simtimes_farsi = list(reformatted_df.loc[(reformatted_df[x_coord_name] == x_coord) & (reformatted_df["FARSI or PA"] == "FARSI")][y_coord_name]) + simtimes_pa = list(reformatted_df.loc[(reformatted_df[x_coord_name] == x_coord) & (reformatted_df["FARSI or PA"] == "PA")][y_coord_name]) + print("simtimes_farsi") + print(simtimes_farsi) + print(np.average(simtimes_farsi)) + print("simtimes_pa") + print(simtimes_pa) + print(np.average(simtimes_pa)) + avg_df_lst.append([np.average(simtimes_farsi), "FARSI", x_coord]) + avg_df_lst.append([np.average(simtimes_pa), "PA", x_coord]) + return pd.DataFrame(avg_df_lst, columns = ["Simulation Time", "FARSI or PA", x_coord_name]) + +#not used yet in this script +def get_df_as_avg_for_each_x_coord(reformatted_df, x_coord_name, y_coord_name = "Simulation Time", hue_col = "FARSI or PA"): + hues = set(list(reformatted_df[hue_col])) + avg_df_lst = [] + for x_coord in set(reformatted_df[x_coord_name]): + #print("hola") + #print(reformatted_df.loc[(reformatted_df[x_coord_name] == x_coord) & (reformatted_df["FARSI or PA"] == "FARSI")]) + for hue in hues: + selectedy_hue = list(reformatted_df.loc[(reformatted_df[x_coord_name] == x_coord) & (reformatted_df[hue_col] == hue)][y_coord_name]) + avg_df_lst.append([np.average(selectedy_hue), hue, x_coord]) + #simtimes_pa = list(reformatted_df.loc[(reformatted_df[x_coord_name] == x_coord) & (reformatted_df["FARSI or PA"] == "PA")][y_coord_name]) + return pd.DataFrame(avg_df_lst, columns = [y_coord_name, hue_col, x_coord_name]) + +if __name__ == "__main__": + data = pd.read_csv(sys.argv[1]) + blk_cnt = list(data["blk_cnt"]) + pe_cnt = list(data["pe_cnt"]) + mem_cnt = list(data["mem_cnt"]) + bus_cnt = list(data["bus_cnt"]) + pa_sim_time = list(data["PA simulation time"]) + farsi_sim_time = list(data["FARSI simulation time"]) + tmp_reformatted_df_data = [blk_cnt*2, pe_cnt*2, mem_cnt*2, bus_cnt*2, pa_sim_time+farsi_sim_time, ["PA"]*len(blk_cnt)+["FARSI"]*len(blk_cnt)] + reformatted_df_data = [[tmp_reformatted_df_data[j][i] for j in range(len(tmp_reformatted_df_data))] for i in range(len(blk_cnt)*2) ] + #print(reformatted_df_data[0:3]) + #exit() + #for col in reformatted_df_data: + # print("Len of col is {}".format(len(col))) + reformatted_df = pd.DataFrame(reformatted_df_data, columns = ["Block counts", "PE counts", "Mem counts", "Bus counts", "Simulation Time", "FARSI or PA"]) + print(reformatted_df.head()) + + + df_blk_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "Block counts") + df_pe_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "PE counts") + df_mem_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "Mem counts") + df_bus_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "Bus counts") + + print("Bola") + print(df_blk_avg) + + splot = sns.scatterplot(data=df_blk_avg, x = "Block counts", y = "Simulation Time", hue = "FARSI or PA") + splot.set(yscale = "log") + plt.savefig('block_counts_vs_simtime.png') + + plt.clf() + splot = sns.scatterplot(data=df_pe_avg, x = "PE counts", y = "Simulation Time", hue = "FARSI or PA") + splot.set(yscale = "log") + plt.savefig('PE_counts_vs_simtime.png') + + + plt.clf() + splot = sns.scatterplot(data=df_mem_avg, x = "Mem counts", y = "Simulation Time", hue = "FARSI or PA") + splot.set(yscale = "log") + plt.savefig('MEM_counts_vs_simtime.png') + + + plt.clf() + splot = sns.scatterplot(data=df_bus_avg, x = "Bus counts", y = "Simulation Time", hue = "FARSI or PA") + splot.set(yscale = "log") + plt.savefig('BUS_counts_vs_simtime.png') diff --git a/Project_FARSI/visualization_utils/Iulian_plots/simtime_vs_lat.py b/Project_FARSI/visualization_utils/Iulian_plots/simtime_vs_lat.py new file mode 100644 index 00000000..1965d893 --- /dev/null +++ b/Project_FARSI/visualization_utils/Iulian_plots/simtime_vs_lat.py @@ -0,0 +1,15 @@ +#Author: Iulian Brumar +#Script that plots latency error per application + +import pandas as pd +import seaborn as sns +import sys +import matplotlib.pyplot as plt +import numpy as np +data = pd.read_csv(sys.argv[1]) + +print(data["error"]) +print(data["app"]) + +sns.barplot(data=data, x = "app", y = "error") +plt.savefig('error_analysis.png') diff --git a/Project_FARSI/visualization_utils/__init__.py b/Project_FARSI/visualization_utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Project_FARSI/visualization_utils/plot.py b/Project_FARSI/visualization_utils/plot.py new file mode 100644 index 00000000..9fc3f5a5 --- /dev/null +++ b/Project_FARSI/visualization_utils/plot.py @@ -0,0 +1,566 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +import numpy as np +from settings import config +import matplotlib.pyplot as plt +from copy import deepcopy +import copy +import matplotlib.cbook as cbook +import _pickle as cPickle + +if config.simulation_method == "power_knobs": + from specs.database_input_powerKnobs import * +elif config.simulation_method == "performance": + from specs.database_input import * +else: + raise NameError("Simulation method unavailable") + + +# ------------------------------ +# Functionality: +# plot moves stats +# ------------------------------ +def move_profile_plot(move_lists_): + move_lists = [move_ for move_ in move_lists_ if not(move_.get_metric() == "cost")] # for now filtered cost + move_on_metric_freq = {} + for metric in config.budgetted_metrics: + move_on_metric_freq[metric] = [0] + + for move in move_lists: + metric = move.get_metric() + move_on_metric_freq[metric] = [move_on_metric_freq[metric][0] + 1] + + labels = ['Metric'] + x = np.arange(len(labels)) # the label locations + width = 0.2 # the width of the bars + + fig, ax = plt.subplots() + rects1 = ax.bar(x - 1 * (width), move_on_metric_freq["latency"], width, label='perf moves', color="orange") + rects2 = ax.bar(x, move_on_metric_freq["power"], width, label='power moves', color="mediumpurple") + rects3 = ax.bar(x + 1 * width, move_on_metric_freq["area"], width, label='area moves', color="brown") + ax.set_ylabel('frequency', fontsize=15) + ax.set_title('Move frequency', fontsize=15) + # ax.set_ylabel('Sim time (s)', fontsize=25) + # ax.set_title('Sim time across system comoplexity.', fontsize=24) + ax.set_xticks(x) + ax.set_xticklabels(labels, fontsize=4) + ax.legend(prop={'size': 15}) + fig.savefig(os.path.join(config.latest_visualization,"move_freq_breakdown.pdf")) + + +# ------------------------------ +# Functionality: +# visualize the sequence of moves made +# Variables: +# des_trail_list: design trail, i.e., the list of designs made in the chronological order +# move_profiles: list of moves made +# des_per_iteration: number of designs tried per iteration (in each iteration, we can be looking at a host of designs +# depending on the depth and breadth of search at the point) +# ------------------------------ +def des_trail_plot(des_trail_list, move_profile, des_per_iteration): + metric_bounds = {} + for metric in config.budgetted_metrics: + metric_bounds[metric] = (+10000, -10000) + metric_ref_des_dict = {} + metric_trans_des_dict = {} + for metric in config.budgetted_metrics: + metric_ref_des_dict[metric] = [] + metric_trans_des_dict[metric] = [] + + # contains all the results + res_list = [] + for ref_des, transformed_des in des_trail_list: + ref_des_metrics = [] + transformed_des_metrics = [] + # get the metrics + for metric in config.budgetted_metrics: + #ref_des_metric_value = 100*(1 - ref_des.get_dp_stats().get_system_complex_metric(metric)/ref_des.database.get_budget(metric, "glass")) + if isinstance(ref_des.get_dp_stats().get_system_complex_metric(metric), dict): # must be latency + system_complex_metric = max(list(ref_des.get_dp_stats().get_system_complex_metric(metric).values())) + system_complex_budget = max(list(ref_des.database.get_budget(metric, "glass").values())) + ref_des_metric_value = 100 * (1 - system_complex_metric/system_complex_budget) + trans_des_metric_value = 100 * (1 - system_complex_metric/system_complex_budget) - ref_des_metric_value# need to subtract since the second one needs to be magnitude + else: + ref_des_metric_value = 100*(1 - ref_des.get_dp_stats().get_system_complex_metric(metric)/ref_des.database.get_budget(metric, "glass")) + trans_des_metric_value = \ + 100*(1 - transformed_des.get_dp_stats().get_system_complex_metric(metric)/ + ref_des.database.get_budget(metric, "glass")) - ref_des_metric_value # need to subtract since the second one needs to be magnitude + + ref_des_metrics.append(ref_des_metric_value) + transformed_des_metrics.append(trans_des_metric_value) + metric_bounds[metric] = (min(metric_bounds[metric][0], ref_des_metric_value, trans_des_metric_value), + max(metric_bounds[metric][1], ref_des_metric_value, trans_des_metric_value)) + metric_ref_des_dict[metric].append(ref_des_metric_value) + metric_trans_des_dict[metric].append((ref_des_metric_value + trans_des_metric_value)) + + #res_list.append(copy.deepcopy(ref_des_metrics + transformed_des_metrics)) + res_list.append(cPickle.loads(cPickle.dumps(ref_des_metrics + transformed_des_metrics, -1))) + + # soa = np.array([[0, 0, 0, 1, 3, 1], [1,1,1, 3,3,3]]) + soa = np.array(res_list) + + des_per_iteration.append(len(res_list)) + des_iteration_unflattened = [(des_per_iteration[itr+1]-des_per_iteration[itr])*[itr+1] for itr,el in enumerate(des_per_iteration[:-1])] + des_iteration = [j for sub in des_iteration_unflattened for j in sub] + + X, Y, Z, U, V, W = zip(*soa) + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + c_ = list(range(0, len(X))) + for id in range(0, len(c_)): + c_[id] = str((1 - c_[id]/(len(X)))/2) + #c_ = (list(range(0, len(X)))* float(1)/len(X)) + ax.quiver(X, Y, Z, U, V, W, arrow_length_ratio=.04, color=c_) + ax.scatter(X[-1], Y[-1], Z[-1], c="red") + ax.scatter(0, 0, 0, c="green") + #ax.quiver(X, Y, Z, U, V, W) + + ax.set_xlim([-1 + metric_bounds[config.budgetted_metrics[0]][0], 1 + max(metric_bounds[config.budgetted_metrics[0]][1], 1)]) + ax.set_ylim([-1 + metric_bounds[config.budgetted_metrics[1]][0], 1 + max(1.1*metric_bounds[config.budgetted_metrics[1]][1], 1)]) + ax.set_zlim([-1 + metric_bounds[config.budgetted_metrics[2]][0], 1 + max(1.1*metric_bounds[config.budgetted_metrics[2]][1], 1)]) + #ax.set_xlim([0*metric_bounds[config.budgetted_metrics[0]][0], 5*metric_bounds[config.budgetted_metrics[0]][1]]) + #ax.set_ylim([0*metric_bounds[config.budgetted_metrics[1]][0], 5*metric_bounds[config.budgetted_metrics[1]][1]]) + #ax.set_zlim([0*metric_bounds[config.budgetted_metrics[2]][0], 5*metric_bounds[config.budgetted_metrics[2]][1]]) + ax.set_title("normalized distance to budget") + ax.set_xlabel(config.budgetted_metrics[0]) + ax.set_ylabel(config.budgetted_metrics[1]) + ax.set_zlabel(config.budgetted_metrics[2]) + fig.savefig(os.path.join(config.latest_visualization,"DSE_trail.pdf")) + + des_iteration_move_markers = {} + des_iteration_move_markers["latency"] = [] + des_iteration_move_markers["power"] = [] + des_iteration_move_markers["area"] = [] + des_iteration_move_markers["energy"] = [] + for move in move_profile: + for metric in metric_ref_des_dict.keys() : + if move.get_metric() == metric: + des_iteration_move_markers[metric].append(1) + else: + des_iteration_move_markers[metric].append(2) + if metric == "cost": + print("ok") + + # proression per metric + for metric in metric_ref_des_dict.keys(): + fig, ax = plt.subplots() + ax.set_title("normalize distance to budget VS iteration") + #blah = des_iteration[des_iteration_move_markers[metric]==1] + #blah3 = metric_ref_des_dict[metric] + #blah2 = blah3[des_iteration_move_markers[metric]==1] + #ax.scatter(des_iteration, metric_ref_des_dict[metric][des_iteration_move_markers[metric]==1], color="red", label="orig des", marker="*") + #ax.scatter(des_iteration[des_iteration_move_markers[metric]==2], metric_ref_des_dict[metric][des_iteration_move_markers[metric] == 2], color="red", label="orig des", marker=".") + #ax.scatter(des_iteration[des_iteration_move_markers[metric] == 1], metric_trans_des_dict[metric][des_iteration_move_markers[metric]==1], color ="green", label="trans des", alpha=.05, marker="*") + #ax.scatter(des_iteration[des_iteration_move_markers[metric] == 2], metric_trans_des_dict[metric][des_iteration_move_markers[metric]==2], color ="green", label="trans des", alpha=.05, marker=".") + blah_ = [np.array(des_iteration_move_markers[metric])==1] + blah = np.array(des_iteration)[np.array(des_iteration_move_markers[metric])==1] + blah2 = np.array(metric_ref_des_dict[metric])[np.array(des_iteration_move_markers[metric])==1] + ax.scatter(np.array(des_iteration)[np.array(des_iteration_move_markers[metric])==1], np.array(metric_ref_des_dict[metric])[np.array(des_iteration_move_markers[metric])==1], color="red", label="orig des", marker="*") + ax.scatter(np.array(des_iteration)[np.array(des_iteration_move_markers[metric])==2], np.array(metric_ref_des_dict[metric])[np.array(des_iteration_move_markers[metric])==2], color="red", label="orig des", marker=".") + ax.scatter(np.array(des_iteration)[np.array(des_iteration_move_markers[metric])==1], np.array(metric_trans_des_dict[metric])[np.array(des_iteration_move_markers[metric])==1], color="green", label="trans_des", marker="*", alpha=.05) + ax.scatter(np.array(des_iteration)[np.array(des_iteration_move_markers[metric])==2], np.array(metric_trans_des_dict[metric])[np.array(des_iteration_move_markers[metric])==2], color="green", label="trans_des", marker=".", alpha=0.05) + ax.legend() + ax.set_xlabel("iteration count") + ax.set_ylabel(metric + " norm dist to budget (%)") + fig.savefig(os.path.join(config.latest_visualization,metric + "_distance_to_buddget_itr.png")) + + + trans_des_dist_to_goal = [] # contains list of designs over all distance to goal + for _, transformed_des in des_trail_list: + trans_des_dist_to_goal.append( + 100*transformed_des.get_dp_stats().dist_to_goal(["latency", "power", "area"],"simple")/len(metric_ref_des_dict.keys())) + + fig, ax = plt.subplots() + ax.set_title("normalize distance to all budgets VS iteration") + ax.scatter(des_iteration, trans_des_dist_to_goal) + #ax.legend() + ax.set_xlabel("iteration count") + ax.set_ylabel(metric + " norm dist to all budgets (%)") + fig.savefig(os.path.join(config.latest_visualization,"avg_budgets_distance_itr.pdf")) + + + plt.close('all') + barplot_moves(move_profile) + +# ------------------------------ +# Functionality: +# plot move stats +# ------------------------------ +def barplot_moves(move_profile): + # move_to_plot (only show the depth/breadth == 0) to simplicity purposes + move_to_plot = [] + for move in move_profile: + if config.regulate_move_tracking: + if move.get_breadth() == 0 and move.get_depth() == 0 and move.get_mini_breadht() == 0 and move.is_valid(): + move_to_plot.append(move) + elif move.is_valid(): + move_to_plot.append(move) + + #if (len(move_to_plot)+1) > 15: + # print("what") + # draw the metrics + metric_dict = {} + metric_names = ["latency", "power", "area"] + for metric_name in metric_names: + metric_dict[metric_name] = [] + metric_dict["cost"] = [] + height_list = [] + for move_ in move_to_plot: + # get metric values + metrics = move_.get_logs("metrics") + for metric_name in metric_names: + metric_dict[metric_name].append(metrics[metric_name]) + selected_metric = move_.get_metric() + + # find the height that you'd like to mark to specify the metric of interest + height = 0 + for metric in metric_names: + if metric == selected_metric: + height += metric_dict[metric][-1]/2 + height_list.append(height) + break + height += metric_dict[metric][-1] + if selected_metric =="cost": + height_list.append(height) + metric_dict["cost"].append(1) + for metric in metric_names: + metric_dict[metric][-1] = 0 + else: + metric_dict["cost"].append(0) + labels = [str(i) for i in list(range(1, len(metric_dict["latency"])+1))] + power_plus_area = [metric_dict["latency"][i]+ metric_dict["power"][i] for i in range(len(labels))] + power_plus_latency= [metric_dict["latency"][i]+ metric_dict["power"][i] for i in range(len(labels))] + power_plus_latency_plus_area= [metric_dict["latency"][i]+ metric_dict["power"][i] for i in range(len(labels))] + + x = np.arange(len(labels)) # the label locations + width = 0.3 # the width of the bars + + + fig, ax = plt.subplots(figsize=(9.4, 4.8)) + rects1 = ax.bar(x - .5*(width), metric_dict["latency"], width, label='perf', color="gold") + rects2 = ax.bar(x - .5*(width), metric_dict["power"], width, bottom=metric_dict["latency"], + label='power', color="orange") + rects3 = ax.bar(x - .5*(width), metric_dict["area"], width, bottom=power_plus_latency, + label='area', color="red") + rects4 = ax.bar(x - .5*(width), metric_dict["cost"], width, bottom=power_plus_latency_plus_area, + label='cost', color="purple") + + plt.plot(x, height_list, marker='o', linewidth=.3, color="red", ms=1) + ax.set_ylabel('Metrics contribution (%)', fontsize=15) + ax.set_xlabel('Iteration ', fontsize=15) + ax.set_title('Metric Selection.', fontsize=15) + ax.set_xticks(x) + ax.set_xticklabels(labels, fontsize=4) + ax.legend(prop={'size': 15}) + plt.savefig(os.path.join(config.latest_visualization,"Metric_Selection")) + plt.close('all') + + # draw the metrics with distance + metric_dict = {} + metric_names = ["latency", "power", "area"] + for metric_name in metric_names: + metric_dict[metric_name] = [] + metric_dict["cost"] = [] + height_list = [] + for move_ in move_to_plot: + # get metric values + dist_to_goal = move_.get_logs("ref_des_dist_to_goal_non_cost") + metrics = move_.get_logs("metrics") + for metric_name in metric_names: + metric_dict[metric_name].append(metrics[metric_name]*dist_to_goal*100) + selected_metric = move_.get_metric() + + # find the height that you'd like to mark to specify the metric of interest + height = 0 + for metric in metric_names: + if metric == selected_metric: + height += metric_dict[metric][-1] / 2 + height_list.append(height) + break + height += metric_dict[metric][-1] + if selected_metric == "cost": + height_list.append(height) + metric_dict["cost"].append(1*dist_to_goal*100) + for metric in metric_names: + metric_dict[metric][-1] = 0 + else: + metric_dict["cost"].append(0) + labels = [str(i) for i in list(range(1, len(metric_dict["latency"]) + 1))] + power_plus_area = [metric_dict["latency"][i] + metric_dict["power"][i] for i in range(len(labels))] + power_plus_latency = [metric_dict["latency"][i] + metric_dict["power"][i] for i in range(len(labels))] + power_plus_latency_plus_area = [metric_dict["latency"][i] + metric_dict["power"][i] for i in range(len(labels))] + + x = np.arange(len(labels)) # the label locations + width = 0.4 # the width of the bars + + fig, ax = plt.subplots(figsize=(9.4, 4.8)) + rects1 = ax.bar(x - .5 * (width), metric_dict["latency"], width, label='perf', color="gold") + rects2 = ax.bar(x - .5 * (width), metric_dict["power"], width, bottom=metric_dict["latency"], + label='power', color="orange") + rects3 = ax.bar(x - .5 * (width), metric_dict["area"], width, bottom=power_plus_latency, + label='area', color="red") + rects4 = ax.bar(x - .5 * (width), metric_dict["cost"], width, bottom=power_plus_latency_plus_area, + label='cost', color="purple") + + plt.yscale("log") + plt.plot(x, height_list, marker='>', linewidth=.6, color="green", label="targeted metric", ms= 1) + ax.set_ylabel('Distance to Goal(Budget) (%)', fontsize=15) + ax.set_xlabel('Iteration ', fontsize=15) + ax.set_title('Metric Selection', fontsize=15) + ax.set_xticks(x) + ax.set_xticklabels(labels, fontsize=4) + ax.legend(prop={'size': 13}, ncol = 4, bbox_to_anchor = (-0.3, 1.4), loc = 'upper left') + fig.tight_layout() + plt.savefig(os.path.join(config.latest_visualization,"Metric_Selection_with_distance")) + plt.close('all') + + # plot kernels + task_dict = {} + task_names = [krnl.get_task().name for krnl in move_profile[0].get_logs("kernels")] + + for task_name in task_names: + task_dict[task_name] = [] + #metric_dict["cost"] = [] + height_list = [] + for move_ in move_to_plot: + # get metric values + krnl_prob_dict = move_.get_logs("kernels") + for krnl, value in krnl_prob_dict.items(): + task_dict[krnl.get_task().name].append(100*value) + selected_krnl = move_.get_kernel_ref() + + # find the height that you'd like to mark to specify the metric of interest + height = 0 + for task_name in task_names: + if task_name == selected_krnl.get_task().name: + height += task_dict[task_name][-1]/2 + height_list.append(height) + break + height += task_dict[task_name][-1] + + selected_metric = move_.get_metric() + if selected_metric =="cost": + for task in task_names: + task_dict[task][-1] = 0 + height_list[-1] = 1 + task_dict[(move_.get_tasks()[0]).name][-1] = 100 + #if move_.dist_to_goal < .05: + # print("what done now") + #else: + # task_dict["cost"].append(0) + labels = [str(i) for i in list(range(1, len(task_dict[task_names[0]])+1))] + """ + sum = 0 + for task in task_names: + sum += task_dict[task][0] + """ + try: + x = np.arange(len(labels)) # the label locations + except: + print("what") + width = 0.4 # the width of the bars + + + #my_cmap = plt.get_cmap("Set3") + my_cmap = ["bisque", "darkorange", "tan", "gold", "olive", "greenyellow", "darkgreen", "turquoise", "crimson", + "lightblue", "yellow", + "chocolate", "hotpink", "darkorchid"] + fig, ax = plt.subplots(figsize=(9.4, 4.8)) + rects1 = ax.bar(x - .5*(width), task_dict[task_names[0]], width, label=task_names[0], color = my_cmap[0]) + prev_offset = len(x)*[0] + rects = [] + prev_task = task_names[0] + for idx, task_name in enumerate(task_names[1:]): + for idx_, y in enumerate(task_dict[prev_task]): + prev_offset[idx_] += y + prev_task = task_name + rects.append(ax.bar(x - .5*(width), task_dict[task_name], width, label=task_name, bottom=prev_offset, color = my_cmap[(idx+1)%len(my_cmap)])) + + plt.plot(x, height_list, marker='>', linewidth=1.5, color="green", label="Targeted Kernel", ms=1) + + ax.set_ylabel('Kernel Contribution (%)', fontsize=15) + ax.set_xlabel('Iteration ', fontsize=15) + ax.set_title('Kernel Selection', fontsize=15) + ax.set_xticks(x) + ax.set_xticklabels(labels, fontsize=2) + ax.legend(prop={'size': 9}, ncol=1, bbox_to_anchor = (1.01, 1), loc = 'upper left') + fig.tight_layout() + plt.savefig(os.path.join(config.latest_visualization,"Kernel_Selection")) + plt.close('all') + + + + # plot blocks + block_dict = {} + block_names = ["pe", "mem", "ic"] + + for block_name in block_names: + block_dict[block_name] = [] + #metric_dict["cost"] = [] + height_list = [] + for move_ in move_to_plot: + selected_metric = move_.get_metric() + # get metric values + block_prob_dict = move_.get_logs("blocks") + seen_blocks = [] + for block, value in block_prob_dict: + if block.type in seen_blocks: # can have multiple memory + if selected_metric == "latency": + block_dict[block.type][-1] = max(value, block_dict[block.type][-1]) + else: + block_dict[block.type][-1] += value + else: + block_dict[block.type].append(value) + seen_blocks.append(block.type) + selected_block = move_.get_block_ref().type + + # find the height that you'd like to mark to specify the metric of interest + height = 0 + seen_blocks = [] + for block in block_names: + if block in seen_blocks: + continue + if block == selected_block: + height += block_dict[block][-1]/2 + height_list.append(min(height, 100)) + break + height += block_dict[block][-1] + seen_blocks.append(block) + + if selected_metric =="cost": + for block in block_names: + block_dict[block][-1] = 0 + height_list[-1] = 1 + block_dict[move_.get_block_ref().type][-1] = 100 + + #if selected_metric == "latency" and not(block_dict["pe"][-1] == 100): + # print("what") + + labels = [str(i) for i in list(range(1, len(block_dict[block_names[0]])+1))] + + x = np.arange(len(labels)) # the label locations + width = 0.2 # the width of the bars + + my_cmap = ["orange", "blue", "red"] + + block_name + fig, ax = plt.subplots(figsize=(9.4, 4.8)) + rects1 = ax.bar(x - .5*(width), block_dict[block_names[0]], width, label=block_names[0].upper(), color = my_cmap[0]) + prev_offset = len(x)*[0] + rects = [] + prev_block_name = block_names[0] + for idx, block_name in enumerate(block_names[1:]): + for idx_, y in enumerate(block_dict[prev_block_name]): + prev_offset[idx_] += y + if not (len(x) == len(block_dict[block_name])): + ok = len(x) + print("what") + rects.append(ax.bar(x - .5 * (width), block_dict[block_name], width, label=block_name.upper(), bottom=prev_offset, + color=my_cmap[idx + 1])) + prev_block_name = block_name + plt.plot(x, height_list, marker='>', linewidth=.6, color="green", ms=1, label="Targeted Block") + + ax.set_ylabel('Block contribution (%)', fontsize=15) + ax.set_xlabel('Iteration ', fontsize=15) + ax.set_title('Block Selection', fontsize=15) + ax.set_xticks(x) + ax.set_xticklabels(labels, fontsize=4) + ax.legend(prop={'size': 10}, ncol=1, bbox_to_anchor = (1.01, 1), loc = 'upper left') + + fig.tight_layout() + #plt.savefig("system_sim_error_diff_workload") + plt.savefig(os.path.join(config.latest_visualization,"Block_Selection")) + plt.close('all') + + + # transformations + fig, ax = plt.subplots(figsize=(9.4, 4.8)) + transformation_decoding = ["split", "migrate", "swap", "split_swap", "cleanup", "identity", "dram_fix", "transfer","routing"] + y = [] + + for move_ in move_to_plot: + y.append(transformation_decoding.index(move_.get_transformation_name())+1) + x = np.arange(len(y)) # the label locations + plt.yticks(list(range(1,len(transformation_decoding)+1)), transformation_decoding, fontsize=15) + ax.set_ylabel('Transformation ', fontsize=15) + ax.set_xlabel('Iteration ', fontsize=15) + ax.set_title('Transformation Selection', fontsize=15) + ax.set_xticks(x) + ax.set_xticklabels(labels, fontsize=4) + fig.tight_layout() + plt.plot(list(range(0, len(y))), y, marker='o', linewidth=.6, color="green", label="Transformation", ms=1) + ax.legend(prop={'size': 15}) + plt.savefig(os.path.join(config.latest_visualization,"Transformation_Selection")) + plt.close('all') + + fig, ax = plt.subplots(figsize=(9.4, 4.8)) + y = [] + + for move_ in move_to_plot: + y.append(move_.get_logs("kernel_rnk_to_consider")) + x = np.arange(len(y)) # the label locations + ax.set_ylabel('Task Rank To Consider', fontsize=15) + ax.set_xlabel('Iteration ', fontsize=15) + ax.set_title('Kernel Rank to Consider', fontsize=15) + ax.set_xticks(x) + ax.set_xticklabels(labels, fontsize=4) + plt.plot(list(range(0, len(y))), y, marker='o', linewidth=.6, color="green", label="Kernel Rank", ms=1) + ax.legend(prop={'size': 10}) + plt.savefig(os.path.join(config.latest_visualization,"Kernel_rank_selection")) + plt.close('all') + + + + # distnace to goal + fig, ax = plt.subplots(figsize=(9.4, 4.8)) + y = [] + for move_ in move_to_plot: + y.append(100*move_.get_logs("ref_des_dist_to_goal_non_cost")) + x = np.arange(len(y)) # the label locations + ax.set_ylabel('(Normalized) Distance to Goal (%)', fontsize=15) + ax.set_xlabel('Iteration ', fontsize=15) + ax.set_title('Convergence to Goal', fontsize=15) + ax.set_xticks(x) + ax.set_xticklabels(labels, fontsize=4) + #plt.rcParams["figure.figsize"] = (10, 5) + plt.plot(list(range(0, len(y))), y, marker='>', linewidth=.6, color="green", label="Distance", ms=1) + ax.legend(prop={'size': 10}) + plt.savefig(os.path.join(config.latest_visualization,"distance_to_goal")) + plt.close('all') + + + # distnace to goal + fig, ax = plt.subplots(figsize=(9.4, 4.8)) + y = [] + for move_ in move_to_plot: + y.append(move_.get_logs("cost")/move_to_plot[0].get_logs("cost")) + x = range(1, len(y)+1) # the label locations + ax.set_ylabel('Cost (Development and Silicon)', fontsize=15) + ax.set_xlabel('Iteration ', fontsize=15) + ax.set_title('Cost', fontsize=15) + ax.set_xticks(x) + #ax.set_xticklabels(labels, fontsize=4) + #plt.rcParams["figure.figsize"] = (10, 5) + plt.plot(list(range(1, len(y)+1)), y, marker='>', linewidth=.6, color="green", label="cost", ms=1) + ax.legend(prop={'size': 10}) + plt.savefig(os.path.join(config.latest_visualization,"cost_")) + plt.close('all') + + +def scatter_plot(x, y, axis_name, database): + fig, ax = plt.subplots() + ax.scatter(x, y, marker="x") + ax.set_xlabel(axis_name[0] + " count") + if (config.DATA_DELIVEYRY== "obfuscate"): + axis_name_ = "normalized " + axis_name[1]+ " (normalized to simple single core design)" + else: + axis_name_ = axis_name[1] + " ("+config.axis_unit[axis_name[1]] +")" + ax.set_ylabel(axis_name_) + title = "" + title = " with constraints" + + title += " \nand " + config.migrant_clustering_policy + " migration policy" + ax.set_title(axis_name[0] + " V.S. " + axis_name[1]) + fig.savefig(os.path.join(config.latest_visualization,axis_name[0]+"_"+axis_name[1]+".png")) \ No newline at end of file diff --git a/Project_FARSI/visualization_utils/plot_arrows.py b/Project_FARSI/visualization_utils/plot_arrows.py new file mode 100644 index 00000000..720025d0 --- /dev/null +++ b/Project_FARSI/visualization_utils/plot_arrows.py @@ -0,0 +1,48 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +import matplotlib.pyplot as plt +from settings import config +from mpl_toolkits.mplot3d import Axes3D +from matplotlib.patches import FancyArrowPatch +from mpl_toolkits.mplot3d import proj3d +from copy import copy +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D +import numpy as np + + +# fine tuning how we plot arrows +def plot3d_arrow(): + input_list = [[0, 0, 0, 1, 3, 1], [1,1,1, 3,3,3]] + min_bounds = {} + max_bounds = {} + + # prepare the input + for el in ["x", "y", "z"]: + min_bounds[el] = 1000 + max_bounds[el] = -1000 + + for input in input_list: + x_0, y_0, z_0, x_1, y_1, z_1 = input + min_bounds["x"] = min(min_bounds["x"], x_0, x_1) + min_bounds["y"] = min(min_bounds["y"], y_0, y_1) + min_bounds["z"] = min(min_bounds["z"], z_0, z_1) + max_bounds["x"] = max(max_bounds["x"], x_0, x_1) + max_bounds["y"] = max(max_bounds["y"], y_0, y_1) + max_bounds["z"] = max(max_bounds["z"], z_0, z_1) + + soa = np.array([[1, 1, 1, 4, 4, 4], [4,4,4, 1,1,1]]) + + # plot + X, Y, Z, U, V, W = zip(*soa) + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + ax.quiver(X, Y, Z, U, V,W, arrow_length_ratio=.05) + ax.set_xlim(.8*min_bounds["x"], 1.4*max_bounds["x"]) + ax.set_ylim(.8*min_bounds["y"], 1.4*max_bounds["y"]) + ax.set_zlim(.8*min_bounds["z"], 1.4*max_bounds["z"]) + plt.show() + +plot3d_arrow() diff --git a/Project_FARSI/visualization_utils/plotting-ying.py b/Project_FARSI/visualization_utils/plotting-ying.py new file mode 100644 index 00000000..a32df0ba --- /dev/null +++ b/Project_FARSI/visualization_utils/plotting-ying.py @@ -0,0 +1,3081 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. +import itertools +import copy +import csv +import os +import matplotlib.pyplot as plt +import pandas as pd +import numpy as np +import shutil +from settings import config_plotting +import time + +def get_column_name_number(dir_addr, mode): + column_name_number_dic = {} + try: + if mode == "all": + file_name = "result_summary/FARSI_simple_run_0_1_all_reults.csv" + else: + file_name = "result_summary/FARSI_simple_run_0_1.csv" + + file_full_addr = os.path.join(dir_addr, file_name) + with open(file_full_addr) as f: + resultReader = csv.reader(f, delimiter=',', quotechar='|') + for row in resultReader: + for idx, el_name in enumerate(row): + column_name_number_dic[el_name] = idx + break + return column_name_number_dic + except Exception as e: + raise e + + + +# + + +# the function to get the column information of the given category +def columnNum(dirName, fileName, cate, result): + if result == "all": + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + for i, row in enumerate(resultReader): + if i == 0: + for j in range(0, len(row)): + if row[j] == cate: + return j + raise Exception("No such category in the list! Check the name: " + cate) + break + elif result == "simple": + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + for i, row in enumerate(resultReader): + if i == 0: + for j in range(0, len(row)): + if row[j] == cate: + return j + raise Exception("No such category in the list! Check the name: " + cate) + break + else: + raise Exception("No such result file! Check the result type! It should be either \"all\" or \"simple\"") + +# the function to plot the frequency of all comm_comp in the pie chart +def plotCommCompAll(dirName, fileName, all_res_column_name_number): + colNum = all_res_column_name_number["comm_comp"] + truNum = all_res_column_name_number["move validity"] + + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + commNum = 0 + compNum = 0 + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i > 1: + if row[colNum] == "comm": + commNum += 1 + elif row[colNum] == "comp": + compNum += 1 + else: + raise Exception("comm_comp is not giving comm or comp! The new type: " + row[colNum]) + + plt.figure() + plt.pie([commNum, compNum], labels = ["comm", "comp"]) + plt.title("comm_comp: Frequency") + plt.savefig(dirName + fileName + "/comm-compFreq-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot the frequency of all high level optimizations in the pie chart +def plothighLevelOptAll(dirName, fileName, all_res_column_name_number): + colNum = all_res_column_name_number["high level optimization name"] + truNum = all_res_column_name_number["move validity"] + + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + topoNum = 0 + tunNum = 0 + mapNum = 0 + idenOptNum = 0 + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i > 1: + if row[colNum] == "topology": + topoNum += 1 + elif row[colNum] == "customization": + tunNum += 1 + elif row[colNum] == "mapping": + mapNum += 1 + elif row[colNum] == "identity": + idenOptNum += 1 + else: + raise Exception("high level optimization name is not giving topology or customization or mapping or identity! The new type: " + row[colNum]) + + plt.figure() + plt.pie([topoNum, tunNum, mapNum, idenOptNum], labels = ["topology", "customization", "mapping", "identity"]) + plt.title("High Level Optimization: Frequency") + plt.savefig(dirName + fileName + "/highLevelOpt-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot the frequency of all architectural variables to improve in the pie chart +def plotArchVarImpAll(dirName, fileName, colNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + parazNum = 0 + custNum = 0 + localNum = 0 + idenImpNum = 0 + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i > 1: + if row[colNum] == "parallelization": + parazNum += 1 + elif row[colNum] == "customization": + custNum += 1 + elif row[colNum] == "locality": + localNum += 1 + elif row[colNum] == "identity": + idenImpNum += 1 + else: + raise Exception("architectural principle is not parallelization or customization or locality or identity! The new type: " + row[colNum]) + + plt.figure() + plt.pie([parazNum, custNum, localNum, idenImpNum], labels = ["parallelization", "customization", "locality", "identity"]) + plt.title("Architectural Principle: Frequency") + plt.savefig(dirName + fileName + "/archVarImp-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot simulation time vs. system block count +def plotSimTimeVSblk(dirName, fileName, blkColNum, simColNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + sysBlkCount = [] + simTime = [] + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i > 1: + sysBlkCount.append(int(row[blkColNum])) + simTime.append(float(row[simColNum])) + + plt.figure() + plt.plot(sysBlkCount, simTime) + plt.xlabel("System Block Count") + plt.ylabel("Simulation Time") + plt.title("Simulation Time vs. Sytem Block Count") + plt.savefig(dirName + fileName + "/simTimeVSblk-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot move generation time vs. system block count +def plotMoveGenTimeVSblk(dirName, fileName, blkColNum, movColNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + sysBlkCount = [] + moveGenTime = [] + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i > 1: + sysBlkCount.append(int(row[blkColNum])) + moveGenTime.append(float(row[movColNum])) + + plt.figure() + plt.plot(sysBlkCount, moveGenTime) + plt.xlabel("System Block Count") + plt.ylabel("Move Generation Time") + plt.title("Move Generation Time vs. System Block Count") + plt.savefig(dirName + fileName + "/moveGenTimeVSblk-" + fileName + ".png") + # plt.show() + plt.close('all') + +def get_experiments_workload(all_res_column_name): + latency_budget = all_res_column_name_number["latency budget"][:-1] + workload_latency = latency_budget.split(";") + workloads = [] + for workload_latency in workload_latency: + workloads.append(workload_latency.split("=")[0]) + return workloads + +def get_experiments_name(file_full_addr, all_res_column_name_number): + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + row1 = next(resultReader) + row2 = next(resultReader) + latency_budget = row2[all_res_column_name_number["latency_budget"]] + power_budget = row2[all_res_column_name_number["power_budget"]] + area_budget = row2[all_res_column_name_number["area_budget"]] + try: + transformation_selection_mode = row2[all_res_column_name_number["transformation_selection_mode"]] + except: + transformation_selection_mode = "" + + + workload_latency = latency_budget[:-1].split(';') + latency_budget_refined ="" + for workload_latency in workload_latency: + latency_budget_refined +="_" + (workload_latency.split("=")[0][0]+workload_latency.split("=")[1]) + + return latency_budget_refined[1:]+"_" + power_budget + "_" + area_budget+"_"+transformation_selection_mode + +def get_all_col_values_of_a_file(file_full_addr, all_res_column_name_number, column_name): + column_number = all_res_column_name_number[column_name] + all_values = [] + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + experiment_name = get_experiments_name(file_full_addr, all_res_column_name_number) + for i, row in enumerate(resultReader): + if i > 1: + if not row[column_number] == '': + value =row[column_number] + values = value.split(";") # if mutiple values + for val in values: + if "=" in val: + val_splitted = val.split("=") + all_values.append(val_splitted[0]) + else: + all_values.append(val) + + return all_values + +def get_all_col_values_of_a_folders(input_dir_names, input_all_res_column_name_number, column_name): + all_values = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + all_values.extend(get_all_col_values_of_a_file(file_full_addr, input_all_res_column_name_number, column_name)) + + # get rid of duplicates + all_values_rid_of_duplicates = list(set(all_values)) + return all_values_rid_of_duplicates + +def extract_latency_values(values_): + print("") + + +def plot_codesign_rate_efficacy_cross_workloads_updated(input_dir_names, res_column_name_number): + #itrColNum = all_res_column_name_number["iteration cnt"] + #distColNum = all_res_column_name_number["dist_to_goal_non_cost"] + trueNum = all_res_column_name_number["move validity"] + move_name_number = all_res_column_name_number["move name"] + + # experiment_names + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + file_full_addr_list.append(file_full_addr) + + axis_font = {'fontname': 'Arial', 'size': '4'} + x_column_name = "iteration cnt" + #y_column_name_list = ["high level optimization name", "exact optimization name", "architectural principle", "comm_comp"] + y_column_name_list = ["exact optimization name", "architectural principle", "comm_comp", "workload"] + + #y_column_name_list = ["high level optimization name", "exact optimization name", "architectural principle", "comm_comp"] + + + + column_co_design_cnt = {} + column_non_co_design_cnt = {} + column_co_design_rate = {} + column_non_co_design_rate = {} + column_co_design_efficacy_avg = {} + column_non_co_design_efficacy_rate = {} + column_non_co_design_efficacy = {} + column_co_design_dist= {} + column_co_design_dist_avg= {} + column_co_design_improvement = {} + experiment_name_list = [] + last_col_val = "" + for file_full_addr in file_full_addr_list: + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_name_list.append(experiment_name) + column_co_design_dist_avg[experiment_name] = {} + column_co_design_efficacy_avg[experiment_name] = {} + + column_co_design_cnt = {} + for y_column_name in y_column_name_list: + y_column_number = res_column_name_number[y_column_name] + x_column_number = res_column_name_number[x_column_name] + + + dis_to_goal_column_number = res_column_name_number["dist_to_goal_non_cost"] + ref_des_dis_to_goal_column_number = res_column_name_number["ref_des_dist_to_goal_non_cost"] + column_co_design_cnt[y_column_name] = [] + column_non_co_design_cnt[y_column_name] = [] + + column_non_co_design_efficacy[y_column_name] = [] + column_co_design_dist[y_column_name] = [] + column_co_design_improvement[y_column_name] = [] + column_co_design_rate[y_column_name] = [] + + all_values = get_all_col_values_of_a_folders(input_dir_names, all_res_column_name_number, y_column_name) + + last_row_change = "" + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + rows = list(resultReader) + for i, row in enumerate(rows): + if i >= 1: + last_row = rows[i - 1] + if row[y_column_number] not in all_values or row[move_name_number]=="identity": + continue + + col_value = row[y_column_number] + col_values = col_value.split(";") + for idx, col_val in enumerate(col_values): + + + # only for improvement + if float(row[ref_des_dis_to_goal_column_number]) - float(row[dis_to_goal_column_number]) < 0: + continue + + delta_x_column = (float(row[x_column_number]) - float(last_row[x_column_number]))/len(col_values) + delta_improvement = (float(last_row[dis_to_goal_column_number]) - float(row[dis_to_goal_column_number]))/(float(last_row[dis_to_goal_column_number])*len(col_values)) + + + if not col_val == last_col_val and i > 1: + if not last_row_change == "": + distance_from_last_change = float(last_row[x_column_number]) - float(last_row_change[x_column_number]) + idx * delta_x_column + column_co_design_dist[y_column_name].append(distance_from_last_change) + improvement_from_last_change = (float(last_row[dis_to_goal_column_number]) - float(row[dis_to_goal_column_number]))/float(last_row[dis_to_goal_column_number]) + idx *delta_improvement + column_co_design_improvement[y_column_name].append(improvement_from_last_change) + + last_row_change = copy.deepcopy(last_row) + + + last_col_val = col_val + + + + # co_des cnt + # we ignore the first element as the first element distance is always zero + co_design_dist_sum = 0 + co_design_efficacy_sum = 0 + avg_ctr = 1 + co_design_dist_selected = column_co_design_dist[y_column_name] + co_design_improvement_selected = column_co_design_improvement[y_column_name] + for idx,el in enumerate(column_co_design_dist[y_column_name]): + if idx == len(co_design_dist_selected) - 1: + break + co_design_dist_sum += (column_co_design_dist[y_column_name][idx] + column_co_design_dist[y_column_name][idx+1]) + co_design_efficacy_sum += (column_co_design_improvement[y_column_name][idx] + column_co_design_improvement[y_column_name][idx+1]) + #/(column_co_design_dist[y_column_name][idx] + column_co_design_dist[y_column_name][idx+1]) + avg_ctr+=1 + + column_co_design_improvement = {} + column_co_design_dist_avg[experiment_name][y_column_name]= co_design_dist_sum/avg_ctr + column_co_design_efficacy_avg[experiment_name][y_column_name] = co_design_efficacy_sum/avg_ctr + + #result = {"rate":{}, "efficacy":{}} + #rate_column_co_design = {} + + plt.figure() + plotdata = pd.DataFrame(column_co_design_dist_avg, index=y_column_name_list) + fontSize = 10 + plotdata.plot(kind='bar', fontsize=fontSize) + plt.xticks(fontsize=fontSize, rotation=6) + plt.yticks(fontsize=fontSize) + plt.xlabel("co design parameter", fontsize=fontSize) + plt.ylabel("co design distance", fontsize=fontSize) + plt.title("co desgin distance of different parameters", fontsize=fontSize) + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "cross_workloads/co_design_rate") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + plt.savefig(os.path.join(output_dir,"_".join(experiment_name_list) +"_"+"co_design_avg_dist"+'_'.join(y_column_name_list)+".png")) + plt.close('all') + + + plt.figure() + plotdata = pd.DataFrame(column_co_design_efficacy_avg, index=y_column_name_list) + fontSize = 10 + plotdata.plot(kind='bar', fontsize=fontSize) + plt.xticks(fontsize=fontSize, rotation=6) + plt.yticks(fontsize=fontSize) + plt.xlabel("co design parameter", fontsize=fontSize) + plt.ylabel("co design dis", fontsize=fontSize) + plt.title("co desgin efficacy of different parameters", fontsize=fontSize) + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "cross_workloads/co_design_rate") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + plt.savefig(os.path.join(output_dir,"_".join(experiment_name_list) +"_"+"co_design_efficacy"+'_'.join(y_column_name_list)+".png")) + plt.close('all') + +def plot_codesign_rate_efficacy_per_workloads(input_dir_names, res_column_name_number): + #itrColNum = all_res_column_name_number["iteration cnt"] + #distColNum = all_res_column_name_number["dist_to_goal_non_cost"] + trueNum = all_res_column_name_number["move validity"] + move_name_number = all_res_column_name_number["move name"] + + # experiment_names + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + file_full_addr_list.append(file_full_addr) + + axis_font = {'fontname': 'Arial', 'size': '4'} + x_column_name = "iteration cnt" + #y_column_name_list = ["high level optimization name", "exact optimization name", "architectural principle", "comm_comp"] + y_column_name_list = ["exact optimization name", "architectural principle", "comm_comp", "workload"] + + #y_column_name_list = ["high level optimization name", "exact optimization name", "architectural principle", "comm_comp"] + + + + column_co_design_cnt = {} + column_non_co_design_cnt = {} + column_co_design_rate = {} + column_non_co_design_rate = {} + column_co_design_efficacy_rate = {} + column_non_co_design_efficacy_rate = {} + column_non_co_design_efficacy = {} + column_co_design_efficacy= {} + last_col_val = "" + for file_full_addr in file_full_addr_list: + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + column_co_design_cnt = {} + for y_column_name in y_column_name_list: + y_column_number = res_column_name_number[y_column_name] + x_column_number = res_column_name_number[x_column_name] + + + dis_to_goal_column_number = res_column_name_number["dist_to_goal_non_cost"] + ref_des_dis_to_goal_column_number = res_column_name_number["ref_des_dist_to_goal_non_cost"] + column_co_design_cnt[y_column_name] = [] + column_non_co_design_cnt[y_column_name] = [] + + column_non_co_design_efficacy[y_column_name] = [] + column_co_design_efficacy[y_column_name] = [] + + all_values = get_all_col_values_of_a_folders(input_dir_names, all_res_column_name_number, y_column_name) + + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + rows = list(resultReader) + for i, row in enumerate(rows): + if i >= 1: + last_row = rows[i - 1] + if row[y_column_number] not in all_values or row[trueNum] == "False" or row[move_name_number]=="identity": + continue + + col_value = row[y_column_number] + col_values = col_value.split(";") + for idx, col_val in enumerate(col_values): + delta_x_column = (float(row[x_column_number]) - float(last_row[x_column_number]))/len(col_values) + + value_to_add_1 = (float(last_row[x_column_number]) + idx * delta_x_column, 1) + value_to_add_0 = (float(last_row[x_column_number]) + idx * delta_x_column, 0) + + # only for improvement + if float(row[ref_des_dis_to_goal_column_number]) - float(row[dis_to_goal_column_number]) < 0: + continue + + if not col_val == last_col_val: + + column_co_design_cnt[y_column_name].append(value_to_add_1) + column_non_co_design_cnt[y_column_name].append(value_to_add_0) + column_co_design_efficacy[y_column_name].append((float(row[ref_des_dis_to_goal_column_number]) - float(row[dis_to_goal_column_number]))/float(row[ref_des_dis_to_goal_column_number])) + column_non_co_design_efficacy[y_column_name].append(0) + else: + column_co_design_cnt[y_column_name].append(value_to_add_0) + column_non_co_design_cnt[y_column_name].append(value_to_add_1) + column_co_design_efficacy[y_column_name].append(0) + column_non_co_design_efficacy[y_column_name].append((float(row[ref_des_dis_to_goal_column_number]) - float(row[dis_to_goal_column_number]))/float(row[ref_des_dis_to_goal_column_number])) + + last_col_val = col_val + + + + # co_des cnt + x_values_co_design_cnt = [el[0] for el in column_co_design_cnt[y_column_name]] + y_values_co_design_cnt = [el[1] for el in column_co_design_cnt[y_column_name]] + y_values_co_design_cnt_total =sum(y_values_co_design_cnt) + total_iter = x_values_co_design_cnt[-1] + + # non co_des cnt + x_values_non_co_design_cnt = [el[0] for el in column_non_co_design_cnt[y_column_name]] + y_values_non_co_design_cnt = [el[1] for el in column_non_co_design_cnt[y_column_name]] + y_values_non_co_design_cnt_total =sum(y_values_non_co_design_cnt) + + column_co_design_rate[y_column_name] = y_values_co_design_cnt_total/total_iter + column_non_co_design_rate[y_column_name] = y_values_non_co_design_cnt_total/total_iter + + # co_des efficacy + y_values_co_design_efficacy = column_co_design_efficacy[y_column_name] + y_values_co_design_efficacy_total =sum(y_values_co_design_efficacy) + + + # non co_des efficacy + y_values_non_co_design_efficacy = column_non_co_design_efficacy[y_column_name] + y_values_non_co_design_efficacy_total =sum(y_values_non_co_design_efficacy) + + column_co_design_efficacy_rate[y_column_name] = y_values_co_design_efficacy_total/(y_values_non_co_design_efficacy_total + y_values_co_design_efficacy_total) + column_non_co_design_efficacy_rate[y_column_name] = y_values_non_co_design_efficacy_total/(y_values_non_co_design_efficacy_total + y_values_co_design_efficacy_total) + + + result = {"rate":{}, "efficacy":{}} + rate_column_co_design = {} + + result["rate"] = {"co_design":column_co_design_rate, "non_co_design": column_non_co_design_rate} + result["efficacy_rate"] = {"co_design":column_co_design_efficacy_rate, "non_co_design": column_non_co_design_efficacy_rate} + # prepare for plotting and plot + + + plt.figure() + plotdata = pd.DataFrame(result["rate"], index=y_column_name_list) + fontSize = 10 + plotdata.plot(kind='bar', fontsize=fontSize, stacked=True) + plt.xticks(fontsize=fontSize, rotation=6) + plt.yticks(fontsize=fontSize) + plt.xlabel("co design parameter", fontsize=fontSize) + plt.ylabel("co design rate", fontsize=fontSize) + plt.title("co desgin rate of different parameters", fontsize=fontSize) + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "single_workload/co_design_rate") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + plt.savefig(os.path.join(output_dir,experiment_name +"_"+"co_design_rate_"+'_'.join(y_column_name_list)+".png")) + plt.close('all') + + + plt.figure() + plotdata = pd.DataFrame(result["efficacy_rate"], index=y_column_name_list) + fontSize = 10 + plotdata.plot(kind='bar', fontsize=fontSize, stacked=True) + plt.xticks(fontsize=fontSize, rotation=6) + plt.yticks(fontsize=fontSize) + plt.xlabel("co design parameter", fontsize=fontSize) + plt.ylabel("co design efficacy rate", fontsize=fontSize) + plt.title("co design efficacy rate of different parameters", fontsize=fontSize) + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "single_workload/co_design_rate") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + plt.savefig(os.path.join(output_dir,experiment_name+"_"+"co_design_efficacy_rate_"+'_'.join(y_column_name_list)+".png")) + plt.close('all') + + + +def plot_codesign_progression_per_workloads(input_dir_names, res_column_name_number): + #itrColNum = all_res_column_name_number["iteration cnt"] + #distColNum = all_res_column_name_number["dist_to_goal_non_cost"] + trueNum = all_res_column_name_number["move validity"] + + # experiment_names + experiment_names = [] + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + file_full_addr_list.append(file_full_addr) + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_names.append(experiment_name) + + axis_font = {'size': '20'} + x_column_name = "iteration cnt" + y_column_name_list = ["high level optimization name", "exact optimization name", "architectural principle", "comm_comp"] + + + experiment_column_value = {} + for file_full_addr in file_full_addr_list: + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + for y_column_name in y_column_name_list: + y_column_number = res_column_name_number[y_column_name] + x_column_number = res_column_name_number[x_column_name] + experiment_column_value[experiment_name] = [] + all_values = get_all_col_values_of_a_folders(input_dir_names, all_res_column_name_number, y_column_name) + all_values_encoding = {} + for idx, val in enumerate(all_values): + all_values_encoding[val] = idx + + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + rows = list(resultReader) + for i, row in enumerate(rows): + #if row[trueNum] != "True": + # continue + if i >= 1: + if row[y_column_number] not in all_values: + continue + + col_value = row[y_column_number] + col_values = col_value.split(";") + for idx, col_val in enumerate(col_values): + last_row = rows[i-1] + delta_x_column = (float(row[x_column_number]) - float(last_row[x_column_number]))/len(col_values) + value_to_add = (float(last_row[x_column_number])+ idx*delta_x_column, col_val) + experiment_column_value[experiment_name].append(value_to_add) + + + + # prepare for plotting and plot + axis_font = {'size': '20'} + fontSize = 20 + + fig = plt.figure(figsize=(12, 8)) + plt.rc('font', **axis_font) + ax = fig.add_subplot(111) + x_values = [el[0] for el in experiment_column_value[experiment_name]] + #y_values = [all_values_encoding[el[1]] for el in experiment_column_value[experiment_name]] + y_values = [el[1] for el in experiment_column_value[experiment_name]] + + #ax.set_title("experiment vs system implicaction") + ax.tick_params(axis='both', which='major', labelsize=fontSize, rotation=60) + ax.set_xlabel(x_column_name, fontsize=20) + ax.set_ylabel(y_column_name, fontsize=20) + ax.plot(x_values, y_values, label=y_column_name, linewidth=2) + ax.legend(bbox_to_anchor=(1, 1), loc='upper left', fontsize=fontSize) + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "single_workload/progression") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + plt.tight_layout() + fig.savefig(os.path.join(output_dir,experiment_name+"_progression_"+'_'.join(y_column_name_list)+".png")) + # plt.show() + plt.close('all') + + fig = plt.figure(figsize=(12, 8)) + plt.rc('font', **axis_font) + ax = fig.add_subplot(111) + x_values = [el[0] for el in experiment_column_value[experiment_name]] + # y_values = [all_values_encoding[el[1]] for el in experiment_column_value[experiment_name]] + y_values = [el[1] for el in experiment_column_value[experiment_name]] + + # ax.set_title("experiment vs system implicaction") + ax.tick_params(axis='both', which='major', labelsize=fontSize, rotation=60) + ax.set_xlabel(x_column_name, fontsize=20) + ax.set_ylabel(y_column_name, fontsize=20) + ax.plot(x_values, y_values, label=y_column_name, linewidth=2) + ax.legend(bbox_to_anchor=(1, 1), loc='upper left', fontsize=fontSize) + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "single_workload/progression") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + plt.tight_layout() + fig.savefig(os.path.join(output_dir, experiment_name + "_progression_" + y_column_name + ".png")) + # plt.show() + plt.close('all') + + +def plot_3d(input_dir_names, res_column_name_number): + # experiment_names + experiment_names = [] + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1.csv") + file_full_addr_list.append(file_full_addr) + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_names.append(experiment_name) + + axis_font = {'size': '10'} + fontSize = 10 + column_value = {} + # initialize the dictionary + column_name_list = ["budget_scaling_power", "budget_scaling_area","budget_scaling_latency"] + + under_study_vars =["iteration cnt", + "local_bus_avg_theoretical_bandwidth", "local_bus_max_actual_bandwidth", + "local_bus_avg_actual_bandwidth", + "system_bus_avg_theoretical_bandwidth", "system_bus_max_actual_bandwidth", + "system_bus_avg_actual_bandwidth", "global_total_traffic", "local_total_traffic", + "global_memory_total_area", "local_memory_total_area", "ips_total_area", + "gpps_total_area","ip_cnt", "max_accel_parallelism", "avg_accel_parallelism", + "gpp_cnt", "max_gpp_parallelism", "avg_gpp_parallelism"] + + + + + # get all the data + for file_full_addr in file_full_addr_list: + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + experiment_name = get_experiments_name( file_full_addr, res_column_name_number) + + for i, row in enumerate(resultReader): + #if row[trueNum] != "True": + # continue + if i >= 1: + for column_name in column_name_list + under_study_vars: + if column_name not in column_value.keys() : + column_value[column_name] = [] + column_number = res_column_name_number[column_name] + col_value = row[column_number] + col_values = col_value.split(";") + if "=" in col_values[0]: + column_value[column_name].append(float((col_values[0]).split("=")[1])) + else: + column_value[column_name].append(float(col_values[0])) + + + for idx,under_study_var in enumerate(under_study_vars): + fig_budget_blkcnt = plt.figure(figsize=(12, 12)) + plt.rc('font', **axis_font) + ax_blkcnt = fig_budget_blkcnt.add_subplot(projection='3d') + img = ax_blkcnt.scatter3D(column_value["budget_scaling_power"], column_value["budget_scaling_area"], column_value["budget_scaling_latency"], + c=column_value[under_study_var], cmap="bwr", s=80, label="System Block Count") + for idx,_ in enumerate(column_value[under_study_var]): + coordinate = column_value[under_study_var][idx] + coord_in_scientific_notatio = "{:.2e}".format(coordinate) + + ax_blkcnt.text(column_value["budget_scaling_power"][idx], column_value["budget_scaling_area"][idx], column_value["budget_scaling_latency"][idx], '%s' % coord_in_scientific_notatio, size=fontSize) + + ax_blkcnt.set_xlabel("Power Budget", fontsize=fontSize) + ax_blkcnt.set_ylabel("Area Budget", fontsize=fontSize) + ax_blkcnt.set_zlabel("Latency Budget", fontsize=fontSize) + ax_blkcnt.legend() + cbar = fig_budget_blkcnt.colorbar(img, aspect=40) + cbar.set_label("System Block Count", rotation=270) + #plt.title("{Power Budget, Area Budget, Latency Budget} VS System Block Count: " + subDirName) + plt.tight_layout() + + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "3D/case_studies") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + plt.savefig(os.path.join(output_dir, under_study_var+ ".png")) + # plt.show() + plt.close('all') + + +def plot_convergence_per_workloads(input_dir_names, res_column_name_number): + #itrColNum = all_res_column_name_number["iteration cnt"] + #distColNum = all_res_column_name_number["dist_to_goal_non_cost"] + trueNum = all_res_column_name_number["move validity"] + move_name_number = all_res_column_name_number["move name"] + + + # experiment_names + experiment_names = [] + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + file_full_addr_list.append(file_full_addr) + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_names.append(experiment_name) + + color_values = ["r","b","y","black","brown","purple"] + column_name_color_val_dict = {"best_des_so_far_power":"purple", "power_budget":"purple","best_des_so_far_area_non_dram":"blue", "area_budget":"blue", + "latency_budget_hpvm_cava":"orange", "latency_budget_audio_decoder":"yellow", "latency_budget_edge_detection":"red", + "best_des_so_far_latency_hpvm_cava":"orange", "best_des_so_far_latency_audio_decoder": "yellow","best_des_so_far_latency_edge_detection": "red", + "latency_budget":"white" + } + + axis_font = {'size': '20'} + fontSize = 20 + x_column_name = "iteration cnt" + y_column_name_list = ["power", "area_non_dram", "latency", "latency_budget", "power_budget","area_budget"] + + experiment_column_value = {} + for file_full_addr in file_full_addr_list: + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_column_value[experiment_name] = {} + for y_column_name in y_column_name_list: + if "budget" in y_column_name: + prefix = "" + else: + prefix = "best_des_so_far_" + y_column_name = prefix+y_column_name + y_column_number = res_column_name_number[y_column_name] + x_column_number = res_column_name_number[x_column_name] + #dis_to_goal_column_number = res_column_name_number["dist_to_goal_non_cost"] + #ref_des_dis_to_goal_column_number = res_column_name_number["ref_des_dist_to_goal_non_cost"] + + if not y_column_name == prefix+"latency": + experiment_column_value[experiment_name][y_column_name] = [] + + + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + for i, row in enumerate(resultReader): + if i > 1: + if row[trueNum] == "FALSE" or row[move_name_number]=="identity": + continue + col_value = row[y_column_number] + if ";" in col_value: + col_value = col_value[:-1] + col_values = col_value.split(";") + for col_val in col_values: + if "=" in col_val: + val_splitted = col_val.split("=") + value_to_add = (float(row[x_column_number]), (val_splitted[0], val_splitted[1])) + else: + value_to_add = (float(row[x_column_number]), col_val) + + if y_column_name in [prefix+"latency", prefix+"latency_budget"] : + new_tuple = (value_to_add[0], 1000*float(value_to_add[1][1])) + if y_column_name+"_"+value_to_add[1][0] not in experiment_column_value[experiment_name].keys(): + experiment_column_value[experiment_name][y_column_name + "_" + value_to_add[1][0]] = [] + experiment_column_value[experiment_name][y_column_name+"_"+value_to_add[1][0]].append(new_tuple) + if y_column_name in [prefix+"power", prefix+"power_budget"]: + new_tuple = (value_to_add[0], float(value_to_add[1])*1000) + experiment_column_value[experiment_name][y_column_name].append(new_tuple) + elif y_column_name in [prefix+"area_non_dram", prefix+"area_budget"]: + new_tuple = (value_to_add[0], float(value_to_add[1]) * 1000000) + experiment_column_value[experiment_name][y_column_name].append(new_tuple) + + # prepare for plotting and plot + fig = plt.figure(figsize=(15, 8)) + ax = fig.add_subplot(111) + for column, values in experiment_column_value[experiment_name].items(): + x_values = [el[0] for el in values] + y_values = [el[1] for el in values] + ax.set_yscale('log') + if "budget" in column: + marker = 'x' + alpha_ = .3 + else: + marker = "_" + alpha_ = 1 + ax.plot(x_values, y_values, label=column, c=column_name_color_val_dict[column], marker=marker, alpha=alpha_) + + #ax.set_title("experiment vs system implicaction") + ax.set_xlabel(x_column_name, fontsize=fontSize) + y_axis_name = "_".join(list(experiment_column_value[experiment_name].keys())) + ax.set_ylabel(y_axis_name, fontsize=fontSize) + ax.legend(bbox_to_anchor=(1, 1), loc='upper left', fontsize=fontSize) + plt.tight_layout() + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "single_workload/convergence") + if not os.path.exists(output_dir): + os.mkdir(output_dir) + fig.savefig(os.path.join(output_dir,experiment_name+"_convergence.png")) + # plt.show() + plt.close('all') + +def plot_convergence_vs_time(input_dir_names, res_column_name_number): + PA_time_scaling_factor = 10 + #itrColNum = all_res_column_name_number["iteration cnt"] + #distColNum = all_res_column_name_number["dist_to_goal_non_cost"] + trueNum = all_res_column_name_number["move validity"] + + # experiment_names + experiment_names = [] + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + file_full_addr_list.append(file_full_addr) + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_names.append(experiment_name) + + axis_font = {'size': '15'} + fontSize = 20 + x_column_name = "exploration_plus_simulation_time" + y_column_name_list = ["best_des_so_far_dist_to_goal_non_cost"] + + PA_column_experiment_value = {} + FARSI_column_experiment_value = {} + + #column_name = "move name" + for k, file_full_addr in enumerate(file_full_addr_list): + for y_column_name in y_column_name_list: + # get all possible the values of interest + y_column_number = res_column_name_number[y_column_name] + x_column_number = res_column_name_number[x_column_name] + PA_column_experiment_value[y_column_name] = [] + FARSI_column_experiment_value[y_column_name] = [] + PA_last_time = 0 + FARSI_last_time = 0 + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + experiment_name = get_experiments_name( file_full_addr, res_column_name_number) + for i, row in enumerate(resultReader): + #if row[trueNum] != "True": + # continue + if i >= 1: + FARSI_last_time += float(row[x_column_number]) + FARSI_value_to_add = (float(FARSI_last_time), row[y_column_number]) + FARSI_column_experiment_value[y_column_name].append(FARSI_value_to_add) + + PA_last_time = FARSI_last_time*PA_time_scaling_factor + PA_value_to_add = (float(PA_last_time), row[y_column_number]) + PA_column_experiment_value[y_column_name].append(PA_value_to_add) + + # prepare for plotting and plot + fig = plt.figure(figsize=(12, 12)) + plt.rc('font', **axis_font) + ax = fig.add_subplot(111) + fontSize = 20 + #plt.tight_layout() + x_values = [el[0] for el in FARSI_column_experiment_value[y_column_name]] + y_values = [str(float(el[1]) * 100 // 1 / 100.0) for el in FARSI_column_experiment_value[y_column_name]] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="FARSI time to completion", marker="_") + # ax.set_yscale('log') + + x_values = [el[0] for el in PA_column_experiment_value[y_column_name]] + y_values = [str(float(el[1]) * 100 // 1 / 100.0) for el in PA_column_experiment_value[y_column_name]] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="PA time to completion", marker="_") + #ax.set_xscale('log') + + #ax.set_title("experiment vs system implicaction") + ax.legend(loc="upper right")#bbox_to_anchor=(1, 1), loc="upper left") + ax.set_xlabel(x_column_name, fontsize=fontSize) + ax.set_ylabel(y_column_name, fontsize=fontSize) + plt.tight_layout() + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "single_workload/progression") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + fig.savefig(os.path.join(output_dir,str(k)+"_" + y_column_name+"_vs_"+x_column_name+"_FARSI_vs_PA.png")) + #plt.show() + plt.close('all') + + +def plot_convergence_cross_workloads(input_dir_names, res_column_name_number): + #itrColNum = all_res_column_name_number["iteration cnt"] + #distColNum = all_res_column_name_number["dist_to_goal_non_cost"] + trueNum = all_res_column_name_number["move validity"] + + # experiment_names + experiment_names = [] + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + file_full_addr_list.append(file_full_addr) + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_names.append(experiment_name) + + axis_font = {'size': '20'} + x_column_name = "iteration cnt" + y_column_name_list = ["best_des_so_far_dist_to_goal_non_cost", "dist_to_goal_non_cost"] + + column_experiment_value = {} + #column_name = "move name" + for y_column_name in y_column_name_list: + # get all possible the values of interest + y_column_number = res_column_name_number[y_column_name] + x_column_number = res_column_name_number[x_column_name] + + column_experiment_value[y_column_name] = {} + # initialize the dictionary + # get all the data + for file_full_addr in file_full_addr_list: + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + experiment_name = get_experiments_name( file_full_addr, res_column_name_number) + column_experiment_value[y_column_name][experiment_name] = [] + + for i, row in enumerate(resultReader): + #if row[trueNum] != "True": + # continue + if i >= 1: + value_to_add = (float(row[x_column_number]), max(float(row[y_column_number]),.01)) + column_experiment_value[y_column_name][experiment_name].append(value_to_add) + + # prepare for plotting and plot + fig = plt.figure() + ax = fig.add_subplot(111) + #plt.tight_layout() + for experiment_name, values in column_experiment_value[y_column_name].items(): + x_values = [el[0] for el in values] + y_values = [el[1] for el in values] + ax.scatter(x_values, y_values, label=experiment_name[1:]) + + #ax.set_title("experiment vs system implicaction") + ax.set_yscale('log') + ax.legend(bbox_to_anchor=(1, 1), loc="best") + ax.set_xlabel(x_column_name) + ax.set_ylabel(y_column_name) + plt.tight_layout() + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "cross_workloads/convergence") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + fig.savefig(os.path.join(output_dir,x_column_name+"_"+y_column_name+".png")) + # plt.show() + plt.close('all') + +def plot_system_implication_analysis(input_dir_names, res_column_name_number, case_study): + # experiment_names + experiment_names = [] + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1.csv") + file_full_addr_list.append(file_full_addr) + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_names.append(experiment_name) + + axis_font = {'size': '10'} + + column_name_list = list(case_study.values())[0] + + column_experiment_value = {} + #column_name = "move name" + for column_name in column_name_list: + # get all possible the values of interest + column_number = res_column_name_number[column_name] + + column_experiment_value[column_name] = {} + # initialize the dictionary + column_experiment_number_dict = {} + experiment_number_dict = {} + + # get all the data + for file_full_addr in file_full_addr_list: + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + experiment_name = get_experiments_name( file_full_addr, res_column_name_number) + + for i, row in enumerate(resultReader): + #if row[trueNum] != "True": + # continue + if i >= 1: + col_value = row[column_number] + col_values = col_value.split(";") + for col_val in col_values: + column_experiment_value[column_name][experiment_name] = float(col_val) + + # prepare for plotting and plot + # plt.figure() + index = experiment_names + plotdata = pd.DataFrame(column_experiment_value, index=index) + if list(case_study.keys())[0] in ["bandwidth_analysis","traffic_analysis"]: + plotdata.plot(kind='bar', fontsize=9, rot=5, log=True) + else: + plotdata.plot(kind='bar', fontsize=9, rot=5) + + plt.legend(loc="best", fontsize="9") + plt.xlabel("experiments", fontsize="10") + plt.ylabel("system implication") + #plt.title("experiment vs system implicaction") + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "cross_workloads/system_implications") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + #plt.tight_layout()list(case_study.keys())[0] + if "re_use" in list(case_study.keys())[0] or "speedup" in list(case_study.keys())[0]: + plt.yscale('log') + plt.savefig(os.path.join(output_dir,list(case_study.keys())[0]+".png")) + plt.close('all') + + + +def plot_co_design_nav_breakdown_post_processing(input_dir_names, column_column_value_experiment_frequency_dict): + column_name_list = [("exact optimization name", "neighbouring design space size", "div")] + #column_name = "move name" + for n, column_name_tuple in enumerate(column_name_list): + first_column = column_name_tuple[0] + second_column = column_name_tuple[1] + operation = column_name_tuple[2] + new_column_name = first_column+"_"+operation+"_"+second_column + + first_column_value_experiment_frequency_dict = column_column_value_experiment_frequency_dict[first_column] + second_column_value_experiment_frequency_dict = column_column_value_experiment_frequency_dict[second_column] + modified_column_value_experiment_frequency_dict = {} + + experiment_names = [] + for column_val, experiment_freq in first_column_value_experiment_frequency_dict.items(): + if column_val == "unknown": + continue + modified_column_value_experiment_frequency_dict[column_val] = {} + for experiment, freq in experiment_freq.items(): + if(second_column_value_experiment_frequency_dict[column_val][experiment]) < .000001: + modified_column_value_experiment_frequency_dict[column_val][experiment] = 0 + else: + modified_column_value_experiment_frequency_dict[column_val][experiment] = first_column_value_experiment_frequency_dict[column_val][experiment]/max(second_column_value_experiment_frequency_dict[column_val][experiment],.0000000000001) + experiment_names.append(experiment) + + axis_font = {'size': '22'} + fontSize = 22 + experiment_names = list(set(experiment_names)) + # prepare for plotting and plot + # plt.figure(n) + plt.rc('font', **axis_font) + index = experiment_names + plotdata = pd.DataFrame(modified_column_value_experiment_frequency_dict, index=index) + plotdata.plot(kind='bar', stacked=True, figsize=(13, 8)) + plt.xlabel("experiments", **axis_font) + plt.ylabel(new_column_name, **axis_font) + plt.xticks(fontsize=fontSize, rotation=45) + plt.yticks(fontsize=fontSize) + plt.title("experiment vs " + new_column_name, **axis_font) + plt.legend(bbox_to_anchor=(1, 1), loc='upper left', fontsize=fontSize) + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "cross_workloads/nav_breakdown") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # plt.tight_layout() + plt.savefig(os.path.join(output_dir,'_'.join(new_column_name.split(" "))+".png"), bbox_inches='tight') + plt.tight_layout() + # plt.show() + plt.close('all') + + + +# navigation breakdown +def plot_codesign_nav_breakdown_per_workload(input_dir_names, input_all_res_column_name_number): + trueNum = input_all_res_column_name_number["move validity"] + + # experiment_names + experiment_names = [] + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + file_full_addr_list.append(file_full_addr) + experiment_name = get_experiments_name(file_full_addr, input_all_res_column_name_number) + experiment_names.append(experiment_name) + + + axis_font = {'size': '20'} + fontSize = 20 + column_name_list = ["transformation_metric", "comm_comp", "workload"]#, "architectural principle", "high level optimization name", "exact optimization name"] + #column_name_list = ["architectural principle", "exact optimization name"] + + #column_name = "move name" + # initialize the dictionary + column_column_value_experiment_frequency_dict = {} + for file_full_addr in file_full_addr_list: + column_column_value_frequency_dict = {} + for column_name in column_name_list: + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + experiment_name = get_experiments_name(file_full_addr, input_all_res_column_name_number) + #column_column_value_frequency_dict[column_name] = {} + # get all possible the values of interest + all_values = get_all_col_values_of_a_folders(input_dir_names, input_all_res_column_name_number, column_name) + columne_number = all_res_column_name_number[column_name] + for column in all_values: + column_column_value_frequency_dict[column] = {} + column_column_value_frequency_dict[column][column_name] = 0 + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + if i > 1: + col_value = row[columne_number] + col_values = col_value.split(";") + for col_val in col_values: + if "=" in col_val: + val_splitted = col_val.split("=") + column_column_value_frequency_dict[val_splitted[0]][column_name] += float(val_splitted[1]) + else: + column_column_value_frequency_dict[col_val][column_name] += 1 + + index = column_name_list + total_cnt = 0 + for val in column_column_value_frequency_dict[column].values(): + total_cnt += val + + for col_val, column_name_val in column_column_value_frequency_dict.items(): + for column_name, val in column_name_val.items(): + column_column_value_frequency_dict[col_val][column_name] /= max(total_cnt,1) + + plotdata = pd.DataFrame(column_column_value_frequency_dict, index=index) + plotdata.plot(kind='bar', stacked=True, figsize=(10, 10)) + plt.rc('font', ** axis_font) + plt.xlabel("experiments", **axis_font) + plt.ylabel(column_name, **axis_font) + plt.xticks(fontsize=fontSize, rotation=45) + plt.yticks(fontsize=fontSize) + plt.title("experiment vs " + column_name, **axis_font) + plt.legend(bbox_to_anchor=(1, 1), loc='upper left', fontsize=fontSize) + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "single_workload/nav_breakdown") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + plt.tight_layout() + plt.savefig(os.path.join(output_dir,"__".join(column_name_list)+".png"), bbox_inches='tight') + # plt.show() + plt.close('all') + #column_column_value_experiment_frequency_dict[column_name] = copy.deepcopy(column_column_value_frequency_dict) + + return column_column_value_experiment_frequency_dict + + + + +def plot_codesign_nav_breakdown_cross_workload(input_dir_names, input_all_res_column_name_number): + trueNum = input_all_res_column_name_number["move validity"] + + # experiment_names + experiment_names = [] + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + file_full_addr_list.append(file_full_addr) + experiment_name = get_experiments_name(file_full_addr, input_all_res_column_name_number) + """ + Ying: the following lines are added to make the names clearer in the plottings + """ + if experiment_name[0] == 'a': + experiment_name = "Audio" + elif experiment_name[0] == 'h': + experiment_name = "CAVA" + elif experiment_name[0] == 'e': + experiment_name = "ED" + """ + Ying: adding finished + """ + experiment_names.append(experiment_name) + + axis_font = {'size': '25'} + fontSize = 25 + column_name_list = ["transformation_metric", "transformation_block_type", "move name", "comm_comp", "architectural principle", "high level optimization name", "exact optimization name", "neighbouring design space size"] + #column_name_list = ["transformation_metric", "move name"]#, "comm_comp", "architectural principle", "high level optimization name", "exact optimization name", "neighbouring design space size"] + #column_name = "move name" + # initialize the dictionary + column_column_value_experiment_frequency_dict = {} + for column_name in column_name_list: + column_value_experiment_frequency_dict = {} + # get all possible the values of interest + all_values = get_all_col_values_of_a_folders(input_dir_names, input_all_res_column_name_number, column_name) + columne_number = all_res_column_name_number[column_name] + for column in all_values: + """ + Ying: the following lines are added for "IC", "Mem", and "PE" + """ + if column_name == "transformation_block_type": + if column == "ic": + column = "IC" + elif column == "mem": + column = "Mem" + elif column == "pe": + column = "PE" + + if column_name == "architectural principle": + if column == "identity" or column == "spatial_locality": + continue + elif column == "task_level_parallelism": + column = "TLP" + elif column == "loop_level_parallelism": + column = "LLP" + """ + Ying: adding finished + """ + column_value_experiment_frequency_dict[column] = {} + + # get all the data + for file_full_addr in file_full_addr_list: + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + experiment_name = get_experiments_name( file_full_addr, input_all_res_column_name_number) + """ + Ying: the following lines are added to make the names clearer in the plottings + """ + if experiment_name[0] == 'a': + experiment_name = "Audio" + elif experiment_name[0] == 'h': + experiment_name = "CAVA" + elif experiment_name[0] == 'e': + experiment_name = "ED" + """ + Ying: adding finished + """ + for column_value in all_values: + """ + Ying: the following lines are added for "IC", "Mem", and "PE" + """ + if column_name == "transformation_block_type": + if column_value == "ic": + column_value = "IC" + elif column_value == "mem": + column_value = "Mem" + elif column_value == "pe": + column_value = "PE" + + if column_name == "architectural principle": + if column_value == "identity" or column_value == "spatial_locality": + continue + elif column_value == "task_level_parallelism": + column_value = "TLP" + elif column_value == "loop_level_parallelism": + column_value = "LLP" + """ + Ying: adding finished + """ + column_value_experiment_frequency_dict[column_value][experiment_name] = 0 + + for i, row in enumerate(resultReader): + #if row[trueNum] != "True": + # continue + if i > 1: + try: + + # the following for workload awareness + #if row[all_res_column_name_number["move name"]] == "identity": + # continue + #if row[all_res_column_name_number["architectural principle"]] == "spatial_locality": + # continue + + + col_value = row[columne_number] + col_values = col_value.split(";") + for col_val in col_values: + if "=" in col_val: + val_splitted = col_val.split("=") + column_value_experiment_frequency_dict[val_splitted[0]][experiment_name] += float(val_splitted[1]) + else: + """ + Ying: the following lines are added for "IC", "Mem", and "PE" + """ + if column_name == "transformation_block_type": + if col_val == "ic": + col_val = "IC" + elif col_val == "mem": + col_val = "Mem" + elif col_val == "pe": + col_val = "PE" + + if column_name == "architectural principle": + if col_val == "identity" or col_val == "spatial_locality": + continue + elif col_val == "task_level_parallelism": + col_val = "TLP" + elif col_val == "loop_level_parallelism": + col_val = "LLP" + """ + Ying: adding finished + """ + column_value_experiment_frequency_dict[col_val][experiment_name] += 1 + except: + print("what") + + total_cnt = {} + for el in column_value_experiment_frequency_dict.values(): + for exp, values in el.items(): + if exp not in total_cnt.keys(): + total_cnt[exp] = 0 + total_cnt[exp] += values + + for col_val, exp_vals in column_value_experiment_frequency_dict.items(): + for exp, values in exp_vals.items(): + column_value_experiment_frequency_dict[col_val][exp] = column_value_experiment_frequency_dict[col_val][exp] + if column_name != "architectural principle" and column_name != "comm_comp" and total_cnt[exp] != 0: # Ying: add to get rid of normalization for the two plottings + column_value_experiment_frequency_dict[col_val][exp] /= total_cnt[exp] # normalize + + # prepare for plotting and plot + # plt.figure(figsize=(6, 6)) + index = experiment_names + plotdata = pd.DataFrame(column_value_experiment_frequency_dict, index=index) + plotdata.plot(kind='bar', stacked=True, figsize=(8, 8)) + plt.rc('font', **axis_font) + plt.xlabel("Workloads", **axis_font) + # plt.ylabel(column_name, **axis_font) # Ying: replace with the following lines + """ + Ying: set the ylabel acordingly + """ + if column_name != "comm_comp": + if column_name == "architectural principle" or column_name == "comm_comp": + plt.ylabel("Iteration Count", **axis_font) + else: + plt.ylabel("Normalized Iteration Portion", **axis_font) + """ + Ying: adding finished + """ + plt.xticks(fontsize=fontSize, rotation=0) # Ying: the original one was 45) + plt.yticks(fontsize=fontSize) + # plt.title("experiment vs " + column_name, **axis_font) # Ying: comment it out as discussed + # plt.legend(bbox_to_anchor=(1, 1), loc='upper left', fontsize=fontSize) # Ying: replaced with the following line + plt.legend(bbox_to_anchor=(0.5, 1.15), loc='upper center', fontsize=fontSize, ncol=3) + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "cross_workloads/nav_breakdown") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + plt.tight_layout() + plt.savefig(os.path.join(output_dir,'_'.join(column_name.split(" "))+".png"), bbox_inches='tight') + # plt.show() + plt.close('all') + column_column_value_experiment_frequency_dict[column_name] = copy.deepcopy(column_value_experiment_frequency_dict) + + """ + # multi-stack plot here + index = experiment_names + plotdata = pd.DataFrame(column_column_value_experiment_frequency_dict, index=index) + + df_g = plotdata.groupby(["transformation_metric", "move name"]) + plotdata.plot(kind='bar', stacked=True, figsize=(12, 10)) + plt.rc('font', **axis_font) + plt.xlabel("experiments", **axis_font) + plt.ylabel(column_name, **axis_font) + plt.xticks(fontsize=fontSize, rotation=45) + plt.yticks(fontsize=fontSize) + plt.title("experiment vs " + column_name, **axis_font) + plt.legend(bbox_to_anchor=(1, 1), loc='upper left', fontsize=fontSize) + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "cross_workloads/nav_breakdown") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + plt.tight_layout() + plt.savefig(os.path.join(output_dir,'column____'.join(column_name.split(" "))+".png"), bbox_inches='tight') + # plt.show() + plt.close('all') + """ + return column_column_value_experiment_frequency_dict + + + + +# the function to plot distance to goal vs. iteration cnt +def plotDistToGoalVSitr(input_dir_names, all_res_column_name_number): + itrColNum = all_res_column_name_number["iteration cnt"] + distColNum = all_res_column_name_number["dist_to_goal_non_cost"] + trueNum = all_res_column_name_number["move validity"] + + experiment_itr_dist_to_goal_dict = {} + # iterate through directories, get data and store in a dictionary + for dir_name in input_dir_names: + itr = [] + distToGoal = [] + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + experiment_name = get_experiments_name(file_full_addr, all_res_column_name_number) + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + if i > 1: + itr.append(int(row[itrColNum])) + distToGoal.append(float(row[distColNum])) + + experiment_itr_dist_to_goal_dict[experiment_name] = (itr[:], distToGoal[:]) + + plt.figure() + # iterate and plot + for experiment_name, value in experiment_itr_dist_to_goal_dict.items(): + itr, distToGoal = value[0], value[1] + if len(itr) == 0 or len(distToGoal) == 0: # no valid move + continue + plt.plot(itr, distToGoal, label=experiment_name) + plt.xlabel("Iteration Cnt") + plt.ylabel("Distance to Goal") + plt.title("Distance to Goal vs. Iteration Cnt") + + # decide on the output dir + if len(input_dir_names) == 1: + output_dir = input_dir_names[0] + else: + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "cross_workloads") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + plt.savefig(os.path.join(output_dir, "distToGoalVSitr.png")) + # plt.show() + plt.close('all') + + +# the function to plot distance to goal vs. iteration cnt +def plotRefDistToGoalVSitr(dirName, fileName, itrColNum, refDistColNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + itr = [] + refDistToGoal = [] + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i > 1: + itr.append(int(row[itrColNum])) + refDistToGoal.append(float(row[refDistColNum])) + + plt.figure() + plt.plot(itr, refDistToGoal) + plt.xlabel("Iteration Cnt") + plt.ylabel("Reference Design Distance to Goal") + plt.title("Reference Design Distance to Goal vs. Iteration Cnt") + plt.savefig(dirName + fileName + "/refDistToGoalVSitr-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to do the zonal partitioning +def zonalPartition(comparedValue, zoneNum, maxValue): + unit = maxValue / zoneNum + + if comparedValue > maxValue: + return zoneNum - 1 + + if comparedValue < 0: + return 0 + + for i in range(0, zoneNum): + if comparedValue <= unit * (i + 1): + return i + + raise Exception("zonalPartition is fed by a strange value! maxValue: " + str(maxValue) + "; comparedValue: " + str(comparedValue)) + +# the function to plot simulation time vs. move name in a zonal format +def plotSimTimeVSmoveNameZoneDist(dirName, fileName, zoneNum, moveColNum, distColNum, simColNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + splitSwapSim = np.zeros(zoneNum, dtype = float) + splitSim = np.zeros(zoneNum, dtype = float) + migrateSim = np.zeros(zoneNum, dtype = float) + swapSim = np.zeros(zoneNum, dtype = float) + tranSim = np.zeros(zoneNum, dtype = float) + routeSim = np.zeros(zoneNum, dtype = float) + identitySim = np.zeros(zoneNum, dtype = float) + + maxDist = 0 + + index = [] + for i in range(0, zoneNum): + index.append(i) + + for i, row in enumerate(resultReader): + # print('"' + row[trueNum] + '"\t"' + row[moveColNum] + '"\t"' + row[distColNum] + '"\t"' + row[simColNum] + '"') + if row[trueNum] != "True": + continue + + if i == 2: + maxDist = float(row[distColNum]) + + if i > 1: + if row[moveColNum] == "split_swap": + splitSwapSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[moveColNum] == "split": + splitSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[moveColNum] == "migrate": + migrateSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[moveColNum] == "swap": + swapSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[moveColNum] == "transfer": + tranSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[moveColNum] == "routing": + routeSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[moveColNum] == "identity": + identitySim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + else: + raise Exception("move name is not split_swap or split or migrate or swap or transfer or routing or identity! The new type: " + row[moveColNum]) + + # plt.figure() + plotdata = pd.DataFrame({ + "split_swap":splitSwapSim, + "split":splitSim, + "migrate":migrateSim, + "swap":swapSim, + "transfer":tranSim, + "routing":routeSim, + "identity":identitySim + }, index = index + ) + plotdata.plot(kind = 'bar', stacked = True) + plt.xlabel("Zone decided by the max distance to goal") + plt.ylabel("Simulation Time") + plt.title("Simulation Time in Each Zone based on Move Name") + plt.savefig(dirName + fileName + "/simTimeVSmoveNameZoneDist-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot move generation time vs. move name in a zonal format +def plotMovGenTimeVSmoveNameZoneDist(dirName, fileName, zoneNum, moveColNum, distColNum, movGenColNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + splitSwapMov = np.zeros(zoneNum, dtype = float) + splitMov = np.zeros(zoneNum, dtype = float) + migrateMov = np.zeros(zoneNum, dtype = float) + swapMov = np.zeros(zoneNum, dtype = float) + tranMov = np.zeros(zoneNum, dtype = float) + routeMov = np.zeros(zoneNum, dtype = float) + identityMov = np.zeros(zoneNum, dtype = float) + + maxDist = 0 + + index = [] + for i in range(0, zoneNum): + index.append(i) + + for i, row in enumerate(resultReader): + # print('"' + row[trueNum] + '"\t"' + row[moveColNum] + '"\t"' + row[distColNum] + '"\t"' + row[movGenColNum] + '"') + if row[trueNum] != "True": + continue + + if i == 2: + maxDist = float(row[distColNum]) + + if i > 1: + if row[moveColNum] == "split_swap": + splitSwapMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[moveColNum] == "split": + splitMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[moveColNum] == "migrate": + migrateMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[moveColNum] == "swap": + swapMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[moveColNum] == "transfer": + tranMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[moveColNum] == "routing": + routeMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[moveColNum] == "identity": + identityMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + else: + raise Exception("move name is not split_swap or split or migrate or swap or transfer of routing or identity! The new type: " + row[moveColNum]) + + # plt.figure() + plotdata = pd.DataFrame({ + "split_swap":splitSwapMov, + "split":splitMov, + "migrate":migrateMov, + "swap":swapMov, + "transfer":tranMov, + "routing":routeMov, + "identity":identityMov + }, index = index + ) + plotdata.plot(kind = 'bar', stacked = True) + plt.xlabel("Zone decided by the max distance to goal") + plt.ylabel("Move Generation Time") + plt.title("Move Generation Time in Each Zone based on Move Name") + plt.savefig(dirName + fileName + "/movGenTimeVSmoveNameZoneDist-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot simulation time vs. comm_comp in a zonal format +def plotSimTimeVScommCompZoneDist(dirName, fileName, zoneNum, commcompColNum, distColNum, simColNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + commSim = np.zeros(zoneNum, dtype = float) + compSim = np.zeros(zoneNum, dtype = float) + + maxDist = 0 + index = [] + for i in range(0, zoneNum): + index.append(i) + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i == 2: + maxDist = float(row[distColNum]) + + if i > 1: + if row[commcompColNum] == "comm": + commSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[commcompColNum] == "comp": + compSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + else: + raise Exception("comm_comp is not giving comm or comp! The new type: " + row[colNum]) + + # plt.figure() + plotdata = pd.DataFrame({ + "comm":commSim, + "comp":compSim + }, index = index + ) + plotdata.plot(kind = 'bar', stacked = True) + plt.xlabel("Zone decided by the max distance to goal") + plt.ylabel("Simulation Time") + plt.title("Simulation Time in Each Zone based on comm_comp") + plt.savefig(dirName + fileName + "/simTimeVScommCompZoneDist-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot simulation time vs. comm_comp in a zonal format +def plotMovGenTimeVScommCompZoneDist(dirName, fileName, zoneNum, commcompColNum, distColNum, movGenColNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + commMov = np.zeros(zoneNum, dtype = float) + compMov = np.zeros(zoneNum, dtype = float) + + maxDist = 0 + index = [] + for i in range(0, zoneNum): + index.append(i) + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i == 2: + maxDist = float(row[distColNum]) + + if i > 1: + if row[commcompColNum] == "comm": + commMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[commcompColNum] == "comp": + compMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + else: + raise Exception("comm_comp is not giving comm or comp! The new type: " + row[colNum]) + + # plt.figure() + plotdata = pd.DataFrame({ + "comm":commMov, + "comp":compMov + }, index = index + ) + plotdata.plot(kind = 'bar', stacked = True) + plt.xlabel("Zone decided by the max distance to goal") + plt.ylabel("Move Generation Time") + plt.title("Move Generation Time in Each Zone based on comm_comp") + plt.savefig(dirName + fileName + "/movGenTimeVScommCompZoneDist-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot simulation time vs. high level optimization name in a zonal format +def plotSimTimeVShighLevelOptZoneDist(dirName, fileName, zoneNum, optColNum, distColNum, simColNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + topoSim = np.zeros(zoneNum, dtype = float) + tunSim = np.zeros(zoneNum, dtype = float) + mapSim = np.zeros(zoneNum, dtype = float) + idenOptSim = np.zeros(zoneNum, dtype = float) + + maxDist = 0 + index = [] + for i in range(0, zoneNum): + index.append(i) + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i == 2: + maxDist = float(row[distColNum]) + + if i > 1: + if row[optColNum] == "topology": + topoSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[optColNum] == "customization": + tunSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[optColNum] == "mapping": + mapSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[optColNum] == "identity": + idenOptSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + else: + raise Exception("high level optimization name is not giving topology or customization or mapping or identity! The new type: " + row[optColNum]) + + # plt.figure() + plotdata = pd.DataFrame({ + "topology":topoSim, + "customization":tunSim, + "mapping":mapSim, + "identity":idenOptSim + }, index = index + ) + plotdata.plot(kind = 'bar', stacked = True) + plt.xlabel("Zone decided by the max distance to goal") + plt.ylabel("Simulation Time") + plt.title("Simulation Time in Each Zone based on Optimation Name") + plt.savefig(dirName + fileName + "/simTimeVShighLevelOptZoneDist-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot simulation time vs. high level optimization name in a zonal format +def plotMovGenTimeVShighLevelOptZoneDist(dirName, fileName, zoneNum, optColNum, distColNum, movGenColNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + topoMov = np.zeros(zoneNum, dtype = float) + tunMov = np.zeros(zoneNum, dtype = float) + mapMov = np.zeros(zoneNum, dtype = float) + idenOptMov = np.zeros(zoneNum, dtype = float) + + maxDist = 0 + index = [] + for i in range(0, zoneNum): + index.append(i) + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i == 2: + maxDist = float(row[distColNum]) + + if i > 1: + if row[optColNum] == "topology": + topoMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[optColNum] == "customization": + tunMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[optColNum] == "mapping": + mapMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[optColNum] == "identity": + idenOptMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + else: + raise Exception("high level optimization name is not giving topology or customization or mapping or identity! The new type: " + row[optColNum]) + + # plt.figure() + plotdata = pd.DataFrame({ + "topology":topoMov, + "customization":tunMov, + "mapping":mapMov, + "identity":idenOptMov + }, index = index + ) + plotdata.plot(kind = 'bar', stacked = True) + plt.xlabel("Zone decided by the max distance to goal") + plt.ylabel("Transformation Generation Time") + plt.title("Transformation Generation Time in Each Zone based on Optimization Name") + plt.savefig(dirName + fileName + "/movGenTimeVShighLevelOptZoneDist-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot simulation time vs. architectural principle in a zonal format +def plotSimTimeVSarchVarImpZoneDist(dirName, fileName, zoneNum, archColNum, distColNum, simColNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + paraSim = np.zeros(zoneNum, dtype = float) + custSim = np.zeros(zoneNum, dtype = float) + localSim = np.zeros(zoneNum, dtype = float) + idenImpSim = np.zeros(zoneNum, dtype = float) + + maxDist = 0 + index = [] + for i in range(0, zoneNum): + index.append(i) + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i == 2: + maxDist = float(row[distColNum]) + + if i > 1: + if row[archColNum] == "parallelization": + paraSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[archColNum] == "customization": + custSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[archColNum] == "locality": + localSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[archColNum] == "identity": + idenImpSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + else: + raise Exception("architectural principle is not giving parallelization or customization or locality or identity! The new type: " + row[archColNum]) + + # plt.figure() + plotdata = pd.DataFrame({ + "parallelization":paraSim, + "customization":custSim, + "locality":localSim, + "identity":idenImpSim + }, index = index + ) + plotdata.plot(kind = 'bar', stacked = True) + plt.xlabel("Zone decided by the max distance to goal") + plt.ylabel("Simulation Time") + plt.title("Simulation Time in Each Zone based on Architectural Principle") + plt.savefig(dirName + fileName + "/simTimeVSarchVarImpZoneDist-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot simulation time vs. architectural principle in a zonal format +def plotMovGenTimeVSarchVarImpZoneDist(dirName, fileName, zoneNum, archColNum, distColNum, movGenColNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + paraMov = np.zeros(zoneNum, dtype = float) + custMov = np.zeros(zoneNum, dtype = float) + localMov = np.zeros(zoneNum, dtype = float) + idenImpMov = np.zeros(zoneNum, dtype = float) + + maxDist = 0 + index = [] + for i in range(0, zoneNum): + index.append(i) + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i == 2: + maxDist = float(row[distColNum]) + + if i > 1: + if row[archColNum] == "parallelization": + paraMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[archColNum] == "customization": + custMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[archColNum] == "locality": + localMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[archColNum] == "identity": + idenImpMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + else: + raise Exception("architectural principle is not giving parallelization or customization or locality or identity! The new type: " + row[archColNum]) + + # plt.figure() + plotdata = pd.DataFrame({ + "parallelization":paraMov, + "customization":custMov, + "locality":localMov, + "identity":idenImpMov + }, index = index + ) + plotdata.plot(kind = 'bar', stacked = True) + plt.xlabel("Zone decided by the max distance to goal") + plt.ylabel("Tranformation Generation Time") + plt.title("Tranformation Generation Time in Each Zone based on Architectural Principle") + plt.savefig(dirName + fileName + "/movGenTimeVSarchVarImpZoneZoneDist-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot convergence vs. iteration cnt, system block count, and routing complexity in 3d +def plotBudgets3d(dirName, subDirName): + newDirName = dirName + "/"+ subDirName + "/" + if os.path.exists(newDirName + "/figures"): + shutil.rmtree(newDirName + "/figures") + resultList = os.listdir(newDirName) + latBudgets = [] + powBudgets = [] + areaBudgets = [] + itrValues = [] + cntValues = [] + routingValues = [] + workloads = [] + for j, fileName in enumerate(resultList): + with open(newDirName + fileName + "/result_summary/FARSI_simple_run_0_1.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + for i, row in enumerate(resultReader): + if i == 1: + itrValues.append(int(row[columnNum(newDirName, fileName, "iteration cnt", "simple")])) + cntValues.append(int(row[columnNum(newDirName, fileName, "system block count", "simple")])) + routingValues.append(float(row[columnNum(newDirName, fileName, "routing complexity", "simple")])) + powBudgets.append(float(row[columnNum(newDirName, fileName, "power_budget", "simple")])) + areaBudgets.append(float(row[columnNum(newDirName, fileName, "area_budget", "simple")])) + lat = row[int(columnNum(newDirName, fileName, "latency_budget", "simple"))][:-1] + latDict = dict(item.split("=") for item in lat.split(";")) + if j == 0: + for k in range(0, len(latDict)): + latBudgets.append([]) + workloads.append(list(latDict.keys())[k]) + latList = list(latDict.values()) + for k in range(0, len(latList)): + latBudgets[k].append(float(latList[k])) + + m = ['o', 'x', '^', 's', 'd', '+', 'v', '<', '>'] + axis_font = {'size': '10'} + fontSize = 10 + os.mkdir(newDirName + "figures") + fig_budget_itr = plt.figure(figsize=(12, 12)) + plt.rc('font', **axis_font) + ax_itr = fig_budget_itr.add_subplot(projection='3d') + for i in range(0, len(latBudgets)): + img = ax_itr.scatter3D(powBudgets, areaBudgets, latBudgets[i], c=itrValues, cmap="bwr", marker=m[i], s=80, label='{0}'.format(workloads[i])) + for j in range(0, len(latBudgets[i])): + coordinate = str(itrValues[j]) + ax_itr.text(powBudgets[j], areaBudgets[j], latBudgets[i][j], '%s' % coordinate, size=fontSize) + break + ax_itr.set_xlabel("Power Budget") + ax_itr.set_ylabel("Area Budget") + ax_itr.set_zlabel("Latency Budget") + ax_itr.legend() + cbar_itr = fig_budget_itr.colorbar(img, aspect = 40) + cbar_itr.set_label("Number of Iterations", rotation = 270) + plt.title("{Power Budget, Area Budget, Latency Budget} VS Iteration Cnt: " + subDirName) + plt.tight_layout() + plt.savefig(newDirName + "figures/budgetVSitr-" + subDirName + ".png") + # plt.show() + plt.close('all') + + fig_budget_blkcnt = plt.figure(figsize=(12, 12)) + plt.rc('font', **axis_font) + ax_blkcnt = fig_budget_blkcnt.add_subplot(projection='3d') + for i in range(0, len(latBudgets)): + img = ax_blkcnt.scatter3D(powBudgets, areaBudgets, latBudgets[i], c=cntValues, cmap="bwr", marker=m[i], s=80, label='{0}'.format(workloads[i])) + for j in range(0, len(latBudgets[i])): + coordinate = str(cntValues[j]) + ax_blkcnt.text(powBudgets[j], areaBudgets[j], latBudgets[i][j], '%s' % coordinate, size=fontSize) + break + ax_blkcnt.set_xlabel("Power Budget") + ax_blkcnt.set_ylabel("Area Budget") + ax_blkcnt.set_zlabel("Latency Budget") + ax_blkcnt.legend() + cbar = fig_budget_blkcnt.colorbar(img, aspect=40) + cbar.set_label("System Block Count", rotation=270) + plt.title("{Power Budget, Area Budget, Latency Budget} VS System Block Count: " + subDirName) + plt.tight_layout() + plt.savefig(newDirName + "figures/budgetVSblkcnt-" + subDirName + ".png") + # plt.show() + plt.close('all') + + fig_budget_routing = plt.figure(figsize=(12, 12)) + plt.rc('font', **axis_font) + ax_routing = fig_budget_routing.add_subplot(projection='3d') + for i in range(0, len(latBudgets)): + img = ax_routing.scatter3D(powBudgets, areaBudgets, latBudgets[i], c=cntValues, cmap="bwr", marker=m[i], s=80, label='{0}'.format(workloads[i])) + for j in range(0, len(latBudgets[i])): + coordinate = str(routingValues[j]) + ax_routing.text(powBudgets[j], areaBudgets[j], latBudgets[i][j], '%s' % coordinate, size=fontSize) + break + ax_routing.set_xlabel("Power Budget") + ax_routing.set_ylabel("Area Budget") + ax_routing.set_zlabel("Latency Budget") + ax_routing.legend() + cbar = fig_budget_routing.colorbar(img, aspect=40) + cbar.set_label("System Block Count", rotation=270) + plt.title("{Power Budget, Area Budget, Latency Budget} VS System Block Count: " + subDirName) + plt.tight_layout() + plt.savefig(newDirName + "figures/budgetVSroutingComplexity-" + subDirName + ".png") + # plt.show() + plt.close('all') + +def get_experiment_dir_list(run_folder_name): + workload_set_folder_list = os.listdir(run_folder_name) + + experiment_full_addr_list = [] + # iterate and generate plots + for workload_set_folder in workload_set_folder_list: + # ignore irelevant files + if workload_set_folder in config_plotting.ignore_file_names: + continue + + # get experiment folder + workload_set_full_addr = os.path.join(run_folder_name,workload_set_folder) + folder_list = os.listdir(workload_set_full_addr) + for experiment_name_relative_addr in folder_list: + if experiment_name_relative_addr in config_plotting.ignore_file_names: + continue + experiment_full_addr_list.append(os.path.join(workload_set_full_addr, experiment_name_relative_addr)) + + return experiment_full_addr_list + + +def find_the_most_recent_directory(top_dir): + dirs = [os.path.join(top_dir, el) for el in os.listdir(top_dir)] + dirs = list(filter(os.path.isdir, dirs)) + dirs.sort(key=lambda x: os.path.getmtime(x), reverse=True) + return dirs + + +def get_experiment_full_file_addr_list(experiment_full_dir_list): + file_name = "result_summary/FARSI_simple_run_0_1.csv" + results = [] + for el in experiment_full_dir_list: + results.append(os.path.join(el, file_name)) + + return results + +######### RADHIKA PANDAS PLOTS ############ + +def grouped_barplot_varying_x(df, metric, metric_ylabel, varying_x, varying_x_labels, ax): + # [[bar heights, errs for varying_x1], [heights, errs for varying_x2]...] + grouped_stats_list = [] + for x in varying_x: + grouped_x = df.groupby([x]) + stats = grouped_x[metric].agg([np.mean, np.std]) + grouped_stats_list.append(stats) + + start_loc = 0 + bar_width = 0.15 + offset = 0 # Ying: original: 0.03 + # [[bar locations for varying_x1], [bar locs for varying_x2]...] + grouped_bar_locs_list = [] + for x in varying_x: + n_unique_varying_x = df[x].nunique() + bound = (n_unique_varying_x-1) * (bar_width+offset) + end_loc = start_loc+bound + bar_locs = np.linspace(start_loc, end_loc, n_unique_varying_x) + grouped_bar_locs_list.append(bar_locs) + start_loc = end_loc + 2*bar_width + + # print(grouped_bar_locs_list) # Ying: comment out for WTF + + color = ["red", "orange", "green"] + ctr = 0 + for x_i,x in enumerate(varying_x): + ax.bar( + grouped_bar_locs_list[x_i], + grouped_stats_list[x_i]["mean"], + width=bar_width, + # yerr=grouped_stats_list[x_i]["std"], + color = color, + # label=metric_ylabel + ) + ctr +=1 + cat_xticks = [] + cat_xticklabels = [] + + xticks = [] + xticklabels = [] + """ + Ying: add the following lines to get rid of the numbers on the x-axis + """ + for i in range(0, 9): + xticklabels.append(' ') + """ + Ying: adding finished + """ + for x_i,x in enumerate(varying_x): + # xticklabels.extend(grouped_stats_list[x_i].index.astype(float)) # Ying: comment out and leave them for legends + xticks.extend(grouped_bar_locs_list[x_i]) + + xticks_cat = grouped_bar_locs_list[x_i] + xticks_cat_start = xticks_cat[0] + xticks_cat_end = xticks_cat[-1] + xticks_cat_mid = xticks_cat_start + (xticks_cat_end - xticks_cat_start) / 2 + + cat_xticks.append(xticks_cat_mid) + cat_xticklabels.append(varying_x_labels[x_i]) # Ying: the original code was: "\n\n" + varying_x_labels[x_i]) + + fontSize = 20 + axis_font = {'size': '20'} + xticks.extend(cat_xticks) + xticklabels.extend(cat_xticklabels) + + ax.set_ylabel(metric_ylabel, fontsize=fontSize) # Ying: add fontsize + #ax.set_xlabel(xlabel) + ax.set_xticks(xticks) + # ax.legend(loc="upper center") # Ying: test the way to add legends + ax.set_xticklabels(xticklabels, fontsize=fontSize) # Ying: add fontsize + + return ax + + +def pandas_plots(input_dir_names, all_results_files, metric): + df = pd.concat((pd.read_csv(f) for f in all_results_files)) + + #df = raw_df.loc[(raw_df["move validity"] == True)] + #df["dist_to_goal_non_cost_delta"] = df["ref_des_dist_to_goal_non_cost"] - df["dist_to_goal_non_cost"] + #df["local_traffic_ratio"] = np.divide(df["local_total_traffic"], df["local_total_traffic"] + df["global_total_traffic"]) + #metric = "global_memory_avg_freq" + #metric_ylabel = "Global memory avg freq" + #metric = "local_traffic_ratio" + + # metric_ylabel = metric #"Local traffic ratio" # Ying: replaced the underscores with whitespaces; the new code is the following line + metric_ylabel = ' '.join(metric.split('_')) + """ + Ying: add the following lines just in case we need them + """ + if metric == "ip_cnt": + metric_ylabel = "IP Count" + elif metric == "local_bus_cnt": + metric_ylabel = "NoC Count" + elif metric == "local_bus_avg_freq": + metric_ylabel = "NoC Avg Frequency" + elif metric == "local_channel_count_per_bus_coeff_var": + metric_ylabel = "Link Variation" + elif metric == "local_memory_area_coeff_var": + metric_ylabel = "Memory Aria Variation" + elif metric == "local_bus_freq_coeff_var": + metric_ylabel = "NoC Frequency Variation" + elif metric == "local_total_traffic_reuse_no_read_in_bytes_per_cluster_avg": + metric_ylabel = "Memory Reuse" + elif metric == "avg_accel_parallelism": + metric_ylabel = "Average Accelerator Parallelism" + elif metric == "local_bus_avg_actual_bandwidth": + metric_ylabel = "Link Bandwidth" + """ + Ying: adding finished + """ + + varying_x = [ + "budget_scaling_latency", + "budget_scaling_power", + "budget_scaling_area", + ] + varying_x_labels = [ + "latency", + "power", + "area", + ] + + axis_font = {'size': "20"} + fig, ax = plt.subplots(1, figsize=(7, 7)) # Ying: add the figure size + grouped_barplot_varying_x( + df, + metric, metric_ylabel, + varying_x, varying_x_labels, + ax + ) + plt.rc('font', **axis_font) + plt.tight_layout() + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "panda_study/") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + fig.savefig(os.path.join(output_dir, metric+".png")) + #plt.show() + plt.close('all') + + + + #fig.tight_layout(rect=[0, 0, 1, 1]) + #fig.savefig("/Users/behzadboro/Project_FARSI_dir/Project_FARSI_with_channels/data_collection/data/simple_run/27_point_coverage_zad/bleh.png") + #plt.close(fig) + +def get_budget_optimality_advanced(input_dir_names,all_result_files, summary_res_column_name_number): + def points_exceed_one_of_the_budgets(point, base_budget, budget_scaling_to_consider): + power = point[0] + area = point[1] + if power > base_budgets["power"] * budget_scale_to_consider and area > base_budgets[ + "area"] * budget_scale_to_consider: + return True + return False + + workload_results = {} + + system_char_to_keep_track_of = {"memory_total_area", "local_memory_total_area","pe_total_area", "ip_cnt", "ips_total_area"} + + # budget scaling to consider + budget_scale_to_consider = .5 + # get budget first + base_budgets = {} + for file in all_result_files: + with open(file, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + for i, row in enumerate(resultReader): + if i == 1: + if float(row[summary_res_column_name_number["budget_scaling_latency"]]) == 1 and\ + float(row[summary_res_column_name_number["budget_scaling_power"]]) == 1 and \ + float(row[summary_res_column_name_number["budget_scaling_area"]]) == 1: + base_budgets["power"] = float(row[summary_res_column_name_number["power_budget"]]) + base_budgets["area"] = float(row[summary_res_column_name_number["area_budget"]]) + break + + + for file in all_result_files: + with open(file, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + for i, row in enumerate(resultReader): + if i == 1: + workload_set_name = row[summary_res_column_name_number["workload_set"]] + if workload_set_name not in workload_results.keys(): + workload_results[workload_set_name] = [] + latency = ((row[summary_res_column_name_number["latency"]].split(";"))[0].split("="))[1] + latency_budget = ((row[summary_res_column_name_number["latency_budget"]].split(";"))[0].split("="))[1] + if float(latency) > float(latency_budget): + continue + + #workload_results[workload_set_name].append((float(power),float(area), float(system_complexity))) + + area= float(row[summary_res_column_name_number["area"]]) + power = float(row[summary_res_column_name_number["power"]]) + system_char = {} + for el in system_char_to_keep_track_of: + system_char[el] = float(row[summary_res_column_name_number[el]]) + point_system_char = {(power, area): system_char} + workload_results[workload_set_name].append(point_system_char) + + workload_pareto_points = {} + for workload, points_ in workload_results.items(): + points = [list(el.keys())[0] for el in points_] + pareto_points= find_pareto_points(list(set(points))) + workload_pareto_points[workload] = [] + for point in pareto_points: + keys = [list(el.keys())[0] for el in workload_results[workload]] + idx = keys.index(point) + workload_pareto_points[workload].append({point:(workload_results[workload])[idx]}) + + + """" + # combine the results + combined_area_power = [] + for results_combined in itertools.product(*list(workload_pareto_points.values())): + combined_power_area_tuple = [0,0] + for el in results_combined: + combined_power_area_tuple[0] += el[0] + combined_power_area_tuple[1] += el[1] + combined_area_power.append(combined_power_area_tuple[:]) + """ + + + all_points_in_isolation = [] + all_points_cross_workloads = [] + + workload_in_isolation = {} + for workload, points in workload_results.items(): + #points = [list(el.keys())[0] for el in points_] + if "cava" in workload and "audio" in workload and "edge_detection" in workload: + for point in points: + all_points_cross_workloads.append(point) + else: + workload_in_isolation[workload] = points + + + ctr = 0 + workload_in_isolation_pareto = {} + for workload, points_ in workload_in_isolation.items(): + workload_in_isolation_pareto[workload] = [] + points = [list(el.keys())[0] for el in points_] + pareto_points = find_pareto_points(list(set(points))) + for point in pareto_points: + keys = [list(el.keys())[0] for el in workload_in_isolation[workload]] + idx = keys.index(point) + workload_in_isolation_pareto[workload].append({point:(workload_in_isolation[workload])[idx]}) + + + + combined_area_power_in_isolation= [] + s = time.time() + + workload_in_isolation_pareto_only_area_power = {} + for key, val in workload_in_isolation_pareto.items(): + workload_in_isolation_pareto_only_area_power[key] = [] + for el in val: + for k,v in el.items(): + workload_in_isolation_pareto_only_area_power[key].append(k) + + + for results_combined in itertools.product(*list(workload_in_isolation_pareto_only_area_power.values())): + # add up all the charactersitics + system_chars = {} + for el in system_char_to_keep_track_of: + system_chars[el] = 0 + + # add up area,power + combined_power_area_tuple = [0,0] + for el in results_combined: + combined_power_area_tuple[0] += el[0] + combined_power_area_tuple[1] += el[1] + + for point in results_combined: + keys = [list(point_.keys())[0] for point_ in workload_in_isolation_pareto[workload]] + idx = keys.index(point) + for el in system_char.keys(): + system_char[el] += workload_in_isolation + + system_chars[workload].append({point: (workload_in_isolation[workload])[idx]}) + + + #combined_area_power_in_isolation.append((combined_power_area_tuple[0],combined_power_area_tuple[1], combined_power_area_tuple[2])) + combined_area_power_in_isolation.append((combined_power_area_tuple[0],combined_power_area_tuple[1])) + + combined_area_power_in_isolation_filtered = [] + for point in combined_area_power_in_isolation: + if not points_exceed_one_of_the_budgets(point, base_budgets, budget_scale_to_consider): + combined_area_power_in_isolation_filtered.append(point) + combined_area_power_pareto = find_pareto_points(list(set(combined_area_power_in_isolation_filtered))) + + + all_points_cross_workloads_filtered = [] + for point in all_points_cross_workloads: + if not points_exceed_one_of_the_budgets(point, base_budgets, budget_scale_to_consider): + all_points_cross_workloads_filtered.append(point) + all_points_cross_workloads_area_power_pareto = find_pareto_points(list(set(all_points_cross_workloads_filtered))) + + + # prepare for plotting and plot + fig = plt.figure(figsize=(12, 12)) + #plt.rc('font', **axis_font) + ax = fig.add_subplot(111) + fontSize = 20 + + x_values = [el[0] for el in combined_area_power_in_isolation_filtered] + y_values = [el[1] for el in combined_area_power_in_isolation_filtered] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="isolated design methodology",marker=".") + + + # plt.tight_layout() + x_values = [el[0] for el in combined_area_power_pareto] + y_values = [el[1] for el in combined_area_power_pareto] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="isolated design methodology pareto front",marker="x") + + + x_values = [el[0] for el in all_points_cross_workloads_filtered] + y_values = [el[1] for el in all_points_cross_workloads_filtered] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="cross workload methodology",marker="8") + ax.legend(loc="upper right") # bbox_to_anchor=(1, 1), loc="upper left") + + x_values = [el[0] for el in all_points_cross_workloads_area_power_pareto] + y_values = [el[1] for el in all_points_cross_workloads_area_power_pareto] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="cross workload pareto front",marker="o") + #for idx,_ in enumeate(x_values): + # plt.text(x_values[idx], y_values[idx], s=) + + #plt.text([ for el in x) + ax.legend(loc="upper right") # bbox_to_anchor=(1, 1), loc="upper left") + + + ax.set_xlabel("power", fontsize=fontSize) + ax.set_ylabel("area", fontsize=fontSize) + plt.tight_layout() + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "budget_optimality/") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + fig.savefig(os.path.join(output_dir, "budget_optimality.png")) + #plt.show() + plt.close('all') + + +def get_budget_optimality(input_dir_names,all_result_files, summary_res_column_name_number): + + def get_equivalent_total(charac): + if charac == "ips_avg_freq": + return "ip_cnt" + elif charac == "avg_accel_parallelism": + return "ip_cnt" + elif charac in ["local_memory_avg_freq"]: + return "local_mem_cnt" + elif charac in ["local_bus_avg_actual_bandwidth", "local_bus_avg_theoretical_bandwidth", "local_bus_avg_bus_width", "avg_freq"]: + return "local_bus_count" + else: + return charac + + def find_sys_char(power,area, results_with_sys_char): + for vals in results_with_sys_char: + for power_area , sys_chars in vals.items(): + power_ = power_area[0] + area_ = power_area[1] + if power == power_ and area_ == area: + return sys_chars + + def points_exceed_one_of_the_budgets(point, base_budget, budget_scaling_to_consider): + power = point[0] + area = point[1] + if power > base_budgets["power"] * budget_scale_to_consider and area > base_budgets[ + "area"] * budget_scale_to_consider: + return True + return False + + workload_results = {} + results_with_sys_char = [] + + system_char_to_keep_track_of = {"memory_total_area", "local_memory_total_area","pe_total_area", "ip_cnt","ips_total_area", "ips_avg_freq", "local_mem_cnt", + "local_bus_avg_actual_bandwidth", "local_bus_avg_theoretical_bandwidth", "local_memory_avg_freq", "local_bus_count", "local_bus_avg_bus_width", "avg_freq", "local_total_traffic", + "global_total_traffic","local_memory_avg_freq", "global_memory_avg_freq", "gpps_total_area", "avg_gpp_parallelism", "avg_accel_parallelism"} + #system_char_to_show = ["local_memory_total_area"] + #system_char_to_show = ["avg_accel_parallelism"] + #system_char_to_show = ["avg_gpp_parallelism"] + #system_char_to_show = ["local_bus_avg_actual_bandwidth"] + #system_char_to_show = ["avg_freq"] # really is buses avg freq + #system_char_to_show = ["local_memory_avg_freq"] # really is buses avg freq + #system_char_to_show = ["ips_avg_freq"] + #system_char_to_show = ["gpps_total_area"] + #system_char_to_show = ["local_bus_avg_bus_width"] + system_char_to_show = ["local_memory_avg_freq"] + #system_char_to_show = ["ips_total_area"] + #system_char_to_show = ["ip_cnt"] + #system_char_to_show = ["local_mem_cnt"] + #system_char_to_show = ["global_memory_avg_freq"] + #system_char_to_show = ["local_bus_avg_theoretical_bandwidth"] + #system_char_to_show = ["local_memory_avg_freq"] + #system_char_to_show = ["local_total_traffic"] + #system_char_to_show = ["global_total_traffic"] + + # budget scaling to consider + budget_scale_to_consider = .5 + # get budget first + base_budgets = {} + for file in all_result_files: + with open(file, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + for i, row in enumerate(resultReader): + if i == 1: + print("file"+file) + if float(row[summary_res_column_name_number["budget_scaling_latency"]]) == 1 and\ + float(row[summary_res_column_name_number["budget_scaling_power"]]) == 1 and \ + float(row[summary_res_column_name_number["budget_scaling_area"]]) == 1: + base_budgets["power"] = float(row[summary_res_column_name_number["power_budget"]]) + base_budgets["area"] = float(row[summary_res_column_name_number["area_budget"]]) + break + + + for file in all_result_files: + with open(file, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + for i, row in enumerate(resultReader): + if i == 1: + workload_set_name = row[summary_res_column_name_number["workload_set"]] + if workload_set_name not in workload_results.keys(): + workload_results[workload_set_name] = [] + latency = ((row[summary_res_column_name_number["latency"]].split(";"))[0].split("="))[1] + latency_budget = ((row[summary_res_column_name_number["latency_budget"]].split(";"))[0].split("="))[1] + if float(latency) > float(latency_budget): + continue + + power = float(row[summary_res_column_name_number["power"]]) + area = float(row[summary_res_column_name_number["area"]]) + + system_complexity = row[summary_res_column_name_number["ip_cnt"]] # + row[summary_res_column_name_number["gpp_cnt"]] + #workload_results[workload_set_name].append((float(power),float(area), float(system_complexity))) + workload_results[workload_set_name].append((power,area)) + system_char = {} + for el in system_char_to_keep_track_of: + #if "latency" == el: + # system_char[el] = row[summary_res_column_name_number[el]] + #else: + system_char[el] = float(row[summary_res_column_name_number[el]]) + point_system_char = {(power, area): system_char} + results_with_sys_char.append(point_system_char) + + + + workload_pareto_points = {} + for workload, points in workload_results.items(): + workload_pareto_points[workload] = find_pareto_points(list(set(points))) + + """" + # combine the results + combined_area_power = [] + for results_combined in itertools.product(*list(workload_pareto_points.values())): + combined_power_area_tuple = [0,0] + for el in results_combined: + combined_power_area_tuple[0] += el[0] + combined_power_area_tuple[1] += el[1] + combined_area_power.append(combined_power_area_tuple[:]) + """ + + + all_points_in_isolation = [] + all_points_cross_workloads = [] + + workload_in_isolation = {} + for workload, points in workload_results.items(): + if "cava" in workload and "audio" in workload and "edge_detection" in workload: + for point in points: + all_points_cross_workloads.append(point) + else: + workload_in_isolation[workload] = points + + + ctr = 0 + workload_in_isolation_pareto = {} + for workload, points in workload_in_isolation.items(): + optimal_points = find_pareto_points(list(set(points))) + workload_in_isolation_pareto[workload] = optimal_points + + + combined_area_power_in_isolation= [] + combined_area_power_in_isolation_with_sys_char = [] + + s = time.time() + for results_combined in itertools.product(*list(workload_in_isolation_pareto.values())): + # add up all the charactersitics + combined_sys_chars = {} + for el in system_char_to_keep_track_of: + combined_sys_chars[el] = (0,0) + + # add up area,power + combined_power_area_tuple = [0,0] + for el in results_combined: + combined_power_area_tuple[0] += el[0] + combined_power_area_tuple[1] += el[1] + + sys_char = find_sys_char(el[0], el[1], results_with_sys_char) + for el_,val_ in sys_char.items(): + if "avg" in el_: + total = sys_char[get_equivalent_total(el_)] + coeff = total + else: + coeff = 1 + #if "latency" in el_: + # combined_sys_chars[el_] = (combined_sys_chars[el_][0]+coeff, str(combined_sys_chars[el_][1])+"_"+val_) + #else: + combined_sys_chars[el_] = (combined_sys_chars[el_][0]+coeff, combined_sys_chars[el_][1]+coeff*float(val_)) + + for key, values in combined_sys_chars.items(): + if "avg" in key: + combined_sys_chars[key] = values[1] /max(values[0],.00000000000000000000000000000001) + else: + combined_sys_chars[key] = values[1] + + #combined_area_power_in_isolation.append((combined_power_area_tuple[0],combined_power_area_tuple[1], combined_power_area_tuple[2])) + combined_area_power_in_isolation.append((combined_power_area_tuple[0],combined_power_area_tuple[1])) + combined_area_power_in_isolation_with_sys_char.append({(combined_power_area_tuple[0],combined_power_area_tuple[1]): combined_sys_chars}) + + #if len(combined_area_power_in_isolation)%100000 == 0: + # print("time passed is" + str(time.time()-s)) + + combined_area_power_in_isolation_filtered = [] + for point in combined_area_power_in_isolation: + if not points_exceed_one_of_the_budgets(point, base_budgets, budget_scale_to_consider): + combined_area_power_in_isolation_filtered.append(point) + combined_area_power_pareto = find_pareto_points(list(set(combined_area_power_in_isolation_filtered))) + + + all_points_cross_workloads_filtered = [] + for point in all_points_cross_workloads: + if not points_exceed_one_of_the_budgets(point, base_budgets, budget_scale_to_consider): + all_points_cross_workloads_filtered.append(point) + all_points_cross_workloads_area_power_pareto = find_pareto_points(list(set(all_points_cross_workloads_filtered))) + + + # prepare for plotting and plot + fig = plt.figure(figsize=(12, 12)) + #plt.rc('font', **axis_font) + ax = fig.add_subplot(111) + fontSize = 20 + + x_values = [el[0] for el in combined_area_power_in_isolation_filtered] + y_values = [el[1] for el in combined_area_power_in_isolation_filtered] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="isolated design methodology",marker=".") + + + # plt.tight_layout() + x_values = [el[0] for el in combined_area_power_pareto] + y_values = [el[1] for el in combined_area_power_pareto] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="isolated design methodology pareto front",marker="x") + for idx, _ in enumerate(x_values) : + power= x_values[idx] + area = y_values[idx] + sys_char = find_sys_char(power, area, combined_area_power_in_isolation_with_sys_char) + value_to_show = 0 + value_to_show = sys_char[system_char_to_show[0]] + #for el in system_char_to_show: + # value_to_show += sys_char[el] + + #if system_char_to_show[0] == "latency": + # value_in_scientific_notation = value_to_show + #else: + #value_to_show = sys_char["local_total_traffic"]/(sys_char["local_memory_total_area"]*4*10**12) + value_in_scientific_notation = "{:.2e}".format(value_to_show) + #if idx ==0: + plt.text(power,area, value_in_scientific_notation) + + + x_values = [el[0] for el in all_points_cross_workloads_filtered] + y_values = [el[1] for el in all_points_cross_workloads_filtered] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="cross workload methodology",marker="8") + ax.legend(loc="upper right") # bbox_to_anchor=(1, 1), loc="upper left") + for idx, _ in enumerate(x_values) : + power= x_values[idx] + area = y_values[idx] + sys_char = find_sys_char(power, area, results_with_sys_char) + + value_to_show = 0 + value_to_show = sys_char[system_char_to_show[0]] + #for el in system_char_to_show: + # value_to_show += sys_char[el] + + #if system_char_to_show[0] == "latency": + # value_in_scientific_notation = value_to_show + #else: + #value_to_show = sys_char["local_total_traffic"]/(sys_char["local_memory_total_area"]*4*10**12) + value_in_scientific_notation = "{:.2e}".format(value_to_show) + plt.text(power,area, value_in_scientific_notation) + plt.text(power,area, value_in_scientific_notation) + + + + x_values = [el[0] for el in all_points_cross_workloads_area_power_pareto] + y_values = [el[1] for el in all_points_cross_workloads_area_power_pareto] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="cross workload pareto front",marker="o") + ax.legend(loc="upper right") # bbox_to_anchor=(1, 1), loc="upper left") + + for idx, _ in enumerate(x_values) : + power= x_values[idx] + area = y_values[idx] + sys_char = find_sys_char(power, area, results_with_sys_char) + + value_to_show = sys_char[system_char_to_show[0]] + + #if system_char_to_show[0] == "latency": + # value_in_scientific_notation = value_to_show + #else: + #value_to_show = sys_char["local_total_traffic"]/sys_char["local_memory_total_area"] + #value_to_show = sys_char["local_total_traffic"]/(sys_char["local_memory_total_area"]*4*10**12) + value_in_scientific_notation = "{:.2e}".format(value_to_show) + plt.text(power,area, value_in_scientific_notation) + #plt.text(power,area, sys_char[system_char_to_show[0]]) + + ax.set_xlabel("power", fontsize=fontSize) + ax.set_ylabel("area", fontsize=fontSize) + plt.tight_layout() + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "budget_optimality/") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + ax.set_title(system_char_to_show[0] +" for FARSI vs in isolation") + #ax.set_title("memory_reuse for FARSI vs in isolation") + fig.savefig(os.path.join(output_dir, system_char_to_show[0] + "_budget_optimality.png")) + + #plt.show() + plt.close('all') + + + + +def find_pareto_points(points): + efficients = is_pareto_efficient_dumb(np.array(points)) + pareto_points_array = [points[idx] for idx, el in enumerate(efficients) if el] + + return pareto_points_array + + pareto_points = [] + for el in pareto_points_array: + list_ = [] + for el_ in el: + list.append(el) + pareto_points.append(list_) + + return pareto_points + + + + +def is_pareto_efficient_dumb(costs): + is_efficient = np.ones(costs.shape[0], dtype = bool) + for i, c in enumerate(costs): + is_efficient[i] = np.all(np.any(costs[:i]>c, axis=1)) and np.all(np.any(costs[i+1:]>c, axis=1)) + return is_efficient + + + + +########################################### + +# the main function. comment out the plots if you do not need them +if __name__ == "__main__": + # populate parameters + run_folder_name = config_plotting.run_folder_name + if config_plotting.run_folder_name == "": + run_folder_name = find_the_most_recent_directory(config_plotting.top_result_folder)[0] + + zoneNum = config_plotting.zoneNum + # get all the experiments under the run folder + print(run_folder_name) + experiment_full_addr_list = get_experiment_dir_list(run_folder_name) + + # according to the plot type, plot + all_res_column_name_number = get_column_name_number(experiment_full_addr_list[0], "all") + all_results_files = get_experiment_full_file_addr_list(experiment_full_addr_list) + summary_res_column_name_number = get_column_name_number(experiment_full_addr_list[0], "simple") + case_studies = {} + case_studies["bandwidth_analysis"] = ["local_bus_avg_theoretical_bandwidth", + "local_bus_max_actual_bandwidth", + "local_bus_avg_actual_bandwidth", + "system_bus_avg_theoretical_bandwidth", + "system_bus_max_actual_bandwidth", + "system_bus_avg_actual_bandwidth", + "local_channel_avg_actual_bandwidth", + "local_channel_max_actual_bandwidth" + ] + + + case_studies["freq_analysis"] = [ + "global_memory_avg_freq", "local_memory_avg_freq", "local_bus_avg_freq",] + + case_studies["bus_width_analysis"] = [ + "global_memory_avg_bus_width","local_memory_avg_bus_width","local_bus_avg_bus_width"] + + case_studies["traffic_analysis"] = ["global_total_traffic", "local_total_traffic", + "local_memory_traffic_per_mem_avg", + "locality_in_bytes", + "local_memory_traffic_per_mem_avg", + "local_bus_traffic_avg", + ] + + + case_studies["local_mem_re_use"] =[ + "local_total_traffic_reuse_no_read_ratio", + "local_total_traffic_reuse_no_read_in_bytes", + "local_total_traffic_reuse_no_read_in_size", + "local_total_traffic_reuse_with_read_ratio", + "local_total_traffic_reuse_with_read_in_bytes", + "local_total_traffic_reuse_with_read_in_size", + "local_total_traffic_reuse_no_read_in_bytes_per_cluster_avg", + "local_total_traffic_reuse_no_read_in_size_per_cluster_avg", + "local_total_traffic_reuse_with_read_in_bytes_per_cluster_avg", + "local_total_traffic_reuse_with_read_in_size_per_cluster_avg" + ] + + case_studies["global_mem_re_use"] =[ + "global_total_traffic_reuse_no_read_ratio", + "global_total_traffic_reuse_with_read_ratio", + "global_total_traffic_reuse_with_read_in_bytes", + "global_total_traffic_reuse_with_read_in_size", + "global_total_traffic_reuse_no_read_in_bytes", + "global_total_traffic_reuse_no_read_in_size", + ] + + + case_studies["area_analysis"] = ["global_memory_total_area", "local_memory_total_area", "ips_total_area", + "gpps_total_area", + ] + case_studies["area_in_bytes_analysis"] = ["global_memory_total_bytes", "local_memory_total_bytes", "local_memory_bytes_avg" + ] + + case_studies["accel_paral_analysis"] = ["ip_cnt","max_accel_parallelism", "avg_accel_parallelism", + "gpp_cnt", "max_gpp_parallelism", "avg_gpp_parallelism"] + case_studies["system_complexity"] = ["system block count", "routing complexity", "system PE count", + "local_mem_cnt", "local_bus_cnt","local_channel_count_per_bus_avg", "channel_cnt", + "loop_itr_ratio_avg", + ] # , "channel_cnt"] + + + case_studies["heterogeneity_var_system_compleixty"] = [ + "local_channel_count_per_bus_coeff_var", + "loop_itr_ratio_var", + # "cluster_pe_cnt_coeff_var" + ] + + case_studies["heterogeneity_std_system_compleixty"] = [ + "local_channel_count_per_bus_std", + "loop_itr_ratio_std" #Ying: comment out: , "cluster_pe_cnt_std" + ] + + + """ + + case_studies["speedup"] = [ + "customization_first_speed_up_avg", + "customization_second_speed_up_avg", + "parallelism_first_speed_up_avg", + "parallelism_second_speed_up_avg", + "interference_degradation_avg", + "customization_first_speed_up_full_system", + "customization_second_speed_up_full_system", + "parallelism_first_speed_up_full_system", + "parallelism_second_speed_up_full_system", + ] + """ + + + + case_studies["heterogenity_area"] = [ + "local_memory_area_coeff_var", + "ips_area_coeff_var", + "pes_area_coeff_var", + + ] + + + case_studies["heterogenity_std_re_use"] = [ + "local_total_traffic_reuse_no_read_in_bytes_per_cluster_std", + "local_total_traffic_reuse_no_read_in_size_per_cluster_std", + "local_total_traffic_reuse_with_read_in_bytes_per_cluster_std", + "local_total_traffic_reuse_with_read_in_size_per_cluster_std", + ] + + case_studies["heterogenity_var_re_use"] = [ + "local_total_traffic_reuse_no_read_in_bytes_per_cluster_var", + "local_total_traffic_reuse_no_read_in_size_per_cluster_var", + "local_total_traffic_reuse_with_read_in_bytes_per_cluster_var", + "local_total_traffic_reuse_with_read_in_size_per_cluster_var", + ] + + case_studies["heterogenity_var_freq"] =[ + "local_bus_freq_coeff_var", + "local_memory_freq_coeff_var", + "ips_freq_coeff_var", + "pes_freq_coeff_var"] + + case_studies["heterogenity_std_freq"] =[ + "local_memory_freq_std", + "local_bus_freq_std", +] + + + + case_studies["heterogenity_std_bus_width"] =[ + "local_memory_bus_width_std", + "local_bus_bus_width_std", + ] + + case_studies["heterogenity_var_bus_width"] =[ + "local_memory_bus_width_coeff_var", + "local_bus_bus_width_coeff_var", + ] + + + + + case_studies["heterogenity_std_bandwidth"]=[ + "local_bus_actual_bandwidth_std", + "local_channel_actual_bandwidth_std"] + + case_studies["heterogenity_var_bandwidth"]=[ + "local_bus_actual_bandwidth_coeff_var", + "local_channel_actual_bandwidth_coeff_var"] + + + + case_studies["heterogenity_std_traffic"] =[ + "local_memory_bytes_std", + "local_memory_traffic_per_mem_coeff_var", + "local_bus_traffic_coeff_var", + ] + + + case_studies["heterogenity_var_traffic"] =[ + "local_memory_bytes_coeff_var", + "local_memory_traffic_per_mem_coeff_var", + "local_bus_traffic_coeff_var", + ] + + + + + if "budget_optimality" in config_plotting.plot_list: + #get_budget_optimality_advanced(experiment_full_addr_list, all_results_files, summary_res_column_name_number) + get_budget_optimality(experiment_full_addr_list, all_results_files, summary_res_column_name_number) + + if "cross_workloads" in config_plotting.plot_list: # Ying: from for_paper/workload_awareness + # get column orders (assumption is that the column order doesn't change between experiments) + # plot_convergence_cross_workloads(experiment_full_addr_list, all_res_column_name_number) + column_column_value_experiment_frequency_dict = plot_codesign_nav_breakdown_cross_workload(experiment_full_addr_list, all_res_column_name_number) + + for key, val in case_studies.items(): + case_study = {key:val} + # plot_system_implication_analysis(experiment_full_addr_list, summary_res_column_name_number, case_study) # Ying: comment out because of "KeyError: 'cluster_pe_cnt_coeff_var'" + # plot_co_design_nav_breakdown_post_processing(experiment_full_addr_list, column_column_value_experiment_frequency_dict) + # plot_codesign_rate_efficacy_cross_workloads_updated(experiment_full_addr_list, all_res_column_name_number) + + if "single_workload" in config_plotting.plot_list: + #single workload + plot_codesign_progression_per_workloads(experiment_full_addr_list, all_res_column_name_number) + _ = plot_codesign_nav_breakdown_per_workload(experiment_full_addr_list, all_res_column_name_number) + plot_convergence_per_workloads(experiment_full_addr_list, all_res_column_name_number) + plot_convergence_vs_time(experiment_full_addr_list, all_res_column_name_number) + + if "plot_3d" in config_plotting.plot_list: + plot_3d(experiment_full_addr_list, summary_res_column_name_number) + + if "pandas_plots" in config_plotting.plot_list: # Ying: from scaling_of_1_2_4_07-31 + #pandas_case_studies = {} + case_studies["system_complexity"] = ["system block count", "routing complexity", "system PE count", + "local_mem_cnt", "local_bus_cnt" , "channel_cnt", "ip_cnt", "gpp_cnt"] + + case_studies["pe_parallelism"] = ["max_accel_parallelism", "avg_accel_parallelism", "avg_gpp_parallelism", "max_gpp_parallelism"] + + case_studies["ip_frequency"] = ["ips_avg_freq", "gpps_avg_freq", "ips_freq_std", "pes_freq_std", + "ips_freq_coeff_var", "pes_freq_coeff_var"] + + case_studies["pe_area"] = ["ips_total_area", "gpps_total_area", "ips_area_std", "pes_area_std", + "ips_area_coeff_var", "pes_area_coeff_var"] + + case_studies["mem_frequency"] = ["local_memory_avg_freq", "global_memory_avg_freq", + "local_memory_freq_std","local_memory_freq_coeff_var"] + + case_studies["mem_area"] = ["local_memory_total_area", "global_memory_total_area", "local_memory_area_std", + "local_memory_area_coeff_var"] + + case_studies["traffic"] = ["local_total_traffic", "global_total_traffic"] + + + case_studies["bus_width"] = ["local_bus_avg_bus_width", + "system_bus_avg_bus_width"] + + + case_studies["bus_bandwidth"] = ["local_bus_avg_actual_bandwidth", "system_bus_avg_actual_bandwidth", + "local_bus_avg_theoretical_bandwidth", "system_bus_avg_theoretical_bandwidth", + "local_bus_max_actual_bandwidth", "system_bus_max_actual_bandwidth"] + + + + for case_study_name, metrics in case_studies.items(): + for metric in metrics: + pandas_plots(experiment_full_addr_list, all_results_files, metric) + + # get the the workload_set folder + # each workload_set has a bunch of experiments underneath it + workload_set_folder_list = os.listdir(run_folder_name) + + # iterate and generate plots + for workload_set_folder in workload_set_folder_list: + # ignore irelevant files + if workload_set_folder in config_plotting.ignore_file_names: + continue + + # start plotting + #plotBudgets3d(run_folder_name, workload_set_folder) + + + """ + # get experiment folder + workload_set_full_addr = os.path.join(run_folder_name,workload_set_folder) + folder_list = os.listdir(workload_set_full_addr) + for experiment_name_relative_addr in folder_list: + print(experiment_name_relative_addr) + if experiment_name_relative_addr in config_plotting.ignore_file_names: + continue + experiment_full_addr = os.path.join(workload_set_full_addr, experiment_name_relative_addr) + + all_res_column_name_number = get_column_name_number(experiment_full_addr, "all") + summary_res_column_name_number = get_column_name_number(experiment_full_addr, "simple") + + workload_set_full_addr +="/" # this is because you didn't use join + commcompColNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "comm_comp", "all") + trueNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "move validity", "all") + optColNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "high level optimization name", "all") + archColNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "architectural principle", "all") + sysBlkNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "system block count", "all") + simColNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "simulation time", "all") + movGenColNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "transformation generation time", "all") + movColNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "move name", "all") + itrNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "iteration cnt", "all") + distColNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "dist_to_goal_non_cost", "all") + refDistColNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "ref_des_dist_to_goal_non_cost", "all") + latNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "latency", "all") + powNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "power", "all") + areaNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "area", "all") + + # comment or uncomment the following functions for your plottings + plotDistToGoalVSitr([experiment_full_addr], all_res_column_name_number) + plotCommCompAll(workload_set_full_addr, experiment_name_relative_addr, all_res_column_name_number) + plothighLevelOptAll(workload_set_full_addr, experiment_name_relative_addr, all_res_column_name_number) + plotArchVarImpAll(workload_set_full_addr, experiment_name_relative_addr, archColNum, trueNum) + plotSimTimeVSblk(workload_set_full_addr, experiment_name_relative_addr, sysBlkNum, simColNum, trueNum) + plotMoveGenTimeVSblk(workload_set_full_addr, experiment_name_relative_addr, sysBlkNum, movGenColNum, trueNum) + plotRefDistToGoalVSitr(workload_set_full_addr, experiment_name_relative_addr, itrNum, refDistColNum, trueNum) + plotSimTimeVSmoveNameZoneDist(workload_set_full_addr, experiment_name_relative_addr, zoneNum, movColNum, distColNum, simColNum, trueNum) + plotMovGenTimeVSmoveNameZoneDist(workload_set_full_addr, experiment_name_relative_addr, zoneNum, movColNum, distColNum, movGenColNum, trueNum) + plotSimTimeVScommCompZoneDist(workload_set_full_addr, experiment_name_relative_addr, zoneNum, commcompColNum, distColNum, simColNum, trueNum) + plotMovGenTimeVScommCompZoneDist(workload_set_full_addr, experiment_name_relative_addr, zoneNum, commcompColNum, distColNum, movGenColNum, trueNum) + plotSimTimeVShighLevelOptZoneDist(workload_set_full_addr, experiment_name_relative_addr, zoneNum, optColNum, distColNum, simColNum, trueNum) + plotMovGenTimeVShighLevelOptZoneDist(workload_set_full_addr, experiment_name_relative_addr, zoneNum, optColNum, distColNum, movGenColNum, trueNum) + plotSimTimeVSarchVarImpZoneDist(workload_set_full_addr, experiment_name_relative_addr, zoneNum, archColNum, distColNum, simColNum, trueNum) + plotMovGenTimeVSarchVarImpZoneDist(workload_set_full_addr, experiment_name_relative_addr, zoneNum, archColNum, distColNum, movGenColNum, trueNum) + """ diff --git a/Project_FARSI/visualization_utils/plotting.py b/Project_FARSI/visualization_utils/plotting.py new file mode 100644 index 00000000..9c5cd1a2 --- /dev/null +++ b/Project_FARSI/visualization_utils/plotting.py @@ -0,0 +1,5179 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. +import math +import itertools +import copy +import csv +import os +import matplotlib.pyplot as plt +import matplotlib +import pandas as pd +import numpy as np +import shutil +from settings import config_plotting +import time +import re +from collections import OrderedDict +from pygmo import * +def get_column_name_number(dir_addr, mode): + column_name_number_dic = {} + try: + if mode == "all": + file_name = "result_summary/FARSI_simple_run_0_1_all_reults.csv" + elif mode =="aggregate": + file_name = "result_summary/aggregate_all_results.csv" + else: + file_name = "result_summary/FARSI_simple_run_0_1.csv" + + file_full_addr = os.path.join(dir_addr, file_name) + with open(file_full_addr) as f: + resultReader = csv.reader(f, delimiter=',', quotechar='|') + for row in resultReader: + for idx, el_name in enumerate(row): + column_name_number_dic[el_name] = idx + break + return column_name_number_dic + except Exception as e: + raise e + + + +# + + +# the function to get the column information of the given category +def columnNum(dirName, fileName, cate, result): + if result == "all": + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + for i, row in enumerate(resultReader): + if i == 0: + for j in range(0, len(row)): + if row[j] == cate: + return j + raise Exception("No such category in the list! Check the name: " + cate) + break + elif result == "simple": + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + for i, row in enumerate(resultReader): + if i == 0: + for j in range(0, len(row)): + if row[j] == cate: + return j + raise Exception("No such category in the list! Check the name: " + cate) + break + else: + raise Exception("No such result file! Check the result type! It should be either \"all\" or \"simple\"") + +# the function to plot the frequency of all comm_comp in the pie chart +def plotCommCompAll(dirName, fileName, all_res_column_name_number): + colNum = all_res_column_name_number["comm_comp"] + truNum = all_res_column_name_number["move validity"] + + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + commNum = 0 + compNum = 0 + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i > 1: + if row[colNum] == "comm": + commNum += 1 + elif row[colNum] == "comp": + compNum += 1 + else: + raise Exception("comm_comp is not giving comm or comp! The new type: " + row[colNum]) + + plt.figure() + plt.pie([commNum, compNum], labels = ["comm", "comp"]) + plt.title("comm_comp: Frequency") + plt.savefig(dirName + fileName + "/comm-compFreq-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot the frequency of all high level optimizations in the pie chart +def plothighLevelOptAll(dirName, fileName, all_res_column_name_number): + colNum = all_res_column_name_number["high level optimization name"] + truNum = all_res_column_name_number["move validity"] + + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + topoNum = 0 + tunNum = 0 + mapNum = 0 + idenOptNum = 0 + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i > 1: + if row[colNum] == "topology": + topoNum += 1 + elif row[colNum] == "customization": + tunNum += 1 + elif row[colNum] == "mapping": + mapNum += 1 + elif row[colNum] == "identity": + idenOptNum += 1 + else: + raise Exception("high level optimization name is not giving topology or customization or mapping or identity! The new type: " + row[colNum]) + + plt.figure() + plt.pie([topoNum, tunNum, mapNum, idenOptNum], labels = ["topology", "customization", "mapping", "identity"]) + plt.title("High Level Optimization: Frequency") + plt.savefig(dirName + fileName + "/highLevelOpt-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot the frequency of all architectural variables to improve in the pie chart +def plotArchVarImpAll(dirName, fileName, colNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + parazNum = 0 + custNum = 0 + localNum = 0 + idenImpNum = 0 + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i > 1: + if row[colNum] == "parallelization": + parazNum += 1 + elif row[colNum] == "customization": + custNum += 1 + elif row[colNum] == "locality": + localNum += 1 + elif row[colNum] == "identity": + idenImpNum += 1 + else: + raise Exception("architectural principle is not parallelization or customization or locality or identity! The new type: " + row[colNum]) + + plt.figure() + plt.pie([parazNum, custNum, localNum, idenImpNum], labels = ["parallelization", "customization", "locality", "identity"]) + plt.title("Architectural Principle: Frequency") + plt.savefig(dirName + fileName + "/archVarImp-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot simulation time vs. system block count +def plotSimTimeVSblk(dirName, fileName, blkColNum, simColNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + sysBlkCount = [] + simTime = [] + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i > 1: + sysBlkCount.append(int(row[blkColNum])) + simTime.append(float(row[simColNum])) + + plt.figure() + plt.plot(sysBlkCount, simTime) + plt.xlabel("System Block Count") + plt.ylabel("Simulation Time") + plt.title("Simulation Time vs. Sytem Block Count") + plt.savefig(dirName + fileName + "/simTimeVSblk-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot move generation time vs. system block count +def plotMoveGenTimeVSblk(dirName, fileName, blkColNum, movColNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + sysBlkCount = [] + moveGenTime = [] + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i > 1: + sysBlkCount.append(int(row[blkColNum])) + moveGenTime.append(float(row[movColNum])) + + plt.figure() + plt.plot(sysBlkCount, moveGenTime) + plt.xlabel("System Block Count") + plt.ylabel("Move Generation Time") + plt.title("Move Generation Time vs. System Block Count") + plt.savefig(dirName + fileName + "/moveGenTimeVSblk-" + fileName + ".png") + # plt.show() + plt.close('all') + +def get_experiments_workload(all_res_column_name): + latency_budget = all_res_column_name_number["latency budget"][:-1] + workload_latency = latency_budget.split(";") + workloads = [] + for workload_latency in workload_latency: + workloads.append(workload_latency.split("=")[0]) + return workloads + +def get_experiments_name(file_full_addr, all_res_column_name_number): + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + row1 = next(resultReader) + row2 = next(resultReader) + latency_budget = row2[all_res_column_name_number["latency_budget"]] + power_budget = row2[all_res_column_name_number["power_budget"]] + area_budget = row2[all_res_column_name_number["area_budget"]] + try: + transformation_selection_mode = row2[all_res_column_name_number["transformation_selection_mode"]] + except: + transformation_selection_mode = "" + + + workload_latency = latency_budget[:-1].split(';') + latency_budget_refined ="" + for workload_latency in workload_latency: + latency_budget_refined +="_" + (workload_latency.split("=")[0][0]+workload_latency.split("=")[1]) + + return latency_budget_refined[1:]+"_" + power_budget + "_" + area_budget+"_"+transformation_selection_mode + +def get_all_col_values_of_a_file(file_full_addr, all_res_column_name_number, column_name): + column_number = all_res_column_name_number[column_name] + all_values = [] + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + experiment_name = get_experiments_name(file_full_addr, all_res_column_name_number) + for i, row in enumerate(resultReader): + if i > 1: + if not row[column_number] == '': + value =row[column_number] + values = value.split(";") # if mutiple values + for val in values: + if "=" in val: + val_splitted = val.split("=") + all_values.append(val_splitted[0]) + else: + all_values.append(val) + + return all_values + +def get_all_col_values_of_a_folders(input_dir_names, input_all_res_column_name_number, column_name): + all_values = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + all_values.extend(get_all_col_values_of_a_file(file_full_addr, input_all_res_column_name_number, column_name)) + + # get rid of duplicates + all_values_rid_of_duplicates = list(set(all_values)) + return all_values_rid_of_duplicates + +def extract_latency_values(values_): + print("") + + +def plot_codesign_rate_efficacy_cross_workloads_updated(input_dir_names, res_column_name_number): + #itrColNum = all_res_column_name_number["iteration cnt"] + #distColNum = all_res_column_name_number["dist_to_goal_non_cost"] + trueNum = all_res_column_name_number["move validity"] + move_name_number = all_res_column_name_number["move name"] + + # experiment_names + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + file_full_addr_list.append(file_full_addr) + + axis_font = {'fontname': 'Arial', 'size': '4'} + x_column_name = "iteration cnt" + #y_column_name_list = ["high level optimization name", "exact optimization name", "architectural principle", "comm_comp"] + y_column_name_list = ["exact optimization name", "architectural principle", "comm_comp", "workload"] + + #y_column_name_list = ["high level optimization name", "exact optimization name", "architectural principle", "comm_comp"] + + + + column_co_design_cnt = {} + column_non_co_design_cnt = {} + column_co_design_rate = {} + column_non_co_design_rate = {} + column_co_design_efficacy_avg = {} + column_non_co_design_efficacy_rate = {} + column_non_co_design_efficacy = {} + column_co_design_dist= {} + column_co_design_dist_avg= {} + column_co_design_improvement = {} + experiment_name_list = [] + last_col_val = "" + for file_full_addr in file_full_addr_list: + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_name_list.append(experiment_name) + column_co_design_dist_avg[experiment_name] = {} + column_co_design_efficacy_avg[experiment_name] = {} + + column_co_design_cnt = {} + for y_column_name in y_column_name_list: + y_column_number = res_column_name_number[y_column_name] + x_column_number = res_column_name_number[x_column_name] + + + dis_to_goal_column_number = res_column_name_number["dist_to_goal_non_cost"] + ref_des_dis_to_goal_column_number = res_column_name_number["ref_des_dist_to_goal_non_cost"] + column_co_design_cnt[y_column_name] = [] + column_non_co_design_cnt[y_column_name] = [] + + column_non_co_design_efficacy[y_column_name] = [] + column_co_design_dist[y_column_name] = [] + column_co_design_improvement[y_column_name] = [] + column_co_design_rate[y_column_name] = [] + + all_values = get_all_col_values_of_a_folders(input_dir_names, all_res_column_name_number, y_column_name) + + last_row_change = "" + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + rows = list(resultReader) + for i, row in enumerate(rows): + if i > 1: + last_row = rows[i - 1] + if row[y_column_number] not in all_values or row[move_name_number]=="identity": + continue + + col_value = row[y_column_number] + col_values = col_value.split(";") + for idx, col_val in enumerate(col_values): + + + # only for improvement + if float(row[ref_des_dis_to_goal_column_number]) - float(row[dis_to_goal_column_number]) < 0: + continue + + try: + delta_x_column = (float(row[x_column_number]) - float(last_row[x_column_number]))/len(col_values) + except: + pass + + delta_improvement = (float(last_row[dis_to_goal_column_number]) - float(row[dis_to_goal_column_number]))/(float(last_row[dis_to_goal_column_number])*len(col_values)) + + + if not col_val == last_col_val and i > 1: + if not last_row_change == "": + distance_from_last_change = float(last_row[x_column_number]) - float(last_row_change[x_column_number]) + idx * delta_x_column + column_co_design_dist[y_column_name].append(distance_from_last_change) + improvement_from_last_change = (float(last_row[dis_to_goal_column_number]) - float(row[dis_to_goal_column_number]))/float(last_row[dis_to_goal_column_number]) + idx *delta_improvement + column_co_design_improvement[y_column_name].append(improvement_from_last_change) + + last_row_change = copy.deepcopy(last_row) + + + last_col_val = col_val + + + + # co_des cnt + # we ignore the first element as the first element distance is always zero + co_design_dist_sum = 0 + co_design_efficacy_sum = 0 + avg_ctr = 1 + co_design_dist_selected = column_co_design_dist[y_column_name] + co_design_improvement_selected = column_co_design_improvement[y_column_name] + for idx,el in enumerate(column_co_design_dist[y_column_name]): + if idx == len(co_design_dist_selected) - 1: + break + co_design_dist_sum += 1/(column_co_design_dist[y_column_name][idx] + column_co_design_dist[y_column_name][idx+1]) + co_design_efficacy_sum += (column_co_design_improvement[y_column_name][idx] + column_co_design_improvement[y_column_name][idx+1]) + #/(column_co_design_dist[y_column_name][idx] + column_co_design_dist[y_column_name][idx+1]) + avg_ctr+=1 + + column_co_design_improvement = {} + column_co_design_dist_avg[experiment_name][y_column_name]= co_design_dist_sum/avg_ctr + column_co_design_efficacy_avg[experiment_name][y_column_name] = co_design_efficacy_sum/avg_ctr + + #result = {"rate":{}, "efficacy":{}} + #rate_column_co_design = {} + + plt.figure() + plotdata = pd.DataFrame(column_co_design_dist_avg, index=y_column_name_list) + #plotdata_2 = pd.DataFrame(column_co_design_efficacy_avg, index=y_column_name_list) + fontSize = 10 + plotdata.plot(kind='bar', fontsize=fontSize) + #plotdata_2.plot(kind='bar', fontsize=fontSize) + plt.xticks(fontsize=fontSize, rotation=6) + plt.yticks(fontsize=fontSize) + plt.xlabel("co design parameter", fontsize=fontSize) + plt.ylabel("co design distance", fontsize=fontSize) + plt.title("co desgin distance of different parameters", fontsize=fontSize) + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "cross_workloads/co_design_rate") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + plt.savefig(os.path.join(output_dir,"_".join(experiment_name_list[0]) +"_"+"co_design_avg_dist"+'_'.join(y_column_name_list)+".png")) + plt.close('all') + + + plt.figure() + plotdata = pd.DataFrame(column_co_design_efficacy_avg, index=y_column_name_list) + fontSize = 10 + plotdata.plot(kind='bar', fontsize=fontSize) + plt.xticks(fontsize=fontSize, rotation=6) + plt.yticks(fontsize=fontSize) + plt.xlabel("co design parameter", fontsize=fontSize) + plt.ylabel("co design dis", fontsize=fontSize) + plt.title("co desgin efficacy of different parameters", fontsize=fontSize) + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "cross_workloads/co_design_rate") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + plt.savefig(os.path.join(output_dir,"_".join(experiment_name_list) +"_"+"co_design_efficacy"+'_'.join(y_column_name_list)+".png")) + plt.close('all') + +def plot_codesign_rate_efficacy_cross_workloads_updated_for_paper(input_dir_names, res_column_name_number): + #itrColNum = all_res_column_name_number["iteration cnt"] + #distColNum = all_res_column_name_number["dist_to_goal_non_cost"] + trueNum = all_res_column_name_number["move validity"] + move_name_number = all_res_column_name_number["move name"] + + # experiment_names + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + file_full_addr_list.append(file_full_addr) + + axis_font = {'fontname': 'Arial', 'size': '4'} + x_column_name = "iteration cnt" + #y_column_name_list = ["high level optimization name", "exact optimization name", "architectural principle", "comm_comp"] + y_column_name_list = ["exact optimization name", "architectural principle", "comm_comp", "workload"] + + #y_column_name_list = ["high level optimization name", "exact optimization name", "architectural principle", "comm_comp"] + + + + column_co_design_cnt = {} + column_non_co_design_cnt = {} + column_co_design_rate = {} + column_non_co_design_rate = {} + column_co_design_efficacy_avg = {} + column_non_co_design_efficacy_rate = {} + column_non_co_design_efficacy = {} + column_co_design_dist= {} + column_co_design_dist_avg= {} + column_co_design_improvement = {} + experiment_name_list = [] + last_col_val = "" + ctr_ = 0 + for file_full_addr in file_full_addr_list: + if ctr_ == 1: + continue + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_name_list.append(experiment_name) + column_co_design_dist_avg[experiment_name] = {} + column_co_design_efficacy_avg[experiment_name] = {} + + column_co_design_cnt = {} + for y_column_name in y_column_name_list: + y_column_number = res_column_name_number[y_column_name] + x_column_number = res_column_name_number[x_column_name] + + + dis_to_goal_column_number = res_column_name_number["dist_to_goal_non_cost"] + ref_des_dis_to_goal_column_number = res_column_name_number["ref_des_dist_to_goal_non_cost"] + column_co_design_cnt[y_column_name] = [] + column_non_co_design_cnt[y_column_name] = [] + + column_non_co_design_efficacy[y_column_name] = [] + column_co_design_dist[y_column_name] = [] + column_co_design_improvement[y_column_name] = [] + column_co_design_rate[y_column_name] = [] + + all_values = get_all_col_values_of_a_folders(input_dir_names, all_res_column_name_number, y_column_name) + + last_row_change = "" + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + rows = list(resultReader) + for i, row in enumerate(rows): + if i > 1: + last_row = rows[i - 1] + if row[y_column_number] not in all_values or row[move_name_number]=="identity": + continue + + col_value = row[y_column_number] + col_values = col_value.split(";") + for idx, col_val in enumerate(col_values): + + + # only for improvement + if float(row[ref_des_dis_to_goal_column_number]) - float(row[dis_to_goal_column_number]) < 0: + continue + + delta_x_column = (float(row[x_column_number]) - float(last_row[x_column_number]))/len(col_values) + delta_improvement = (float(last_row[dis_to_goal_column_number]) - float(row[dis_to_goal_column_number]))/(float(last_row[dis_to_goal_column_number])*len(col_values)) + + + if not col_val == last_col_val and i > 1: + if not last_row_change == "": + distance_from_last_change = float(last_row[x_column_number]) - float(last_row_change[x_column_number]) + idx * delta_x_column + column_co_design_dist[y_column_name].append(distance_from_last_change) + improvement_from_last_change = (float(last_row[dis_to_goal_column_number]) - float(row[dis_to_goal_column_number]))/float(last_row[dis_to_goal_column_number]) + idx *delta_improvement + column_co_design_improvement[y_column_name].append(improvement_from_last_change) + + last_row_change = copy.deepcopy(last_row) + + + last_col_val = col_val + + + + # co_des cnt + # we ignore the first element as the first element distance is always zero + co_design_dist_sum = 0 + co_design_efficacy_sum = 0 + avg_ctr = 1 + co_design_dist_selected = column_co_design_dist[y_column_name] + co_design_improvement_selected = column_co_design_improvement[y_column_name] + for idx,el in enumerate(column_co_design_dist[y_column_name]): + if idx == len(co_design_dist_selected) - 1: + break + co_design_dist_sum += 1/(column_co_design_dist[y_column_name][idx] + column_co_design_dist[y_column_name][idx+1]) + co_design_efficacy_sum += (column_co_design_improvement[y_column_name][idx] + column_co_design_improvement[y_column_name][idx+1]) + #/(column_co_design_dist[y_column_name][idx] + column_co_design_dist[y_column_name][idx+1]) + avg_ctr+=1 + + column_co_design_improvement = {} + column_co_design_dist_avg[experiment_name][y_column_name]= co_design_dist_sum/avg_ctr + column_co_design_efficacy_avg[experiment_name][y_column_name] = co_design_efficacy_sum/avg_ctr + ctr_ +=1 + #result = {"rate":{}, "efficacy":{}} + #rate_column_co_design = {} + + plt.figure() + y_column_name_list_rep = ["L.Opt", "H.Opt", "CM", "WL"] + y_column_name_list_rep_rep = [re.sub("(.{5})", "\\1\n", label, 0, re.DOTALL) for label in y_column_name_list_rep] + plotdata = pd.DataFrame(column_co_design_dist_avg, index=y_column_name_list) + fontSize = 26 + # print(plotdata) + # print(y_column_name_list) + color_list={'a0.021_e0.034_h0.034_0.008737_1.7475e-05_arch-aware': 'green', 'a0.021_e0.034_h0.034_0.008737_1.7475e-05_random': 'red'} # Ying: uncomment for blind_study_all_dumb_versions/blind_vs_arch_aware (T.B.M), also the one below + ax = plotdata.plot(kind='bar', fontsize=fontSize, figsize=(6.6, 6.6), color=color_list) # Ying: uncomment for blind_study_all_dumb_versions/blind_vs_arch_aware (T.B.M), also the one below + ax.set_xticklabels(y_column_name_list_rep_rep, rotation=0) + # Ying: hardcode here + ax.set_xlabel("Co-design Parameter", fontsize=fontSize, labelpad=-25) + ax.set_ylabel("Co-design Rate", fontsize=fontSize) + for experiment_name, value in column_co_design_dist_avg.items(): + # print(experiment_name[-6:]) + if experiment_name[-6:] == "random": + ax.legend(['SA', 'FARSI'], bbox_to_anchor=(0.5, 1.23), loc="upper center", fontsize=fontSize - 2, ncol=2) + else: + ax.legend(['FARSI', 'SA'], bbox_to_anchor=(0.5, 1.23), loc="upper center", fontsize=fontSize - 2, ncol=2) + break + # Ying: hardcode finished + plt.tight_layout() + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "cross_workloads/co_design_rate") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + plt.savefig(os.path.join(output_dir,"co_design_avg_dist"+'_'.join(y_column_name_list)+".png"), bbox_inches='tight') + # plt.show() + plt.close('all') + + + plt.figure() + plotdata = pd.DataFrame(column_co_design_efficacy_avg, index=y_column_name_list) + fontSize = 26 + ax = plotdata.plot(kind='bar', fontsize=fontSize, figsize=(6.6, 6.6), color=color_list) # Ying: uncomment for blind_study_all_dumb_versions/blind_vs_arch_aware (T.B.M), also the one above + ax.set_xticklabels(y_column_name_list_rep_rep, rotation=0) + # Ying: hardcode here + ax.set_xlabel("Co-design Parameter", fontsize=fontSize, labelpad=-25) + ax.set_ylabel("Co-design Improvement", fontsize=fontSize) + for experiment_name, value in column_co_design_dist_avg.items(): + # print(experiment_name[-6:]) + if experiment_name[-6:] == "random": + ax.legend(['SA', 'FARSI'], bbox_to_anchor=(0.5, 1.23), loc="upper center", fontsize=fontSize - 2, ncol=2) + else: + ax.legend(['FARSI', 'SA'], bbox_to_anchor=(0.5, 1.23), loc="upper center", fontsize=fontSize - 2, ncol=2) + break + # Ying: hardcode finished + plt.tight_layout() + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "cross_workloads/co_design_rate") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + plt.savefig(os.path.join(output_dir,"co_design_efficacy"+'_'.join(y_column_name_list)+".png"), bbox_inches='tight') + # plt.show() + plt.close('all') + +""" +def plot_codesign_rate_efficacy_per_workloads(input_dir_names, res_column_name_number): + #itrColNum = all_res_column_name_number["iteration cnt"] + #distColNum = all_res_column_name_number["dist_to_goal_non_cost"] + trueNum = all_res_column_name_number["move validity"] + move_name_number = all_res_column_name_number["move name"] + + # experiment_names + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + file_full_addr_list.append(file_full_addr) + + axis_font = {'fontname': 'Arial', 'size': '4'} + x_column_name = "iteration cnt" + #y_column_name_list = ["high level optimization name", "exact optimization name", "architectural principle", "comm_comp"] + y_column_name_list = ["exact optimization name", "architectural principle", "comm_comp", "workload"] + + #y_column_name_list = ["high level optimization name", "exact optimization name", "architectural principle", "comm_comp"] + + + + column_co_design_cnt = {} + column_non_co_design_cnt = {} + column_co_design_rate = {} + column_non_co_design_rate = {} + column_co_design_efficacy_rate = {} + column_non_co_design_efficacy_rate = {} + column_non_co_design_efficacy = {} + column_co_design_efficacy= {} + last_col_val = "" + for file_full_addr in file_full_addr_list: + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + column_co_design_cnt = {} + for y_column_name in y_column_name_list: + y_column_number = res_column_name_number[y_column_name] + x_column_number = res_column_name_number[x_column_name] + + + dis_to_goal_column_number = res_column_name_number["dist_to_goal_non_cost"] + ref_des_dis_to_goal_column_number = res_column_name_number["ref_des_dist_to_goal_non_cost"] + column_co_design_cnt[y_column_name] = [] + column_non_co_design_cnt[y_column_name] = [] + + column_non_co_design_efficacy[y_column_name] = [] + column_co_design_efficacy[y_column_name] = [] + + all_values = get_all_col_values_of_a_folders(input_dir_names, all_res_column_name_number, y_column_name) + + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + rows = list(resultReader) + for i, row in enumerate(rows): + if i >= 1: + last_row = rows[i - 1] + if row[y_column_number] not in all_values or row[trueNum] == "False" or row[move_name_number]=="identity": + continue + + col_value = row[y_column_number] + col_values = col_value.split(";") + for idx, col_val in enumerate(col_values): + delta_x_column = (float(row[x_column_number]) - float(last_row[x_column_number]))/len(col_values) + + value_to_add_1 = (float(last_row[x_column_number]) + idx * delta_x_column, 1) + value_to_add_0 = (float(last_row[x_column_number]) + idx * delta_x_column, 0) + + # only for improvement + if float(row[ref_des_dis_to_goal_column_number]) - float(row[dis_to_goal_column_number]) < 0: + continue + + if not col_val == last_col_val: + + column_co_design_cnt[y_column_name].append(value_to_add_1) + column_non_co_design_cnt[y_column_name].append(value_to_add_0) + column_co_design_efficacy[y_column_name].append((float(row[ref_des_dis_to_goal_column_number]) - float(row[dis_to_goal_column_number]))/float(row[ref_des_dis_to_goal_column_number])) + column_non_co_design_efficacy[y_column_name].append(0) + else: + column_co_design_cnt[y_column_name].append(value_to_add_0) + column_non_co_design_cnt[y_column_name].append(value_to_add_1) + column_co_design_efficacy[y_column_name].append(0) + column_non_co_design_efficacy[y_column_name].append((float(row[ref_des_dis_to_goal_column_number]) - float(row[dis_to_goal_column_number]))/float(row[ref_des_dis_to_goal_column_number])) + + last_col_val = col_val + + + + # co_des cnt + x_values_co_design_cnt = [el[0] for el in column_co_design_cnt[y_column_name]] + y_values_co_design_cnt = [el[1] for el in column_co_design_cnt[y_column_name]] + y_values_co_design_cnt_total =sum(y_values_co_design_cnt) + total_iter = x_values_co_design_cnt[-1] + + # non co_des cnt + x_values_non_co_design_cnt = [el[0] for el in column_non_co_design_cnt[y_column_name]] + y_values_non_co_design_cnt = [el[1] for el in column_non_co_design_cnt[y_column_name]] + y_values_non_co_design_cnt_total =sum(y_values_non_co_design_cnt) + + column_co_design_rate[y_column_name] = y_values_co_design_cnt_total/total_iter + column_non_co_design_rate[y_column_name] = y_values_non_co_design_cnt_total/total_iter + + # co_des efficacy + y_values_co_design_efficacy = column_co_design_efficacy[y_column_name] + y_values_co_design_efficacy_total =sum(y_values_co_design_efficacy) + + + # non co_des efficacy + y_values_non_co_design_efficacy = column_non_co_design_efficacy[y_column_name] + y_values_non_co_design_efficacy_total =sum(y_values_non_co_design_efficacy) + + column_co_design_efficacy_rate[y_column_name] = y_values_co_design_efficacy_total/(y_values_non_co_design_efficacy_total + y_values_co_design_efficacy_total) + column_non_co_design_efficacy_rate[y_column_name] = y_values_non_co_design_efficacy_total/(y_values_non_co_design_efficacy_total + y_values_co_design_efficacy_total) + + + result = {"rate":{}, "efficacy":{}} + rate_column_co_design = {} + + result["rate"] = {"co_design":column_co_design_rate, "non_co_design": column_non_co_design_rate} + result["efficacy_rate"] = {"co_design":column_co_design_efficacy_rate, "non_co_design": column_non_co_design_efficacy_rate} + # prepare for plotting and plot + + + plt.figure() + plotdata = pd.DataFrame(result["rate"], index=y_column_name_list) + fontSize = 10 + plotdata.plot(kind='bar', fontsize=fontSize, stacked=True) + plt.xticks(fontsize=fontSize, rotation=6) + plt.yticks(fontsize=fontSize) + plt.xlabel("co design parameter", fontsize=fontSize) + plt.ylabel("co design rate", fontsize=fontSize) + plt.title("co desgin rate of different parameters", fontsize=fontSize) + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "single_workload/co_design_rate") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + plt.savefig(os.path.join(output_dir,experiment_name +"_"+"co_design_rate_"+'_'.join(y_column_name_list)+".png")) + plt.close('all') + + + plt.figure() + plotdata = pd.DataFrame(result["efficacy_rate"], index=y_column_name_list) + fontSize = 10 + plotdata.plot(kind='bar', fontsize=fontSize, stacked=True) + plt.xticks(fontsize=fontSize, rotation=6) + plt.yticks(fontsize=fontSize) + plt.xlabel("co design parameter", fontsize=fontSize) + plt.ylabel("co design efficacy rate", fontsize=fontSize) + plt.title("co design efficacy rate of different parameters", fontsize=fontSize) + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "single_workload/co_design_rate") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + plt.savefig(os.path.join(output_dir,experiment_name+"_"+"co_design_efficacy_rate_"+'_'.join(y_column_name_list)+".png")) + plt.close('all') +""" + + +def plot_codesign_progression_per_workloads(input_dir_names, res_column_name_number): + #itrColNum = all_res_column_name_number["iteration cnt"] + #distColNum = all_res_column_name_number["dist_to_goal_non_cost"] + trueNum = all_res_column_name_number["move validity"] + + # experiment_names + experiment_names = [] + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + file_full_addr_list.append(file_full_addr) + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_names.append(experiment_name) + + axis_font = {'size': '20'} + x_column_name = "iteration cnt" + y_column_name_list = ["high level optimization name", "exact optimization name", "architectural principle", "comm_comp"] + + + experiment_column_value = {} + for file_full_addr in file_full_addr_list: + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + for y_column_name in y_column_name_list: + y_column_number = res_column_name_number[y_column_name] + x_column_number = res_column_name_number[x_column_name] + experiment_column_value[experiment_name] = [] + all_values = get_all_col_values_of_a_folders(input_dir_names, all_res_column_name_number, y_column_name) + all_values_encoding = {} + for idx, val in enumerate(all_values): + all_values_encoding[val] = idx + + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + rows = list(resultReader) + for i, row in enumerate(rows): + #if row[trueNum] != "True": + # continue + if i >= 1: + if row[y_column_number] not in all_values: + continue + + col_value = row[y_column_number] + col_values = col_value.split(";") + for idx, col_val in enumerate(col_values): + last_row = rows[i-1] + delta_x_column = (float(row[x_column_number]) - float(last_row[x_column_number]))/len(col_values) + value_to_add = (float(last_row[x_column_number])+ idx*delta_x_column, col_val) + experiment_column_value[experiment_name].append(value_to_add) + + + + # prepare for plotting and plot + axis_font = {'size': '20'} + fontSize = 20 + + fig = plt.figure(figsize=(12, 8)) + plt.rc('font', **axis_font) + ax = fig.add_subplot(111) + x_values = [el[0] for el in experiment_column_value[experiment_name]] + #y_values = [all_values_encoding[el[1]] for el in experiment_column_value[experiment_name]] + y_values = [el[1] for el in experiment_column_value[experiment_name]] + + #ax.set_title("experiment vs system implicaction") + ax.tick_params(axis='both', which='major', labelsize=fontSize, rotation=60) + ax.set_xlabel(x_column_name, fontsize=20) + ax.set_ylabel(y_column_name, fontsize=20) + ax.plot(x_values, y_values, label=y_column_name, linewidth=2) + ax.legend(bbox_to_anchor=(1, 1), loc='upper left', fontsize=fontSize) + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "single_workload/progression") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + plt.tight_layout() + fig.savefig(os.path.join(output_dir,experiment_name+"_progression_"+'_'.join(y_column_name_list)+".png")) + # plt.show() + plt.close('all') + + fig = plt.figure(figsize=(12, 8)) + plt.rc('font', **axis_font) + ax = fig.add_subplot(111) + x_values = [el[0] for el in experiment_column_value[experiment_name]] + # y_values = [all_values_encoding[el[1]] for el in experiment_column_value[experiment_name]] + y_values = [el[1] for el in experiment_column_value[experiment_name]] + + # ax.set_title("experiment vs system implicaction") + ax.tick_params(axis='both', which='major', labelsize=fontSize, rotation=60) + ax.set_xlabel(x_column_name, fontsize=20) + ax.set_ylabel(y_column_name, fontsize=20) + ax.plot(x_values, y_values, label=y_column_name, linewidth=2) + ax.legend(bbox_to_anchor=(1, 1), loc='upper left', fontsize=fontSize) + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "single_workload/progression") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + plt.tight_layout() + fig.savefig(os.path.join(output_dir, experiment_name + "_progression_" + y_column_name + ".png")) + # plt.show() + plt.close('all') + + +def plot_3d(input_dir_names, res_column_name_number): + # experiment_names + experiment_names = [] + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1.csv") + file_full_addr_list.append(file_full_addr) + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_names.append(experiment_name) + + axis_font = {'size': '10'} + fontSize = 10 + column_value = {} + # initialize the dictionary + column_name_list = ["budget_scaling_power", "budget_scaling_area","budget_scaling_latency"] + + under_study_vars =["iteration cnt", + "local_bus_avg_theoretical_bandwidth", "local_bus_max_actual_bandwidth", + "local_bus_avg_actual_bandwidth", + "system_bus_avg_theoretical_bandwidth", "system_bus_max_actual_bandwidth", + "system_bus_avg_actual_bandwidth", "global_total_traffic", "local_total_traffic", + "global_memory_total_area", "local_memory_total_area", "ips_total_area", + "gpps_total_area","ip_cnt", "max_accel_parallelism", "avg_accel_parallelism", + "gpp_cnt", "max_gpp_parallelism", "avg_gpp_parallelism"] + + + + + # get all the data + for file_full_addr in file_full_addr_list: + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + experiment_name = get_experiments_name( file_full_addr, res_column_name_number) + + for i, row in enumerate(resultReader): + #if row[trueNum] != "True": + # continue + if i >= 1: + for column_name in column_name_list + under_study_vars: + if column_name not in column_value.keys() : + column_value[column_name] = [] + column_number = res_column_name_number[column_name] + col_value = row[column_number] + col_values = col_value.split(";") + if "=" in col_values[0]: + column_value[column_name].append(float((col_values[0]).split("=")[1])) + else: + column_value[column_name].append(float(col_values[0])) + + + for idx,under_study_var in enumerate(under_study_vars): + fig_budget_blkcnt = plt.figure(figsize=(12, 12)) + plt.rc('font', **axis_font) + ax_blkcnt = fig_budget_blkcnt.add_subplot(projection='3d') + img = ax_blkcnt.scatter3D(column_value["budget_scaling_power"], column_value["budget_scaling_area"], column_value["budget_scaling_latency"], + c=column_value[under_study_var], cmap="bwr", s=80, label="System Block Count") + for idx,_ in enumerate(column_value[under_study_var]): + coordinate = column_value[under_study_var][idx] + coord_in_scientific_notatio = "{:.2e}".format(coordinate) + + ax_blkcnt.text(column_value["budget_scaling_power"][idx], column_value["budget_scaling_area"][idx], column_value["budget_scaling_latency"][idx], '%s' % coord_in_scientific_notatio, size=fontSize) + + ax_blkcnt.set_xlabel("Power Budget", fontsize=fontSize) + ax_blkcnt.set_ylabel("Area Budget", fontsize=fontSize) + ax_blkcnt.set_zlabel("Latency Budget", fontsize=fontSize) + ax_blkcnt.legend() + cbar = fig_budget_blkcnt.colorbar(img, aspect=40) + cbar.set_label("System Block Count", rotation=270) + #plt.title("{Power Budget, Area Budget, Latency Budget} VS System Block Count: " + subDirName) + plt.tight_layout() + + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "3D/case_studies") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + plt.savefig(os.path.join(output_dir, under_study_var+ ".png")) + # plt.show() + plt.close('all') + + +def plot_convergence_per_workloads_for_paper(input_dir_names, res_column_name_number): + budget_alpha = 1 + non_optimization_alpha = .1 + budget_marker = "_" + regu_marker = "." + budget_marker_size = 4 + regu_marker_size = 1 + + + #itrColNum = all_res_column_name_number["iteration cnt"] + #distColNum = all_res_column_name_number["dist_to_goal_non_cost"] + trueNum = all_res_column_name_number["move validity"] + move_name_number = all_res_column_name_number["move name"] + + + # experiment_names + experiment_names = [] + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + file_full_addr_list.append(file_full_addr) + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_names.append(experiment_name) + + #color_values = ["r","b","y","black","brown","purple"] + color_values = {} + color_values["latency_edge_detection"] = color_values["best_des_so_far_latency_edge_detection"] =color_values["latency_budget_edge_detection"] = matplotlib.colors.to_rgba("red") + color_values["latency_hpvm_cava"] =color_values["best_des_so_far_latency_hpvm_cava"] = color_values["latency_budget_hpvm_cava"] = matplotlib.colors.to_rgba("magenta") + color_values["latency_audio_decoder"] = color_values["best_des_so_far_latency_audio_decoder"] =color_values["latency_budget_audio_decoder"] = matplotlib.colors.to_rgba("orange") + color_values["area_non_dram"] =color_values["best_des_so_far_area_non_dram"] = color_values["area_budget"] = matplotlib.colors.to_rgba("forestgreen") + color_values["brown"] = (1,0,0,1) + color_values["power"] = color_values["best_des_so_far_power"] =color_values["power_budget"] = matplotlib.colors.to_rgba("mediumblue") + + color_values["latency_budget_edge_detection"] = matplotlib.colors.to_rgba("white") + color_values["latency_budget_hpvm_cava"] = matplotlib.colors.to_rgba("white") + color_values["latency_budget_audio_decoder"] = matplotlib.colors.to_rgba("white") + color_values["area_budget"] = matplotlib.colors.to_rgba("white") + color_values["power_budget"] = matplotlib.colors.to_rgba("white") + + + + column_name_color_val_dict = {"best_des_so_far_power":"purple", "power_budget":"purple","best_des_so_far_area_non_dram":"blue", "area_budget":"blue", + "latency_budget_hpvm_cava":"orange", "latency_budget_audio_decoder":"yellow", "latency_budget_edge_detection":"red", + "best_des_so_far_latency_hpvm_cava":"orange", "best_des_so_far_latency_audio_decoder": "yellow","best_des_so_far_latency_edge_detection": "red", + "latency_budget":"white" + } + + + + """ + column_name_color_val_dict = {"power":"purple", "power_budget":"purple","area_non_dram":"blue", "area_budget":"blue", + "latency_budget_hpvm_cava":"orange", "latency_budget_audio_decoder":"yellow", "latency_budget_edge_detection":"red", + "latency_hpvm_cava":"orange", "latency_audio_decoder": "yellow","latency_edge_detection": "red", + "latency_budget":"white" + } + """ + + axis_font = {'size': '20'} + fontSize = 20 + x_column_name = "iteration cnt" + #y_column_name_list = ["power", "area_non_dram", "latency", "latency_budget", "power_budget","area_budget"] + y_column_name_list = ["power", "area_non_dram", "latency"] + + experiment_column_value = {} + for file_full_addr in file_full_addr_list: + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_column_value[experiment_name] = {} + for y_column_name in y_column_name_list: + if "budget" in y_column_name : + prefix = "" + else: + prefix = "best_des_so_far_" + y_column_name = prefix+y_column_name + y_column_number = res_column_name_number[y_column_name] + x_column_number = res_column_name_number[x_column_name] + #dis_to_goal_column_number = res_column_name_number["dist_to_goal_non_cost"] + #ref_des_dis_to_goal_column_number = res_column_name_number["ref_des_dist_to_goal_non_cost"] + + if not y_column_name == prefix+"latency": + experiment_column_value[experiment_name][y_column_name] = [] + + + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + for i, row in enumerate(resultReader): + if i > 1: + if row[trueNum] == "FALSE" or row[move_name_number]=="identity": + continue + metric_chosen = row[res_column_name_number["transformation_metric"]] + workload_chosen = row[res_column_name_number["workload"]] + if "budget" in y_column_name: + alpha = budget_alpha + elif metric_chosen in y_column_name : + alpha = 1 + else: + alpha = non_optimization_alpha + + col_value = row[y_column_number] + if ";" in col_value: + col_value = col_value[:-1] + col_values = col_value.split(";") + for col_val in col_values: + if "=" in col_val: + if "budget" in y_column_name: + alpha = budget_alpha + elif workload_chosen in col_val: + alpha = 1 + else: + alpha = non_optimization_alpha + val_splitted = col_val.split("=") + value_to_add = (float(row[x_column_number]), (val_splitted[0], val_splitted[1]), alpha) + else: + value_to_add = (float(row[x_column_number]), col_val, alpha) + + if y_column_name in [prefix+"latency", prefix+"latency_budget"] : + new_tuple = (value_to_add[0], 1000*float(value_to_add[1][1]), value_to_add[2]) + if y_column_name+"_"+value_to_add[1][0] not in experiment_column_value[experiment_name].keys(): + experiment_column_value[experiment_name][y_column_name + "_" + value_to_add[1][0]] = [] + experiment_column_value[experiment_name][y_column_name+"_"+value_to_add[1][0]].append(new_tuple) + if y_column_name in [prefix+"power", prefix+"power_budget"]: + new_tuple = (value_to_add[0], float(value_to_add[1])*1000, value_to_add[2]) + experiment_column_value[experiment_name][y_column_name].append(new_tuple) + elif y_column_name in [prefix+"area_non_dram", prefix+"area_budget"]: + new_tuple = (value_to_add[0], float(value_to_add[1]) * 1000000, value_to_add[2]) + experiment_column_value[experiment_name][y_column_name].append(new_tuple) + + # prepare for plotting and plot + fig = plt.figure(figsize=(15, 8)) + ax = fig.add_subplot(111) + for column, values in experiment_column_value[experiment_name].items(): + print(column) + if "budget" in column: + marker = budget_marker + marker_size = budget_marker_size + else: + marker = regu_marker + marker_size =regu_marker_size + + x_values = [el[0] for el in values] + y_values = [el[1] for el in values] + colors = [] + for el in values: + color_ = list(color_values[column]) + color_[-1] = el[2] + colors.append(tuple(color_)) + #alphas = [el[2] for el in values] + ax.set_yscale('log') + + if "budget" in column: + marker = '_' + #alpha_ = .3 + else: + marker = "x" + #alpha_ = 1 + for i,x in enumerate(x_values): + + ax.plot(x_values[i], y_values[i], label=column, color=colors[i], marker=marker, markersize=marker_size) + + + #ax.set_title("experiment vs system implicaction") + ax.set_xlabel(x_column_name, fontsize=fontSize) + y_axis_name = "_".join(list(experiment_column_value[experiment_name].keys())) + ax.set_ylabel(y_axis_name, fontsize=fontSize) + ax.legend(bbox_to_anchor=(1, 1), loc='upper left', fontsize=fontSize) + plt.tight_layout() + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "single_workload/convergence") + if not os.path.exists(output_dir): + os.mkdir(output_dir) + fig.savefig(os.path.join(output_dir,experiment_name+"_convergence.png")) + # plt.show() + plt.close('all') + + +def get_budget_values(input_dir_names, res_column_name_number): + experiment_names = [] + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + file_full_addr_list.append(file_full_addr) + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_names.append(experiment_name) + + x_column_name = "iteration cnt" + trueNum = all_res_column_name_number["move validity"] + move_name_number = all_res_column_name_number["move name"] + + y_column_name_list = ["latency_budget", "power_budget","area_budget"] + experiment_column_value = {} + budgets = {} + for file_full_addr in file_full_addr_list: + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_column_value[experiment_name] = {} + for y_column_name in y_column_name_list: + if "budget" in y_column_name : + prefix = "" + else: + prefix = "best_des_so_far_" + y_column_name = prefix+y_column_name + y_column_number = res_column_name_number[y_column_name] + x_column_number = res_column_name_number[x_column_name] + #dis_to_goal_column_number = res_column_name_number["dist_to_goal_non_cost"] + #ref_des_dis_to_goal_column_number = res_column_name_number["ref_des_dist_to_goal_non_cost"] + + if not y_column_name == prefix+"latency": + experiment_column_value[experiment_name][y_column_name] = [] + + + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + for i, row in enumerate(resultReader): + if i > 1: + if row[trueNum] == "FALSE" or row[move_name_number]=="identity": + continue + metric_chosen = row[res_column_name_number["transformation_metric"]] + workload_chosen = row[res_column_name_number["workload"]] + + col_value = row[y_column_number] + if ";" in col_value: + col_value = col_value[:-1] + col_values = col_value.split(";") + for col_val in col_values: + alpha = 1 + if "=" in col_val: + val_splitted = col_val.split("=") + value_to_add = (float(row[x_column_number]), (val_splitted[0], val_splitted[1]), alpha) + else: + value_to_add = (float(row[x_column_number]), col_val, alpha) + + if y_column_name in [prefix+"latency", prefix+"latency_budget"] : + new_tuple = (value_to_add[0], 1000*float(value_to_add[1][1]), value_to_add[2]) + if y_column_name+"_"+value_to_add[1][0] not in experiment_column_value[experiment_name].keys(): + experiment_column_value[experiment_name][y_column_name + "_" + value_to_add[1][0]] = [] + experiment_column_value[experiment_name][y_column_name+"_"+value_to_add[1][0]].append(new_tuple) + if y_column_name in [prefix+"power", prefix+"power_budget"]: + new_tuple = (value_to_add[0], float(value_to_add[1])*1000, value_to_add[2]) + experiment_column_value[experiment_name][y_column_name].append(new_tuple) + elif y_column_name in [prefix+"area_non_dram", prefix+"area_budget"]: + new_tuple = (value_to_add[0], float(value_to_add[1]) * 1000000, value_to_add[2]) + experiment_column_value[experiment_name][y_column_name].append(new_tuple) + + for column, values in experiment_column_value[experiment_name].items(): + if len(values) == 0: + continue + x_values = [el[0] for el in values] + budgets[column] = values[1][1] + + return budgets + + +def plot_convergence_per_workloads(input_dir_names, res_column_name_number): + + budgets = get_budget_values(input_dir_names, res_column_name_number) + + budget_alpha = 1 + non_optimization_alpha = .1 + budget_marker = "_" + regu_marker = "." + budget_marker_size = 4 + regu_marker_size = 1 + + + #itrColNum = all_res_column_name_number["iteration cnt"] + #distColNum = all_res_column_name_number["dist_to_goal_non_cost"] + trueNum = all_res_column_name_number["move validity"] + move_name_number = all_res_column_name_number["move name"] + + + # experiment_names + experiment_names = [] + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + file_full_addr_list.append(file_full_addr) + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_names.append(experiment_name) + + #color_values = ["r","b","y","black","brown","purple"] + color_values = {} + color_values["latency_edge_detection"] = color_values["best_des_so_far_latency_edge_detection"] =color_values["latency_budget_edge_detection"] = matplotlib.colors.to_rgba("red") + color_values["latency_hpvm_cava"] =color_values["best_des_so_far_latency_hpvm_cava"] = color_values["latency_budget_hpvm_cava"] = matplotlib.colors.to_rgba("magenta") + color_values["latency_audio_decoder"] = color_values["best_des_so_far_latency_audio_decoder"] =color_values["latency_budget_audio_decoder"] = matplotlib.colors.to_rgba("orange") + color_values["area_non_dram"] =color_values["best_des_so_far_area_non_dram"] = color_values["area_budget"] = matplotlib.colors.to_rgba("forestgreen") + color_values["brown"] = (1,0,0,1) + color_values["power"] = color_values["best_des_so_far_power"] =color_values["power_budget"] = matplotlib.colors.to_rgba("mediumblue") + + color_values["latency_budget_edge_detection"] = matplotlib.colors.to_rgba("white") + color_values["latency_budget_hpvm_cava"] = matplotlib.colors.to_rgba("white") + color_values["latency_budget_audio_decoder"] = matplotlib.colors.to_rgba("white") + color_values["area_budget"] = matplotlib.colors.to_rgba("white") + color_values["power_budget"] = matplotlib.colors.to_rgba("white") + + + + column_name_color_val_dict = {"best_des_so_far_power":"purple", "power_budget":"purple","best_des_so_far_area_non_dram":"blue", "area_budget":"blue", + "latency_budget_hpvm_cava":"orange", "latency_budget_audio_decoder":"yellow", "latency_budget_edge_detection":"red", + "best_des_so_far_latency_hpvm_cava":"orange", "best_des_so_far_latency_audio_decoder": "yellow","best_des_so_far_latency_edge_detection": "red", + "latency_budget":"white" + } + + + + """ + column_name_color_val_dict = {"power":"purple", "power_budget":"purple","area_non_dram":"blue", "area_budget":"blue", + "latency_budget_hpvm_cava":"orange", "latency_budget_audio_decoder":"yellow", "latency_budget_edge_detection":"red", + "latency_hpvm_cava":"orange", "latency_audio_decoder": "yellow","latency_edge_detection": "red", + "latency_budget":"white" + } + """ + + axis_font = {'size': '20'} + fontSize = 20 + x_column_name = "iteration cnt" + #y_column_name_list = ["power", "area_non_dram", "latency", "latency_budget", "power_budget","area_budget"] + y_column_name_list = ["power", "area_non_dram", "latency"] + + experiment_column_value = {} + for file_full_addr in file_full_addr_list: + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_column_value[experiment_name] = {} + for y_column_name in y_column_name_list: + if "budget" in y_column_name : + prefix = "" + else: + prefix = "best_des_so_far_" + y_column_name = prefix+y_column_name + y_column_number = res_column_name_number[y_column_name] + x_column_number = res_column_name_number[x_column_name] + #dis_to_goal_column_number = res_column_name_number["dist_to_goal_non_cost"] + #ref_des_dis_to_goal_column_number = res_column_name_number["ref_des_dist_to_goal_non_cost"] + + if not y_column_name == prefix+"latency": + experiment_column_value[experiment_name][y_column_name] = [] + + + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + for i, row in enumerate(resultReader): + if i > 1: + if row[trueNum] == "FALSE" or row[move_name_number]=="identity": + continue + metric_chosen = row[res_column_name_number["transformation_metric"]] + workload_chosen = row[res_column_name_number["workload"]] + if "budget" in y_column_name: + alpha = budget_alpha + elif metric_chosen in y_column_name : + alpha = 1 + else: + alpha = non_optimization_alpha + + col_value = row[y_column_number] + if ";" in col_value: + col_value = col_value[:-1] + col_values = col_value.split(";") + for col_val in col_values: + if "=" in col_val: + if "budget" in y_column_name: + alpha = budget_alpha + elif workload_chosen in col_val: + alpha = 1 + else: + alpha = non_optimization_alpha + val_splitted = col_val.split("=") + value_to_add = (float(row[x_column_number]), (val_splitted[0], val_splitted[1]), alpha) + else: + value_to_add = (float(row[x_column_number]), col_val, alpha) + + if y_column_name in [prefix+"latency", prefix+"latency_budget"] : + budget = budgets["latency_budget"+"_"+value_to_add[1][0]] + new_tuple = (value_to_add[0], (-budget + 1000*float(value_to_add[1][1]))/budget, value_to_add[2]) + if y_column_name+"_"+value_to_add[1][0] not in experiment_column_value[experiment_name].keys(): + experiment_column_value[experiment_name][y_column_name + "_" + value_to_add[1][0]] = [] + experiment_column_value[experiment_name][y_column_name+"_"+value_to_add[1][0]].append(new_tuple) + if y_column_name in [prefix+"power", prefix+"power_budget"]: + budget = budgets["power_budget"] + new_tuple = (value_to_add[0], (-budget + float(value_to_add[1])*1000)/budget, value_to_add[2]) + experiment_column_value[experiment_name][y_column_name].append(new_tuple) + elif y_column_name in [prefix+"area_non_dram", prefix+"area_budget"]: + budget = budgets["area_budget"] + new_tuple = (value_to_add[0], (-budget + float(value_to_add[1]) * 1000000)/budget, value_to_add[2]) + experiment_column_value[experiment_name][y_column_name].append(new_tuple) + + # prepare for plotting and plot + fig = plt.figure(figsize=(15, 8)) + ax = fig.add_subplot(111) + for column, values in experiment_column_value[experiment_name].items(): + if "budget" in column: + marker = budget_marker + marker_size = budget_marker_size + else: + marker = regu_marker + marker_size =regu_marker_size + + x_values = [el[0] for el in values] + y_values = [el[1] for el in values] + colors = [] + for el in values: + color_ = list(color_values[column]) + color_[-1] = el[2] + colors.append(tuple(color_)) + #alphas = [el[2] for el in values] + ax.set_yscale('log') + + if "budget" in column: + marker = '_' + #alpha_ = .3 + else: + marker = "x" + #alpha_ = 1 + for i,x in enumerate(x_values): + ax.plot(x_values[i], y_values[i]+10, label=column, color=colors[i], marker=marker, markersize=marker_size) + + + #ax.set_title("experiment vs system implicaction") + ax.set_xlabel(x_column_name, fontsize=fontSize) + y_axis_name = "_".join(list(experiment_column_value[experiment_name].keys())) + ax.set_ylabel(y_axis_name, fontsize=fontSize) + ax.legend(bbox_to_anchor=(1, 1), loc='upper left', fontsize=fontSize) + plt.tight_layout() + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "single_workload/convergence") + if not os.path.exists(output_dir): + os.mkdir(output_dir) + fig.savefig(os.path.join(output_dir,experiment_name+"_convergence.png")) + # plt.show() + plt.close('all') + +def plot_convergence_vs_time(input_dir_names, res_column_name_number): + PA_time_scaling_factor = 8500 + #itrColNum = all_res_column_name_number["iteration cnt"] + #distColNum = all_res_column_name_number["dist_to_goal_non_cost"] + trueNum = all_res_column_name_number["move validity"] + + # experiment_names + experiment_names = [] + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + file_full_addr_list.append(file_full_addr) + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_names.append(experiment_name) + + axis_font = {'size': '15'} + fontSize = 20 + x_column_name = "exploration_plus_simulation_time" + y_column_name_list = ["best_des_so_far_dist_to_goal_non_cost"] + y_column_name_list = ["dist_to_goal_non_cost"] + + + PA_column_experiment_value = {} + FARSI_column_experiment_value = {} + + #column_name = "move name" + for k, file_full_addr in enumerate(file_full_addr_list): + for y_column_name in y_column_name_list: + # get all possible the values of inteFARSI_results/blind_study_all_dumb_versionrest + y_column_number = res_column_name_number[y_column_name] + x_column_number = res_column_name_number[x_column_name] + PA_column_experiment_value[y_column_name] = [] + FARSI_column_experiment_value[y_column_name] = [] + PA_last_time = 0 + FARSI_last_time = 0 + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + experiment_name = get_experiments_name( file_full_addr, res_column_name_number) + for i, row in enumerate(resultReader): + #if row[trueNum] != "True": + # continue + if i >= 1: + FARSI_last_time += float(row[x_column_number]) + FARSI_value_to_add = (float(FARSI_last_time), row[y_column_number]) + FARSI_column_experiment_value[y_column_name].append(FARSI_value_to_add) + + PA_last_time = FARSI_last_time*PA_time_scaling_factor + PA_value_to_add = (float(PA_last_time), row[y_column_number]) + PA_column_experiment_value[y_column_name].append(PA_value_to_add) + + # prepare for plotting and plot + fig = plt.figure(figsize=(12, 12)) + plt.rc('font', **axis_font) + ax = fig.add_subplot(111) + fontSize = 20 + #plt.tight_layout() + x_values = [el[0] for el in FARSI_column_experiment_value[y_column_name]] + y_values = [(float(el[1]) * 100 // 1 / 100.0) for el in FARSI_column_experiment_value[y_column_name]] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="FARSI time to completion", marker="*") + # ax.set_yscale('log') + + x_values = [el[0] for el in PA_column_experiment_value[y_column_name]] + y_values = [(float(el[1]) * 100 // 1 / 100.0) for el in PA_column_experiment_value[y_column_name]] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="PA time to completion", marker="*") + ax.set_xscale('log') + + #ax.set_title("experiment vs system implicaction") + ax.legend(loc="upper right")#bbox_to_anchor=(1, 1), loc="upper left") + ax.set_xlabel(x_column_name, fontsize=fontSize) + ax.set_ylabel(y_column_name, fontsize=fontSize) + plt.tight_layout() + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "single_workload/progression") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + fig.savefig(os.path.join(output_dir,str(k)+"_" + y_column_name+"_vs_"+x_column_name+"_FARSI_vs_PA.png")) + #plt.show() + plt.close('all') + +def plot_convergence_vs_time_for_paper(input_dir_names, res_column_name_number): + PA_time_scaling_factor = 8500 + #itrColNum = all_res_column_name_number["iteration cnt"] + #distColNum = all_res_column_name_number["dist_to_goal_non_cost"] + trueNum = all_res_column_name_number["move validity"] + + # experiment_names + experiment_names = [] + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + file_full_addr_list.append(file_full_addr) + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_names.append(experiment_name) + + axis_font = {'size': '25'} + fontSize = 25 + x_column_name = "exploration_plus_simulation_time" + y_column_name_list = ["best_des_so_far_dist_to_goal_non_cost"] + y_column_name_list = ["dist_to_goal_non_cost"] + + + PA_column_experiment_value = {} + FARSI_column_experiment_value = {} + + #column_name = "move name" + for k, file_full_addr in enumerate(file_full_addr_list): + for y_column_name in y_column_name_list: + # get all possible the values of interest + y_column_number = res_column_name_number[y_column_name] + x_column_number = res_column_name_number[x_column_name] + PA_column_experiment_value[y_column_name] = [] + FARSI_column_experiment_value[y_column_name] = [] + PA_last_time = 0 + FARSI_last_time = 0 + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + experiment_name = get_experiments_name( file_full_addr, res_column_name_number) + for i, row in enumerate(resultReader): + #if row[trueNum] != "True": + # continue + if i >= 1: + FARSI_last_time += float(row[x_column_number]) + FARSI_value_to_add = (float(FARSI_last_time), row[y_column_number]) + FARSI_column_experiment_value[y_column_name].append(FARSI_value_to_add) + + PA_last_time = FARSI_last_time*PA_time_scaling_factor + PA_value_to_add = (float(PA_last_time), row[y_column_number]) + PA_column_experiment_value[y_column_name].append(PA_value_to_add) + + # prepare for plotting and plot + print(k) + fig = plt.figure(figsize=(6, 6)) + # axes = plt.gca() + ymax = 2800.0 + plt.rc('font', **axis_font) + ax = fig.add_subplot(111) + #plt.tight_layout() + x_values = [el[0] for el in FARSI_column_experiment_value[y_column_name]] + y_values = [(float(el[1]) * 100 // 1 / 100.0) for el in FARSI_column_experiment_value[y_column_name]] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="FARSI", marker="*", color="green", s=10) + # ax.set_yscale('log') + + x_values = [el[0] for el in PA_column_experiment_value[y_column_name]] + y_values = [(float(el[1]) * 100 // 1 / 100.0) for el in PA_column_experiment_value[y_column_name]] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="PA", marker="*", color="hotpink", s=10) + ax.set_xscale('log') + ax.set_yscale('log') + # ax.set_yscale('linear') # Ying: by default, the y ticks are strange + ax.set_ylim([np.power(10.0, -2), np.power(10.0, 3.5)]) + + #ax.set_title("experiment vs system implicaction") + ax.legend(bbox_to_anchor=(0.5, 1.15), loc="upper center", fontsize=fontSize, ncol=2, borderpad=0, markerscale=6)#bbox_to_anchor=(1, 1), loc="upper left") + ax.set_xlabel("Time to Completion (s)", fontsize=fontSize) + ax.set_ylabel("Norm Distance to Goal", fontsize=fontSize) + # floatY = [float(i) for i in y_values] + # maxY = max(floatY) + # ax.set_yticks(np.arange(0, ymax + 1, ymax / 5)) + ax.set_xticks(np.power(10.0, [0, 4, 8])) + ax.set_yticks(np.power(10.0, [-2, -1, 0, 1, 2, 3])) + plt.tight_layout() + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "single_workload/progression") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + fig.savefig(os.path.join(output_dir,str(k)+"_" + y_column_name+"_vs_"+x_column_name+"_FARSI_vs_PA.png"), bbox_inches='tight') + #plt.show() + plt.close('all') + +def plot_convergence_cross_workloads(input_dir_names, res_column_name_number): + #itrColNum = all_res_column_name_number["iteration cnt"] + #distColNum = all_res_column_name_number["dist_to_goal_non_cost"] + trueNum = all_res_column_name_number["move validity"] + + # experiment_names + experiment_names = [] + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + file_full_addr_list.append(file_full_addr) + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_names.append(experiment_name) + + axis_font = {'size': '20'} + x_column_name = "iteration cnt" + y_column_name_list = ["best_des_so_far_dist_to_goal_non_cost", "dist_to_goal_non_cost"] + + column_experiment_value = {} + #column_name = "move name" + for y_column_name in y_column_name_list: + # get all possible the values of interest + y_column_number = res_column_name_number[y_column_name] + #x_column_number = res_column_name_number[x_column_name] + ctr = 0 + + column_experiment_value[y_column_name] = {} + # initialize the dictionary + # get all the data + for file_full_addr in file_full_addr_list: + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + experiment_name = get_experiments_name( file_full_addr, res_column_name_number) +str(ctr) + column_experiment_value[y_column_name][experiment_name] = [] + + for i, row in enumerate(resultReader): + #if row[trueNum] != "True": + # continue + if i >= 1: + value_to_add = (i, max(float(row[y_column_number]),.01)) + column_experiment_value[y_column_name][experiment_name].append(value_to_add) + ctr +=1 + + # prepare for plotting and plot + fig = plt.figure() + ax = fig.add_subplot(111) + #plt.tight_layout() + for experiment_name, values in column_experiment_value[y_column_name].items(): + x_values = [el[0] for el in values[:-10]] + y_values = [el[1] for el in values[:-10]] + ax.scatter(x_values, y_values, label=experiment_name[1:]) + + #ax.set_title("experiment vs system implicaction") + ax.set_yscale('log') + ax.legend(bbox_to_anchor=(1, 1), loc="best") + ax.set_xlabel(x_column_name) + ax.set_ylabel(y_column_name) + plt.tight_layout() + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "cross_workloads/convergence") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + fig.savefig(os.path.join(output_dir,x_column_name+"_"+y_column_name+".png")) + # plt.show() + plt.close('all') + +def plot_convergence_cross_workloads_for_paper(input_dir_names, res_column_name_number): + #itrColNum = all_res_column_name_number["iteration cnt"] + #distColNum = all_res_column_name_number["dist_to_goal_non_cost"] + trueNum = all_res_column_name_number["move validity"] + + # experiment_names + experiment_names = [] + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + file_full_addr_list.append(file_full_addr) + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_names.append(experiment_name) + + axis_font = {'size': '25'} + fontSize = 25 + x_column_name = "Iteration" + y_column_name_list = ["best_des_so_far_dist_to_goal_non_cost", "dist_to_goal_non_cost"] + + column_experiment_value = {} + #column_name = "move name" + for y_column_name in y_column_name_list: + y_column_name_rep = y_column_name + # get all possible the values of interest + y_column_number = res_column_name_number[y_column_name] + #x_column_number = res_column_name_number[x_column_name] + + column_experiment_value[y_column_name] = {} + # initialize the dictionary + # get all the data + ctr = 0 + for file_full_addr in file_full_addr_list: + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + experiment_name = get_experiments_name( file_full_addr, res_column_name_number) +str(ctr) + + column_experiment_value[y_column_name][experiment_name] = [] + + for i, row in enumerate(resultReader): + #if row[trueNum] != "True": + # continue + if i >= 1: + value_to_add = (i, max(float(row[y_column_number]),.01)) + column_experiment_value[y_column_name][experiment_name].append(value_to_add) + + ctr +=1 + # prepare for plotting and plot + fig = plt.figure(figsize=(8, 6.5)) + plt.rc('font', **axis_font) + ax = fig.add_subplot(111) + #plt.tight_layout() + labelName = "" + color = (1.0, 0.0, 0.0, 0.5) # "" + for experiment_name, values in column_experiment_value[y_column_name].items(): + x_values = [el[0] for el in values[:-10]] + y_values = [el[1] for el in values[:-10]] + # Ying: hardcode here + if experiment_name[-1] == "3": + labelName = "SA" + color = (1, 0, 0, 1) # "Red" + if experiment_name[-1] == "2": + labelName = "FARSI" + color = (0, 0.6, 0, 1) # "ForestGreen" + if experiment_name[-1] == "1": + labelName = "Task-aware" + color = (1, 0.5, 0.0, 1) # "DarkOrange" + if experiment_name[-1] == "0": + labelName = "Task&Block-aware" + color = (0.5, 0.5, 0, 1) # "Olive" + # Ying: hardcode finished + ax.scatter(x_values, y_values, label=labelName, color=color, marker='*', s=5) + + #ax.set_title("experiment vs system implicaction") + ax.set_yscale('log') + ax.legend(bbox_to_anchor=(0.5, 1.25), loc="upper center", fontsize=fontSize-2, ncol=2, borderpad=0, markerscale=6) + ax.set_xlabel(x_column_name, fontsize=fontSize) + # Ying: hardcode here + if y_column_name == "dist_to_goal_non_cost": + y_column_name_rep = "Norm Distance to Goal" + # Ying: hardcode finished + ax.set_ylabel(y_column_name_rep, fontsize=fontSize) + plt.xticks(np.arange(0, 7000, 2000.0)) + plt.yticks(np.power(10.0, [-2, -1, 0, 1, 2, 3])) + plt.tight_layout() + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "cross_workloads/convergence") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + fig.savefig(os.path.join(output_dir,x_column_name+"_"+y_column_name+".png"), bbox_inches='tight') + # plt.show() + plt.close('all') + +def plot_system_implication_analysis(input_dir_names, res_column_name_number, case_study): + # experiment_names + experiment_names = [] + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1.csv") + file_full_addr_list.append(file_full_addr) + experiment_name = get_experiments_name(file_full_addr, res_column_name_number) + experiment_names.append(experiment_name) + + axis_font = {'size': '10'} + + column_name_list = list(case_study.values())[0] + + column_experiment_value = {} + #column_name = "move name" + for column_name in column_name_list: + # get all possible the values of interest + column_number = res_column_name_number[column_name] + + column_experiment_value[column_name] = {} + # initialize the dictionary + column_experiment_number_dict = {} + experiment_number_dict = {} + + # get all the data + for file_full_addr in file_full_addr_list: + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + experiment_name = get_experiments_name( file_full_addr, res_column_name_number) + + for i, row in enumerate(resultReader): + #if row[trueNum] != "True": + # continue + if i >= 1: + col_value = row[column_number] + col_values = col_value.split(";") + for col_val in col_values: + column_experiment_value[column_name][experiment_name] = float(col_val) + + # prepare for plotting and plot + # plt.figure() + index = experiment_names + plotdata = pd.DataFrame(column_experiment_value, index=index) + if list(case_study.keys())[0] in ["bandwidth_analysis","traffic_analysis"]: + plotdata.plot(kind='bar', fontsize=9, rot=5, log=True) + else: + plotdata.plot(kind='bar', fontsize=9, rot=5) + + plt.legend(loc="best", fontsize="9") + plt.xlabel("experiments", fontsize="10") + plt.ylabel("system implication") + #plt.title("experiment vs system implicaction") + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "cross_workloads/system_implications") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + #plt.tight_layout()list(case_study.keys())[0] + if "re_use" in list(case_study.keys())[0] or "speedup" in list(case_study.keys())[0]: + plt.yscale('log') + plt.savefig(os.path.join(output_dir,list(case_study.keys())[0]+".png")) + plt.close('all') + + + +def plot_co_design_nav_breakdown_post_processing(input_dir_names, column_column_value_experiment_frequency_dict): + column_name_list = [("exact optimization name", "neighbouring design space size", "div")] + #column_name = "move name" + for n, column_name_tuple in enumerate(column_name_list): + first_column = column_name_tuple[0] + second_column = column_name_tuple[1] + operation = column_name_tuple[2] + new_column_name = first_column+"_"+operation+"_"+second_column + + first_column_value_experiment_frequency_dict = column_column_value_experiment_frequency_dict[first_column] + second_column_value_experiment_frequency_dict = column_column_value_experiment_frequency_dict[second_column] + modified_column_value_experiment_frequency_dict = {} + + experiment_names = [] + for column_val, experiment_freq in first_column_value_experiment_frequency_dict.items(): + if column_val == "unknown": + continue + modified_column_value_experiment_frequency_dict[column_val] = {} + for experiment, freq in experiment_freq.items(): + if(second_column_value_experiment_frequency_dict[column_val][experiment]) < .000001: + modified_column_value_experiment_frequency_dict[column_val][experiment] = 0 + else: + modified_column_value_experiment_frequency_dict[column_val][experiment] = first_column_value_experiment_frequency_dict[column_val][experiment]/max(second_column_value_experiment_frequency_dict[column_val][experiment],.0000000000001) + experiment_names.append(experiment) + + axis_font = {'size': '22'} + fontSize = 22 + experiment_names = list(set(experiment_names)) + # prepare for plotting and plot + # plt.figure(n) + plt.rc('font', **axis_font) + index = experiment_names + plotdata = pd.DataFrame(modified_column_value_experiment_frequency_dict, index=index) + plotdata.plot(kind='bar', stacked=True, figsize=(13, 8)) + plt.xlabel("experiments", **axis_font) + plt.ylabel(new_column_name, **axis_font) + plt.xticks(fontsize=fontSize, rotation=45) + plt.yticks(fontsize=fontSize) + plt.title("experiment vs " + new_column_name, **axis_font) + plt.legend(bbox_to_anchor=(1, 1), loc='upper left', fontsize=fontSize) + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "cross_workloads/nav_breakdown") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # plt.tight_layout() + plt.savefig(os.path.join(output_dir,'_'.join(new_column_name.split(" "))+".png"), bbox_inches='tight') + plt.tight_layout() + # plt.show() + plt.close('all') + + + +# navigation breakdown +def plot_codesign_nav_breakdown_per_workload(input_dir_names, input_all_res_column_name_number): + trueNum = input_all_res_column_name_number["move validity"] + + # experiment_names + experiment_names = [] + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + file_full_addr_list.append(file_full_addr) + experiment_name = get_experiments_name(file_full_addr, input_all_res_column_name_number) + experiment_names.append(experiment_name) + + + axis_font = {'size': '20'} + fontSize = 20 + column_name_list = ["transformation_metric", "comm_comp", "workload"]#, "architectural principle", "high level optimization name", "exact optimization name"] + #column_name_list = ["architectural principle", "exact optimization name"] + + #column_name = "move name" + # initialize the dictionary + column_column_value_experiment_frequency_dict = {} + for file_full_addr in file_full_addr_list: + column_column_value_frequency_dict = {} + for column_name in column_name_list: + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + experiment_name = get_experiments_name(file_full_addr, input_all_res_column_name_number) + #column_column_value_frequency_dict[column_name] = {} + # get all possible the values of interest + all_values = get_all_col_values_of_a_folders(input_dir_names, input_all_res_column_name_number, column_name) + columne_number = all_res_column_name_number[column_name] + for column in all_values: + column_column_value_frequency_dict[column] = {} + column_column_value_frequency_dict[column][column_name] = 0 + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + if i > 1: + col_value = row[columne_number] + col_values = col_value.split(";") + for col_val in col_values: + if "=" in col_val: + val_splitted = col_val.split("=") + column_column_value_frequency_dict[val_splitted[0]][column_name] += float(val_splitted[1]) + else: + column_column_value_frequency_dict[col_val][column_name] += 1 + + index = column_name_list + total_cnt = 0 + for val in column_column_value_frequency_dict[column].values(): + total_cnt += val + + for col_val, column_name_val in column_column_value_frequency_dict.items(): + for column_name, val in column_name_val.items(): + column_column_value_frequency_dict[col_val][column_name] /= max(total_cnt,1) + + plotdata = pd.DataFrame(column_column_value_frequency_dict, index=index) + plotdata.plot(kind='bar', stacked=True, figsize=(10, 10)) + plt.rc('font', ** axis_font) + plt.xlabel("experiments", **axis_font) + plt.ylabel(column_name, **axis_font) + plt.xticks(fontsize=fontSize, rotation=45) + plt.yticks(fontsize=fontSize) + plt.title("experiment vs " + column_name, **axis_font) + plt.legend(bbox_to_anchor=(1, 1), loc='upper left', fontsize=fontSize) + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "single_workload/nav_breakdown") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + plt.tight_layout() + plt.savefig(os.path.join(output_dir,"__".join(column_name_list)+".png"), bbox_inches='tight') + # plt.show() + plt.close('all') + #column_column_value_experiment_frequency_dict[column_name] = copy.deepcopy(column_column_value_frequency_dict) + + return column_column_value_experiment_frequency_dict + + + + +def plot_codesign_nav_breakdown_cross_workload(input_dir_names, input_all_res_column_name_number): + trueNum = input_all_res_column_name_number["move validity"] + + # experiment_names + experiment_names = [] + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + file_full_addr_list.append(file_full_addr) + experiment_name = get_experiments_name(file_full_addr, input_all_res_column_name_number) + experiment_names.append(experiment_name) + + axis_font = {'size': '20'} + fontSize = 20 + column_name_list = ["transformation_metric", "transformation_block_type", "move name", "comm_comp", "architectural principle", "high level optimization name", "exact optimization name", "neighbouring design space size"] + #column_name_list = ["transformation_metric", "move name"]#, "comm_comp", "architectural principle", "high level optimization name", "exact optimization name", "neighbouring design space size"] + #column_name = "move name" + # initialize the dictionary + column_column_value_experiment_frequency_dict = {} + for column_name in column_name_list: + column_value_experiment_frequency_dict = {} + # get all possible the values of interest + all_values = get_all_col_values_of_a_folders(input_dir_names, input_all_res_column_name_number, column_name) + columne_number = all_res_column_name_number[column_name] + for column in all_values: + column_value_experiment_frequency_dict[column] = {} + + # get all the data + for file_full_addr in file_full_addr_list: + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + experiment_name = get_experiments_name( file_full_addr, input_all_res_column_name_number) + for column_value in all_values: + column_value_experiment_frequency_dict[column_value][experiment_name] = 0 + + for i, row in enumerate(resultReader): + #if row[trueNum] != "True": + # continue + if i > 1: + try: + + # the following for workload awareness + #if row[all_res_column_name_number["move name"]] == "identity": + # continue + #if row[all_res_column_name_number["architectural principle"]] == "spatial_locality": + # continue + + + col_value = row[columne_number] + col_values = col_value.split(";") + for col_val in col_values: + if "=" in col_val: + val_splitted = col_val.split("=") + column_value_experiment_frequency_dict[val_splitted[0]][experiment_name] += float(val_splitted[1]) + else: + column_value_experiment_frequency_dict[col_val][experiment_name] += 1 + except: + print("what") + + total_cnt = {} + for el in column_value_experiment_frequency_dict.values(): + for exp, values in el.items(): + if exp not in total_cnt.keys(): + total_cnt[exp] = 0 + total_cnt[exp] += values + + for col_val, exp_vals in column_value_experiment_frequency_dict.items(): + for exp, values in exp_vals.items(): + column_value_experiment_frequency_dict[col_val][exp] = column_value_experiment_frequency_dict[col_val][exp] + column_value_experiment_frequency_dict[col_val][exp] /= total_cnt[exp] # normalize + + # prepare for plotting and plot + # plt.figure(figsize=(10, 8)) + index = experiment_names + plotdata = pd.DataFrame(column_value_experiment_frequency_dict, index=index) + plotdata.plot(kind='bar', stacked=True, figsize=(12, 10)) + plt.rc('font', **axis_font) + plt.xlabel("experiments", **axis_font) + plt.ylabel(column_name, **axis_font) + plt.xticks(fontsize=fontSize, rotation=45) + plt.yticks(fontsize=fontSize) + plt.title("experiment vs " + column_name, **axis_font) + plt.legend(bbox_to_anchor=(1, 1), loc='upper left', fontsize=fontSize) + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "cross_workloads/nav_breakdown") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + plt.tight_layout() + plt.savefig(os.path.join(output_dir,'_'.join(column_name.split(" "))+".png"), bbox_inches='tight') + # plt.show() + plt.close('all') + column_column_value_experiment_frequency_dict[column_name] = copy.deepcopy(column_value_experiment_frequency_dict) + + """ + # multi-stack plot here + index = experiment_names + plotdata = pd.DataFrame(column_column_value_experiment_frequency_dict, index=index) + + df_g = plotdata.groupby(["transformation_metric", "move name"]) + plotdata.plot(kind='bar', stacked=True, figsize=(12, 10)) + plt.rc('font', **axis_font) + plt.xlabel("experiments", **axis_font) + plt.ylabel(column_name, **axis_font) + plt.xticks(fontsize=fontSize, rotation=45) + plt.yticks(fontsize=fontSize) + plt.title("experiment vs " + column_name, **axis_font) + plt.legend(bbox_to_anchor=(1, 1), loc='upper left', fontsize=fontSize) + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "cross_workloads/nav_breakdown") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + plt.tight_layout() + plt.savefig(os.path.join(output_dir,'column____'.join(column_name.split(" "))+".png"), bbox_inches='tight') + # plt.show() + plt.close('all') + """ + return column_column_value_experiment_frequency_dict + +def plot_codesign_nav_breakdown_cross_workload_for_paper(input_dir_names, input_all_res_column_name_number): + trueNum = input_all_res_column_name_number["move validity"] + + # experiment_names + experiment_names = [] + file_full_addr_list = [] + for dir_name in input_dir_names: + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + file_full_addr_list.append(file_full_addr) + experiment_name = get_experiments_name(file_full_addr, input_all_res_column_name_number) + """ + Ying: the following lines are added to make the names clearer in the plottings + """ + if experiment_name[0] == 'a': + experiment_name = "Audio" + elif experiment_name[0] == 'h': + experiment_name = "CAVA" + elif experiment_name[0] == 'e': + experiment_name = "ED" + """ + Ying: adding finished + """ + experiment_names.append(experiment_name) + experiment_names.sort() + + axis_font = {'size': '25'} + fontSize = 25 + column_name_list = ["transformation_metric", "transformation_block_type", "move name", "comm_comp", "architectural principle", "high level optimization name", "exact optimization name", "neighbouring design space size"] + #column_name_list = ["transformation_metric", "move name"]#, "comm_comp", "architectural principle", "high level optimization name", "exact optimization name", "neighbouring design space size"] + #column_name = "move name" + # initialize the dictionary + column_column_value_experiment_frequency_dict = {} + for column_name in column_name_list: + column_value_experiment_frequency_dict = {} + # get all possible the values of interest + all_values = get_all_col_values_of_a_folders(input_dir_names, input_all_res_column_name_number, column_name) + columne_number = input_all_res_column_name_number[column_name] + for column in all_values: + """ + Ying: the following lines are added for "IC", "Mem", and "PE" + """ + if column_name == "transformation_block_type": + if column == "ic": + column = "NoC" + elif column == "mem": + column = "Mem" + elif column == "pe": + column = "PE" + + if column_name == "architectural principle": + if column == "identity" or column == "spatial_locality": + continue + elif column == "task_level_parallelism": + column = "TLP" + elif column == "loop_level_parallelism": + column = "LLP" + elif column == "customization": + column = "Customization" + """ + Ying: adding finished + """ + column_value_experiment_frequency_dict[column] = {} + + # get all the data + for file_full_addr in file_full_addr_list: + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + experiment_name = get_experiments_name( file_full_addr, input_all_res_column_name_number) + """ + Ying: the following lines are added to make the names clearer in the plottings + """ + if experiment_name[0] == 'a': + experiment_name = "Audio" + elif experiment_name[0] == 'h': + experiment_name = "CAVA" + elif experiment_name[0] == 'e': + experiment_name = "ED" + """ + Ying: adding finished + """ + for column_value in all_values: + """ + Ying: the following lines are added for "IC", "Mem", and "PE" + """ + if column_name == "transformation_block_type": + if column_value == "ic": + column_value = "NoC" + elif column_value == "mem": + column_value = "Mem" + elif column_value == "pe": + column_value = "PE" + + if column_name == "architectural principle": + if column_value == "identity" or column_value == "spatial_locality": + continue + elif column_value == "task_level_parallelism": + column_value = "TLP" + elif column_value == "loop_level_parallelism": + column_value = "LLP" + elif column_value == "customization": + column_value = "Customization" + """ + Ying: adding finished + """ + column_value_experiment_frequency_dict[column_value][experiment_name] = 0 + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + if i > 1: + try: + + # the following for workload awareness + if row[input_all_res_column_name_number["architectural principle"]] == "spatial_locality" or row[all_res_column_name_number["architectural principle"]] == "identity": + continue + + col_value = row[columne_number] + col_values = col_value.split(";") + for col_val in col_values: + if "=" in col_val: + val_splitted = col_val.split("=") + column_value_experiment_frequency_dict[val_splitted[0]][experiment_name] += float(val_splitted[1]) + else: + """ + Ying: the following lines are added for "IC", "Mem", and "PE" + """ + if column_name == "transformation_block_type": + if col_val == "ic": + col_val = "NoC" + elif col_val == "mem": + col_val = "Mem" + elif col_val == "pe": + col_val = "PE" + + if column_name == "architectural principle": + if col_val == "identity" or col_val == "spatial_locality": + continue + elif col_val == "task_level_parallelism": + col_val = "TLP" + elif col_val == "loop_level_parallelism": + col_val = "LLP" + elif col_val == "customization": + col_val = "Customization" + """ + Ying: adding finished + """ + column_value_experiment_frequency_dict[col_val][experiment_name] += 1 + except: + print("what") + + total_cnt = {} + for el in column_value_experiment_frequency_dict.values(): + for exp, values in el.items(): + if exp not in total_cnt.keys(): + total_cnt[exp] = 0 + total_cnt[exp] += values + + for col_val, exp_vals in column_value_experiment_frequency_dict.items(): + for exp, values in exp_vals.items(): + column_value_experiment_frequency_dict[col_val][exp] = column_value_experiment_frequency_dict[col_val][exp] + if column_name != "architectural principle" and column_name != "comm_comp" and total_cnt[exp] != 0: # Ying: add to get rid of normalization for the two plottings + column_value_experiment_frequency_dict[col_val][exp] /= total_cnt[exp] # normalize + + # prepare for plotting and plot + # plt.figure(figsize=(6, 6)) + index = experiment_names + column_value_experiment_frequency_dict = dict(sorted(column_value_experiment_frequency_dict.items(), key=lambda x: x[0].lower())) + print(column_value_experiment_frequency_dict) + plotdata = pd.DataFrame(column_value_experiment_frequency_dict, index=index) + # tempC, tempE, tempA= plotdata.iloc[0].copy(), plotdata.iloc[1].copy(), plotdata.iloc[2].copy() + # plotdata.iloc[0] = tempA + # plotdata.iloc[1] = tempC + # plotdata.iloc[2] = tempE + print(plotdata) + color_list = ["mediumseagreen", "gold", "tomato"] + plotdata.plot(kind='bar', stacked=True, figsize=(6, 6.6), color=color_list) # Ying: (6, 7) for arch principle + plt.rc('font', **axis_font) + plt.xlabel("Workloads", **axis_font) + # plt.ylabel(column_name, **axis_font) # Ying: replace with the following lines + """ + Ying: set the ylabel acordingly + """ + if column_name == "architectural principle" or column_name == "comm_comp": + plt.ylabel("Iteration Count", **axis_font) + else: + plt.ylabel("Normalized Iteration Portion", **axis_font) + """ + Ying: adding finished + """ + plt.xticks(fontsize=fontSize, rotation=0) # Ying: the original one was 45) + plt.yticks(fontsize=fontSize) + if column_name == "architectural principle" or column_name == "comm_comp": + plt.yticks(np.arange(0, 200, 25.0), fontsize=fontSize) # Ying: change according to the graph ("architectural principle", "comm_comp") + # plt.title("experiment vs " + column_name, **axis_font) # Ying: comment it out as discussed + # plt.legend(bbox_to_anchor=(1, 1), loc='upper left', fontsize=fontSize) # Ying: replaced with the following line + plt.legend(bbox_to_anchor=(0.5, 1.3), loc='upper center', fontsize=fontSize, ncol=2, borderpad=0) # Ying: change according to the graph ("architectural principle", "comm_comp", "Normalized Iteration Portion") + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "cross_workloads/nav_breakdown") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + plt.tight_layout() + plt.savefig(os.path.join(output_dir,'_'.join(column_name.split(" "))+".png"), bbox_inches='tight') + # plt.show() + plt.close('all') + column_column_value_experiment_frequency_dict[column_name] = copy.deepcopy(column_value_experiment_frequency_dict) + + """ + # multi-stack plot here + index = experiment_names + plotdata = pd.DataFrame(column_column_value_experiment_frequency_dict, index=index) + + df_g = plotdata.groupby(["transformation_metric", "move name"]) + plotdata.plot(kind='bar', stacked=True, figsize=(12, 10)) + plt.rc('font', **axis_font) + plt.xlabel("experiments", **axis_font) + plt.ylabel(column_name, **axis_font) + plt.xticks(fontsize=fontSize, rotation=45) + plt.yticks(fontsize=fontSize) + plt.title("experiment vs " + column_name, **axis_font) + plt.legend(bbox_to_anchor=(1, 1), loc='upper left', fontsize=fontSize) + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "cross_workloads/nav_breakdown") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + plt.tight_layout() + plt.savefig(os.path.join(output_dir,'column____'.join(column_name.split(" "))+".png"), bbox_inches='tight') + # plt.show() + plt.close('all') + """ + return column_column_value_experiment_frequency_dict + + +# the function to plot distance to goal vs. iteration cnt +def plotDistToGoalVSitr(input_dir_names, all_res_column_name_number): + itrColNum = all_res_column_name_number["iteration cnt"] + distColNum = all_res_column_name_number["dist_to_goal_non_cost"] + trueNum = all_res_column_name_number["move validity"] + + experiment_itr_dist_to_goal_dict = {} + # iterate through directories, get data and store in a dictionary + for dir_name in input_dir_names: + itr = [] + distToGoal = [] + file_full_addr = os.path.join(dir_name, "result_summary/FARSI_simple_run_0_1_all_reults.csv") + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + experiment_name = get_experiments_name(file_full_addr, all_res_column_name_number) + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + if i > 1: + itr.append(int(row[itrColNum])) + distToGoal.append(float(row[distColNum])) + + experiment_itr_dist_to_goal_dict[experiment_name] = (itr[:], distToGoal[:]) + + plt.figure() + # iterate and plot + for experiment_name, value in experiment_itr_dist_to_goal_dict.items(): + itr, distToGoal = value[0], value[1] + if len(itr) == 0 or len(distToGoal) == 0: # no valid move + continue + plt.plot(itr, distToGoal, label=experiment_name) + plt.xlabel("Iteration Cnt") + plt.ylabel("Distance to Goal") + plt.title("Distance to Goal vs. Iteration Cnt") + + # decide on the output dir + if len(input_dir_names) == 1: + output_dir = input_dir_names[0] + else: + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "cross_workloads") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + plt.savefig(os.path.join(output_dir, "distToGoalVSitr.png")) + # plt.show() + plt.close('all') + + +# the function to plot distance to goal vs. iteration cnt +def plotRefDistToGoalVSitr(dirName, fileName, itrColNum, refDistColNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + itr = [] + refDistToGoal = [] + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i > 1: + itr.append(int(row[itrColNum])) + refDistToGoal.append(float(row[refDistColNum])) + + plt.figure() + plt.plot(itr, refDistToGoal) + plt.xlabel("Iteration Cnt") + plt.ylabel("Reference Design Distance to Goal") + plt.title("Reference Design Distance to Goal vs. Iteration Cnt") + plt.savefig(dirName + fileName + "/refDistToGoalVSitr-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to do the zonal partitioning +def zonalPartition(comparedValue, zoneNum, maxValue): + unit = maxValue / zoneNum + + if comparedValue > maxValue: + return zoneNum - 1 + + if comparedValue < 0: + return 0 + + for i in range(0, zoneNum): + if comparedValue <= unit * (i + 1): + return i + + raise Exception("zonalPartition is fed by a strange value! maxValue: " + str(maxValue) + "; comparedValue: " + str(comparedValue)) + +# the function to plot simulation time vs. move name in a zonal format +def plotSimTimeVSmoveNameZoneDist(dirName, fileName, zoneNum, moveColNum, distColNum, simColNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + splitSwapSim = np.zeros(zoneNum, dtype = float) + splitSim = np.zeros(zoneNum, dtype = float) + migrateSim = np.zeros(zoneNum, dtype = float) + swapSim = np.zeros(zoneNum, dtype = float) + tranSim = np.zeros(zoneNum, dtype = float) + routeSim = np.zeros(zoneNum, dtype = float) + identitySim = np.zeros(zoneNum, dtype = float) + + maxDist = 0 + + index = [] + for i in range(0, zoneNum): + index.append(i) + + for i, row in enumerate(resultReader): + # print('"' + row[trueNum] + '"\t"' + row[moveColNum] + '"\t"' + row[distColNum] + '"\t"' + row[simColNum] + '"') + if row[trueNum] != "True": + continue + + if i == 2: + maxDist = float(row[distColNum]) + + if i > 1: + if row[moveColNum] == "split_swap": + splitSwapSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[moveColNum] == "split": + splitSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[moveColNum] == "migrate": + migrateSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[moveColNum] == "swap": + swapSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[moveColNum] == "transfer": + tranSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[moveColNum] == "routing": + routeSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[moveColNum] == "identity": + identitySim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + else: + raise Exception("move name is not split_swap or split or migrate or swap or transfer or routing or identity! The new type: " + row[moveColNum]) + + # plt.figure() + plotdata = pd.DataFrame({ + "split_swap":splitSwapSim, + "split":splitSim, + "migrate":migrateSim, + "swap":swapSim, + "transfer":tranSim, + "routing":routeSim, + "identity":identitySim + }, index = index + ) + plotdata.plot(kind = 'bar', stacked = True) + plt.xlabel("Zone decided by the max distance to goal") + plt.ylabel("Simulation Time") + plt.title("Simulation Time in Each Zone based on Move Name") + plt.savefig(dirName + fileName + "/simTimeVSmoveNameZoneDist-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot move generation time vs. move name in a zonal format +def plotMovGenTimeVSmoveNameZoneDist(dirName, fileName, zoneNum, moveColNum, distColNum, movGenColNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + splitSwapMov = np.zeros(zoneNum, dtype = float) + splitMov = np.zeros(zoneNum, dtype = float) + migrateMov = np.zeros(zoneNum, dtype = float) + swapMov = np.zeros(zoneNum, dtype = float) + tranMov = np.zeros(zoneNum, dtype = float) + routeMov = np.zeros(zoneNum, dtype = float) + identityMov = np.zeros(zoneNum, dtype = float) + + maxDist = 0 + + index = [] + for i in range(0, zoneNum): + index.append(i) + + for i, row in enumerate(resultReader): + # print('"' + row[trueNum] + '"\t"' + row[moveColNum] + '"\t"' + row[distColNum] + '"\t"' + row[movGenColNum] + '"') + if row[trueNum] != "True": + continue + + if i == 2: + maxDist = float(row[distColNum]) + + if i > 1: + if row[moveColNum] == "split_swap": + splitSwapMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[moveColNum] == "split": + splitMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[moveColNum] == "migrate": + migrateMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[moveColNum] == "swap": + swapMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[moveColNum] == "transfer": + tranMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[moveColNum] == "routing": + routeMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[moveColNum] == "identity": + identityMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + else: + raise Exception("move name is not split_swap or split or migrate or swap or transfer of routing or identity! The new type: " + row[moveColNum]) + + # plt.figure() + plotdata = pd.DataFrame({ + "split_swap":splitSwapMov, + "split":splitMov, + "migrate":migrateMov, + "swap":swapMov, + "transfer":tranMov, + "routing":routeMov, + "identity":identityMov + }, index = index + ) + plotdata.plot(kind = 'bar', stacked = True) + plt.xlabel("Zone decided by the max distance to goal") + plt.ylabel("Move Generation Time") + plt.title("Move Generation Time in Each Zone based on Move Name") + plt.savefig(dirName + fileName + "/movGenTimeVSmoveNameZoneDist-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot simulation time vs. comm_comp in a zonal format +def plotSimTimeVScommCompZoneDist(dirName, fileName, zoneNum, commcompColNum, distColNum, simColNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + commSim = np.zeros(zoneNum, dtype = float) + compSim = np.zeros(zoneNum, dtype = float) + + maxDist = 0 + index = [] + for i in range(0, zoneNum): + index.append(i) + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i == 2: + maxDist = float(row[distColNum]) + + if i > 1: + if row[commcompColNum] == "comm": + commSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[commcompColNum] == "comp": + compSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + else: + raise Exception("comm_comp is not giving comm or comp! The new type: " + row[colNum]) + + # plt.figure() + plotdata = pd.DataFrame({ + "comm":commSim, + "comp":compSim + }, index = index + ) + plotdata.plot(kind = 'bar', stacked = True) + plt.xlabel("Zone decided by the max distance to goal") + plt.ylabel("Simulation Time") + plt.title("Simulation Time in Each Zone based on comm_comp") + plt.savefig(dirName + fileName + "/simTimeVScommCompZoneDist-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot simulation time vs. comm_comp in a zonal format +def plotMovGenTimeVScommCompZoneDist(dirName, fileName, zoneNum, commcompColNum, distColNum, movGenColNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + commMov = np.zeros(zoneNum, dtype = float) + compMov = np.zeros(zoneNum, dtype = float) + + maxDist = 0 + index = [] + for i in range(0, zoneNum): + index.append(i) + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i == 2: + maxDist = float(row[distColNum]) + + if i > 1: + if row[commcompColNum] == "comm": + commMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[commcompColNum] == "comp": + compMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + else: + raise Exception("comm_comp is not giving comm or comp! The new type: " + row[colNum]) + + # plt.figure() + plotdata = pd.DataFrame({ + "comm":commMov, + "comp":compMov + }, index = index + ) + plotdata.plot(kind = 'bar', stacked = True) + plt.xlabel("Zone decided by the max distance to goal") + plt.ylabel("Move Generation Time") + plt.title("Move Generation Time in Each Zone based on comm_comp") + plt.savefig(dirName + fileName + "/movGenTimeVScommCompZoneDist-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot simulation time vs. high level optimization name in a zonal format +def plotSimTimeVShighLevelOptZoneDist(dirName, fileName, zoneNum, optColNum, distColNum, simColNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + topoSim = np.zeros(zoneNum, dtype = float) + tunSim = np.zeros(zoneNum, dtype = float) + mapSim = np.zeros(zoneNum, dtype = float) + idenOptSim = np.zeros(zoneNum, dtype = float) + + maxDist = 0 + index = [] + for i in range(0, zoneNum): + index.append(i) + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i == 2: + maxDist = float(row[distColNum]) + + if i > 1: + if row[optColNum] == "topology": + topoSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[optColNum] == "customization": + tunSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[optColNum] == "mapping": + mapSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[optColNum] == "identity": + idenOptSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + else: + raise Exception("high level optimization name is not giving topology or customization or mapping or identity! The new type: " + row[optColNum]) + + # plt.figure() + plotdata = pd.DataFrame({ + "topology":topoSim, + "customization":tunSim, + "mapping":mapSim, + "identity":idenOptSim + }, index = index + ) + plotdata.plot(kind = 'bar', stacked = True) + plt.xlabel("Zone decided by the max distance to goal") + plt.ylabel("Simulation Time") + plt.title("Simulation Time in Each Zone based on Optimation Name") + plt.savefig(dirName + fileName + "/simTimeVShighLevelOptZoneDist-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot simulation time vs. high level optimization name in a zonal format +def plotMovGenTimeVShighLevelOptZoneDist(dirName, fileName, zoneNum, optColNum, distColNum, movGenColNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + topoMov = np.zeros(zoneNum, dtype = float) + tunMov = np.zeros(zoneNum, dtype = float) + mapMov = np.zeros(zoneNum, dtype = float) + idenOptMov = np.zeros(zoneNum, dtype = float) + + maxDist = 0 + index = [] + for i in range(0, zoneNum): + index.append(i) + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i == 2: + maxDist = float(row[distColNum]) + + if i > 1: + if row[optColNum] == "topology": + topoMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[optColNum] == "customization": + tunMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[optColNum] == "mapping": + mapMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[optColNum] == "identity": + idenOptMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + else: + raise Exception("high level optimization name is not giving topology or customization or mapping or identity! The new type: " + row[optColNum]) + + # plt.figure() + plotdata = pd.DataFrame({ + "topology":topoMov, + "customization":tunMov, + "mapping":mapMov, + "identity":idenOptMov + }, index = index + ) + plotdata.plot(kind = 'bar', stacked = True) + plt.xlabel("Zone decided by the max distance to goal") + plt.ylabel("Transformation Generation Time") + plt.title("Transformation Generation Time in Each Zone based on Optimization Name") + plt.savefig(dirName + fileName + "/movGenTimeVShighLevelOptZoneDist-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot simulation time vs. architectural principle in a zonal format +def plotSimTimeVSarchVarImpZoneDist(dirName, fileName, zoneNum, archColNum, distColNum, simColNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + paraSim = np.zeros(zoneNum, dtype = float) + custSim = np.zeros(zoneNum, dtype = float) + localSim = np.zeros(zoneNum, dtype = float) + idenImpSim = np.zeros(zoneNum, dtype = float) + + maxDist = 0 + index = [] + for i in range(0, zoneNum): + index.append(i) + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i == 2: + maxDist = float(row[distColNum]) + + if i > 1: + if row[archColNum] == "parallelization": + paraSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[archColNum] == "customization": + custSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[archColNum] == "locality": + localSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + elif row[archColNum] == "identity": + idenImpSim[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[simColNum]) + else: + raise Exception("architectural principle is not giving parallelization or customization or locality or identity! The new type: " + row[archColNum]) + + # plt.figure() + plotdata = pd.DataFrame({ + "parallelization":paraSim, + "customization":custSim, + "locality":localSim, + "identity":idenImpSim + }, index = index + ) + plotdata.plot(kind = 'bar', stacked = True) + plt.xlabel("Zone decided by the max distance to goal") + plt.ylabel("Simulation Time") + plt.title("Simulation Time in Each Zone based on Architectural Principle") + plt.savefig(dirName + fileName + "/simTimeVSarchVarImpZoneDist-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot simulation time vs. architectural principle in a zonal format +def plotMovGenTimeVSarchVarImpZoneDist(dirName, fileName, zoneNum, archColNum, distColNum, movGenColNum, trueNum): + with open(dirName + fileName + "/result_summary/FARSI_simple_run_0_1_all_reults.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + + paraMov = np.zeros(zoneNum, dtype = float) + custMov = np.zeros(zoneNum, dtype = float) + localMov = np.zeros(zoneNum, dtype = float) + idenImpMov = np.zeros(zoneNum, dtype = float) + + maxDist = 0 + index = [] + for i in range(0, zoneNum): + index.append(i) + + for i, row in enumerate(resultReader): + if row[trueNum] != "True": + continue + + if i == 2: + maxDist = float(row[distColNum]) + + if i > 1: + if row[archColNum] == "parallelization": + paraMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[archColNum] == "customization": + custMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[archColNum] == "locality": + localMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + elif row[archColNum] == "identity": + idenImpMov[zonalPartition(float(row[distColNum]), zoneNum, maxDist)] += float(row[movGenColNum]) + else: + raise Exception("architectural principle is not giving parallelization or customization or locality or identity! The new type: " + row[archColNum]) + + # plt.figure() + plotdata = pd.DataFrame({ + "parallelization":paraMov, + "customization":custMov, + "locality":localMov, + "identity":idenImpMov + }, index = index + ) + plotdata.plot(kind = 'bar', stacked = True) + plt.xlabel("Zone decided by the max distance to goal") + plt.ylabel("Tranformation Generation Time") + plt.title("Tranformation Generation Time in Each Zone based on Architectural Principle") + plt.savefig(dirName + fileName + "/movGenTimeVSarchVarImpZoneZoneDist-" + fileName + ".png") + # plt.show() + plt.close('all') + +# the function to plot convergence vs. iteration cnt, system block count, and routing complexity in 3d +def plotBudgets3d(dirName, subDirName): + newDirName = dirName + "/"+ subDirName + "/" + if os.path.exists(newDirName + "/figures"): + shutil.rmtree(newDirName + "/figures") + resultList = os.listdir(newDirName) + latBudgets = [] + powBudgets = [] + areaBudgets = [] + itrValues = [] + cntValues = [] + routingValues = [] + workloads = [] + for j, fileName in enumerate(resultList): + with open(newDirName + fileName + "/result_summary/FARSI_simple_run_0_1.csv", newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + for i, row in enumerate(resultReader): + if i == 1: + itrValues.append(int(row[columnNum(newDirName, fileName, "iteration cnt", "simple")])) + cntValues.append(int(row[columnNum(newDirName, fileName, "system block count", "simple")])) + routingValues.append(float(row[columnNum(newDirName, fileName, "routing complexity", "simple")])) + powBudgets.append(float(row[columnNum(newDirName, fileName, "power_budget", "simple")])) + areaBudgets.append(float(row[columnNum(newDirName, fileName, "area_budget", "simple")])) + lat = row[int(columnNum(newDirName, fileName, "latency_budget", "simple"))][:-1] + latDict = dict(item.split("=") for item in lat.split(";")) + if j == 0: + for k in range(0, len(latDict)): + latBudgets.append([]) + workloads.append(list(latDict.keys())[k]) + latList = list(latDict.values()) + for k in range(0, len(latList)): + latBudgets[k].append(float(latList[k])) + + m = ['o', 'x', '^', 's', 'd', '+', 'v', '<', '>'] + axis_font = {'size': '10'} + fontSize = 10 + os.mkdir(newDirName + "figures") + fig_budget_itr = plt.figure(figsize=(12, 12)) + plt.rc('font', **axis_font) + ax_itr = fig_budget_itr.add_subplot(projection='3d') + for i in range(0, len(latBudgets)): + img = ax_itr.scatter3D(powBudgets, areaBudgets, latBudgets[i], c=itrValues, cmap="bwr", marker=m[i], s=80, label='{0}'.format(workloads[i])) + for j in range(0, len(latBudgets[i])): + coordinate = str(itrValues[j]) + ax_itr.text(powBudgets[j], areaBudgets[j], latBudgets[i][j], '%s' % coordinate, size=fontSize) + break + ax_itr.set_xlabel("Power Budget") + ax_itr.set_ylabel("Area Budget") + ax_itr.set_zlabel("Latency Budget") + ax_itr.legend() + cbar_itr = fig_budget_itr.colorbar(img, aspect = 40) + cbar_itr.set_label("Number of Iterations", rotation = 270) + plt.title("{Power Budget, Area Budget, Latency Budget} VS Iteration Cnt: " + subDirName) + plt.tight_layout() + plt.savefig(newDirName + "figures/budgetVSitr-" + subDirName + ".png") + # plt.show() + plt.close('all') + + fig_budget_blkcnt = plt.figure(figsize=(12, 12)) + plt.rc('font', **axis_font) + ax_blkcnt = fig_budget_blkcnt.add_subplot(projection='3d') + for i in range(0, len(latBudgets)): + img = ax_blkcnt.scatter3D(powBudgets, areaBudgets, latBudgets[i], c=cntValues, cmap="bwr", marker=m[i], s=80, label='{0}'.format(workloads[i])) + for j in range(0, len(latBudgets[i])): + coordinate = str(cntValues[j]) + ax_blkcnt.text(powBudgets[j], areaBudgets[j], latBudgets[i][j], '%s' % coordinate, size=fontSize) + break + ax_blkcnt.set_xlabel("Power Budget") + ax_blkcnt.set_ylabel("Area Budget") + ax_blkcnt.set_zlabel("Latency Budget") + ax_blkcnt.legend() + cbar = fig_budget_blkcnt.colorbar(img, aspect=40) + cbar.set_label("System Block Count", rotation=270) + plt.title("{Power Budget, Area Budget, Latency Budget} VS System Block Count: " + subDirName) + plt.tight_layout() + plt.savefig(newDirName + "figures/budgetVSblkcnt-" + subDirName + ".png") + # plt.show() + plt.close('all') + + fig_budget_routing = plt.figure(figsize=(12, 12)) + plt.rc('font', **axis_font) + ax_routing = fig_budget_routing.add_subplot(projection='3d') + for i in range(0, len(latBudgets)): + img = ax_routing.scatter3D(powBudgets, areaBudgets, latBudgets[i], c=cntValues, cmap="bwr", marker=m[i], s=80, label='{0}'.format(workloads[i])) + for j in range(0, len(latBudgets[i])): + coordinate = str(routingValues[j]) + ax_routing.text(powBudgets[j], areaBudgets[j], latBudgets[i][j], '%s' % coordinate, size=fontSize) + break + ax_routing.set_xlabel("Power Budget") + ax_routing.set_ylabel("Area Budget") + ax_routing.set_zlabel("Latency Budget") + ax_routing.legend() + cbar = fig_budget_routing.colorbar(img, aspect=40) + cbar.set_label("System Block Count", rotation=270) + plt.title("{Power Budget, Area Budget, Latency Budget} VS System Block Count: " + subDirName) + plt.tight_layout() + plt.savefig(newDirName + "figures/budgetVSroutingComplexity-" + subDirName + ".png") + # plt.show() + plt.close('all') + +def get_experiment_dir_list(run_folder_name): + workload_set_folder_list = os.listdir(run_folder_name) + + experiment_full_addr_list = [] + # iterate and generate plots + for workload_set_folder in workload_set_folder_list: + # ignore irelevant files + if workload_set_folder in config_plotting.ignore_file_names: + continue + + # get experiment folder + workload_set_full_addr = os.path.join(run_folder_name,workload_set_folder) + folder_list = os.listdir(workload_set_full_addr) + for experiment_name_relative_addr in folder_list: + if experiment_name_relative_addr in config_plotting.ignore_file_names: + continue + experiment_full_addr_list.append(os.path.join(workload_set_full_addr, experiment_name_relative_addr)) + + return experiment_full_addr_list + + +def find_the_most_recent_directory(top_dir): + dirs = [os.path.join(top_dir, el) for el in os.listdir(top_dir)] + dirs = list(filter(os.path.isdir, dirs)) + dirs.sort(key=lambda x: os.path.getmtime(x), reverse=True) + return dirs + + +def get_experiment_full_file_addr_list(experiment_full_dir_list, aggregate=False): + if aggregate: + file_name = 'result_summary/aggregate_all_results.csv' + else: + file_name = "result_summary/FARSI_simple_run_0_1.csv" + results = [] + for el in experiment_full_dir_list: + results.append(os.path.join(el, file_name)) + + return results + +######### RADHIKA PANDAS PLOTS ############ +def simple_stack_bar_plot(output_dir): + labels = ['1', '2', '3'] + customization = [13, 11.74, 11.55] + LLP = [100, 49, 27] + TLP = [2.0, 2.3, 2.14] + degradation = [1.07, 1.25, 1.35] + width = 0.35 # the width of the bars: can also be len(x) sequence + + fig, ax = plt.subplots() + + ax.bar(labels, customization, width, label='customization') + ax.bar(labels, LLP, width, bottom=customization, label='LLP') + ax.bar(labels, TLP, width, bottom=np.add(LLP, customization).tolist(), label='TLP') + + ax.set_ylabel('Performance Driver') + #ax.set_title() + ax.legend() + + output_dir_ = os.path.join(output_dir, "drivers") + if not os.path.exists(output_dir_): + os.makedirs(output_dir_) + plt.savefig(os.path.join(output_dir_, "performance_driver.png")) + plt.show() + + + +def grouped_barplot_varying_x(df, metric, metric_ylabel, varying_x, varying_x_labels, ax): + # [[bar heights, errs for varying_x1], [heights, errs for varying_x2]...] + #if metric in ["ip_cnt", "local_bus_avg_bus_width", "local_memory_total_area"]: + # print("ok") + + + + grouped_stats_list = [] + for x in varying_x: + grouped_x = df.groupby([x]) + stats = grouped_x[metric].agg([np.mean, np.std]) + grouped_stats_list.append(stats) + """ + for el in grouped_stats_list: + #x = el.at[2, "mean"] + diff = (((el.at[1, "mean"] - el.at[4, "mean"])/el.at[4, "mean"])*100) + print (metric+":"+el.index.name +" : "+ str(diff)) + """ + + if metric in [# "loop_unrolling_parallelism_speed_up_full_system", "customization_speed_up_full_system", + "task_level_parallelism_speed_up_full_system", + "interference_degradation_avg"]: + + print("metric is:" + metric) + for el in grouped_stats_list: + print(el) + start_loc = 0 + bar_width = 0.15 + offset = 0 + # [[bar locations for varying_x1], [bar locs for varying_x2]...] + grouped_bar_locs_list = [] + for x in varying_x: + n_unique_varying_x = df[x].nunique() + bound = (n_unique_varying_x-1) * (bar_width+offset) + end_loc = start_loc+bound + bar_locs = np.linspace(start_loc, end_loc, n_unique_varying_x) + grouped_bar_locs_list.append(bar_locs) + start_loc = end_loc + 2*bar_width + + #print(grouped_bar_locs_list) + + color = ["red", "orange", "green"] + ctr = 0 + for x_i,x in enumerate(varying_x): + ax.bar( + grouped_bar_locs_list[x_i], + grouped_stats_list[x_i]["mean"], + width=bar_width, + yerr=grouped_stats_list[x_i]["std"], + color = color, + label=metric_ylabel + ) + ctr +=1 + cat_xticks = [] + cat_xticklabels = [] + + xticks = [] + xticklabels = [] + for x_i,x in enumerate(varying_x): + xticklabels.extend(grouped_stats_list[x_i].index.astype(float)) + xticks.extend(grouped_bar_locs_list[x_i]) + + xticks_cat = grouped_bar_locs_list[x_i] + xticks_cat_start = xticks_cat[0] + xticks_cat_end = xticks_cat[-1] + xticks_cat_mid = xticks_cat_start + (xticks_cat_end - xticks_cat_start) / 2 + + cat_xticks.append(xticks_cat_mid) + cat_xticklabels.append("\n\n" + varying_x_labels[x_i]) + + xticks.extend(cat_xticks) + xticklabels.extend(cat_xticklabels) + + ax.set_ylabel(metric_ylabel) + #ax.set_xlabel(xlabel) + ax.set_xticks(xticks) + ax.set_xticklabels(xticklabels) + + return ax + +def pie_chart(dir_names, all_res_column_name_number, case_study): + + + file_full_addr = os.path.join(dir_names[0], "result_summary/FARSI_simple_run_0_1_all_reults.csv") + column_name_number_dic = {} + + column_name_list = case_study[1] + column_aggregate = {} + for column_name in column_name_list: + column_aggregate[column_name] = 0 + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + for i, row in enumerate(resultReader): + if i > 1: + try: + column_aggregate[column_name] += float(row[all_res_column_name_number[column_name]]) + except: + continue + + y = np.array(list(column_aggregate.values())) + mylabels = list(column_aggregate.keys()) + + plt.pie(y, labels=mylabels) + plt.legend() + + output_base_dir = '/'.join(dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "pie_chart") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + plt.savefig(os.path.join(output_dir, case_study[0]+".png")) + #plt.show() + + + plt.show() + +def pie_chart_for_paper(dir_names, all_res_column_name_number, case_study): + + + file_full_addr = os.path.join(dir_names[0], "result_summary/FARSI_simple_run_0_1_all_reults.csv") + column_name_number_dic = {} + + column_name_list = case_study[1] + column_aggregate = {} + for column_name in column_name_list: + column_aggregate[column_name] = 0 + with open(file_full_addr, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + for i, row in enumerate(resultReader): + if i > 1: + try: + column_aggregate[column_name] += float(row[all_res_column_name_number[column_name]]) + except: + continue + + y = np.array(list(column_aggregate.values())) + mylabels = list(column_aggregate.keys()) + + axis_font = {'size': '28'} + fontSize = 24 + if case_study[0] == "Performance Breakdown": + plt.figure(figsize=(8, 6)) + elif case_study[0] == "Transformation_Generation_Breakdown": + plt.figure(figsize=(8, 6)) + plt.rc('font', **axis_font) + if mylabels == ['transformation generation time', 'simulation time', 'neighbour selection time']: + mylabels = ['System Generation', 'Simulation', 'System Selection'] + elif mylabels == ['metric selection time', 'dir selection time', 'kernel selection time', 'block selection time', 'transformation selection time', 'design duplication time']: + mylabels = ['Metric', 'Direction', 'Task', 'Block', 'Move', 'Design Duplication'] + plt.pie(y, autopct=lambda p: '{:1.1f}%'.format(p) if p > 0.8 else '') # Ying: original: , labels=mylabels) + if case_study[0] == "Performance Breakdown": + plt.legend(mylabels, bbox_to_anchor=(0.5, 1.5), loc="upper center", ncol=1, fontsize=fontSize) + elif case_study[0] == "Transformation_Generation_Breakdown": + plt.legend(mylabels, bbox_to_anchor=(0.5, 1.5), loc="upper center", ncol=2, fontsize=fontSize) + plt.tight_layout() + + output_base_dir = '/'.join(dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "pie_chart") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + print(output_dir) + plt.savefig(os.path.join(output_dir, case_study[0]+".png"), bbox_inches='tight') + #plt.show() + + + plt.show() + plt.close('all') + +def pandas_plots(input_dir_names, all_results_files, metric): + df = pd.concat((pd.read_csv(f) for f in all_results_files)) + + #df = raw_df.loc[(raw_df["move validity"] == True)] + #df["dist_to_goal_non_cost_delta"] = df["ref_des_dist_to_goal_non_cost"] - df["dist_to_goal_non_cost"] + #df["local_traffic_ratio"] = np.divide(df["local_total_traffic"], df["local_total_traffic"] + df["global_total_traffic"]) + #metric = "global_memory_avg_freq" + #metric_ylabel = "Global memory avg freq" + #metric = "local_traffic_ratio" + + metric_ylabel = metric #"Local traffic ratio" + + + varying_x = [ + "budget_scaling_latency", + "budget_scaling_power", + "budget_scaling_area", + ] + varying_x_labels = [ + "latency", + "power", + "area", + ] + + + fig, ax = plt.subplots(1) + grouped_barplot_varying_x( + df, + metric, metric_ylabel, + varying_x, varying_x_labels, + ax + ) + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "panda_study/") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + fig.savefig(os.path.join(output_dir, metric+".png")) + #plt.show() + plt.close('all') + + + + #fig.tight_layout(rect=[0, 0, 1, 1]) + #fig.savefig("/Users/behzadboro/Project_FARSI_dir/Project_FARSI_with_channels/data_collection/data/simple_run/27_point_coverage_zad/bleh.png") + #plt.close(fig) + +def grouped_barplot_varying_x_for_paper(df, metric, metric_ylabel, varying_x, varying_x_labels, ax): + # [[bar heights, errs for varying_x1], [heights, errs for varying_x2]...] + grouped_stats_list = [] + for x in varying_x: + grouped_x = df.groupby([x]) + stats = grouped_x[metric].agg([np.mean, np.std]) + grouped_stats_list.append(stats) + + start_loc = 0 + bar_width = 0.15 + offset = 0 # Ying: original: 0.03 + # [[bar locations for varying_x1], [bar locs for varying_x2]...] + grouped_bar_locs_list = [] + for x in varying_x: + n_unique_varying_x = df[x].nunique() + bound = (n_unique_varying_x-1) * (bar_width+offset) + end_loc = start_loc+bound + bar_locs = np.linspace(start_loc, end_loc, n_unique_varying_x) + grouped_bar_locs_list.append(bar_locs) + start_loc = end_loc + 2*bar_width + + # print(grouped_bar_locs_list) # Ying: comment out for WTF + + color = ["mediumseagreen", "gold", "tomato"] + legendLabel = ["1X", "2X", "4X"] + ctr = 0 + coloredLocList=[[], [], []] + coloredStatsValueList = [[], [], []] + """ + Ying: add the following lines for the legend + """ + for x_i,x in enumerate(varying_x): + for i in range(0, len(grouped_bar_locs_list)): + coloredLocList[x_i].append(grouped_bar_locs_list[i][x_i]) + coloredStatsValueList[i].append(grouped_stats_list[x_i]["mean"].to_numpy()[i]) + """ + Ying: adding finished + """ + for x_i, x, in enumerate(varying_x): + # ax.bar( + # grouped_bar_locs_list[x_i], + # grouped_stats_list[x_i]["mean"], + # width=bar_width, + # # yerr=grouped_stats_list[x_i]["std"], + # color = color, + # ) + ax.bar( + coloredLocList[x_i], + coloredStatsValueList[x_i], + width=bar_width, + # yerr=grouped_stats_list[x_i]["std"], + color=color[x_i], + label=legendLabel[x_i] + ) + ctr +=1 + cat_xticks = [] + cat_xticklabels = [] + + xticks = [] + xticklabels = [] + """ + Ying: add the following lines to get rid of the numbers on the x-axis + """ + for i in range(0, 9): + xticklabels.append(' ') + """ + Ying: adding finished + """ + for x_i,x in enumerate(varying_x): + # xticklabels.extend(grouped_stats_list[x_i].index.astype(float)) # Ying: comment out and leave them for legends + xticks.extend(grouped_bar_locs_list[x_i]) + + xticks_cat = grouped_bar_locs_list[x_i] + xticks_cat_start = xticks_cat[0] + xticks_cat_end = xticks_cat[-1] + xticks_cat_mid = xticks_cat_start + (xticks_cat_end - xticks_cat_start) / 2 + + cat_xticks.append(xticks_cat_mid) + cat_xticklabels.append(varying_x_labels[x_i]) # Ying: the original code was: "\n\n" + varying_x_labels[x_i]) + + fontSize = 28 + axis_font = {'size': '28'} + xticks.extend(cat_xticks) + xticklabels.extend(cat_xticklabels) + + ax.set_ylabel(metric_ylabel, fontsize=fontSize) # Ying: add fontsize + ax.yaxis.set_label_coords(-0.15, 0.6) + #ax.set_xlabel(xlabel) + ax.set_xticks(xticks) + ax.legend(legendLabel, bbox_to_anchor=(0.5, 1.35), loc="upper center", fontsize=fontSize-4, ncol=3) # Ying: test the way to add legends # 1.35 for the figures with 1e8, etc., 1.2 for the figures without that + ax.set_xticklabels(xticklabels, fontsize=fontSize) # Ying: add fontsize + + return ax + + + +def find_average_iteration_to_distance(input_dir_names, all_results_files, intrested_distance_to_consider): + # iterate and collect all the data + heuristic_dist_iter_all = {} + for result_file in all_results_files: + + df = pd.concat((pd.read_csv(f) for f in [result_file])) + ht = df['heuristic_type'] + if list(ht)[0] not in heuristic_dist_iter_all: + heuristic_dist_iter_all[list(ht)[0]] = [] + dist_to_goal_non_cost = df["ref_des_dist_to_goal_non_cost"] + + + dist_itr = OrderedDict() + for intrested_dist in intrested_distance_to_consider: + for itr, dist in enumerate(dist_to_goal_non_cost) : + if dist < intrested_dist: + dist_itr[intrested_dist] = itr + break + #if len(list(dist_itr.values())) <= 2: + # continue + heuristic_dist_iter_all[list(ht)[0]].append(dist_itr) + + + + # per heuristic reduce + heuristic_dist_iter_avg = {} + for heuristic, values in heuristic_dist_iter_all.items(): + aggregate = OrderedDict() + for val in values: + for dist, itr in val.items(): + if dist in aggregate: + aggregate[dist].append(itr) + else: + aggregate[dist] = [itr] + + if heuristic not in heuristic_dist_iter_avg.keys(): + heuristic_dist_iter_avg[heuristic] = OrderedDict() + for dist, all_itr in aggregate.items(): + heuristic_dist_iter_avg[heuristic][dist] = sum(all_itr)/len(all_itr) + + """ + # compare heuristics + speedup = {} + for heuristic in heuristic_dist_iter_avg.keys(): + speedup[heuristic] = {} + for dist in intrested_distance_to_consider: + if dist not in heuristic_dist_iter_avg[heuristic].keys(): + speedup[heuristic][dist] = float('inf') + else: + speedup[heuristic][dist] = heuristic_dist_iter_avg[heuristic][dist]/heuristic_dist_iter_avg["FARSI"][dist] + + speedup_sorted = {} + for heuristic in speedup.keys(): + keys = sorted(speedup[heuristic].keys()) + sorted_ = sorted({key: speedup[heuristic][key] for key in keys}.items()) + speedup_sorted[heuristic] = sorted_ + """ + + + return heuristic_dist_iter_avg + +def heuristic_scaling_parsing(input_dir_names, all_results_files): + max_min_dist = 30 + + # iterate and collect data for the closest distance + heuristic_closest_dist_iter_all = {} + heuristic_avg_neighbourhood_size_all = {} + power_budget_ref = .008737 + + for result_file in all_results_files: + df = pd.concat((pd.read_csv(f) for f in [result_file])) + dist_to_goal_non_cost = df["ref_des_dist_to_goal_non_cost"] + task_count = len((df["latency"][1]).split(";")[:-1]) + budget_scaling = float(df["power_budget"][1])/(power_budget_ref*task_count) + neighbourhood_sizes = [] + for el in df["neighbouring design space size"]: + blah = type(el) + if (not el == 0) and not (math.isnan(el)): + neighbourhood_sizes.append(el) + + # get avg neighbourhood size for the entire run + avg_neighbourhood_size = sum(neighbourhood_sizes)/len(neighbourhood_sizes) + + dist_itr = OrderedDict() + min_dist = 1000 + for itr, dist in enumerate(dist_to_goal_non_cost) : + min_dist = min(min_dist, dist) + min_itr = itr + if min_dist >= max_min_dist: + continue + + # collect data about the closest distance and neighbour hood in a dictionary + if task_count not in heuristic_closest_dist_iter_all.keys(): + heuristic_closest_dist_iter_all[task_count] = {} + if task_count not in heuristic_avg_neighbourhood_size_all.keys(): + heuristic_avg_neighbourhood_size_all[task_count] = {} + + if budget_scaling not in heuristic_closest_dist_iter_all[task_count].keys(): + heuristic_closest_dist_iter_all[task_count][budget_scaling] = [] + if budget_scaling not in heuristic_avg_neighbourhood_size_all[task_count].keys(): + heuristic_avg_neighbourhood_size_all[task_count][budget_scaling] = [] + + + heuristic_closest_dist_iter_all[task_count][budget_scaling].append(min_itr) + heuristic_avg_neighbourhood_size_all[task_count][budget_scaling].append(avg_neighbourhood_size) + + # avg collected data + heuristic_closest_dist_iter_avg = {} + for task_count in heuristic_closest_dist_iter_all.keys(): + heuristic_closest_dist_iter_avg[task_count] = {} + for budget_scaling in heuristic_closest_dist_iter_all[task_count].keys(): + heuristic_closest_dist_iter_avg[task_count][budget_scaling] = sum(heuristic_closest_dist_iter_all[task_count][budget_scaling])/len(heuristic_closest_dist_iter_all[task_count][budget_scaling]) + + heuristic_avg_neighbourhood_size_avg = {} + for task_count in heuristic_avg_neighbourhood_size_all.keys(): + heuristic_avg_neighbourhood_size_avg[task_count] = {} + for budget_scaling in heuristic_avg_neighbourhood_size_all[task_count].keys(): + heuristic_avg_neighbourhood_size_avg[task_count][budget_scaling] = sum(heuristic_avg_neighbourhood_size_all[task_count][budget_scaling])/len(heuristic_avg_neighbourhood_size_all[task_count][budget_scaling]) + + + + return heuristic_closest_dist_iter_avg, heuristic_avg_neighbourhood_size_avg + + +def find_closest_dist_on_average(input_dir_names, all_results_files): + max_min_dist = 30 + # iterate and collect data for farthest distance + heuristic_closest_dist_iter_all = {} + results_for_plotting = {} + for result_file in all_results_files: + df = pd.concat((pd.read_csv(f) for f in [result_file])) + ht = df['heuristic_type'] + if list(ht)[0] not in heuristic_closest_dist_iter_all: + heuristic_closest_dist_iter_all[list(ht)[0]] = [] + dist_to_goal_non_cost = df["ref_des_dist_to_goal_non_cost"] + + + dist_itr = OrderedDict() + min_dist = 1000 + for itr, dist in enumerate(dist_to_goal_non_cost) : + min_dist = min(min_dist, dist) + if min_dist >= max_min_dist: + continue + heuristic_closest_dist_iter_all[list(ht)[0]].append(min_dist) + if int(min_dist) == 15 and list(ht)[0] == "SA": + results_for_plotting["SA"] = result_file + if int(min_dist) == 4 and list(ht)[0] == "moos": + results_for_plotting["moos"] = result_file + + + + # per heuristic reduce + heuristic_closest_dist_iter_avg = {} + max_of_closest_dist_reached = {} + for heuristic, values in heuristic_closest_dist_iter_all.items(): + #max_of_closest_dist_reached = 0 + heuristic_closest_dist_iter_avg[heuristic] = sum(values)/len(values) + max_of_closest_dist_reached[heuristic] = max(values) + + return heuristic_closest_dist_iter_avg, max_of_closest_dist_reached + + +def heuristic_scaling(input_dir_names, all_results_files, summary_res_column_name_number): + heuristic_closest_dist_iter = heuristic_scaling_parsing(input_dir_names, all_results_files) + print("ok") + +def heuristic_comparison(input_dir_names, all_results_files, summary_res_column_name_number): + intrested_distance_to_consider = [500, 200, 100, 50, 10] + #intrested_distance_to_consider = [1000, 500, 100, 10, 5, 1, .01] + heuristic_dist_along_the_way_iter_avg = find_average_iteration_to_distance(input_dir_names, all_results_files, intrested_distance_to_consider) + heuristic_closest_dist_avg, max_of_closest_dist_reached = find_closest_dist_on_average(input_dir_names, all_results_files) + #longest_dist = max(list(heuristic_closest_dist_iter_avg.values())) + heuristic_dist_iter_avg = find_average_iteration_to_distance(input_dir_names, all_results_files, [max_of_closest_dist_reached["SA"]]) + heuristic_phv = pareto_studies(input_dir_names, all_results_files, summary_res_column_name_number) + + + # calculate quality gain + #FARSI_closest_dist = heuristic_closest_dist_avg["FARSI"] + quality_gain = {} + for heu,dist in heuristic_closest_dist_avg.items(): + quality_gain[heu] = ((heuristic_closest_dist_avg["SA"] - dist)/heuristic_closest_dist_avg["SA"])*100 + + + # calculate speedup + speedup = {} + for heu, vals in heuristic_dist_iter_avg.items(): + for dist, iter in vals.items(): + if not dist == max_of_closest_dist_reached["SA"]: + continue + speedup[heu] = heuristic_dist_iter_avg["SA"][dist]/iter + + + # calculate speedup + speedup_along_the_way = {} + for heu, vals in heuristic_dist_along_the_way_iter_avg.items(): + if heu not in speedup_along_the_way.keys(): + speedup_along_the_way[heu] = {} + for dist, iter in vals.items(): + speedup_along_the_way[heu][dist] = iter/heuristic_dist_along_the_way_iter_avg["SA"][dist] + + + # calculate hpv + phv_normalized = {} + for heu in heuristic_phv.keys(): + phv_normalized[heu] = heuristic_phv[heu]/heuristic_phv["SA"] + return speedup + + + + + print("ok") + + + +def pandas_plots_for_paper(input_dir_names, all_results_files, metric): + df = pd.concat((pd.read_csv(f) for f in all_results_files)) + + #df = raw_df.loc[(raw_df["move validity"] == True)] + #df["dist_to_goal_non_cost_delta"] = df["ref_des_dist_to_goal_non_cost"] - df["dist_to_goal_non_cost"] + #df["local_traffic_ratio"] = np.divide(df["local_total_traffic"], df["local_total_traffic"] + df["global_total_traffic"]) + #metric = "global_memory_avg_freq" + #metric_ylabel = "Global memory avg freq" + #metric = "local_traffic_ratio" + + # metric_ylabel = metric #"Local traffic ratio" # Ying: replaced the underscores with whitespaces; the new code is the following line + metric_ylabel = ' '.join(metric.split('_')) + """ + Ying: add the following lines just in case we need them + """ + if metric == "ip_cnt": + metric_ylabel = "IP Count" + elif metric == "local_bus_cnt": + metric_ylabel = "NoC Count" + elif metric == "local_bus_avg_freq": + metric_ylabel = "NoC Avg Frequency (Hz)" + elif metric == "local_channel_count_per_bus_coeff_var": + metric_ylabel = "Channel Variation" + elif metric == "local_memory_area_coeff_var": + metric_ylabel = "Memory Area Variation" + elif metric == "local_bus_freq_coeff_var": + metric_ylabel = "NoC Frequency Variation" + elif metric == "local_total_traffic_reuse_no_read_in_bytes_per_cluster_avg": + metric_ylabel = "Memory Reuse (Bytes)" + elif metric == "avg_accel_parallelism": + metric_ylabel = "Avg ALP" + elif metric == "local_bus_avg_actual_bandwidth": + metric_ylabel = "Link Bandwidth (Bytes/s)" + elif metric == "local_memory_total_area": + metric_ylabel = "Local Memory Area (mm2)" + elif metric == "local_total_traffic": + metric_ylabel = "Local Total Traffic (Bytes)" + """ + Ying: adding finished + """ + + varying_x = [ + "budget_scaling_latency", + "budget_scaling_power", + "budget_scaling_area", + ] + varying_x_labels = [ + "latency", + "power", + "area", + ] + + axis_font = {'size': "28"} + plt.figure(figsize=(7, 6.4)) + fig, ax = plt.subplots(1, figsize=(7, 6.4)) # Ying: add the figure size + grouped_barplot_varying_x_for_paper( + df, + metric, metric_ylabel, + varying_x, varying_x_labels, + ax + ) + plt.rc('font', **axis_font) + plt.tight_layout() + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "panda_study/") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + fig.savefig(os.path.join(output_dir, metric+".png"), bbox_inches='tight') + #plt.show() + plt.close('all') + + + + #fig.tight_layout(rect=[0, 0, 1, 1]) + #fig.savefig("/Users/behzadboro/Project_FARSI_dir/Project_FARSI_with_channels/data_collection/data/simple_run/27_point_coverage_zad/bleh.png") + #plt.close(fig) + + +def pareto_studies(input_dir_names,all_result_files, summary_res_column_name_number): + def points_exceed_one_of_the_budgets(point, base_budget, budget_scaling_to_consider): + power = point[0] + area = point[1] + if power > base_budgets["power"] * budget_scale_to_consider and area > base_budgets[ + "area"] * budget_scale_to_consider: + return True + return False + + heuristic_results = {} + + #system_char_to_keep_track_of = {"memory_total_area", "local_memory_total_area","pe_total_area", "ip_cnt", "ips_total_area"} + system_char_to_keep_track_of = {"system bus count"} + + # budget scaling to consider + budget_scale_to_consider = .5 + # get budget first + base_budgets = {} + for file in all_result_files: + with open(file, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + for i, row in enumerate(resultReader): + if i == 1: + base_budgets["power"] = float(row[summary_res_column_name_number["power_budget"]]) + base_budgets["area"] = float(row[summary_res_column_name_number["area_budget"]]) + break + + + for file in all_result_files: + with open(file, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + for i, row in enumerate(resultReader): + if i == 0: + continue + heuristic_set_name = row[summary_res_column_name_number["heuristic_type"]] + if heuristic_set_name not in heuristic_results.keys(): + heuristic_results[heuristic_set_name] = [] + latencies = row[summary_res_column_name_number["latency"]].split(";") + latency = {} + for el in latencies: + if el == "": + continue + workload_name = el.split("=")[0] + latency[workload_name] = float(el.split("=")[1]) + + #workload_results[workload_set_name].append((float(power),float(area), float(system_complexity))) + + area= float(row[summary_res_column_name_number["area"]]) + power = float(row[summary_res_column_name_number["power"]]) + system_char = {} + for el in system_char_to_keep_track_of: + system_char[el] = float(row[summary_res_column_name_number[el]]) + point_system_char = {(latency["audio_decoder"], latency["hpvm_cava"], latency["edge_detection"], power, area): system_char} + heuristic_results[heuristic_set_name].append(point_system_char) + + heuristic_pareto_points = {} + for heuristic, points_ in heuristic_results.items(): + points = [list(el.keys())[0] for el in points_] + pareto_points= find_pareto_points(list(set(points))) + #pareto_points= find_pareto_points(list(set(points))) + heuristic_pareto_points[heuristic] = pareto_points + + max_0, max_1, max_2, max_3, max_4 = 0, 0,0,0,0 + for heuristic, pareto_points in heuristic_pareto_points.items(): + for el in pareto_points: + max_0_, max_1_, max_2_, max_3_, max_4_ = el + max_0 = max([max_0, max_0_]) + max_1 = max([max_1, max_1_]) + max_2 = max([max_2, max_2_]) + max_3 = max([max_3, max_3_]) + max_4 = max([max_4, max_4_]) + + + heuristic_hv = {} + ref = [max_0, max_1, max_2, max_3, max_4] + for heuristic, pareto_points in heuristic_pareto_points.items(): + hv = hypervolume(pareto_points) + hv_value = hv.compute(ref) + heuristic_hv[heuristic] = hv_value + return heuristic_hv + + all_points_in_isolation = [] + all_points_cross_workloads = [] + + workload_in_isolation = {} + for workload, points in workload_results.items(): + #points = [list(el.keys())[0] for el in points_] + if "cava" in workload and "audio" in workload and "edge_detection" in workload: + for point in points: + all_points_cross_workloads.append(point) + else: + workload_in_isolation[workload] = points + + + ctr = 0 + workload_in_isolation_pareto = {} + for workload, points_ in workload_in_isolation.items(): + workload_in_isolation_pareto[workload] = [] + points = [list(el.keys())[0] for el in points_] + pareto_points = find_pareto_points(list(set(points))) + for point in pareto_points: + keys = [list(el.keys())[0] for el in workload_in_isolation[workload]] + idx = keys.index(point) + workload_in_isolation_pareto[workload].append({point:(workload_in_isolation[workload])[idx]}) + + + + combined_area_power_in_isolation= [] + s = time.time() + + workload_in_isolation_pareto_only_area_power = {} + for key, val in workload_in_isolation_pareto.items(): + workload_in_isolation_pareto_only_area_power[key] = [] + for el in val: + for k,v in el.items(): + workload_in_isolation_pareto_only_area_power[key].append(k) + + + for results_combined in itertools.product(*list(workload_in_isolation_pareto_only_area_power.values())): + # add up all the charactersitics + system_chars = {} + for el in system_char_to_keep_track_of: + system_chars[el] = 0 + + # add up area,power + combined_power_area_tuple = [0,0] + for el in results_combined: + combined_power_area_tuple[0] += el[0] + combined_power_area_tuple[1] += el[1] + + for point in results_combined: + keys = [list(point_.keys())[0] for point_ in workload_in_isolation_pareto[workload]] + idx = keys.index(point) + for el in system_char.keys(): + system_char[el] += workload_in_isolation + + system_chars[workload].append({point: (workload_in_isolation[workload])[idx]}) + + + #combined_area_power_in_isolation.append((combined_power_area_tuple[0],combined_power_area_tuple[1], combined_power_area_tuple[2])) + combined_area_power_in_isolation.append((combined_power_area_tuple[0],combined_power_area_tuple[1])) + + combined_area_power_in_isolation_filtered = [] + for point in combined_area_power_in_isolation: + if not points_exceed_one_of_the_budgets(point, base_budgets, budget_scale_to_consider): + combined_area_power_in_isolation_filtered.append(point) + combined_area_power_pareto = find_pareto_points(list(set(combined_area_power_in_isolation_filtered))) + + + all_points_cross_workloads_filtered = [] + for point in all_points_cross_workloads: + if not points_exceed_one_of_the_budgets(point, base_budgets, budget_scale_to_consider): + all_points_cross_workloads_filtered.append(point) + all_points_cross_workloads_area_power_pareto = find_pareto_points(list(set(all_points_cross_workloads_filtered))) + + + # prepare for plotting and plot + fig = plt.figure(figsize=(12, 12)) + #plt.rc('font', **axis_font) + ax = fig.add_subplot(111) + fontSize = 20 + + x_values = [el[0] for el in combined_area_power_in_isolation_filtered] + y_values = [el[1] for el in combined_area_power_in_isolation_filtered] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="isolated design methodology",marker=".") + + + # plt.tight_layout() + x_values = [el[0] for el in combined_area_power_pareto] + y_values = [el[1] for el in combined_area_power_pareto] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="isolated design methodology pareto front",marker="x") + + + x_values = [el[0] for el in all_points_cross_workloads_filtered] + y_values = [el[1] for el in all_points_cross_workloads_filtered] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="cross workload methodology",marker="8") + ax.legend(loc="upper right") # bbox_to_anchor=(1, 1), loc="upper left") + + x_values = [el[0] for el in all_points_cross_workloads_area_power_pareto] + y_values = [el[1] for el in all_points_cross_workloads_area_power_pareto] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="cross workload pareto front",marker="o") + #for idx,_ in enumeate(x_values): + # plt.text(x_values[idx], y_values[idx], s=) + + #plt.text([ for el in x) + ax.legend(loc="upper right") # bbox_to_anchor=(1, 1), loc="upper left") + + + ax.set_xlabel("power", fontsize=fontSize) + ax.set_ylabel("area", fontsize=fontSize) + plt.tight_layout() + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "budget_optimality/") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + fig.savefig(os.path.join(output_dir, "budget_optimality.png")) + #plt.show() + plt.close('all') + + + + +def get_budget_optimality_advanced(input_dir_names,all_result_files, summary_res_column_name_number): + def points_exceed_one_of_the_budgets(point, base_budget, budget_scaling_to_consider): + power = point[0] + area = point[1] + if power > base_budgets["power"] * budget_scale_to_consider and area > base_budgets[ + "area"] * budget_scale_to_consider: + return True + return False + + workload_results = {} + + #system_char_to_keep_track_of = {"memory_total_area", "local_memory_total_area","pe_total_area", "ip_cnt", "ips_total_area"} + system_char_to_keep_track_of = {"ip_cnt"} + + # budget scaling to consider + budget_scale_to_consider = .5 + # get budget first + base_budgets = {} + for file in all_result_files: + with open(file, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + for i, row in enumerate(resultReader): + if i == 1: + if float(row[summary_res_column_name_number["budget_scaling_latency"]]) == 1 and\ + float(row[summary_res_column_name_number["budget_scaling_power"]]) == 1 and \ + float(row[summary_res_column_name_number["budget_scaling_area"]]) == 1: + base_budgets["power"] = float(row[summary_res_column_name_number["power_budget"]]) + base_budgets["area"] = float(row[summary_res_column_name_number["area_budget"]]) + break + + + for file in all_result_files: + with open(file, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + for i, row in enumerate(resultReader): + if i == 1: + workload_set_name = row[summary_res_column_name_number["workload_set"]] + if workload_set_name not in workload_results.keys(): + workload_results[workload_set_name] = [] + latency = ((row[summary_res_column_name_number["latency"]].split(";"))[0].split("="))[1] + latency_budget = ((row[summary_res_column_name_number["latency_budget"]].split(";"))[0].split("="))[1] + if float(latency) > float(latency_budget): + continue + + #workload_results[workload_set_name].append((float(power),float(area), float(system_complexity))) + + area= float(row[summary_res_column_name_number["area"]]) + power = float(row[summary_res_column_name_number["power"]]) + system_char = {} + for el in system_char_to_keep_track_of: + system_char[el] = float(row[summary_res_column_name_number[el]]) + point_system_char = {(power, area): system_char} + workload_results[workload_set_name].append(point_system_char) + + workload_pareto_points = {} + for workload, points_ in workload_results.items(): + points = [list(el.keys())[0] for el in points_] + pareto_points= find_pareto_points(list(set(points))) + workload_pareto_points[workload] = [] + for point in pareto_points: + keys = [list(el.keys())[0] for el in workload_results[workload]] + idx = keys.index(point) + workload_pareto_points[workload].append({point:(workload_results[workload])[idx]}) + + + """" + # combine the results + combined_area_power = [] + for results_combined in itertools.product(*list(workload_pareto_points.values())): + combined_power_area_tuple = [0,0] + for el in results_combined: + combined_power_area_tuple[0] += el[0] + combined_power_area_tuple[1] += el[1] + combined_area_power.append(combined_power_area_tuple[:]) + """ + + + all_points_in_isolation = [] + all_points_cross_workloads = [] + + workload_in_isolation = {} + for workload, points in workload_results.items(): + #points = [list(el.keys())[0] for el in points_] + if "cava" in workload and "audio" in workload and "edge_detection" in workload: + for point in points: + all_points_cross_workloads.append(point) + else: + workload_in_isolation[workload] = points + + + ctr = 0 + workload_in_isolation_pareto = {} + for workload, points_ in workload_in_isolation.items(): + workload_in_isolation_pareto[workload] = [] + points = [list(el.keys())[0] for el in points_] + pareto_points = find_pareto_points(list(set(points))) + for point in pareto_points: + keys = [list(el.keys())[0] for el in workload_in_isolation[workload]] + idx = keys.index(point) + workload_in_isolation_pareto[workload].append({point:(workload_in_isolation[workload])[idx]}) + + + + combined_area_power_in_isolation= [] + s = time.time() + + workload_in_isolation_pareto_only_area_power = {} + for key, val in workload_in_isolation_pareto.items(): + workload_in_isolation_pareto_only_area_power[key] = [] + for el in val: + for k,v in el.items(): + workload_in_isolation_pareto_only_area_power[key].append(k) + + + for results_combined in itertools.product(*list(workload_in_isolation_pareto_only_area_power.values())): + # add up all the charactersitics + system_chars = {} + for el in system_char_to_keep_track_of: + system_chars[el] = 0 + + # add up area,power + combined_power_area_tuple = [0,0] + for el in results_combined: + combined_power_area_tuple[0] += el[0] + combined_power_area_tuple[1] += el[1] + + for point in results_combined: + keys = [list(point_.keys())[0] for point_ in workload_in_isolation_pareto[workload]] + idx = keys.index(point) + for el in system_char.keys(): + system_char[el] += workload_in_isolation + + system_chars[workload].append({point: (workload_in_isolation[workload])[idx]}) + + + #combined_area_power_in_isolation.append((combined_power_area_tuple[0],combined_power_area_tuple[1], combined_power_area_tuple[2])) + combined_area_power_in_isolation.append((combined_power_area_tuple[0],combined_power_area_tuple[1])) + + combined_area_power_in_isolation_filtered = [] + for point in combined_area_power_in_isolation: + if not points_exceed_one_of_the_budgets(point, base_budgets, budget_scale_to_consider): + combined_area_power_in_isolation_filtered.append(point) + combined_area_power_pareto = find_pareto_points(list(set(combined_area_power_in_isolation_filtered))) + + + all_points_cross_workloads_filtered = [] + for point in all_points_cross_workloads: + if not points_exceed_one_of_the_budgets(point, base_budgets, budget_scale_to_consider): + all_points_cross_workloads_filtered.append(point) + all_points_cross_workloads_area_power_pareto = find_pareto_points(list(set(all_points_cross_workloads_filtered))) + + + # prepare for plotting and plot + fig = plt.figure(figsize=(12, 12)) + #plt.rc('font', **axis_font) + ax = fig.add_subplot(111) + fontSize = 20 + + x_values = [el[0] for el in combined_area_power_in_isolation_filtered] + y_values = [el[1] for el in combined_area_power_in_isolation_filtered] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="isolated design methodology",marker=".") + + + # plt.tight_layout() + x_values = [el[0] for el in combined_area_power_pareto] + y_values = [el[1] for el in combined_area_power_pareto] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="isolated design methodology pareto front",marker="x") + + + x_values = [el[0] for el in all_points_cross_workloads_filtered] + y_values = [el[1] for el in all_points_cross_workloads_filtered] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="cross workload methodology",marker="8") + ax.legend(loc="upper right") # bbox_to_anchor=(1, 1), loc="upper left") + + x_values = [el[0] for el in all_points_cross_workloads_area_power_pareto] + y_values = [el[1] for el in all_points_cross_workloads_area_power_pareto] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="cross workload pareto front",marker="o") + #for idx,_ in enumeate(x_values): + # plt.text(x_values[idx], y_values[idx], s=) + + #plt.text([ for el in x) + ax.legend(loc="upper right") # bbox_to_anchor=(1, 1), loc="upper left") + + + ax.set_xlabel("power", fontsize=fontSize) + ax.set_ylabel("area", fontsize=fontSize) + plt.tight_layout() + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "budget_optimality/") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + fig.savefig(os.path.join(output_dir, "budget_optimality.png")) + #plt.show() + plt.close('all') + + +def get_budget_optimality_error(): + X = ['CAVA', "ED", "Audio"] + optimal = [10, 20, 20, 40] + Myopic_budgetting = [20, 30, 25, 30] + + X_axis = np.arange(len(X)) + + plt.bar(X_axis - 0.2, Ygirls, 0.4, label='Girls') + plt.bar(X_axis + 0.2, Zboys, 0.4, label='Boys') + + plt.xticks(X_axis, X) + plt.xlabel("Groups") + plt.ylabel("Number of Students") + plt.title("Number of Students in each group") + plt.legend() + plt.show() + +def get_budget_optimality(input_dir_names,all_result_files, reg_summary_res_column_name_number, a_e_h_summary_res_column_name_number): + #methodology 1: + e = {"power":0.00211483379112716, "area":0.00000652737916} + a = {"power": 0.00133071013064968, "area":0.0000048380066802} + hc = {"power": 0.000471573339922609, "area":0.0000019979449678} + e_budget = {"power":0.00190575, "area":0.0000038115} + a_budget = {"power":0.0015485, "area":0.0000030975} + hc_budget = {"power":0.001005, "area":0.00000201} + + + e_dis_budget = {"power":0, "area":0} + a_dis_budget = {"power":0, "area":0} + hc_dis_budget = {"power":0, "area":0} + + for el in e_dis_budget.keys(): + e_dis_budget[el] = (e_budget[el]-e[el])/e[el] + + for el in a_dis_budget.keys(): + a_dis_budget[el] = (a_budget[el]-a[el])/a[el] + + for el in hc_dis_budget.keys(): + hc_dis_budget[el] = (hc_budget[el]-hc[el])/hc[el] + + + + + + combined_design_methodology_A = {"power":0, "area":0} + for el in e.keys() : + combined_design_methodology_A[el] += e[el] + for el in a.keys() : + combined_design_methodology_A[el] += e[el] + for el in hc.keys() : + combined_design_methodology_A[el] += e[el] + + + def get_equivalent_total(charac): + if charac == "ips_avg_freq": + return "ip_cnt" + if charac == "cluster_pe_cnt_avg": + return "ip_cnt" + elif charac == "avg_accel_parallelism": + return "ip_cnt" + elif charac in ["local_memory_avg_freq"]: + return "local_mem_cnt" + elif charac in ["local_bus_avg_actual_bandwidth", "local_bus_avg_theoretical_bandwidth", "local_bus_avg_bus_width", "avg_freq"]: + return "local_bus_count" + else: + return charac + + def find_sys_char(power,area, results_with_sys_char): + for vals in results_with_sys_char: + for power_area , sys_chars in vals.items(): + power_ = power_area[0] + area_ = power_area[1] + if power == power_ and area_ == area: + return sys_chars + + def points_exceed_one_of_the_budgets(point, base_budget, budget_scaling_to_consider): + power = point[0] + area = point[1] + if power > base_budgets["power"] * budget_scale_to_consider and area > base_budgets[ + "area"] * budget_scale_to_consider: + return True + return False + + + workload_results = {} + results_with_sys_char = [] + + system_char_to_keep_track_of = {"memory_total_area", "local_memory_total_area","pe_total_area", "ip_cnt","ips_total_area", "ips_avg_freq", "local_mem_cnt", + "local_bus_avg_actual_bandwidth", "local_bus_avg_theoretical_bandwidth", "local_memory_avg_freq", "local_bus_count", "local_bus_avg_bus_width", "avg_freq", "local_total_traffic", + "global_total_traffic","local_memory_avg_freq", "global_memory_avg_freq", "gpps_total_area", "avg_gpp_parallelism", "avg_accel_parallelism", "channel_cnt", + "local_total_traffic_reuse_with_read_in_bytes", + } + system_char_to_keep_track_of = {"ip_cnt" + } + #system_char_to_show = ["local_memory_total_area"] + #system_char_to_show = ["avg_accel_parallelism"] + #system_char_to_show = ["avg_gpp_parallelism"] + #system_char_to_show = ["local_bus_avg_actual_bandwidth"] + #system_char_to_show = ["avg_freq"] # really is buses avg freq + #system_char_to_show = ["local_memory_avg_freq"] # really is buses avg freq + #system_char_to_show = ["ips_avg_freq"] + #system_char_to_show = ["gpps_total_area"] + #system_char_to_show = ["local_bus_avg_bus_width"] + #system_char_to_show = ["local_memory_avg_freq"] + #system_char_to_show = ["ips_total_area"] + system_char_to_show = ["ip_cnt"] + #system_char_to_show = ["local_mem_cnt"] + #system_char_to_show = ["global_memory_avg_freq"] + #system_char_to_show = ["local_bus_avg_theoretical_bandwidth"] + #system_char_to_show = ["local_memory_avg_freq"] + #system_char_to_show = ["local_total_traffic"] + #system_char_to_show = ["global_total_traffic"] + #system_char_to_show = ["local_total_traffic_reuse_with_read_in_bytes"] + #system_char_to_show = ["local_bus_count"] + #system_char_to_show = ["cluster_pe_cnt_avg"] + #system_char_to_show = ["ip_cnt"] + #system_char_to_show = ["channel_cnt"] + + # budget scaling to consider + budget_scale_to_consider = .5 + # get budget first + base_budgets = {} + """ + for file in all_result_files: + with open(file, newline='') as csvfile: + if not ("a_e_h" in file and "lat_1__pow_1__area_1" in file): + continue + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + for i, row in enumerate(resultReader): + if i == 1: + print("file"+file) + if float(row[reg_summary_res_column_name_number["budget_scaling_latency"]]) == 1 and\ + float(row[reg_summary_res_column_name_number["budget_scaling_power"]]) == 1 and \ + float(row[reg_summary_res_column_name_number["budget_scaling_area"]]) == 1: + base_budgets["power"] = float(row[summary_res_column_name_number["power_budget"]]) + base_budgets["area"] = float(row[summary_res_column_name_number["area_budget"]]) + break + """ + base_budgets = {"power":.008738,"area":.000017475} + + for file in all_result_files: + if "a_e_h" in file and "lat_1__pow_1__area_1" in file: + continue + if ("a_e_h" in file): + summary_res_column_name_number = a_e_h_summary_res_column_name_number + else: + summary_res_column_name_number = reg_summary_res_column_name_number + + with open(file, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + for i, row in enumerate(resultReader): + if i == 1: + workload_set_name = row[summary_res_column_name_number["workload_set"]] + if workload_set_name not in workload_results.keys(): + workload_results[workload_set_name] = [] + latency = ((row[summary_res_column_name_number["latency"]].split(";"))[0].split("="))[1] + latency_budget = ((row[summary_res_column_name_number["latency_budget"]].split(";"))[0].split("="))[1] + if float(latency) > float(latency_budget): + continue + + power = float(row[summary_res_column_name_number["power"]]) + area = float(row[summary_res_column_name_number["area"]]) + + system_complexity = row[summary_res_column_name_number["ip_cnt"]] # + row[summary_res_column_name_number["gpp_cnt"]] + #workload_results[workload_set_name].append((float(power),float(area), float(system_complexity))) + workload_results[workload_set_name].append((power,area)) + system_char = {} + for el in system_char_to_keep_track_of: + #if "latency" == el: + # system_char[el] = row[summary_res_column_name_number[el]] + #else: + system_char[el] = float(row[summary_res_column_name_number[el]]) + system_char["file"] = file + point_system_char = {(power, area): system_char} + results_with_sys_char.append(point_system_char) + + + + workload_pareto_points = {} + for workload, points in workload_results.items(): + workload_pareto_points[workload] = find_pareto_points(list(set(points))) + + """" + # combine the results + combined_area_power = [] + for results_combined in itertools.product(*list(workload_pareto_points.values())): + combined_power_area_tuple = [0,0] + for el in results_combined: + combined_power_area_tuple[0] += el[0] + combined_power_area_tuple[1] += el[1] + combined_area_power.append(combined_power_area_tuple[:]) + """ + + + all_points_in_isolation = [] + all_points_cross_workloads = [] + + workload_in_isolation = {} + for workload, points in workload_results.items(): + if "cava" in workload and "audio" in workload and "edge_detection" in workload: + for point in points: + all_points_cross_workloads.append(point) + else: + workload_in_isolation[workload] = points + + + ctr = 0 + workload_in_isolation_pareto = {} + for workload, points in workload_in_isolation.items(): + optimal_points = find_pareto_points(list(set(points))) + workload_in_isolation_pareto[workload] = optimal_points + #workload_in_isolation_pareto[workload] = points # show all points instead + + + combined_area_power_in_isolation= [] + combined_area_power_in_isolation_with_sys_char = [] + + s = time.time() + for results_combined in itertools.product(*list(workload_in_isolation_pareto.values())): + # add up all the charactersitics + combined_sys_chars = {} + + system_char_to_keep_track_of.add("file") + for el in system_char_to_keep_track_of: + if el =="file": + combined_sys_chars[el] = [] + else: + combined_sys_chars[el] = (0,0) + + # add up area,power + combined_power_area_tuple = [0,0] + for el in results_combined: + combined_power_area_tuple[0] += el[0] + combined_power_area_tuple[1] += el[1] + + sys_char = find_sys_char(el[0], el[1], results_with_sys_char) + for el_,val_ in sys_char.items(): + if el_ == "file": + combined_sys_chars[el_].append(val_) + continue + + if "avg" in el_: + total = sys_char[get_equivalent_total(el_)] + coeff = total + else: + coeff = 1 + #if "latency" in el_: + # combined_sys_chars[el_] = (combined_sys_chars[el_][0]+coeff, str(combined_sys_chars[el_][1])+"_"+val_) + #else: + combined_sys_chars[el_] = (combined_sys_chars[el_][0]+coeff, combined_sys_chars[el_][1]+coeff*float(val_)) + + for key, values in combined_sys_chars.items(): + if "avg" in key: + combined_sys_chars[key] = values[1] /max(values[0],.00000000000000000000000000000001) + elif "file" in key: + combined_sys_chars[key] = values + else: + combined_sys_chars[key] = values[1] + + if float(combined_sys_chars[system_char_to_show[0]]) == 53675449.0: + for f in combined_sys_chars["file"]: + print("/".join(f.split("/")[:-2])+"./runs/0/system_image.dot") + #if float(combined_sys_chars[system_char_to_show[0]]) == 53008113: + + #combined_area_power_in_isolation.append((combined_power_area_tuple[0],combined_power_area_tuple[1], combined_power_area_tuple[2])) + combined_area_power_in_isolation.append((combined_power_area_tuple[0],combined_power_area_tuple[1])) + combined_area_power_in_isolation_with_sys_char.append({(combined_power_area_tuple[0],combined_power_area_tuple[1]): combined_sys_chars}) + + #if len(combined_area_power_in_isolation)%100000 == 0: + # print("time passed is" + str(time.time()-s)) + + combined_area_power_in_isolation_filtered = [] + for point in combined_area_power_in_isolation: + if not points_exceed_one_of_the_budgets(point, base_budgets, budget_scale_to_consider): + combined_area_power_in_isolation_filtered.append(point) + combined_area_power_pareto = find_pareto_points(list(set(combined_area_power_in_isolation_filtered))) + + + all_points_cross_workloads_filtered = [] + for point in all_points_cross_workloads: + if not points_exceed_one_of_the_budgets(point, base_budgets, budget_scale_to_consider): + all_points_cross_workloads_filtered.append(point) + all_points_cross_workloads_area_power_pareto = find_pareto_points(list(set(all_points_cross_workloads_filtered))) + + + # prepare for plotting and plot + fig = plt.figure(figsize=(12, 12)) + #plt.rc('font', **axis_font) + ax = fig.add_subplot(111) + fontSize = 20 + + x_values = [el[0] for el in combined_area_power_in_isolation_filtered] + y_values = [el[1] for el in combined_area_power_in_isolation_filtered] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="isolated design methodology",marker=".") + + + # plt.tight_layout() + x_values = [el[0] for el in combined_area_power_pareto] + y_values = [el[1] for el in combined_area_power_pareto] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="isolated design methodology pareto front",marker="x") + for idx, _ in enumerate(x_values) : + power= x_values[idx] + area = y_values[idx] + sys_char = find_sys_char(power, area, combined_area_power_in_isolation_with_sys_char) + + value_to_show = 0 + value_to_show = sys_char[system_char_to_show[0]] + #for el in system_char_to_show: + # value_to_show += sys_char[el] + + #if system_char_to_show[0] == "latency": + # value_in_scientific_notation = value_to_show + #else: + #value_to_show = sys_char["local_total_traffic"]/(sys_char["local_memory_total_area"]*4*10**12) + value_in_scientific_notation = "{:.10e}".format(value_to_show) + value_in_scientific_notation = value_to_show + #if idx ==0: + + #plt.text(power,area, value_in_scientific_notation) + + + x_values = [el[0] for el in all_points_cross_workloads_filtered] + y_values = [el[1] for el in all_points_cross_workloads_filtered] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="cross workload methodology",marker="8") + ax.legend(loc="upper right") # bbox_to_anchor=(1, 1), loc="upper left") + for idx, _ in enumerate(x_values) : + power= x_values[idx] + area = y_values[idx] + sys_char = find_sys_char(power, area, results_with_sys_char) + + value_to_show = 0 + value_to_show = sys_char[system_char_to_show[0]] + #for el in system_char_to_show: + # value_to_show += sys_char[el] + + #if system_char_to_show[0] == "latency": + # value_in_scientific_notation = value_to_show + #else: + #value_to_show = sys_char["local_total_traffic"]/(sys_char["local_memory_total_area"]*4*10**12) + value_in_scientific_notation = "{:.2e}".format(value_to_show) + + #plt.text(power,area, value_in_scientific_notation) + #plt.text(power,area, value_in_scientific_notation) + + + + x_values = [el[0] for el in all_points_cross_workloads_area_power_pareto] + y_values = [el[1] for el in all_points_cross_workloads_area_power_pareto] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="cross workload pareto front",marker="o") + + for idx, _ in enumerate(x_values) : + power= x_values[idx] + area = y_values[idx] + sys_char = find_sys_char(power, area, results_with_sys_char) + + value_to_show = sys_char[system_char_to_show[0]] + + #if system_char_to_show[0] == "latency": + # value_in_scientific_notation = value_to_show + #else: + #value_to_show = sys_char["local_total_traffic"]/sys_char["local_memory_total_area"] + #value_to_show = sys_char["local_total_traffic"]/(sys_char["local_memory_total_area"]*4*10**12) + value_in_scientific_notation = "{:.2e}".format(value_to_show) + + #plt.text(power,area, value_in_scientific_notation) + + + ax.set_xlabel("power", fontsize=fontSize) + ax.set_ylabel("area", fontsize=fontSize) + plt.tight_layout() + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "budget_optimality/") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + + ax.scatter(combined_design_methodology_A["power"], combined_design_methodology_A["area"], label="methodology A",marker="+") + ax.legend(loc="upper right") # bbox_to_anchor=(1, 1), loc="upper left") + ax.set_title(system_char_to_show[0] +" for FARSI vs in isolation") + #ax.set_title("memory_reuse for FARSI vs in isolation") + fig.savefig(os.path.join(output_dir, system_char_to_show[0] + "_budget_optimality.png")) + + + #plt.show() + plt.close('all') + +def get_budget_optimality_for_paper(input_dir_names,all_result_files, reg_summary_res_column_name_number, a_e_h_summary_res_column_name_number): + #methodology 1: + e = {"power":0.00211483379112716, "area":0.00000652737916} + a = {"power": 0.00133071013064968, "area":0.0000048380066802} + hc = {"power": 0.000471573339922609, "area":0.0000019979449678} + e_budget = {"power":0.00190575, "area":0.0000038115} + a_budget = {"power":0.0015485, "area":0.0000030975} + hc_budget = {"power":0.001005, "area":0.00000201} + + + e_dis_budget = {"power":0, "area":0} + a_dis_budget = {"power":0, "area":0} + hc_dis_budget = {"power":0, "area":0} + + for el in e_dis_budget.keys(): + e_dis_budget[el] = (e_budget[el]-e[el])/e[el] + + for el in a_dis_budget.keys(): + a_dis_budget[el] = (a_budget[el]-a[el])/a[el] + + for el in hc_dis_budget.keys(): + hc_dis_budget[el] = (hc_budget[el]-hc[el])/hc[el] + + + + + + combined_design_methodology_A = {"power":0, "area":0} + for el in e.keys() : + combined_design_methodology_A[el] += e[el] + for el in a.keys() : + combined_design_methodology_A[el] += e[el] + for el in hc.keys() : + combined_design_methodology_A[el] += e[el] + + + def get_equivalent_total(charac): + if charac == "ips_avg_freq": + return "ip_cnt" + if charac == "cluster_pe_cnt_avg": + return "ip_cnt" + elif charac == "avg_accel_parallelism": + return "ip_cnt" + elif charac in ["local_memory_avg_freq"]: + return "local_mem_cnt" + elif charac in ["local_bus_avg_actual_bandwidth", "local_bus_avg_theoretical_bandwidth", "local_bus_avg_bus_width", "avg_freq"]: + return "local_bus_count" + else: + return charac + + def find_sys_char(power,area, results_with_sys_char): + for vals in results_with_sys_char: + for power_area , sys_chars in vals.items(): + power_ = power_area[0] + area_ = power_area[1] + if power == power_ and area_ == area: + return sys_chars + + def points_exceed_one_of_the_budgets(point, base_budget, budget_scaling_to_consider): + power = point[0] + area = point[1] + if power > base_budgets["power"] * budget_scale_to_consider and area > base_budgets[ + "area"] * budget_scale_to_consider: + return True + return False + + + workload_results = {} + results_with_sys_char = [] + + system_char_to_keep_track_of = {"memory_total_area", "local_memory_total_area","pe_total_area", "ip_cnt","ips_total_area", "ips_avg_freq", "local_mem_cnt", + "local_bus_avg_actual_bandwidth", "local_bus_avg_theoretical_bandwidth", "local_memory_avg_freq", "local_bus_count", "local_bus_avg_bus_width", "avg_freq", "local_total_traffic", + "global_total_traffic","local_memory_avg_freq", "global_memory_avg_freq", "gpps_total_area", "avg_gpp_parallelism", "avg_accel_parallelism", "channel_cnt", + "local_total_traffic_reuse_with_read_in_bytes", + } + system_char_to_keep_track_of = {"ip_cnt" + } + #system_char_to_show = ["local_memory_total_area"] + #system_char_to_show = ["avg_accel_parallelism"] + #system_char_to_show = ["avg_gpp_parallelism"] + #system_char_to_show = ["local_bus_avg_actual_bandwidth"] + #system_char_to_show = ["avg_freq"] # really is buses avg freq + #system_char_to_show = ["local_memory_avg_freq"] # really is buses avg freq + #system_char_to_show = ["ips_avg_freq"] + #system_char_to_show = ["gpps_total_area"] + #system_char_to_show = ["local_bus_avg_bus_width"] + #system_char_to_show = ["local_memory_avg_freq"] + #system_char_to_show = ["ips_total_area"] + system_char_to_show = ["ip_cnt"] + #system_char_to_show = ["local_mem_cnt"] + #system_char_to_show = ["global_memory_avg_freq"] + #system_char_to_show = ["local_bus_avg_theoretical_bandwidth"] + #system_char_to_show = ["local_memory_avg_freq"] + #system_char_to_show = ["local_total_traffic"] + #system_char_to_show = ["global_total_traffic"] + #system_char_to_show = ["local_total_traffic_reuse_with_read_in_bytes"] + #system_char_to_show = ["local_bus_count"] + #system_char_to_show = ["cluster_pe_cnt_avg"] + #system_char_to_show = ["ip_cnt"] + #system_char_to_show = ["channel_cnt"] + + # budget scaling to consider + budget_scale_to_consider = .5 + # get budget first + base_budgets = {} + """ + for file in all_result_files: + with open(file, newline='') as csvfile: + if not ("a_e_h" in file and "lat_1__pow_1__area_1" in file): + continue + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + for i, row in enumerate(resultReader): + if i == 1: + print("file"+file) + if float(row[reg_summary_res_column_name_number["budget_scaling_latency"]]) == 1 and\ + float(row[reg_summary_res_column_name_number["budget_scaling_power"]]) == 1 and \ + float(row[reg_summary_res_column_name_number["budget_scaling_area"]]) == 1: + base_budgets["power"] = float(row[summary_res_column_name_number["power_budget"]]) + base_budgets["area"] = float(row[summary_res_column_name_number["area_budget"]]) + break + """ + base_budgets = {"power":.008738,"area":.000017475} + + for file in all_result_files: + if "a_e_h" in file and "lat_1__pow_1__area_1" in file: + continue + if ("a_e_h" in file): + summary_res_column_name_number = a_e_h_summary_res_column_name_number + else: + summary_res_column_name_number = reg_summary_res_column_name_number + + with open(file, newline='') as csvfile: + resultReader = csv.reader(csvfile, delimiter=',', quotechar='|') + for i, row in enumerate(resultReader): + if i == 1: + workload_set_name = row[summary_res_column_name_number["workload_set"]] + if workload_set_name not in workload_results.keys(): + workload_results[workload_set_name] = [] + latency = ((row[summary_res_column_name_number["latency"]].split(";"))[0].split("="))[1] + latency_budget = ((row[summary_res_column_name_number["latency_budget"]].split(";"))[0].split("="))[1] + if float(latency) > float(latency_budget): + continue + + power = float(row[summary_res_column_name_number["power"]]) + area = float(row[summary_res_column_name_number["area"]]) + + system_complexity = row[summary_res_column_name_number["ip_cnt"]] # + row[summary_res_column_name_number["gpp_cnt"]] + #workload_results[workload_set_name].append((float(power),float(area), float(system_complexity))) + workload_results[workload_set_name].append((power,area)) + system_char = {} + for el in system_char_to_keep_track_of: + #if "latency" == el: + # system_char[el] = row[summary_res_column_name_number[el]] + #else: + system_char[el] = float(row[summary_res_column_name_number[el]]) + system_char["file"] = file + point_system_char = {(power, area): system_char} + results_with_sys_char.append(point_system_char) + + + + workload_pareto_points = {} + for workload, points in workload_results.items(): + workload_pareto_points[workload] = find_pareto_points(list(set(points))) + + """" + # combine the results + combined_area_power = [] + for results_combined in itertools.product(*list(workload_pareto_points.values())): + combined_power_area_tuple = [0,0] + for el in results_combined: + combined_power_area_tuple[0] += el[0] + combined_power_area_tuple[1] += el[1] + combined_area_power.append(combined_power_area_tuple[:]) + """ + + + all_points_in_isolation = [] + all_points_cross_workloads = [] + + workload_in_isolation = {} + for workload, points in workload_results.items(): + if "cava" in workload and "audio" in workload and "edge_detection" in workload: + for point in points: + all_points_cross_workloads.append(point) + else: + workload_in_isolation[workload] = points + + + ctr = 0 + workload_in_isolation_pareto = {} + for workload, points in workload_in_isolation.items(): + optimal_points = find_pareto_points(list(set(points))) + workload_in_isolation_pareto[workload] = optimal_points + #workload_in_isolation_pareto[workload] = points # show all points instead + + + combined_area_power_in_isolation= [] + combined_area_power_in_isolation_with_sys_char = [] + + s = time.time() + for results_combined in itertools.product(*list(workload_in_isolation_pareto.values())): + # add up all the charactersitics + combined_sys_chars = {} + + system_char_to_keep_track_of.add("file") + for el in system_char_to_keep_track_of: + if el =="file": + combined_sys_chars[el] = [] + else: + combined_sys_chars[el] = (0,0) + + # add up area,power + combined_power_area_tuple = [0,0] + for el in results_combined: + combined_power_area_tuple[0] += el[0] + combined_power_area_tuple[1] += el[1] + + sys_char = find_sys_char(el[0], el[1], results_with_sys_char) + for el_,val_ in sys_char.items(): + if el_ == "file": + combined_sys_chars[el_].append(val_) + continue + + if "avg" in el_: + total = sys_char[get_equivalent_total(el_)] + coeff = total + else: + coeff = 1 + #if "latency" in el_: + # combined_sys_chars[el_] = (combined_sys_chars[el_][0]+coeff, str(combined_sys_chars[el_][1])+"_"+val_) + #else: + combined_sys_chars[el_] = (combined_sys_chars[el_][0]+coeff, combined_sys_chars[el_][1]+coeff*float(val_)) + + for key, values in combined_sys_chars.items(): + if "avg" in key: + combined_sys_chars[key] = values[1] /max(values[0],.00000000000000000000000000000001) + elif "file" in key: + combined_sys_chars[key] = values + else: + combined_sys_chars[key] = values[1] + + if float(combined_sys_chars[system_char_to_show[0]]) == 53675449.0: + for f in combined_sys_chars["file"]: + print("/".join(f.split("/")[:-2])+"./runs/0/system_image.dot") + #if float(combined_sys_chars[system_char_to_show[0]]) == 53008113: + + #combined_area_power_in_isolation.append((combined_power_area_tuple[0],combined_power_area_tuple[1], combined_power_area_tuple[2])) + combined_area_power_in_isolation.append((combined_power_area_tuple[0],combined_power_area_tuple[1])) + combined_area_power_in_isolation_with_sys_char.append({(combined_power_area_tuple[0],combined_power_area_tuple[1]): combined_sys_chars}) + + #if len(combined_area_power_in_isolation)%100000 == 0: + # print("time passed is" + str(time.time()-s)) + + combined_area_power_in_isolation_filtered = [] + for point in combined_area_power_in_isolation: + if not points_exceed_one_of_the_budgets(point, base_budgets, budget_scale_to_consider): + combined_area_power_in_isolation_filtered.append(point) + combined_area_power_pareto = find_pareto_points(list(set(combined_area_power_in_isolation_filtered))) + + + all_points_cross_workloads_filtered = [] + for point in all_points_cross_workloads: + if not points_exceed_one_of_the_budgets(point, base_budgets, budget_scale_to_consider): + all_points_cross_workloads_filtered.append(point) + all_points_cross_workloads_area_power_pareto = find_pareto_points(list(set(all_points_cross_workloads_filtered))) + + + # prepare for plotting and plot + fig = plt.figure(figsize=(7, 7)) # Ying: (7.5, 7.5) for the main paper, (5, 5) for the extended abstract + axis_font = {'size':'30'} + fontSize = 30 + plt.rc('font', **axis_font) + ax = fig.add_subplot(111) + # ax.set_xscale('log') + + x_values = [(el[0]) for el in combined_area_power_in_isolation_filtered] + y_values = [(el[1]) for el in combined_area_power_in_isolation_filtered] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="Myopic Opt",marker="+", color='gold', alpha=1, s=700) + + + # plt.tight_layout() + x_values = [(el[0]) for el in combined_area_power_pareto] + y_values = [(el[1]) for el in combined_area_power_pareto] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="Myopic Opt Pareto Front",marker="+", color='darkorange', alpha=0.8, s=700) + # for idx, _ in enumerate(x_values) : + # power= x_values[idx] + # area = y_values[idx] + # sys_char = find_sys_char(power, area, combined_area_power_in_isolation_with_sys_char) + # + # value_to_show = 0 + # value_to_show = sys_char[system_char_to_show[0]] + # #for el in system_char_to_show: + # # value_to_show += sys_char[el] + # + # #if system_char_to_show[0] == "latency": + # # value_in_scientific_notation = value_to_show + # #else: + # #value_to_show = sys_char["local_total_traffic"]/(sys_char["local_memory_total_area"]*4*10**12) + # value_in_scientific_notation = "{:.10e}".format(value_to_show) + # value_in_scientific_notation = value_to_show + # #if idx ==0: + # + # #plt.text(power,area, value_in_scientific_notation) + + + x_values = [(el[0]) for el in all_points_cross_workloads_filtered] + y_values = [(el[1]) for el in all_points_cross_workloads_filtered] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="FARSI",marker="*", color='limegreen', alpha=1, s=700) + # ax.legend(loc="upper right") # bbox_to_anchor=(1, 1), loc="upper left") + # for idx, _ in enumerate(x_values) : + # power= x_values[idx] + # area = y_values[idx] + # sys_char = find_sys_char(power, area, results_with_sys_char) + # + # value_to_show = 0 + # value_to_show = sys_char[system_char_to_show[0]] + # #for el in system_char_to_show: + # # value_to_show += sys_char[el] + # + # #if system_char_to_show[0] == "latency": + # # value_in_scientific_notation = value_to_show + # #else: + # #value_to_show = sys_char["local_total_traffic"]/(sys_char["local_memory_total_area"]*4*10**12) + # value_in_scientific_notation = "{:.2e}".format(value_to_show) + # + # #plt.text(power,area, value_in_scientific_notation) + # #plt.text(power,area, value_in_scientific_notation) + + + + x_values = [(el[0]) for el in all_points_cross_workloads_area_power_pareto] + y_values = [(el[1]) for el in all_points_cross_workloads_area_power_pareto] + x_values.reverse() + y_values.reverse() + ax.scatter(x_values, y_values, label="FARSI Pareto Front",marker="*", color='darkgreen', alpha=0.8, s=700) + + # for idx, _ in enumerate(x_values) : + # # power= x_values[idx] + # # area = y_values[idx] + # # sys_char = find_sys_char(power, area, results_with_sys_char) + # # + # # value_to_show = sys_char[system_char_to_show[0]] + # # + # # #if system_char_to_show[0] == "latency": + # # # value_in_scientific_notation = value_to_show + # # #else: + # # #value_to_show = sys_char["local_total_traffic"]/sys_char["local_memory_total_area"] + # # #value_to_show = sys_char["local_total_traffic"]/(sys_char["local_memory_total_area"]*4*10**12) + # # value_in_scientific_notation = "{:.2e}".format(value_to_show) + # # + # # #plt.text(power,area, value_in_scientific_notation) + + # ax.set_xticks(np.arange(0.002, 0.009, 0.002)) + + ax.set_xlabel("Power (W)", fontsize=fontSize) + ax.set_ylabel("Area (mm2)", fontsize=fontSize) + + + # dump in the top folder + output_base_dir = '/'.join(input_dir_names[0].split("/")[:-2]) + output_dir = os.path.join(output_base_dir, "budget_optimality/") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + + ax.scatter(combined_design_methodology_A["power"], combined_design_methodology_A["area"], label="Myopic Budgetting",marker="x", color='red', s=700) # Ying: comment out for the extended abstract + ax.legend(loc="upper left", fontsize=fontSize, bbox_to_anchor=(1, 1.1), borderpad=1) # , borderpad=0) # bbox_to_anchor=(1, 1), loc="upper left") + # ax.set_title(system_char_to_show[0] +" for FARSI vs in isolation") + #ax.set_title("memory_reuse for FARSI vs in isolation") + plt.tight_layout() + fig.savefig(os.path.join(output_dir, system_char_to_show[0] + "_budget_optimality.png"), bbox_inches='tight') + + + #plt.show() + plt.close('all') + +def find_pareto_points(points): + efficients = is_pareto_efficient_dumb(np.array(points)) + pareto_points_array = [points[idx] for idx, el in enumerate(efficients) if el] + + return pareto_points_array + + pareto_points = [] + for el in pareto_points_array: + list_ = [] + for el_ in el: + list.append(el) + pareto_points.append(list_) + + return pareto_points + + +def is_pareto_efficient_dumb(costs): + is_efficient = np.ones(costs.shape[0], dtype = bool) + for i, c in enumerate(costs): + is_efficient[i] = np.all(np.any(costs[:i]>c, axis=1)) and np.all(np.any(costs[i+1:]>c, axis=1)) + return is_efficient + + + +########################################### + +# the main function. comment out the plots if you do not need them +if __name__ == "__main__": + # populate parameters + run_folder_name = config_plotting.run_folder_name + if config_plotting.run_folder_name == "": + run_folder_name = find_the_most_recent_directory(config_plotting.top_result_folder)[0] + + zoneNum = config_plotting.zoneNum + # get all the experiments under the run folder + print(run_folder_name) + experiment_full_addr_list = get_experiment_dir_list(run_folder_name) + + # according to the plot type, plot + all_res_column_name_number = get_column_name_number(experiment_full_addr_list[0], "all") + all_results_files = get_experiment_full_file_addr_list(experiment_full_addr_list) + aggregate_res_column_name_number = get_column_name_number(experiment_full_addr_list[0], "aggregate") + + summary_res_column_name_number = get_column_name_number(experiment_full_addr_list[0], "simple") + case_studies = {} + case_studies["bandwidth_analysis"] = ["local_bus_avg_theoretical_bandwidth", + "local_bus_max_actual_bandwidth", + "local_bus_avg_actual_bandwidth", + "system_bus_avg_theoretical_bandwidth", + "system_bus_max_actual_bandwidth", + "system_bus_avg_actual_bandwidth", + "local_channel_avg_actual_bandwidth", + "local_channel_max_actual_bandwidth" + ] + + + case_studies["freq_analysis"] = [ + "global_memory_avg_freq", "local_memory_avg_freq", "local_bus_avg_freq",] + + case_studies["bus_width_analysis"] = [ + "global_memory_avg_bus_width","local_memory_avg_bus_width","local_bus_avg_bus_width"] + + case_studies["traffic_analysis"] = ["global_total_traffic", "local_total_traffic", + "local_memory_traffic_per_mem_avg", + "locality_in_bytes", + "local_memory_traffic_per_mem_avg", + "local_bus_traffic_avg", + ] + + + case_studies["local_mem_re_use"] =[ + "local_total_traffic_reuse_no_read_ratio", + "local_total_traffic_reuse_no_read_in_bytes", + "local_total_traffic_reuse_no_read_in_size", + "local_total_traffic_reuse_with_read_ratio", + "local_total_traffic_reuse_with_read_in_bytes", + "local_total_traffic_reuse_with_read_in_size", + "local_total_traffic_reuse_no_read_in_bytes_per_cluster_avg", + "local_total_traffic_reuse_no_read_in_size_per_cluster_avg", + "local_total_traffic_reuse_with_read_in_bytes_per_cluster_avg", + "local_total_traffic_reuse_with_read_in_size_per_cluster_avg" + ] + + case_studies["global_mem_re_use"] =[ + "global_total_traffic_reuse_no_read_ratio", + "global_total_traffic_reuse_with_read_ratio", + "global_total_traffic_reuse_with_read_in_bytes", + "global_total_traffic_reuse_with_read_in_size", + "global_total_traffic_reuse_no_read_in_bytes", + "global_total_traffic_reuse_no_read_in_size", + ] + + + case_studies["area_analysis"] = ["global_memory_total_area", "local_memory_total_area", "ips_total_area", + "gpps_total_area", + ] + case_studies["area_in_bytes_analysis"] = ["global_memory_total_bytes", "local_memory_total_bytes", "local_memory_bytes_avg" + ] + + case_studies["accel_paral_analysis"] = ["ip_cnt","max_accel_parallelism", "avg_accel_parallelism", + "gpp_cnt", "max_gpp_parallelism", "avg_gpp_parallelism"] + case_studies["system_complexity"] = ["system block count", "routing complexity", "system PE count", + "local_mem_cnt", "local_bus_cnt","local_channel_count_per_bus_avg", "channel_cnt", + "loop_itr_ratio_avg", + ] # , "channel_cnt"] + + case_studies["heterogeneity_var_system_compleixty"] = [ + "local_channel_count_per_bus_coeff_var", + "loop_itr_ratio_var", + # "cluster_pe_cnt_coeff_var" + ] + + case_studies["heterogeneity_std_system_compleixty"] = [ + "local_channel_count_per_bus_std", + "loop_itr_ratio_std" # , "cluster_pe_cnt_std" + ] + + """ + case_studies["speedup"] = [ + # "customization_speed_up_full_system", + # "loop_unrolling_parallelism_speed_up_full_system", + "customization_speed_up_full_system", + "task_level_parallelism_speed_up_full_system", + "interference_degradation_avg"] + """ + """ + [ + "customization_first_speed_up_avg", + "customization_second_speed_up_avg", + "parallelism_first_speed_up_avg", + "parallelism_second_speed_up_avg", + "interference_degradation_avg", + "customization_first_speed_up_full_system", + "customization_second_speed_up_full_system", + "parallelism_first_speed_up_full_system", + "parallelism_second_speed_up_full_system", + ] + case_studies["speedup"] = [ + "interference_degradation_avg", + "customization_speed_up_full_system", + "parallelism_speed_up_full_system", + "parallelism_nd_speed_up_full_system", + ] + + + """ + + + case_studies["heterogenity_area"] = [ + "local_memory_area_coeff_var", + "ips_area_coeff_var", + "pes_area_coeff_var", + + ] + + + case_studies["heterogenity_std_re_use"] = [ + "local_total_traffic_reuse_no_read_in_bytes_per_cluster_std", + "local_total_traffic_reuse_no_read_in_size_per_cluster_std", + "local_total_traffic_reuse_with_read_in_bytes_per_cluster_std", + "local_total_traffic_reuse_with_read_in_size_per_cluster_std", + ] + + case_studies["heterogenity_var_re_use"] = [ + "local_total_traffic_reuse_no_read_in_bytes_per_cluster_var", + "local_total_traffic_reuse_no_read_in_size_per_cluster_var", + "local_total_traffic_reuse_with_read_in_bytes_per_cluster_var", + "local_total_traffic_reuse_with_read_in_size_per_cluster_var", + ] + + case_studies["heterogenity_var_freq"] =[ + "local_bus_freq_coeff_var", + "local_memory_freq_coeff_var", + "ips_freq_coeff_var", + "pes_freq_coeff_var"] + + case_studies["heterogenity_std_freq"] =[ + "local_memory_freq_std", + "local_bus_freq_std", + ] + + + + case_studies["heterogenity_std_bus_width"] =[ + "local_memory_bus_width_std", + "local_bus_bus_width_std", + ] + + case_studies["heterogenity_var_bus_width"] =[ + "local_memory_bus_width_coeff_var", + "local_bus_bus_width_coeff_var", + ] + + + + + case_studies["heterogenity_std_bandwidth"]=[ + "local_bus_actual_bandwidth_std", + "local_channel_actual_bandwidth_std"] + + case_studies["heterogenity_var_bandwidth"]=[ + "local_bus_actual_bandwidth_coeff_var", + "local_channel_actual_bandwidth_coeff_var"] + + + + case_studies["heterogenity_std_traffic"] =[ + "local_memory_bytes_std", + "local_memory_traffic_per_mem_coeff_var", + "local_bus_traffic_coeff_var", + ] + + + case_studies["heterogenity_var_traffic"] =[ + "local_memory_bytes_coeff_var", + "local_memory_traffic_per_mem_coeff_var", + "local_bus_traffic_coeff_var", + ] + + if "heuristic_comparison" in config_plotting.plot_list: # Ying: optimal_budgetting_problem_08_1 + #experiment_full_addr_list = get_experiment_dir_list(config_plotting.heuristic_comparison_folder) + #all_dirs = [x[0] for x in os.walk(config_plotting.heuristic_comparison_folder)] + + metrics = ["heuristic_type", "ref_des_dist_to_goal_non_cost"] + all_files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(config_plotting.heuristic_comparison_folder) for f in filenames if + os.path.splitext(f)[1] == '.csv'] + aggregate_results = [f for f in all_files if "aggregate_all_results" in f and not ("prev_iter" in f)] + + heuristic_comparison(experiment_full_addr_list, aggregate_results, aggregate_res_column_name_number) + + if "heuristic_scaling" in config_plotting.plot_list: # Ying: optimal_budgetting_problem_08_1 + metrics = ["heuristic_type", "ref_des_dist_to_goal_non_cost"] + all_files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(config_plotting.heuristic_comparison_folder) for f in filenames if + os.path.splitext(f)[1] == '.csv'] + aggregate_results = [f for f in all_files if "aggregate_all_results" in f and not ("prev_iter" in f)] + + heuristic_scaling(experiment_full_addr_list, aggregate_results, aggregate_res_column_name_number) + + """ + if "pareto_studies" in config_plotting.plot_list: + metrics = ["heuristic_type", "ref_des_dist_to_goal_non_cost"] + all_files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(config_plotting.heuristic_comparison_folder) + for f in filenames if + os.path.splitext(f)[1] == '.csv'] + aggregate_results = [f for f in all_files if "aggregate_all_results" in f and not ("prev_iter" in f)] + + heuristic_hv = pareto_studies(experiment_full_addr_list, aggregate_results, aggregate_res_column_name_number) + """ + + + if "budget_optimality" in config_plotting.plot_list: # Ying: optimal_budgetting_problem_08_1 + #get_budget_optimality_advanced(experiment_full_addr_list, all_results_files, summary_res_column_name_number) + + for el in experiment_full_addr_list: + if "a_e_h" in el and not "lat_1__pow_1__area_1" in el: + a_e_h_summary_res_column_name_number = get_column_name_number(el, "simple") + + if config_plotting.draw_for_paper: + get_budget_optimality_for_paper(experiment_full_addr_list, all_results_files, summary_res_column_name_number, a_e_h_summary_res_column_name_number) + else: + get_budget_optimality(experiment_full_addr_list, all_results_files, summary_res_column_name_number, a_e_h_summary_res_column_name_number) + + if "cross_workloads" in config_plotting.plot_list: # Ying: from for_paper/workload_awareness (PE, Mem, IC, TLP, comm_comp); or blind_study_smart_krnel_selection/blind_vs_arch_ware; or blind_study_dumb_kernel_selection/blind_vs_arch_aware; or blind_study_all_dumb_versions/blind_vs_arch_aware (SA and other aware) + # get column orders (assumption is that the column order doesn't change between experiments) + if config_plotting.draw_for_paper: + column_column_value_experiment_frequency_dict = plot_codesign_nav_breakdown_cross_workload_for_paper( + experiment_full_addr_list, all_res_column_name_number) + plot_convergence_cross_workloads_for_paper(experiment_full_addr_list, all_res_column_name_number) + + else: + column_column_value_experiment_frequency_dict = plot_codesign_nav_breakdown_cross_workload(experiment_full_addr_list, all_res_column_name_number) + plot_convergence_cross_workloads(experiment_full_addr_list, all_res_column_name_number) + + for key, val in case_studies.items(): + case_study = {key:val} + # plot_system_implication_analysis(experiment_full_addr_list, summary_res_column_name_number, case_study) + plot_co_design_nav_breakdown_post_processing(experiment_full_addr_list, column_column_value_experiment_frequency_dict) + if config_plotting.draw_for_paper: + plot_codesign_rate_efficacy_cross_workloads_updated_for_paper(experiment_full_addr_list, all_res_column_name_number) + else: + #plot_codesign_rate_efficacy_cross_workloads_updated(experiment_full_addr_list, all_res_column_name_number) + pass + + if "single_workload" in config_plotting.plot_list: # Ying: blind_study_all_dumb_versions/blind_vs_arch_aware + #single workload + # plot_codesign_progression_per_workloads(experiment_full_addr_list, all_res_column_name_number) + _ = plot_codesign_nav_breakdown_per_workload(experiment_full_addr_list, all_res_column_name_number) + + if config_plotting.draw_for_paper: + # plot_convergence_per_workloads_for_paper(experiment_full_addr_list, all_res_column_name_number) + plot_convergence_vs_time_for_paper(experiment_full_addr_list, all_res_column_name_number) + else: + plot_convergence_per_workloads(experiment_full_addr_list, all_res_column_name_number) + plot_convergence_vs_time(experiment_full_addr_list, all_res_column_name_number) + + if "plot_3d" in config_plotting.plot_list: + plot_3d(experiment_full_addr_list, summary_res_column_name_number) + + if "pie_chart" in config_plotting.plot_list: # Ying: 1_1_1_for_paper_07-31 + pie_chart_case_study = {"Performance Breakdown": ["transformation generation time", "simulation time", + "neighbour selection time"], + "Transformation_Generation_Breakdown": ["metric selection time", "dir selection time", "kernel selection time", + "block selection time", "transformation selection time", + "design duplication time", "metric selection time"]} + # , "architectural principle", "high level optimization name", "exact optimization name"] + + for case_study_ in pie_chart_case_study.items(): + if config_plotting.draw_for_paper: + pie_chart_for_paper(experiment_full_addr_list, all_res_column_name_number, case_study_) + else: + pie_chart(experiment_full_addr_list, all_res_column_name_number, case_study_) + + if "pandas_plots" in config_plotting.plot_list: # Ying: from scaling_of_1_2_4_across_all_budgets_07-31 + #pandas_case_studies = {} + case_studies["system_complexity"] = ["system block count", "routing complexity", "system PE count", + "local_mem_cnt", "local_bus_cnt" , "channel_cnt", "ip_cnt", "gpp_cnt"] + + case_studies["pe_parallelism"] = ["max_accel_parallelism", "avg_accel_parallelism", "avg_gpp_parallelism", "max_gpp_parallelism"] + + case_studies["ip_frequency"] = ["ips_avg_freq", "gpps_avg_freq", "ips_freq_std", "pes_freq_std", + "ips_freq_coeff_var", "pes_freq_coeff_var"] + + case_studies["pe_area"] = ["ips_total_area", "gpps_total_area", "ips_area_std", "pes_area_std", + "ips_area_coeff_var", "pes_area_coeff_var"] + + case_studies["mem_frequency"] = ["local_memory_avg_freq", "global_memory_avg_freq", + "local_memory_freq_std","local_memory_freq_coeff_var"] + + case_studies["mem_area"] = ["local_memory_total_area", "global_memory_total_area", "local_memory_area_std", + "local_memory_area_coeff_var"] + + case_studies["traffic"] = ["local_total_traffic", "global_total_traffic"] + + + case_studies["bus_width"] = ["local_bus_avg_bus_width", + "system_bus_avg_bus_width"] + + + case_studies["bus_bandwidth"] = ["local_bus_avg_actual_bandwidth", "system_bus_avg_actual_bandwidth", + "local_bus_avg_theoretical_bandwidth", "system_bus_avg_theoretical_bandwidth", + "local_bus_max_actual_bandwidth", "system_bus_max_actual_bandwidth"] + + + for case_study_name, metrics in case_studies.items(): + for metric in metrics: + if config_plotting.draw_for_paper: + pandas_plots_for_paper(experiment_full_addr_list, all_results_files, metric) + else: + pandas_plots(experiment_full_addr_list, all_results_files, metric) + + # get the the workload_set folder + # each workload_set has a bunch of experiments underneath it + workload_set_folder_list = os.listdir(run_folder_name) + + if "drivers" in config_plotting.plot_list: + simple_stack_bar_plot(run_folder_name) + + # iterate and generate plots + for workload_set_folder in workload_set_folder_list: + # ignore irelevant files + if workload_set_folder in config_plotting.ignore_file_names: + continue + + # start plotting + #plotBudgets3d(run_folder_name, workload_set_folder) + + + """ + # get experiment folder + workload_set_full_addr = os.path.join(run_folder_name,workload_set_folder) + folder_list = os.listdir(workload_set_full_addr) + for experiment_name_relative_addr in folder_list: + print(experiment_name_relative_addr) + if experiment_name_relative_addr in config_plotting.ignore_file_names: + continue + experiment_full_addr = os.path.join(workload_set_full_addr, experiment_name_relative_addr) + + all_res_column_name_number = get_column_name_number(experiment_full_addr, "all") + summary_res_column_name_number = get_column_name_number(experiment_full_addr, "simple") + + workload_set_full_addr +="/" # this is because you didn't use join + commcompColNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "comm_comp", "all") + trueNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "move validity", "all") + optColNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "high level optimization name", "all") + archColNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "architectural principle", "all") + sysBlkNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "system block count", "all") + simColNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "simulation time", "all") + movGenColNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "transformation generation time", "all") + movColNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "move name", "all") + itrNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "iteration cnt", "all") + distColNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "dist_to_goal_non_cost", "all") + refDistColNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "ref_des_dist_to_goal_non_cost", "all") + latNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "latency", "all") + powNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "power", "all") + areaNum = columnNum(workload_set_full_addr, experiment_name_relative_addr, "area", "all") + + # comment or uncomment the following functions for your plottings + plotDistToGoalVSitr([experiment_full_addr], all_res_column_name_number) + plotCommCompAll(workload_set_full_addr, experiment_name_relative_addr, all_res_column_name_number) + plothighLevelOptAll(workload_set_full_addr, experiment_name_relative_addr, all_res_column_name_number) + plotArchVarImpAll(workload_set_full_addr, experiment_name_relative_addr, archColNum, trueNum) + plotSimTimeVSblk(workload_set_full_addr, experiment_name_relative_addr, sysBlkNum, simColNum, trueNum) + plotMoveGenTimeVSblk(workload_set_full_addr, experiment_name_relative_addr, sysBlkNum, movGenColNum, trueNum) + plotRefDistToGoalVSitr(workload_set_full_addr, experiment_name_relative_addr, itrNum, refDistColNum, trueNum) + plotSimTimeVSmoveNameZoneDist(workload_set_full_addr, experiment_name_relative_addr, zoneNum, movColNum, distColNum, simColNum, trueNum) + plotMovGenTimeVSmoveNameZoneDist(workload_set_full_addr, experiment_name_relative_addr, zoneNum, movColNum, distColNum, movGenColNum, trueNum) + plotSimTimeVScommCompZoneDist(workload_set_full_addr, experiment_name_relative_addr, zoneNum, commcompColNum, distColNum, simColNum, trueNum) + plotMovGenTimeVScommCompZoneDist(workload_set_full_addr, experiment_name_relative_addr, zoneNum, commcompColNum, distColNum, movGenColNum, trueNum) + plotSimTimeVShighLevelOptZoneDist(workload_set_full_addr, experiment_name_relative_addr, zoneNum, optColNum, distColNum, simColNum, trueNum) + plotMovGenTimeVShighLevelOptZoneDist(workload_set_full_addr, experiment_name_relative_addr, zoneNum, optColNum, distColNum, movGenColNum, trueNum) + plotSimTimeVSarchVarImpZoneDist(workload_set_full_addr, experiment_name_relative_addr, zoneNum, archColNum, distColNum, simColNum, trueNum) + plotMovGenTimeVSarchVarImpZoneDist(workload_set_full_addr, experiment_name_relative_addr, zoneNum, archColNum, distColNum, movGenColNum, trueNum) + """ diff --git a/Project_FARSI/visualization_utils/plotting_Iulian.py b/Project_FARSI/visualization_utils/plotting_Iulian.py new file mode 100644 index 00000000..db1d96b8 --- /dev/null +++ b/Project_FARSI/visualization_utils/plotting_Iulian.py @@ -0,0 +1,514 @@ +import pandas as pd +import seaborn as sns +import sys +import matplotlib.pyplot as plt +import numpy as np +sys.path.append("..") +#from plot_validations import * +from sklearn.linear_model import LinearRegression +from settings import config_plotting +import os +def abline(slope, intercept, color): + """Plot a line from slope and intercept""" + axes = plt.gca() + x_vals = np.array(axes.get_xlim()) + y_vals = intercept + slope * x_vals + plt.plot(x_vals, y_vals, '--', color = color) + +def get_df_as_avg_for_each_x_coord(reformatted_df, x_coord_name, y_coord_name = "Simulation Time"): + avg_df_lst = [] + for x_coord in set(reformatted_df[x_coord_name]): + #print("hola") + #print(reformatted_df.loc[(reformatted_df[x_coord_name] == x_coord) & (reformatted_df["FARSI or PA"] == "FARSI")]) + simtimes_farsi = list(reformatted_df.loc[(reformatted_df[x_coord_name] == x_coord) & (reformatted_df["FARSI or PA"] == "FARSI")][y_coord_name]) + simtimes_pa = list(reformatted_df.loc[(reformatted_df[x_coord_name] == x_coord) & (reformatted_df["FARSI or PA"] == "PA")][y_coord_name]) + print("simtimes_farsi") + print(simtimes_farsi) + print(np.average(simtimes_farsi)) + print("simtimes_pa") + print(simtimes_pa) + print(np.average(simtimes_pa)) + avg_df_lst.append([np.average(simtimes_farsi), "FARSI", x_coord]) + avg_df_lst.append([np.average(simtimes_pa), "PA", x_coord]) + return pd.DataFrame(avg_df_lst, columns = ["Simulation Time", "FARSI or PA", x_coord_name]) + +#not used yet in this script +def get_df_as_avg_for_each_x_coord(reformatted_df, x_coord_name, y_coord_name = "Simulation Time", hue_col = "FARSI or PA"): + hues = set(list(reformatted_df[hue_col])) + avg_df_lst = [] + for x_coord in set(reformatted_df[x_coord_name]): + #print("hola") + #print(reformatted_df.loc[(reformatted_df[x_coord_name] == x_coord) & (reformatted_df["FARSI or PA"] == "FARSI")]) + for hue in hues: + selectedy_hue = list(reformatted_df.loc[(reformatted_df[x_coord_name] == x_coord) & (reformatted_df[hue_col] == hue)][y_coord_name]) + avg_df_lst.append([np.average(selectedy_hue), hue, x_coord]) + #simtimes_pa = list(reformatted_df.loc[(reformatted_df[x_coord_name] == x_coord) & (reformatted_df["FARSI or PA"] == "PA")][y_coord_name]) + return pd.DataFrame(avg_df_lst, columns = [y_coord_name, hue_col, x_coord_name]) + + + +def plot_sim_time_vs_system_char_minimal(output_dir, csv_file_addr): + data = pd.read_csv(csv_file_addr) + blk_cnt = list(data["blk_cnt"]) + pa_sim_time = list(data["PA simulation time"]) + farsi_sim_time = list(data["FARSI simulation time"]) + tmp_reformatted_df_data = [blk_cnt * 2, pa_sim_time + farsi_sim_time, + ["PA"] * len(blk_cnt) + ["FARSI"] * len(blk_cnt)] + reformatted_df_data = [[tmp_reformatted_df_data[j][i] for j in range(len(tmp_reformatted_df_data))] for i in + range(len(blk_cnt) * 2)] + # print(reformatted_df_data[0:3]) + # exit() + # for col in reformatted_df_data: + # print("Len of col is {}".format(len(col))) + reformatted_df = pd.DataFrame(reformatted_df_data, + columns=["Block counts", "Simulation Time", + "FARSI or PA"]) + print(reformatted_df.head()) + + df_blk_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "Block counts") + + + + df_avg = get_df_as_avg_for_each_x_coord(reformatted_df, x_coord_name = "Block counts", y_coord_name = "Simulation Time", hue_col = "FARSI or PA") + + #df_pe_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "PE counts") + #df_mem_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "Mem counts") + #df_bus_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "Bus counts") + + #print("Bola") + #print(df_blk_avg) + + + + splot = sns.scatterplot(data=df_avg, x="Block counts", y="Simulation Time", hue="FARSI or PA") + splot.set(yscale="log") + + color_per_hue = {"FARSI" : "green", "PA" : "orange"} + hues = set(list(df_avg["FARSI or PA"])) + for hue in hues: + #x required to be in matrix format in sklearn + print(np.isnan(df_avg["Simulation Time"])) + xs_hue = [[x] for x in list(df_avg.loc[(df_avg["FARSI or PA"] == hue) & (df_avg["Simulation Time"].notnull())]["Block counts"])] + ys_hue = np.array(list(df_avg.loc[(df_avg["FARSI or PA"] == hue) & (df_avg["Simulation Time"].notnull())]["Simulation Time"])) + print("xs_hue") + print(xs_hue) + + print("ys_hue") + print(ys_hue) + reg = LinearRegression().fit(xs_hue, ys_hue) + m = reg.coef_[0] + n = reg.intercept_ + abline(m, n, color_per_hue[hue]) + #plt.set_ylim(top = 10) + + + plt.savefig(os.path.join(output_dir,'block_counts_vs_simtime.png')) + + plt.close("all") + +def plot_sim_time_vs_system_char_minimal_for_paper(output_dir, csv_file_addr): + data = pd.read_csv(csv_file_addr) + blk_cnt = list(data["blk_cnt"]) + pa_sim_time = list(data["PA simulation time"]) + farsi_sim_time = list(data["FARSI simulation time"]) + tmp_reformatted_df_data = [blk_cnt * 2, pa_sim_time + farsi_sim_time, + ["PA"] * len(blk_cnt) + ["FARSI"] * len(blk_cnt)] + reformatted_df_data = [[tmp_reformatted_df_data[j][i] for j in range(len(tmp_reformatted_df_data))] for i in + range(len(blk_cnt) * 2)] + # print(reformatted_df_data[0:3]) + # exit() + # for col in reformatted_df_data: + # print("Len of col is {}".format(len(col))) + reformatted_df = pd.DataFrame(reformatted_df_data, + columns=["Block Counts", "Simulation Time", + "FARSI or PA"]) + print(reformatted_df.head()) + + df_blk_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "Block Counts") + + + + df_avg = get_df_as_avg_for_each_x_coord(reformatted_df, x_coord_name = "Block Counts", y_coord_name = "Simulation Time", hue_col = "FARSI or PA") + + #df_pe_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "PE counts") + #df_mem_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "Mem counts") + #df_bus_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "Bus counts") + + #print("Bola") + #print(df_blk_avg) + + axis_font = {'size': '20'} + fontSize = 20 + sns.set(font_scale=2, rc={'figure.figsize': (6, 4)}) + sns.set_style("white") + color_per_hue = {'PA': 'hotpink', 'FARSI': 'green'} + splot = sns.scatterplot(data=df_avg, x="Block Counts", y="Simulation Time", hue="FARSI or PA", sizes=(6, 6), palette=color_per_hue) + splot.set(yscale="log") + splot.legend(title="", fontsize=fontSize, loc="center right") + + hues = set(list(df_avg["FARSI or PA"])) + for hue in hues: + #x required to be in matrix format in sklearn + print(np.isnan(df_avg["Simulation Time"])) + xs_hue = [[x] for x in list(df_avg.loc[(df_avg["FARSI or PA"] == hue) & (df_avg["Simulation Time"].notnull())]["Block Counts"])] + ys_hue = np.array(list(df_avg.loc[(df_avg["FARSI or PA"] == hue) & (df_avg["Simulation Time"].notnull())]["Simulation Time"])) + print("xs_hue") + print(xs_hue) + + print("ys_hue") + print(ys_hue) + reg = LinearRegression().fit(xs_hue, ys_hue) + m = reg.coef_[0] + n = reg.intercept_ + abline(m, n, color_per_hue[hue]) + #plt.set_ylim(top = 10) + + plt.xticks(np.arange(0, 30, 10.0)) + plt.yticks(np.power(10.0, [-1, 0, 1, 2, 3])) + plt.xlabel("Block Counts") + plt.ylabel("Simulation Time (s)") + plt.tight_layout() + plt.savefig(os.path.join(output_dir,'block_counts_vs_simtime.png'), bbox_inches='tight') + # plt.show() + plt.close("all") + +""" +def plot_sim_time_vs_system_char(output_dir, csv_file_addr): + data = pd.read_csv(csv_file_addr) + blk_cnt = list(data["blk_cnt"]) + pe_cnt = list(data["pe_cnt"]) + mem_cnt = list(data["mem_cnt"]) + bus_cnt = list(data["bus_cnt"]) + pa_sim_time = list(data["PA simulation time"]) + farsi_sim_time = list(data["FARSI simulation time"]) + tmp_reformatted_df_data = [blk_cnt * 2, pe_cnt * 2, mem_cnt * 2, bus_cnt * 2, pa_sim_time + farsi_sim_time, + ["PA"] * len(blk_cnt) + ["FARSI"] * len(blk_cnt)] + reformatted_df_data = [[tmp_reformatted_df_data[j][i] for j in range(len(tmp_reformatted_df_data))] for i in + range(len(blk_cnt) * 2)] + # print(reformatted_df_data[0:3]) + # exit() + # for col in reformatted_df_data: + # print("Len of col is {}".format(len(col))) + reformatted_df = pd.DataFrame(reformatted_df_data, + columns=["Block counts", "PE counts", "Mem counts", "Bus counts", "Simulation Time", + "FARSI or PA"]) + print(reformatted_df.head()) + + df_blk_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "Block counts") + df_pe_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "PE counts") + df_mem_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "Mem counts") + df_bus_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "Bus counts") + + print("Bola") + print(df_blk_avg) + + + + splot = sns.scatterplot(data=df_blk_avg, x="Block counts", y="Simulation Time", hue="FARSI or PA") + splot.set(yscale="log") + + splot_1 = sns.scatterplot(data=df_pe_avg, x="PE counts", y="Simulation Time", hue="FARSI or PA") + splot_1.set(yscale="log") + + splot_2 = sns.scatterplot(data=df_mem_avg, x="Mem counts", y="Simulation Time", hue="FARSI or PA") + splot_1.set(yscale="log") + splot_3 = sns.scatterplot(data=df_bus_avg, x="Bus counts", y="Simulation Time", hue="FARSI or PA") + splot_1.set(yscale="log") + + plt.savefig(os.path.join(output_dir,'block_counts_vs_simtime.png')) + +""" + + + + +def plot_error_vs_system_char(output_dir, csv_file_addr): + data = pd.read_csv(csv_file_addr) + error = list(data["error"]) + blk_cnt = list(data["blk_cnt"]) + pe_cnt = list(data["pe_cnt"]) + mem_cnt = list(data["mem_cnt"]) + bus_cnt = list(data["bus_cnt"]) + #channel_cnt = list(data["channel_cnt"]) + pa_sim_time = list(data["PA simulation time"]) + farsi_sim_time = list(data["FARSI simulation time"]) + + num_counts_cols = 4 + tmp_reformatted_df_data = [blk_cnt+pe_cnt+mem_cnt+bus_cnt, ["Block Counts"]*len(blk_cnt)+["PE Counts"]*len(blk_cnt) + ["Mem Counts"]*len(blk_cnt) + ["Bus Counts"]*len(bus_cnt) , error*num_counts_cols] + + reformatted_df_data = [[tmp_reformatted_df_data[j][i] for j in range(len(tmp_reformatted_df_data))] for i in range(len(blk_cnt)*num_counts_cols) ] + + + + #print(reformatted_df_data[0:3]) + #exit() + #for col in reformatted_df_data: + # print("Len of col is {}".format(len(col))) + reformatted_df = pd.DataFrame(reformatted_df_data, columns = ["Counts", "ArchParam", "Error"]) + print(reformatted_df.tail()) + + + df_avg = get_df_as_avg_for_each_x_coord(reformatted_df, x_coord_name = "Counts", y_coord_name = "Error", hue_col = "ArchParam") + + color_per_hue = {"Bus Counts" : "green", "Mem Counts" : "orange", "PE Counts" : "blue", "Block Counts" : "red", "Channel Counts" : "pink"} + #df_avg = df_avg.loc[df_avg["ArchParam"] != "Bus Counts"] + splot = sns.scatterplot(data=df_avg, y = "Error", x = "Counts", hue = "ArchParam", palette = color_per_hue) + #splot.set(yscale = "log") + + + + #sklearn.linear_model.LinearRegression() + hues = set(list(df_avg["ArchParam"])) + for hue in hues: + #x required to be in matrix format in sklearn + print(np.isnan(df_avg["Error"])) + xs_hue = [[x] for x in list(df_avg.loc[(df_avg["ArchParam"] == hue) & (df_avg["Error"].notnull())]["Counts"])] + ys_hue = np.array(list(df_avg.loc[(df_avg["ArchParam"] == hue) & (df_avg["Error"].notnull())]["Error"])) + print("xs_hue") + print(xs_hue) + + print("ys_hue") + print(ys_hue) + reg = LinearRegression().fit(xs_hue, ys_hue) + m = reg.coef_[0] + n = reg.intercept_ + abline(m, n, color_per_hue[hue]) + #plt.set_ylim(top = 10) + + output_file = os.path.join(output_dir, "error_vs_system_char.png") + plt.savefig(output_file) + plt.close("all") + +def plot_error_vs_system_char_for_paper(output_dir, csv_file_addr): + data = pd.read_csv(csv_file_addr) + error = list(data["error"]) + blk_cnt = list(data["blk_cnt"]) + pe_cnt = list(data["pe_cnt"]) + mem_cnt = list(data["mem_cnt"]) + bus_cnt = list(data["bus_cnt"]) + #channel_cnt = list(data["channel_cnt"]) + pa_sim_time = list(data["PA simulation time"]) + farsi_sim_time = list(data["FARSI simulation time"]) + + num_counts_cols = 4 + tmp_reformatted_df_data = [blk_cnt+pe_cnt+mem_cnt+bus_cnt, ["Block Counts"]*len(blk_cnt)+["PE Counts"]*len(blk_cnt) + ["Memory Counts"]*len(blk_cnt) + ["NoC Counts"]*len(bus_cnt) , error*num_counts_cols] + + reformatted_df_data = [[tmp_reformatted_df_data[j][i] for j in range(len(tmp_reformatted_df_data))] for i in range(len(blk_cnt)*num_counts_cols) ] + + + + #print(reformatted_df_data[0:3]) + #exit() + #for col in reformatted_df_data: + # print("Len of col is {}".format(len(col))) + reformatted_df = pd.DataFrame(reformatted_df_data, columns = ["Counts", "ArchParam", "Error"]) + print(reformatted_df.tail()) + + + df_avg = get_df_as_avg_for_each_x_coord(reformatted_df, x_coord_name = "Counts", y_coord_name = "Error", hue_col = "ArchParam") + + color_per_hue = {"NoC Counts" : "green", "Memory Counts" : "orange", "PE Counts" : "blue", "Block Counts" : "red", "Channel Counts" : "pink"} + #df_avg = df_avg.loc[df_avg["ArchParam"] != "Bus Counts"] + axis_font = {'size': '20'} + fontSize = 20 + sns.set(font_scale=2, rc={'figure.figsize': (6, 4.2)}) + sns.set_style("white") + splot = sns.scatterplot(data=df_avg, y = "Error", x = "Counts", hue = "ArchParam", palette = color_per_hue, hue_order= ["NoC Counts", "Memory Counts", "PE Counts", "Block Counts"], sizes=(8, 8)) + #splot.set(yscale = "log") + + #sklearn.linear_model.LinearRegression() + hues = set(list(df_avg["ArchParam"])) + splot.legend(title="", fontsize=fontSize, loc="upper right") + for hue in hues: + #x required to be in matrix format in sklearn + print(np.isnan(df_avg["Error"])) + xs_hue = [[x] for x in list(df_avg.loc[(df_avg["ArchParam"] == hue) & (df_avg["Error"].notnull())]["Counts"])] + ys_hue = np.array(list(df_avg.loc[(df_avg["ArchParam"] == hue) & (df_avg["Error"].notnull())]["Error"])) + print("xs_hue") + print(xs_hue) + + print("ys_hue") + print(ys_hue) + reg = LinearRegression().fit(xs_hue, ys_hue) + m = reg.coef_[0] + n = reg.intercept_ + abline(m, n, color_per_hue[hue]) + #plt.set_ylim(top = 10) + + plt.xticks(np.arange(-5, 30, 10.0)) + plt.yticks(np.arange(-5, 50, 10.0)) + plt.xlabel("Block Counts") + plt.ylabel("Error (%)") + plt.tight_layout() + output_file = os.path.join(output_dir, "error_vs_system_char.png") + plt.savefig(output_file, bbox_inches='tight') + # plt.show() + plt.close("all") + +def plot_latency_vs_sim_time(output_dir, csv_file_addr): + data = pd.read_csv(csv_file_addr) + blk_cnt = list(data["blk_cnt"]) + pe_cnt = list(data["pe_cnt"]) + mem_cnt = list(data["mem_cnt"]) + bus_cnt = list(data["bus_cnt"]) + pa_sim_time = list(data["PA simulation time"]) + farsi_sim_time = list(data["FARSI simulation time"]) + pa_predicted_lat = list(data["PA_predicted_latency"]) + tmp_reformatted_df_data = [pa_predicted_lat * 2, pa_sim_time + farsi_sim_time, + ["PA"] * len(blk_cnt) + ["FARSI"] * len(blk_cnt)] + reformatted_df_data = [[tmp_reformatted_df_data[j][i] for j in range(len(tmp_reformatted_df_data))] for i in + range(len(blk_cnt) * 2)] + # print(reformatted_df_data[0:3]) + # exit() + # for col in reformatted_df_data: + # print("Len of col is {}".format(len(col))) + reformatted_df = pd.DataFrame(reformatted_df_data, + columns=["PA Predicted Latency", "Simulation Time", "FARSI or PA"]) + + reformatted_df = pd.DataFrame(reformatted_df_data, + columns=["PA _predicted_latencys", "Simulation Time", + "FARSI or PA"]) + print(reformatted_df.head()) + + df_blk_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "PA _predicted_latencys") + + df_avg = get_df_as_avg_for_each_x_coord(reformatted_df, x_coord_name="PA _predicted_latencys", y_coord_name="Simulation Time", + hue_col="FARSI or PA") + + # df_pe_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "PE counts") + # df_mem_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "Mem counts") + # df_bus_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "Bus counts") + + # print("Bola") + # print(df_blk_avg) + + splot = sns.scatterplot(data=df_avg, x="PA _predicted_latencys", y="Simulation Time", hue="FARSI or PA") + splot.set(yscale="log") + + color_per_hue = {"FARSI": "green", "PA": "orange"} + hues = set(list(df_avg["FARSI or PA"])) + for hue in hues: + # x required to be in matrix format in sklearn + print(np.isnan(df_avg["Simulation Time"])) + xs_hue = [[x] for x in list( + df_avg.loc[(df_avg["FARSI or PA"] == hue) & (df_avg["Simulation Time"].notnull())]["PA _predicted_latencys"])] + ys_hue = np.array( + list(df_avg.loc[(df_avg["FARSI or PA"] == hue) & (df_avg["Simulation Time"].notnull())]["Simulation Time"])) + print("xs_hue") + print(xs_hue) + + print("ys_hue") + print(ys_hue) + reg = LinearRegression().fit(xs_hue, ys_hue) + m = reg.coef_[0] + n = reg.intercept_ + abline(m, n, color_per_hue[hue]) + # plt.set_ylim(top = 10) + + #plt.savefig(os.path.join(output_dir, 'block_counts_vs_simtime.png')) + plt.savefig(os.path.join(output_dir,'latency_vs_sim_time.png')) + plt.close("all") + + """ + df_avg = get_df_as_avg_for_each_x_coord(reformatted_df, x_coord_name = "counts", y_coord_name = "Simulation Time", hue_col = "FARSI or PA") + + + print(reformatted_df.head()) + splot = sns.scatterplot(data=reformatted_df, x="PA Predicted Latency", y="Simulation Time", hue="FARSI or PA") + splot.set(yscale="log") + + output_file = os.path.join(output_dir, "sim_time_vs_latency.png") + plt.savefig(output_file) + plt.close("all") + """ + +def plot_latency_vs_sim_time_for_paper(output_dir, csv_file_addr): + data = pd.read_csv(csv_file_addr) + blk_cnt = list(data["blk_cnt"]) + pe_cnt = list(data["pe_cnt"]) + mem_cnt = list(data["mem_cnt"]) + bus_cnt = list(data["bus_cnt"]) + pa_sim_time = list(data["PA simulation time"]) + farsi_sim_time = list(data["FARSI simulation time"]) + pa_predicted_lat = list(data["PA_predicted_latency"]) + tmp_reformatted_df_data = [pa_predicted_lat * 2, pa_sim_time + farsi_sim_time, + ["PA"] * len(blk_cnt) + ["FARSI"] * len(blk_cnt)] + reformatted_df_data = [[tmp_reformatted_df_data[j][i] for j in range(len(tmp_reformatted_df_data))] for i in + range(len(blk_cnt) * 2)] + # print(reformatted_df_data[0:3]) + # exit() + # for col in reformatted_df_data: + # print("Len of col is {}".format(len(col))) + reformatted_df = pd.DataFrame(reformatted_df_data, + columns=["PA Predicted Latency", "Simulation Time", "FARSI or PA"]) + + reformatted_df = pd.DataFrame(reformatted_df_data, + columns=["PA _predicted_latencys", "Simulation Time", + "FARSI or PA"]) + print(reformatted_df.head()) + + df_blk_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "PA _predicted_latencys") + + df_avg = get_df_as_avg_for_each_x_coord(reformatted_df, x_coord_name="PA _predicted_latencys", y_coord_name="Simulation Time", + hue_col="FARSI or PA") + + # df_pe_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "PE counts") + # df_mem_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "Mem counts") + # df_bus_avg = get_df_as_avg_for_each_x_coord(reformatted_df, "Bus counts") + + # print("Bola") + # print(df_blk_avg) + + axis_font = {'size': '20'} + fontSize = 20 + sns.set(font_scale=2, rc={'figure.figsize': (6, 4)}) + sns.set_style("white") + color_per_hue = {'PA': 'hotpink', 'FARSI': 'green'} + splot = sns.scatterplot(data=df_avg, x="PA _predicted_latencys", y="Simulation Time", hue="FARSI or PA", palette=color_per_hue) + splot.set(yscale="log") + splot.legend(title="", fontsize=fontSize, loc="center right") + + hues = set(list(df_avg["FARSI or PA"])) + for hue in hues: + # x required to be in matrix format in sklearn + print(np.isnan(df_avg["Simulation Time"])) + xs_hue = [[x] for x in list( + df_avg.loc[(df_avg["FARSI or PA"] == hue) & (df_avg["Simulation Time"].notnull())]["PA _predicted_latencys"])] + ys_hue = np.array( + list(df_avg.loc[(df_avg["FARSI or PA"] == hue) & (df_avg["Simulation Time"].notnull())]["Simulation Time"])) + print("xs_hue") + print(xs_hue) + + print("ys_hue") + print(ys_hue) + reg = LinearRegression().fit(xs_hue, ys_hue) + m = reg.coef_[0] + n = reg.intercept_ + abline(m, n, color_per_hue[hue]) + # plt.set_ylim(top = 10) + + plt.xticks(np.arange(0, 60, 10.0)) + plt.yticks(np.power(10.0, [-1, 0, 1, 2, 3])) + plt.xlabel("Execution latency") + plt.ylabel("Simulation Time (s)") + plt.tight_layout() + #plt.savefig(os.path.join(output_dir, 'block_counts_vs_simtime.png')) + plt.savefig(os.path.join(output_dir,'latency_vs_sim_time.png'), bbox_inches='tight') + # plt.show() + plt.close("all") + +if __name__ == "__main__": # Ying: for aggregate_data + run_folder_name = config_plotting.run_folder_name + csv_file_addr = os.path.join(run_folder_name, "input_data","aggregate_data.csv") + output_dir = os.path.join(run_folder_name, "validation") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + if config_plotting.draw_for_paper: # Ying: "cross_workloads", from aggregate_data + plot_error_vs_system_char_for_paper(output_dir, csv_file_addr) + plot_sim_time_vs_system_char_minimal_for_paper(output_dir, csv_file_addr) + plot_latency_vs_sim_time_for_paper(output_dir, csv_file_addr) + else: + plot_error_vs_system_char(output_dir, csv_file_addr) + plot_sim_time_vs_system_char_minimal(output_dir, csv_file_addr) + plot_latency_vs_sim_time(output_dir, csv_file_addr) diff --git a/Project_FARSI/visualization_utils/vis_hardware.py b/Project_FARSI/visualization_utils/vis_hardware.py new file mode 100644 index 00000000..034cabbc --- /dev/null +++ b/Project_FARSI/visualization_utils/vis_hardware.py @@ -0,0 +1,250 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +import pygraphviz as pgv +from design_utils.design import * +from sys import platform +import time +import sys + +# global value +ctr = 0 + +# This class contains information about each node (block) in the hardware graph. +# Only used for visualization purposes. +class node_content(): + # ------------------------------ + # Functionality + # constructor + # Variables: + # block: block of interest + # mode: obfuscated or not (for outside or inside distribution) we obfuscate the name and content of the nodes. + # task_obfuscated_table: contains a table mapping task's real name to their obfuscated names + # block_obfuscated_table: contains a table mapping blocks' real name to their obfuscated names + # ------------------------------ + def __init__(self, block:Block, mode, task_obfuscated_table, block_obfuscated_table): + self.block = block + self.content = "" + if (mode == "block"): + self.content = block_obfuscated_table[self.block.instance_name] + elif (mode == "block_extra"): + self.content = block_obfuscated_table[self.block.instance_name] + if self.block.type == "ic": + self.content += " , width(Bits):" + str((self.block.peak_work_rate/database_input.ref_ic_clock)*8) +\ + " , clock(MHz)" + str(database_input.ref_ic_clock/10**6) + elif self.block.type == "mem": + self.content += " , width(Bits):" + str((self.block.peak_work_rate/database_input.ref_mem_clock)*8) +\ + " , clock(MHz)" + str(database_input.ref_mem_clock/10**6) \ + + " , size(kB):" + str(self.block.get_area()*database_input.ref_mem_work_over_area/10**3) + elif self.block.subtype == "gpp": + tasks = self.block.get_task_dirs_of_block() + task_names_dirs = [task_obfuscated_table[str(task_.name)] for task_,dir in tasks] + self.content += ":"+ ",".join(task_names_dirs) + elif (mode == "block_task"): + tasks = self.block.get_task_dirs_of_block() + task_names_dirs = [str((task_obfuscated_table[task_.name],dir)) for task_,dir in tasks] + if not(self.block.instance_name in block_obfuscated_table.keys()): + print("what") + self.content = block_obfuscated_table[self.block.instance_name] +"("+ self.block.type +")" +": "+ ",".join(task_names_dirs) + if self.only_dummy_tasks(block): self.only_dummy = True + else: self.only_dummy = False + + # ------------------------------ + # Functionality: + # determines whether a task is a dummy or real task + # (PS: dummy tasks are located at the root (named with souurce suffix) + # and leaf (named with siink suffix) of the graph + # Variables: + # block: variable of interest. + # ------------------------------ + def only_dummy_tasks(self, block): + num_of_tasks = len(block.get_tasks_of_block()) + if num_of_tasks == 2: + a = [task.name for task in block.get_tasks_of_block()] + if any("souurce" in task.name for task in block.get_tasks_of_block()) and \ + any("siink" in task.name for task in block.get_tasks_of_block()): + return True + elif num_of_tasks == 1: + a = [task.name for task in block.get_tasks_of_block()] + if any("souurce" in task.name for task in block.get_tasks_of_block()) or \ + any("siink" in task.name for task in block.get_tasks_of_block()): + return True + else: + return False + + def format_content(self, content): + n = 50 + chunks = [content[i:i + n] for i in range(0, len(content), n)] + return "\n".join(chunks) + + # ------------------------------ + # Functionality: + # get all the contents: tasks, blocks and information about each. + # ------------------------------ + def get_content(self): + return self.format_content(self.content) + + # ------------------------------ + # Functionality: + # get coloring associated with each processing block. + # ------------------------------ + def get_color(self): + if self.block.type == "mem": + return "cyan3" + elif self.block.type == "ic": + return "white" + else: + if self.block.subtype == "gpp" and "A53" in self.block.instance_name: + return "gold" + elif self.block.subtype == "gpp" and "G" in self.block.instance_name: + return "goldenrod3" + elif self.block.subtype == "gpp" and "P" in self.block.instance_name: + return "goldenrod2" + else: + return "orange" + + +# ------------------------------ +# Functionality: +# build a dot compatible graph recursively. This is provided to the dot visualizer after. +# Variables: +# parent_block: parent of a block (block it reads from). +# child_block: child of a block (block it writes to). +# blocks_visited: blocks already visit (used for preventing double graphing in depth first search). +# hardware_dot_graph: dot compatible graph associated with our hardware graph. +# task_obfuscated_table: contains a table mapping task's real name to their obfuscated name. +# block_obfuscated_table: contains a table mapping block's real name to their obfuscated names. +# graphing_mode: mode determines whether to obfuscate (Names and content) or +# not (for outside distribution or not). +# ------------------------------ +def build_dot_recursively(parent_block, child_block, blocks_visited, hardware_dot_graph, + task_obfuscated_table, block_obfuscatred_table, graphing_mode): + if (child_block, parent_block) in blocks_visited: + return None + global ctr + if parent_block: + parent_node = node_content(parent_block, graphing_mode, task_obfuscated_table, block_obfuscatred_table) + if not parent_node.only_dummy: + hardware_dot_graph.add_node(parent_node.get_content(), fillcolor=parent_node.get_color()) + child_node = node_content(child_block, graphing_mode, task_obfuscated_table, block_obfuscatred_table) + if not child_node.only_dummy: + hardware_dot_graph.add_node(child_node.get_content(), fillcolor=child_node.get_color()) + ctr +=1 + if not parent_node.only_dummy and not child_node.only_dummy: + hardware_dot_graph.add_edge(parent_node.get_content(), child_node.get_content()) + blocks_visited.append((child_block, parent_block)) + parent_block = child_block + for child_block_ in parent_block.neighs: + build_dot_recursively(parent_block, child_block_, blocks_visited, hardware_dot_graph, task_obfuscated_table, + block_obfuscatred_table, graphing_mode) + +# ------------------------------ +# Functionality: +# generating the obfuscation table, so the result can be distributed for outside companies/acadamia as well. +# Obfuscation table contains mapping from real names to fake names. +# In addition, the contents (e.g., computational load) are eliminated. +# Variables: +# sim_dp: design point simulation. +# ------------------------------ +def gen_obfuscation_table(sim_dp:SimDesignPoint): + block_names = [block.instance_name for block in sim_dp.hardware_graph.blocks] + ctr = 0 + task_obfuscated_table = {} + block_obfuscated_table = {} + # obfuscate the tasks + for task in sim_dp.get_tasks(): + if config.DATA_DELIVEYRY == "obfuscate": + task_obfuscated_table[task.name] = "T" + str(ctr) + ctr +=1 + else: + task_obfuscated_table[task.name] = task.name + + # obfuscate the blocks + dsp_ctr = 0 + gpp_ctr = 0 + mem_ctr = 0 + ic_ctr = 0 + for block in sim_dp.get_blocks(): + got_name = False + if block.type == "pe" and config.DATA_DELIVEYRY == "obfuscate": + if block.subtype == "ip": + name = "" + for task in block.get_tasks_of_block(): + name += (task_obfuscated_table[task.name]+"_") + block_obfuscated_table[block.instance_name] = name + "ip" + else: # gpp + if "G" in block.instance_name: + block_obfuscated_table[block.instance_name] = "DSP_G3_" + str(dsp_ctr) + dsp_ctr += 1 + elif "P" in block.instance_name: + block_obfuscated_table[block.instance_name] = "DSP_P6_" + str(dsp_ctr) + dsp_ctr += 1 + elif "A" in block.instance_name: + block_obfuscated_table[block.instance_name] = "GPP" + str(gpp_ctr) + gpp_ctr += 1 + else: + print("this processor" + str(block.instance_name) + "is not obfuscated") + exit(0) + elif block.type == "mem" and config.DATA_DELIVEYRY == "obfuscate": + block_obfuscated_table[block.instance_name] = block.instance_name.split("_")[0][1:]+str(mem_ctr) + mem_ctr +=1 + elif block.type == "ic" and config.DATA_DELIVEYRY == "obfuscate": + block_obfuscated_table[block.instance_name] = "NOC"+str(ic_ctr) + ic_ctr +=1 + else: + block_obfuscated_table[block.instance_name] = block.instance_name + + return task_obfuscated_table, block_obfuscated_table + + +# ------------------------------ +# Functionality: +# visualizing the hardware graph and task dependencies and mapping between the two using dot graph. +# Variables: +# sim_dp: design point simulation. +# graphing_mode: mode determines whether to +# obfuscate (Names and content) or not (for outside distribution or not). +# sim_dp: design point simulation. +# ------------------------------ +def vis_hardware(sim_dp:SimDesignPoint, graphing_mode=config.hw_graphing_mode, output_folder=config.latest_visualization, + output_file_name="system_image.pdf"): + + try: + output_file_name_1 = os.path.join(output_folder, output_file_name) + output_file_name_2 = config.latest_visualization+"/system_image.pdf" + if not os.path.exists(config.latest_visualization): + os.system("mkdir -p " + config.latest_visualization) + + global ctr + ctr = 0 + + if not (sys.platform == "darwin"): + output_file_name_1 = output_file_name_1.split(".pdf")[0] + ".dot" + output_file_name_2 = output_file_name_2.split(".pdf")[0] + ".dot" + + hardware_dot_graph =pgv.AGraph() + hardware_dot_graph.node_attr['style'] = 'filled' + hardware_graph = sim_dp.get_hardware_graph() + root = hardware_graph.get_root() + + task_obfuscated_table, block_obfuscated_table = gen_obfuscation_table(sim_dp) + + build_dot_recursively(None, root, [], hardware_dot_graph, task_obfuscated_table, block_obfuscated_table, graphing_mode ) + blah = ctr + hardware_dot_graph.layout() + hardware_dot_graph.layout(prog='circo') + hardware_dot_graph + time.sleep(.0008) + + output_file_1 = os.path.join(output_folder, output_file_name_1) + output_file_2 = os.path.join(output_folder, output_file_name_2) + #output_file_real_time_vis = os.path.join(".", output_file_name) # this is used for realtime visualization + if graphing_mode == "block_extra": + hardware_dot_graph.draw(output_file_1,prog='circo') + hardware_dot_graph.draw(output_file_2, prog='circo') + else: + hardware_dot_graph.draw(output_file_1,prog='circo') + hardware_dot_graph.draw(output_file_2,prog='circo') + except: + print("could not draw the system_image. Moving on for now. Fix Later.") \ No newline at end of file diff --git a/Project_FARSI/visualization_utils/vis_sim.py b/Project_FARSI/visualization_utils/vis_sim.py new file mode 100644 index 00000000..db008725 --- /dev/null +++ b/Project_FARSI/visualization_utils/vis_sim.py @@ -0,0 +1,405 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +import matplotlib.pyplot as plt +import plotly.graph_objs as go +from design_utils.design import * + +# some global variables +x_min = -3 +vertical_slice_size =.5 +vertical_slice_distance = 2*vertical_slice_size +color_palette =['orange', 'green', 'red'] + +# ------------------------------ +# Functionality: +# used to offset the bars for broken bar graph +# ------------------------------ +def next_y_offset(): + global vertical_slice_distance + offset = vertical_slice_distance + while True: + yield offset + offset += vertical_slice_distance + +# ------------------------------ +# Functionality: +# broken bar type plots. +# Variables: +# xstart: starting point for the plot on x axis +# xwidth: with of the plot +# ystart: starting point for the plot on y axis +# colors: colors associated with the plot +# ------------------------------ +def broken_bars(xstart, xwidth, ystart, yh, colors): + if len(xstart) != len(xwidth) or len(xstart) != len(colors): + raise ValueError('xstart, xwidth and colors must have the same length') + shapes = [] + for k in range(len(xstart)): + shapes.append(dict(type="rect", + x0=xstart[k], + y0=ystart, + x1=xstart[k] + xwidth[k], + y1=ystart + yh, + fillcolor=colors[k], + line_color=colors[k], + opacity = .4, + )) + return shapes + +# ------------------------------ +# Functionality: +# broken barh type plots. +# Variables: +# name__start_width_metric_unit_dict: name of the metric, starting point, with, and it's unit, +# compacted into a dictionary. +# ------------------------------ +def plotly_broken_barh(name__start_width_metric_unit_dict): + global x_min + name_shapes_dict = {} # name_plotly dict + name_centroid_metric_unit_dict = {} + next_y = next_y_offset() + max_x = 0 + ctr = 0 + + # extract all the values + for name, values in name__start_width_metric_unit_dict.items(): + start_list = [value[0] for value in values] + width_list = [value[1] for value in values] + metric_list = [value[2] for value in values] + unit_list = [value[3] for value in values] + y = next(next_y) + + # set the text + centroid_metric_unit_list = [] + for value in values: + ctr_x = value[0]+float(value[1]/2) + ctr_y = y + vertical_slice_size/2 + metric = value[2] + unit = value[3] + centroid_metric_unit_list.append((ctr_x, ctr_y, metric, unit)) + name_centroid_metric_unit_dict[name] = centroid_metric_unit_list + + # set the shapes + max_x = max(max_x, (start_list[-1] + width_list[-1])) + name_shapes_dict[name] = broken_bars(start_list, width_list, y, vertical_slice_size, + len(start_list)*[color_palette[ctr % len(color_palette)]]) + ctr += 1 + + fig = go.Figure() + + # add all the shapes + list_of_lists_of_shapes = list(name_shapes_dict.values()) + flattented_list_of_shapes = [item for sublist in list_of_lists_of_shapes for item in sublist] + + # add annotations to the figure + flattned_list_of_centroid_metric_unit = [item for sublist in list(name_centroid_metric_unit_dict.values()) for item in sublist] + get_ctroid_x = lambda input_list: [el[0] for el in input_list] + get_ctroid_y = lambda input_list: [el[1] for el in input_list] + get_metric_unit = lambda input_list: [str(el[2])+"("+el[3]+")" for el in input_list] + + x = get_ctroid_x(flattned_list_of_centroid_metric_unit), + + fig.add_trace(go.Scatter( + x= get_ctroid_x(flattned_list_of_centroid_metric_unit), + y= get_ctroid_y(flattned_list_of_centroid_metric_unit), + mode="text", + text= get_metric_unit(flattned_list_of_centroid_metric_unit), + textposition="bottom center" + )) + x_min = -.1*max_x + name_zero_y = [] + for name, value in name_centroid_metric_unit_dict.items(): + name_zero_y.append((name, x_min, value[0][1])) + fig.add_trace(go.Scatter( + x= [el[1] for el in name_zero_y], + y= [el[2] for el in name_zero_y], + mode="text", + text= [el[0] for el in name_zero_y], + textposition="top right" + )) + fig.update_layout( + xaxis = dict(range=[x_min, 1.1*max_x], title="time"), + yaxis = dict(range=[0, next(next_y)], visible=False), + shapes= flattented_list_of_shapes) + + return fig + +# ------------------------------ +# Functionality: +# save the result into a html for visualization. +# Variables: +# fig: figure to save. +# file_addr: address of the file to output the result to. +# ------------------------------ +def save_to_html(fig, file_addr): + fig_json = fig.to_json() + + # a simple HTML template + template = """ + + + + +
+ + + + """ + + # write the JSON to the HTML template + with open(file_addr, 'w') as f: + f.write(template.format(fig_json)) + +# ------------------------------ +# Functionality: +# plot simulation progress. +# Variables: +# dp_stats: design point stats. statistical information (such as latency, energy, ...) associated with the design. +# ex_dp: example design point. +# result_folder: the folder to dump the results in. +# ------------------------------ +def plot_sim_data(dp_stats, ex_dp, result_folder): + # plot latency + kernel__metric_value_per_SOC = dp_stats.get_sim_progress("latency") + for kernel__metrtic_value in kernel__metric_value_per_SOC: + name__start_width_metric_unit_dict = {} + kernel_end_time_dict = {} + for kernel in kernel__metrtic_value.keys(): + name__start_width_metric_unit_dict[kernel.get_task_name()] = [] + + + phase_latency_dict = dp_stats.get_phase_latency() + phase_start_end_dict = {} + start = 0 + for phase, duration in phase_latency_dict.items(): + phase_start_end_dict[phase] = [start, start + duration] + start += duration + + for kernel, values in kernel__metrtic_value.items(): + for value in values: + metric = value[2] + unit = value[3] + break + break + + + for kernel,phases_operating_state in dp_stats.dp.krnl_phase_present_operating_state.items(): + last_phase_number = 9999 + last_operating_state = "na" + starting_phase = True + for phase, operating_state in phases_operating_state: + start, end = phase_start_end_dict[phase] + if last_phase_number == phase - 1 and last_operating_state == operating_state: + start = (name__start_width_metric_unit_dict[kernel.get_task_name()][-1])[0] + label = operating_state[0]+":"+str("{:.4f}".format(end)) + name__start_width_metric_unit_dict[kernel.get_task_name()][-1] = (start, end - start, label, "") + else: + label = operating_state[0]+":"+str("{:.4f}".format(end)) + name__start_width_metric_unit_dict[kernel.get_task_name()].append((start, end-start, label, "")) + if starting_phase: + kernel_end_time_dict[kernel.get_task_name()] = start + starting_phase = False + last_phase_number = phase + last_operating_state = operating_state + """ + for kernel, values in kernel__metrtic_value.items(): + first_start_time = values[0][0] + for value in values: + start = value[0] + width = value[1] + metric = value[2] + unit = value[3] + name__start_width_metric_unit_dict[kernel.get_task_name()].append((start, width, metric, unit)) + last_start = start + last_width = width + kernel_end_time_dict[kernel.get_task_name()] = first_start_time + """ + # now sort it based on end time + sorted_kernel_end_time = sorted(kernel_end_time_dict.items(), + key=operator.itemgetter(1), reverse=True) + + sorted_name__start_width_metric_unit_dict = {} + for element in sorted_kernel_end_time: + kernel_name = element[0] + sorted_name__start_width_metric_unit_dict[kernel_name] = name__start_width_metric_unit_dict[kernel_name] + fig = plotly_broken_barh(sorted_name__start_width_metric_unit_dict) + save_to_html(fig, result_folder+"/"+"latest.html") + + # color map + my_cmap = ["hotpink", "olive", "gold", "darkgreen", "turquoise", "crimson", + "lightblue", "darkorange", "yellow", + "chocolate", "darkorchid", "greenyellow"] + + # plot utilization: + fig, ax = plt.subplots() + ax.set_ylabel('Utilizaiton (%)', fontsize=15) + ax.set_xlabel('Phase', fontsize=15) + ax.set_title('Block Utilization', fontsize=15) + ctr = 0 + for type,id in ex_dp.get_designs_SOCs(): + block_phase_utilization = dp_stats.get_SOC_s_sim_utilization(type, id) + for block, phase_utilization in block_phase_utilization.items(): + if not block.type == "ic": # for now, mainly interested in memory + continue + block_name = block.instance_name + phase = list(phase_utilization.keys()) + utilization = [x*100 for x in list(phase_utilization.values())] + plt.plot(phase, utilization, marker='>', linewidth=.6, color=my_cmap[ctr%len(my_cmap)], ms=1, label=block_name) + ctr +=1 + ax.legend(prop={'size': 10}, ncol=1, bbox_to_anchor=(1.01, 1), loc='upper left') + fig.tight_layout() + plt.savefig(result_folder+"/FARSI_estimated_Block_utilization_"+str(type)+str(id)) + + sorted_listified_phase_latency_dict = sorted(dp_stats.dp.phase_latency_dict.items(), key=operator.itemgetter(0)) + phase_begin_end = {} + for phase,duration in sorted_listified_phase_latency_dict: + if phase == -1: + phase_begin_end[phase] = (0,0) + else: + last_phase_end = phase_begin_end[phase - 1][1] + phase_begin_end[phase] = (last_phase_end, last_phase_end + duration) + + phase_ending_time = {} + #for phase, duration in sorted_listified_phase_latency_dict: + # if phase == -1 + for dir__ in ["write", "read"]: + # plot bandwidth: + fig, ax = plt.subplots() + ax.set_ylabel('Bandwidth write (%)', fontsize=15) + ax.set_xlabel('Phase', fontsize=15) + ax.set_title('Path Bandwidth for ' + dir__, fontsize=15) + ctr = 0 + markers = ["o", ">", "1", "8", "s", "p"] + seen = [] + seen_values = {} + for type, id in ex_dp.get_designs_SOCs(): + for pipe_cluster, pipe_phase_work_rate in dp_stats.get_SOC_s_pipe_cluster_pathlet_phase_work_rate(type, + id).items(): + if pipe_cluster.cluster_type == "dummy" or not pipe_cluster.get_dir() == dir__: + continue + if not pipe_cluster.get_block_ref().type == "ic": + continue + block_name = '_'.join(pipe_cluster.get_block_ref().instance_name.split("_")[-3:]) + dir_ = pipe_cluster.get_dir() + for path, phase_work_rate in pipe_phase_work_rate.items(): + in_pipe, out_pipe = path.get_in_pipe(), path.get_out_pipe() + if in_pipe.get_master().type == "pe": + master_name = '_'.join(in_pipe.get_master().instance_name.split("_")[:3]) + else: + master_name = '_'.join(in_pipe.get_master().instance_name.split("_")[-3:]) + if out_pipe == None: + slave_name == "non" + else: + if out_pipe.get_slave().type == "pe": + slave_name = '_'.join(out_pipe.get_slave().instance_name.split("_")[:3]) + else: + slave_name = '_'.join(out_pipe.get_slave().instance_name.split("_")[-3:]) + name = master_name + "__" + block_name + "__" + slave_name + "__" + dir_ + if name in seen: + continue + seen.append(name) + phases = list(phase_work_rate.keys()) + bandwidths = [int(x / 1000000) for x in list(phase_work_rate.values())] + x = [] + y = [] + for phase in list(phase_work_rate.keys()): + bandwidth = int(phase_work_rate[phase]/1000000) + begin_t, end_t = phase_begin_end[phase] + x.append(begin_t) + x.append(end_t) + y.append(bandwidth) + y.append(bandwidth) + values = '_'.join([str((el[0], el[1])) for el in zip(phases, bandwidths)]) + if values in seen_values.keys(): + ctr_ = seen_values[values] + plt.plot(x, y, marker=markers[ctr % len(markers)], linewidth=4, + color=my_cmap[ctr_ % len(my_cmap)], ms=1, + label=name) + else: + seen_values[values] = ctr + plt.plot(x, y, marker=markers[ctr % len(markers)], linewidth=4, + color=my_cmap[ctr % len(my_cmap)], ms=1, + label=name) + ctr += 1 + for a, b in zip(x, y): + plt.text(a, b, str(b)) + + ax.legend(prop={'size': 8}, ncol=1, loc='best') + fig.tight_layout() + plt.savefig(result_folder + "/FARSI_estimated_pathlet_bandwidth_"+dir__ + str(type) + str(id)) + + """ + # draw pathlet latency + for dir__ in ["write", "read"]: + # plot bandwidth: + fig, ax = plt.subplots() + ax.set_ylabel('pathlet latency (in cycles)', fontsize=15) + ax.set_xlabel('Phase', fontsize=15) + ax.set_title('Path latency for ' + dir__, fontsize=15) + ctr = 0 + markers = ["o", ">", "1", "8", "s", "p"] + seen = [] + seen_values = {} + for type, id in ex_dp.get_designs_SOCs(): + for pipe_cluster, pathlet_phase_latency in dp_stats.get_SOC_s_pipe_cluster_path_phase_latency(type, + id).items(): + if pipe_cluster.cluster_type == "dummy" or not pipe_cluster.get_dir() == dir__: + continue + if not pipe_cluster.get_block_ref().type == "ic": + continue + block_name = '_'.join(pipe_cluster.get_block_ref().instance_name.split("_")[-3:]) + dir_ = pipe_cluster.get_dir() + for pathlet_, phase_latency in pathlet_phase_latency.items(): + in_pipe, out_pipe = pathlet_.get_in_pipe(), pathlet_.get_out_pipe() + if in_pipe.get_master().type == "pe": + master_name = '_'.join(in_pipe.get_master().instance_name.split("_")[:3]) + else: + master_name = '_'.join(in_pipe.get_master().instance_name.split("_")[-3:]) + if out_pipe == None: + slave_name == "non" + else: + if out_pipe.get_slave().type == "pe": + slave_name = '_'.join(out_pipe.get_slave().instance_name.split("_")[:3]) + else: + slave_name = '_'.join(out_pipe.get_slave().instance_name.split("_")[-3:]) + name = master_name + "__" + block_name + "__" + slave_name + "__" + dir_ + if name in seen: + continue + seen.append(name) + phases = list(phase_latency.keys()) + latencies= [x for x in list(phase_latency.values())] + x = [] + y = [] + for phase in list(phase_latency.keys()): + latency = phase_latency[phase] + begin_t, end_t = phase_begin_end[phase] + x.append(begin_t) + x.append(end_t) + y.append(latency) + y.append(latency) + values = '_'.join([str((el[0], el[1])) for el in zip(phases, latencies)]) + if values in seen_values.keys(): + ctr_ = seen_values[values] + plt.plot(x, y, marker=markers[ctr % len(markers)], linewidth=4, + color=my_cmap[ctr_ % len(my_cmap)], ms=1, + label=name) + else: + seen_values[values] = ctr + plt.plot(x, y, marker=markers[ctr % len(markers)], linewidth=4, + color=my_cmap[ctr % len(my_cmap)], ms=1, + label=name) + ctr += 1 + for a, b in zip(x, y): + plt.text(a, b, str(b)) + + ax.legend(prop={'size': 8}, ncol=1, loc='best') + fig.tight_layout() + plt.savefig(result_folder + "/FARSI_estimated_pathlet_latency_"+dir__ + str(type) + str(id)) + """ + + plt.close('all') diff --git a/Project_FARSI/visualization_utils/vis_stats.py b/Project_FARSI/visualization_utils/vis_stats.py new file mode 100644 index 00000000..1ddc894b --- /dev/null +++ b/Project_FARSI/visualization_utils/vis_stats.py @@ -0,0 +1,34 @@ +#Copyright (c) Facebook, Inc. and its affiliates. +#This source code is licensed under the MIT license found in the +#LICENSE file in the root directory of this source tree. + +from design_utils.design import * +import os +from sys import platform + +# ------------------------------ +# Functionality: +# visualize all the stats associated with the results. +# Variables: +# dp_stats: design point stats. statistical information (such as latency, energy, ...) associated with the design. +# ------------------------------ +def vis_stats(dpstats): + sorted_kernels = dpstats.get_kernels_sort() + stats_output_file = config.stats_output + with open(stats_output_file, "w") as output: + for kernel in sorted_kernels: + output.write("\n-------------\n") + output.write("kernel name:" + kernel.get_task_name()+ "\n") + output.write(" total work:"+ str(kernel.get_total_work()) + "\n") + output.write(" blocks mapped to" + str(kernel.get_block_list_names()) + "\n") + output.write(" latency" + str(kernel.stats.latency) + "\n") + phase_block_duration_bottleneck = kernel.stats.phase_block_duration_bottleneck + phase_block_duration_bottleneck_printable = [(phase, block_duration[0].instance_name, block_duration[1]) for phase, block_duration in phase_block_duration_bottleneck.items()] + output.write(" phase,block,bottlneck_duration:" + str(phase_block_duration_bottleneck_printable) + "\n") + + if platform == "linux" or platform == "linux2": + os.system("soffice --convert-to png " + stats_output_file) + os.system("convert " + stats_output_file+".png" " to " + stats_output_file+".pdf") + elif platform == "darwin": + os.system("textutil -convert html " + stats_output_file) + os.system("cupsfilter " + stats_output_file +".html > test.pdf") \ No newline at end of file diff --git a/acme/.gitignore b/acme/.gitignore new file mode 100644 index 00000000..d82fa7a9 --- /dev/null +++ b/acme/.gitignore @@ -0,0 +1,143 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# static files generated from Django application using `collectstatic` +media +static diff --git a/acme/.pylintrc b/acme/.pylintrc new file mode 100644 index 00000000..2882fa51 --- /dev/null +++ b/acme/.pylintrc @@ -0,0 +1,424 @@ +# This Pylint rcfile contains a best-effort configuration to uphold the +# best-practices and style described in the Google Python style guide: +# https://google.github.io/styleguide/pyguide.html +# +# Its canonical open-source location is: +# https://google.github.io/styleguide/pylintrc + +# TODO(b/178199529): Replace section header when pylint is updated. +[MASTER] + +# Add files or directories to the ban list. They should be base names, not +# paths. +ignore=third_party + +# Add files or directories matching the regex patterns to the ban list. The +# regex matches against base names, not paths. +ignore-patterns= + +# Pickle collected data for later comparisons. +persistent=no + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Use multiple processes to speed up Pylint. +jobs=4 + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED +confidence= + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +#enable= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" +disable=apply-builtin, + backtick, + bad-option-value, + basestring-builtin, + buffer-builtin, + c-extension-no-member, + cmp-builtin, + cmp-method, + coerce-builtin, + coerce-method, + delslice-method, + div-method, + duplicate-code, + eq-without-hash, + execfile-builtin, + file-builtin, + filter-builtin-not-iterating, + fixme, + getslice-method, + global-statement, + hex-method, + idiv-method, + implicit-str-concat-in-sequence, + import-error, + import-self, + import-star-module-level, + input-builtin, + intern-builtin, + invalid-str-codec, + locally-disabled, + long-builtin, + long-suffix, + map-builtin-not-iterating, + metaclass-assignment, + next-method-called, + next-method-defined, + no-absolute-import, + no-else-break, + no-else-continue, + no-else-raise, + no-else-return, + no-member, + no-name-in-module, + no-self-use, + nonzero-method, + oct-method, + old-division, + old-ne-operator, + old-octal-literal, + old-raise-syntax, + parameter-unpacking, + print-statement, + raising-string, + range-builtin-not-iterating, + raw_input-builtin, + rdiv-method, + reduce-builtin, + relative-import, + reload-builtin, + round-builtin, + setslice-method, + signature-differs, + standarderror-builtin, + suppressed-message, + sys-max-int, + too-few-public-methods, + too-many-ancestors, + too-many-arguments, + too-many-boolean-expressions, + too-many-branches, + too-many-instance-attributes, + too-many-locals, + too-many-public-methods, + too-many-return-statements, + too-many-statements, + trailing-newlines, + unichr-builtin, + unicode-builtin, + unpacking-in-except, + useless-else-on-loop, + useless-suppression, + using-cmp-argument, + xrange-builtin, + zip-builtin-not-iterating, + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Put messages in a separate file for each module / package specified on the +# command line instead of printing them on stdout. Reports (if any) will be +# written in a file name "pylint_global.[txt|html]". This option is deprecated +# and it will be removed in Pylint 2.0. +files-output=no + +# Tells whether to display a full report or only the messages +reports=no + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[BASIC] + +# Good variable names which should always be accepted, separated by a comma +good-names=main,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Include a hint for the correct naming format with invalid-name +include-naming-hint=no + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl + +# Regular expression matching correct function names +function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ + +# Regular expression matching correct variable names +variable-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct constant names +const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct attribute names +attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ + +# Regular expression matching correct argument names +argument-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class attribute names +class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct inline iteration names +inlinevar-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class names +class-rgx=^_?[A-Z][a-zA-Z0-9]*$ + +# Regular expression matching correct module names +module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ + +# Regular expression matching correct method names +method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=10 + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis. It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=80 + +# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt +# lines made too long by directives to pytype. + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=(?x)( + ^\s*(\#\ )??$| + ^\s*(from\s+\S+\s+)?import\s+.+$) + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=yes + +# List of optional constructs for which whitespace checking is disabled. `dict- +# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. +# `trailing-comma` allows a space between comma and closing bracket: (a, ). +# `empty-line` allows space-only lines. +no-space-check= + +# Maximum number of lines in a module +max-module-lines=99999 + +# String used as indentation unit. The internal Google style guide mandates 2 +# spaces. Google's externaly-published style guide says 4, consistent with +# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google +# projects (like TensorFlow). +indent-string=' ' + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=TODO + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the name of dummy variables (i.e. expectedly +# not used). +dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_,_cb + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools + + +[LOGGING] + +# Logging modules to check that the string format arguments are in logging +# function parameter format +logging-modules=logging,absl.logging,tensorflow.google.logging + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[SPELLING] + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub, + TERMIOS, + Bastion, + rexec, + sets + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant, absl + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls, + class_ + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception" +overgeneral-exceptions=StandardError, + Exception, + BaseException diff --git a/acme/.readthedocs.yaml b/acme/.readthedocs.yaml new file mode 100644 index 00000000..de5be726 --- /dev/null +++ b/acme/.readthedocs.yaml @@ -0,0 +1,8 @@ +# Read the Docs configuration. +version: 2 +sphinx: + configuration: docs/conf.py +python: + install: + - requirements: docs/requirements.txt + diff --git a/acme/CONTRIBUTING.md b/acme/CONTRIBUTING.md new file mode 100644 index 00000000..db177d4a --- /dev/null +++ b/acme/CONTRIBUTING.md @@ -0,0 +1,28 @@ +# How to Contribute + +We'd love to accept your patches and contributions to this project. There are +just a few small guidelines you need to follow. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution; +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + +## Community Guidelines + +This project follows +[Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). diff --git a/acme/LICENSE b/acme/LICENSE new file mode 100644 index 00000000..d6456956 --- /dev/null +++ b/acme/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/acme/MANIFEST.in b/acme/MANIFEST.in new file mode 100644 index 00000000..1aba38f6 --- /dev/null +++ b/acme/MANIFEST.in @@ -0,0 +1 @@ +include LICENSE diff --git a/acme/README.md b/acme/README.md new file mode 100644 index 00000000..1eb07780 --- /dev/null +++ b/acme/README.md @@ -0,0 +1,118 @@ + + +# Acme: a research framework for reinforcement learning + +[![PyPI Python Version][pypi-versions-badge]][pypi] +[![PyPI version][pypi-badge]][pypi] +[![acme-tests][tests-badge]][tests] +[![Documentation Status][rtd-badge]][documentation] + +[pypi-versions-badge]: https://img.shields.io/pypi/pyversions/dm-acme +[pypi-badge]: https://badge.fury.io/py/dm-acme.svg +[pypi]: https://pypi.org/project/dm-acme/ +[tests-badge]: https://github.com/deepmind/acme/workflows/acme-tests/badge.svg +[tests]: https://github.com/deepmind/acme/actions/workflows/ci.yml +[rtd-badge]: https://readthedocs.org/projects/dm-acme/badge/?version=latest + +Acme is a library of reinforcement learning (RL) building blocks that strives to +expose simple, efficient, and readable agents. These agents first and foremost +serve both as reference implementations as well as providing strong baselines +for algorithm performance. However, the baseline agents exposed by Acme should +also provide enough flexibility and simplicity that they can be used as a +starting block for novel research. Finally, the building blocks of Acme are +designed in such a way that the agents can be written at multiple scales (e.g. +single-stream vs. distributed agents). + +## Getting started + +The quickest way to get started is to take a look at the detailed working code +examples found in the [examples] subdirectory. These show how to instantiate a +number of different agents and run them within a variety of environments. See +the [quickstart notebook][Quickstart] for an even quicker dive into using a +single agent. Even more detail on the internal construction of an agent can be +found inside our [tutorial notebook][Tutorial]. Finally, a full description Acme +and its underlying components can be found by referring to the [documentation]. +More background information and details behind the design decisions can be found +in our [technical report][Paper]. + +> NOTE: Acme is first and foremost a framework for RL research written by +> researchers, for researchers. We use it for our own work on a daily basis. So +> with that in mind, while we will make every attempt to keep everything in good +> working order, things may break occasionally. But if so we will make our best +> effort to fix them as quickly as possible! + +[examples]: examples/ +[tutorial]: https://github.com/deepmind/acme/blob/master/examples/tutorial.ipynb +[quickstart]: https://github.com/deepmind/acme/blob/master/examples/quickstart.ipynb +[documentation]: https://dm-acme.readthedocs.io/ +[paper]: https://arxiv.org/abs/2006.00979 + +## Installation + +We have tested Acme on Python 3.8 and 3.9. To get up and running quickly just +follow the steps below: + +1. While you can install Acme in your standard python environment, we + *strongly* recommend using a + [Python virtual environment](https://docs.python.org/3/tutorial/venv.html) + to manage your dependencies. This should help to avoid version conflicts and + just generally make the installation process easier. + + ```bash + python3 -m venv acme + source acme/bin/activate + pip install --upgrade pip setuptools wheel + ``` + +1. While the core `dm-acme` library can be installed directly, the set of + dependencies included for installation is minimal. In particular, to run any + of the included agents you will also need either [JAX] or [TensorFlow] + depending on the agent. As a result we recommend installing these components + as well, i.e. + + ```bash + pip install dm-acme[jax,tensorflow] + ``` + +1. Finally, to install a few example environments (including [gym], + [dm_control], and [bsuite]): + + ```bash + pip install dm-acme[envs] + ``` + +1. **Installing from github**: if you're interested in running the + bleeding-edge version of Acme, you can do so by cloning the Acme GitHub + repository and then executing following command from the main directory + (where `setup.py` is located): + + ```bash + pip install .[jax,tf,testing,envs] + ``` + +## Citing Acme + +If you use Acme in your work, please cite the accompanying +[technical report][paper]: + +```bibtex +@article{hoffman2020acme, + title={Acme: A Research Framework for Distributed Reinforcement Learning}, + author={Matt Hoffman and Bobak Shahriari and John Aslanides and Gabriel + Barth-Maron and Feryal Behbahani and Tamara Norman and Abbas Abdolmaleki + and Albin Cassirer and Fan Yang and Kate Baumli and Sarah Henderson and + Alex Novikov and Sergio Gómez Colmenarejo and Serkan Cabi and Caglar + Gulcehre and Tom Le Paine and Andrew Cowie and Ziyu Wang and Bilal Piot + and Nando de Freitas}, + year={2020}, + journal={arXiv preprint arXiv:2006.00979}, + url={https://arxiv.org/abs/2006.00979}, +} +``` + +[JAX]: https://github.com/google/jax +[TensorFlow]: https://tensorflow.org +[gym]: https://github.com/openai/gym +[dm_control]: https://github.com/deepmind/dm_env +[dm_env]: https://github.com/deepmind/dm_env +[bsuite]: https://github.com/deepmind/bsuite diff --git a/acme/acme/__init__.py b/acme/acme/__init__.py new file mode 100644 index 00000000..5b3ac31c --- /dev/null +++ b/acme/acme/__init__.py @@ -0,0 +1,38 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Acme is a framework for reinforcement learning.""" + +# Internal import. + +# Expose specs and types modules. +from acme import specs +from acme import types + +# Make __version__ accessible. +from acme._metadata import __version__ + +# Expose core interfaces. +from acme.core import Actor +from acme.core import Learner +from acme.core import Saveable +from acme.core import VariableSource +from acme.core import Worker + +# Expose the environment loop. +from acme.environment_loop import EnvironmentLoop + +from acme.specs import make_environment_spec + +# Acme loves you. \ No newline at end of file diff --git a/acme/acme/_metadata.py b/acme/acme/_metadata.py new file mode 100644 index 00000000..97ea5a98 --- /dev/null +++ b/acme/acme/_metadata.py @@ -0,0 +1,27 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Package metadata for acme. + +This is kept in a separate module so that it can be imported from setup.py, at +a time when acme's dependencies may not have been installed yet. +""" + +# We follow Semantic Versioning (https://semver.org/) +_MAJOR_VERSION = '0' +_MINOR_VERSION = '4' +_PATCH_VERSION = '1' + +# Example: '0.4.2' +__version__ = '.'.join([_MAJOR_VERSION, _MINOR_VERSION, _PATCH_VERSION]) diff --git a/acme/acme/adders/__init__.py b/acme/acme/adders/__init__.py new file mode 100644 index 00000000..5d08479a --- /dev/null +++ b/acme/acme/adders/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Adders for sending data from actors to replay buffers.""" + +# pylint: disable=unused-import + +from acme.adders.base import Adder +from acme.adders.wrappers import ForkingAdder +from acme.adders.wrappers import IgnoreExtrasAdder diff --git a/acme/acme/adders/base.py b/acme/acme/adders/base.py new file mode 100644 index 00000000..7067e873 --- /dev/null +++ b/acme/acme/adders/base.py @@ -0,0 +1,82 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Interface for adders which transmit data to a replay buffer.""" + +import abc + +from acme import types +import dm_env + + +class Adder(abc.ABC): + """The Adder interface. + + An adder packs together data to send to the replay buffer, and potentially + performs some reduction/transformation to this data in the process. + + All adders will use this API. Below is an illustrative example of how they + are intended to be used in a typical RL run-loop. We assume that the + environment conforms to the dm_env environment API. + + ```python + # Reset the environment and add the first observation. + timestep = env.reset() + adder.add_first(timestep.observation) + + while not timestep.last(): + # Generate an action from the policy and step the environment. + action = my_policy(timestep) + timestep = env.step(action) + + # Add the action and the resulting timestep. + adder.add(action, next_timestep=timestep) + ``` + + Note that for all adders, the `add()` method expects an action taken and the + *resulting* timestep observed after taking this action. Note that this + timestep is named `next_timestep` precisely to emphasize this point. + """ + + @abc.abstractmethod + def add_first(self, timestep: dm_env.TimeStep): + """Defines the interface for an adder's `add_first` method. + + We expect this to be called at the beginning of each episode and it will + start a trajectory to be added to replay with an initial observation. + + Args: + timestep: a dm_env TimeStep corresponding to the first step. + """ + + @abc.abstractmethod + def add( + self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + extras: types.NestedArray = (), + ): + """Defines the adder `add` interface. + + Args: + action: A possibly nested structure corresponding to a_t. + next_timestep: A dm_env Timestep object corresponding to the resulting + data obtained by taking the given action. + extras: A possibly nested structure of extra data to add to replay. + """ + + @abc.abstractmethod + def reset(self): + """Resets the adder's buffer.""" + diff --git a/acme/acme/adders/reverb/__init__.py b/acme/acme/adders/reverb/__init__.py new file mode 100644 index 00000000..5ba789ff --- /dev/null +++ b/acme/acme/adders/reverb/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Adders for Reverb replay buffers.""" + +# pylint: disable=unused-import + +from acme.adders.reverb.base import DEFAULT_PRIORITY_TABLE +from acme.adders.reverb.base import PriorityFn +from acme.adders.reverb.base import PriorityFnInput +from acme.adders.reverb.base import ReverbAdder +from acme.adders.reverb.base import Step +from acme.adders.reverb.base import Trajectory + +from acme.adders.reverb.episode import EpisodeAdder +from acme.adders.reverb.sequence import EndBehavior +from acme.adders.reverb.sequence import SequenceAdder +from acme.adders.reverb.transition import NStepTransitionAdder diff --git a/acme/acme/adders/reverb/base.py b/acme/acme/adders/reverb/base.py new file mode 100644 index 00000000..57612e69 --- /dev/null +++ b/acme/acme/adders/reverb/base.py @@ -0,0 +1,253 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Adders that use Reverb (github.com/deepmind/reverb) as a backend.""" + +import abc +import time +from typing import Callable, Iterable, Mapping, NamedTuple, Optional, Sized, Union, Tuple + +from absl import logging +from acme import specs +from acme import types +from acme.adders import base +import dm_env +import numpy as np +import reverb +import tensorflow as tf +import tree + +DEFAULT_PRIORITY_TABLE = 'priority_table' +_MIN_WRITER_LIFESPAN_SECONDS = 60 +StartOfEpisodeType = Union[bool, specs.Array, tf.Tensor, tf.TensorSpec, + Tuple[()]] + + +# TODO(b/188510142): Delete Step. +class Step(NamedTuple): + """Step class used internally for reverb adders.""" + observation: types.NestedArray + action: types.NestedArray + reward: types.NestedArray + discount: types.NestedArray + start_of_episode: StartOfEpisodeType + extras: types.NestedArray = () + + +# TODO(b/188510142): Replace with proper Trajectory class. +Trajectory = Step + + +class PriorityFnInput(NamedTuple): + """The input to a priority function consisting of stacked steps.""" + observations: types.NestedArray + actions: types.NestedArray + rewards: types.NestedArray + discounts: types.NestedArray + start_of_episode: types.NestedArray + extras: types.NestedArray + + +# Define the type of a priority function and the mapping from table to function. +PriorityFn = Callable[['PriorityFnInput'], float] +PriorityFnMapping = Mapping[str, Optional[PriorityFn]] + + +def spec_like_to_tensor_spec(paths: Iterable[str], spec: specs.Array): + return tf.TensorSpec.from_spec(spec, name='/'.join(str(p) for p in paths)) + + +class ReverbAdder(base.Adder): + """Base class for Reverb adders.""" + + def __init__( + self, + client: reverb.Client, + max_sequence_length: int, + max_in_flight_items: int, + delta_encoded: bool = False, + priority_fns: Optional[PriorityFnMapping] = None, + validate_items: bool = True, + ): + """Initialize a ReverbAdder instance. + + Args: + client: A client to the Reverb backend. + max_sequence_length: The maximum length of sequences (corresponding to the + number of observations) that can be added to replay. + max_in_flight_items: The maximum number of items allowed to be "in flight" + at the same time. See `block_until_num_items` in + `reverb.TrajectoryWriter.flush` for more info. + delta_encoded: If `True` (False by default) enables delta encoding, see + `Client` for more information. + priority_fns: A mapping from table names to priority functions; if + omitted, all transitions/steps/sequences are given uniform priorities + (1.0) and placed in DEFAULT_PRIORITY_TABLE. + validate_items: Whether to validate items against the table signature + before they are sent to the server. This requires table signature to be + fetched from the server and cached locally. + """ + if priority_fns: + priority_fns = dict(priority_fns) + else: + priority_fns = {DEFAULT_PRIORITY_TABLE: None} + + self._client = client + self._priority_fns = priority_fns + self._max_sequence_length = max_sequence_length + self._delta_encoded = delta_encoded + # TODO(b/206629159): Remove this. + self._max_in_flight_items = max_in_flight_items + self._add_first_called = False + + # This is exposed as the _writer property in such a way that it will create + # a new writer automatically whenever the internal __writer is None. Users + # should ONLY ever interact with self._writer. + self.__writer = None + # Every time a new writer is created, it must fetch the signature from the + # Reverb server. If this is set too low it can crash the adders in a + # distributed setup where the replay may take a while to spin up. + self._validate_items = validate_items + + def __del__(self): + if self.__writer is not None: + timeout_ms = 10_000 + # Try flush all appended data before closing to avoid loss of experience. + try: + self.__writer.flush(0, timeout_ms=timeout_ms) + except reverb.DeadlineExceededError as e: + logging.error( + 'Timeout (%d ms) exceeded when flushing the writer before ' + 'deleting it. Caught Reverb exception: %s', timeout_ms, str(e)) + self.__writer.close() + + @property + def _writer(self) -> reverb.TrajectoryWriter: + if self.__writer is None: + self.__writer = self._client.trajectory_writer( + num_keep_alive_refs=self._max_sequence_length, + validate_items=self._validate_items) + self._writer_created_timestamp = time.time() + return self.__writer + + def add_priority_table(self, table_name: str, + priority_fn: Optional[PriorityFn]): + if table_name in self._priority_fns: + raise ValueError( + f'A priority function already exists for {table_name}. ' + f'Existing tables: {", ".join(self._priority_fns.keys())}.' + ) + self._priority_fns[table_name] = priority_fn + + def reset(self, timeout_ms: Optional[int] = None): + """Resets the adder's buffer.""" + if self.__writer: + # Flush all appended data and clear the buffers. + self.__writer.end_episode(clear_buffers=True, timeout_ms=timeout_ms) + + # Create a new writer unless the current one is too young. + # This is to reduce the relative overhead of creating a new Reverb writer. + if (time.time() - self._writer_created_timestamp > + _MIN_WRITER_LIFESPAN_SECONDS): + self.__writer = None + self._add_first_called = False + + def add_first(self, timestep: dm_env.TimeStep): + """Record the first observation of a trajectory.""" + if not timestep.first(): + raise ValueError('adder.add_first with an initial timestep (i.e. one for ' + 'which timestep.first() is True') + + # Record the next observation but leave the history buffer row open by + # passing `partial_step=True`. + self._writer.append(dict(observation=timestep.observation, + start_of_episode=timestep.first()), + partial_step=True) + self._add_first_called = True + + def add(self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + extras: types.NestedArray = ()): + """Record an action and the following timestep.""" + + if not self._add_first_called: + raise ValueError('adder.add_first must be called before adder.add.') + + # Add the timestep to the buffer. + has_extras = (len(extras) > 0 if isinstance(extras, Sized) # pylint: disable=g-explicit-length-test + else extras is not None) + current_step = dict( + # Observation was passed at the previous add call. + action=action, + reward=next_timestep.reward, + discount=next_timestep.discount, + # Start of episode indicator was passed at the previous add call. + **({'extras': extras} if has_extras else {}) + ) + self._writer.append(current_step) + + # Record the next observation and write. + self._writer.append( + dict( + observation=next_timestep.observation, + start_of_episode=next_timestep.first()), + partial_step=True) + self._write() + + if next_timestep.last(): + # Complete the row by appending zeros to remaining open fields. + # TODO(b/183945808): remove this when fields are no longer expected to be + # of equal length on the learner side. + dummy_step = tree.map_structure(np.zeros_like, current_step) + self._writer.append(dummy_step) + self._write_last() + self.reset() + + @classmethod + def signature(cls, environment_spec: specs.EnvironmentSpec, + extras_spec: types.NestedSpec = ()): + """This is a helper method for generating signatures for Reverb tables. + + Signatures are useful for validating data types and shapes, see Reverb's + documentation for details on how they are used. + + Args: + environment_spec: A `specs.EnvironmentSpec` whose fields are nested + structures with leaf nodes that have `.shape` and `.dtype` attributes. + This should come from the environment that will be used to generate + the data inserted into the Reverb table. + extras_spec: A nested structure with leaf nodes that have `.shape` and + `.dtype` attributes. The structure (and shapes/dtypes) of this must + be the same as the `extras` passed into `ReverbAdder.add`. + + Returns: + A `Step` whose leaf nodes are `tf.TensorSpec` objects. + """ + spec_step = Step( + observation=environment_spec.observations, + action=environment_spec.actions, + reward=environment_spec.rewards, + discount=environment_spec.discounts, + start_of_episode=specs.Array(shape=(), dtype=bool), + extras=extras_spec) + return tree.map_structure_with_path(spec_like_to_tensor_spec, spec_step) + + @abc.abstractmethod + def _write(self): + """Write data to replay from the buffer.""" + + @abc.abstractmethod + def _write_last(self): + """Write data to replay from the buffer.""" diff --git a/acme/acme/adders/reverb/episode.py b/acme/acme/adders/reverb/episode.py new file mode 100644 index 00000000..2b5f6448 --- /dev/null +++ b/acme/acme/adders/reverb/episode.py @@ -0,0 +1,151 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Episode adders. + +This implements full episode adders, potentially with padding. +""" + +from typing import Callable, Optional, Iterable, Tuple + +from acme import specs +from acme import types +from acme.adders.reverb import base +from acme.adders.reverb import utils + +import dm_env +import numpy as np +import reverb +import tensorflow as tf +import tree + +_PaddingFn = Callable[[Tuple[int, ...], np.dtype], np.ndarray] + + +class EpisodeAdder(base.ReverbAdder): + """Adder which adds entire episodes as trajectories.""" + + def __init__( + self, + client: reverb.Client, + max_sequence_length: int, + delta_encoded: bool = False, + priority_fns: Optional[base.PriorityFnMapping] = None, + max_in_flight_items: int = 1, + padding_fn: Optional[_PaddingFn] = None, + # Deprecated kwargs. + chunk_length: Optional[int] = None, + ): + del chunk_length + + super().__init__( + client=client, + max_sequence_length=max_sequence_length, + delta_encoded=delta_encoded, + priority_fns=priority_fns, + max_in_flight_items=max_in_flight_items, + ) + self._padding_fn = padding_fn + + def add( + self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + extras: types.NestedArray = (), + ): + if self._writer.episode_steps >= self._max_sequence_length - 1: + raise ValueError( + 'The number of observations within the same episode will exceed ' + 'max_sequence_length with the addition of this transition.') + + super().add(action, next_timestep, extras) + + def _write(self): + # This adder only writes at the end of the episode, see _write_last() + pass + + def _write_last(self): + if self._padding_fn is not None and self._writer.episode_steps < self._max_sequence_length: + history = self._writer.history + padding_step = dict( + observation=history['observation'], + action=history['action'], + reward=history['reward'], + discount=history['discount'], + extras=history.get('extras', ())) + # Get shapes and dtypes from the last element. + padding_step = tree.map_structure( + lambda col: self._padding_fn(col[-1].shape, col[-1].dtype), + padding_step) + padding_step['start_of_episode'] = False + while self._writer.episode_steps < self._max_sequence_length: + self._writer.append(padding_step) + + trajectory = tree.map_structure(lambda x: x[:], self._writer.history) + + # Pack the history into a base.Step structure and get numpy converted + # variant for priotiy computation. + trajectory = base.Trajectory(**trajectory) + + # Calculate the priority for this episode. + table_priorities = utils.calculate_priorities(self._priority_fns, + trajectory) + + # Create a prioritized item for each table. + for table_name, priority in table_priorities.items(): + self._writer.create_item(table_name, priority, trajectory) + self._writer.flush(self._max_in_flight_items) + + # TODO(b/185309817): make this into a standalone method. + @classmethod + def signature(cls, + environment_spec: specs.EnvironmentSpec, + extras_spec: types.NestedSpec = (), + sequence_length: Optional[int] = None): + """This is a helper method for generating signatures for Reverb tables. + + Signatures are useful for validating data types and shapes, see Reverb's + documentation for details on how they are used. + + Args: + environment_spec: A `specs.EnvironmentSpec` whose fields are nested + structures with leaf nodes that have `.shape` and `.dtype` attributes. + This should come from the environment that will be used to generate the + data inserted into the Reverb table. + extras_spec: A nested structure with leaf nodes that have `.shape` and + `.dtype` attributes. The structure (and shapes/dtypes) of this must be + the same as the `extras` passed into `ReverbAdder.add`. + sequence_length: An optional integer representing the expected length of + sequences that will be added to replay. + + Returns: + A `Step` whose leaf nodes are `tf.TensorSpec` objects. + """ + + def add_time_dim(paths: Iterable[str], spec: tf.TensorSpec): + return tf.TensorSpec( + shape=(sequence_length, *spec.shape), + dtype=spec.dtype, + name='/'.join(str(p) for p in paths)) + + trajectory_env_spec, trajectory_extras_spec = tree.map_structure_with_path( + add_time_dim, (environment_spec, extras_spec)) + + trajectory_spec = base.Trajectory( + *trajectory_env_spec, + start_of_episode=tf.TensorSpec( + shape=(sequence_length,), dtype=tf.bool, name='start_of_episode'), + extras=trajectory_extras_spec) + + return trajectory_spec diff --git a/acme/acme/adders/reverb/episode_test.py b/acme/acme/adders/reverb/episode_test.py new file mode 100644 index 00000000..05d1a8e2 --- /dev/null +++ b/acme/acme/adders/reverb/episode_test.py @@ -0,0 +1,111 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Episode adders.""" + +from acme.adders.reverb import episode as adders +from acme.adders.reverb import test_utils +import numpy as np + +from absl.testing import absltest +from absl.testing import parameterized + + +class EpisodeAdderTest(test_utils.AdderTestMixin, parameterized.TestCase): + + @parameterized.parameters(2, 10, 50) + def test_adder(self, max_sequence_length): + adder = adders.EpisodeAdder(self.client, max_sequence_length) + + # Create a simple trajectory to add. + observations = range(max_sequence_length) + first, steps = test_utils.make_trajectory(observations) + + expected_episode = test_utils.make_sequence(observations) + self.run_test_adder( + adder=adder, + first=first, + steps=steps, + expected_items=[expected_episode], + signature=adder.signature(*test_utils.get_specs(steps[0]))) + + @parameterized.parameters(2, 10, 50) + def test_max_sequence_length(self, max_sequence_length): + adder = adders.EpisodeAdder(self.client, max_sequence_length) + + first, steps = test_utils.make_trajectory(range(max_sequence_length + 1)) + adder.add_first(first) + for action, step in steps[:-1]: + adder.add(action, step) + + # We should have max_sequence_length-1 timesteps that have been written, + # where the -1 is due to the dangling observation (ie we have actually + # seen max_sequence_length observations). + self.assertEqual(self.num_items(), 0) + + # Adding one more step should raise an error. + with self.assertRaises(ValueError): + action, step = steps[-1] + adder.add(action, step) + + # Since the last insert failed it should not affect the internal state. + self.assertEqual(self.num_items(), 0) + + @parameterized.parameters((2, 1), (10, 2), (50, 5)) + def test_padding(self, max_sequence_length, padding): + adder = adders.EpisodeAdder( + self.client, + max_sequence_length + padding, + padding_fn=np.zeros) + + # Create a simple trajectory to add. + observations = range(max_sequence_length) + first, steps = test_utils.make_trajectory(observations) + + expected_episode = test_utils.make_sequence(observations) + for _ in range(padding): + expected_episode.append((0, 0, 0.0, 0.0, False, ())) + + self.run_test_adder( + adder=adder, + first=first, + steps=steps, + expected_items=[expected_episode], + signature=adder.signature(*test_utils.get_specs(steps[0]))) + + @parameterized.parameters((2, 1), (10, 2), (50, 5)) + def test_nonzero_padding(self, max_sequence_length, padding): + adder = adders.EpisodeAdder( + self.client, + max_sequence_length + padding, + padding_fn=lambda s, d: np.zeros(s, d) - 1) + + # Create a simple trajectory to add. + observations = range(max_sequence_length) + first, steps = test_utils.make_trajectory(observations) + + expected_episode = test_utils.make_sequence(observations) + for _ in range(padding): + expected_episode.append((-1, -1, -1.0, -1.0, False, ())) + + self.run_test_adder( + adder=adder, + first=first, + steps=steps, + expected_items=[expected_episode], + signature=adder.signature(*test_utils.get_specs(steps[0]))) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/adders/reverb/sequence.py b/acme/acme/adders/reverb/sequence.py new file mode 100644 index 00000000..7d0669e4 --- /dev/null +++ b/acme/acme/adders/reverb/sequence.py @@ -0,0 +1,296 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sequence adders. + +This implements adders which add sequences or partial trajectories. +""" + +import enum +import operator +from typing import Iterable, Optional + +from acme import specs +from acme import types +from acme.adders.reverb import base +from acme.adders.reverb import utils + +import numpy as np +import reverb +import tensorflow as tf +import tree + + +class EndBehavior(enum.Enum): + """Class to enumerate available options for writing behavior at episode ends. + + Example: + + sequence_length = 3 + period = 2 + + Episode steps (digits) and writing events (W): + + 1 2 3 4 5 6 + W W + + First two sequences: + + 1 2 3 + . . 3 4 5 + + Written sequences for the different end of episode behaviors: + Here are the last written sequences for each end of episode behavior: + + WRITE . . . 4 5 6 + CONTINUE . . . . 5 6 F + ZERO_PAD . . . . 5 6 0 + TRUNCATE . . . . 5 6 + + Key: + F: First step of the next episode + 0: Zero-filled Step + """ + WRITE = 'write_buffer' + CONTINUE = 'continue_to_next_episode' + ZERO_PAD = 'zero_pad_til_next_write' + TRUNCATE = 'write_truncated_buffer' + + +class SequenceAdder(base.ReverbAdder): + """An adder which adds sequences of fixed length.""" + + def __init__( + self, + client: reverb.Client, + sequence_length: int, + period: int, + *, + delta_encoded: bool = False, + priority_fns: Optional[base.PriorityFnMapping] = None, + max_in_flight_items: Optional[int] = 2, + end_of_episode_behavior: Optional[EndBehavior] = None, + # Deprecated kwargs. + chunk_length: Optional[int] = None, + pad_end_of_episode: Optional[bool] = None, + break_end_of_episode: Optional[bool] = None, + validate_items: bool = True, + ): + """Makes a SequenceAdder instance. + + Args: + client: See docstring for BaseAdder. + sequence_length: The fixed length of sequences we wish to add. + period: The period with which we add sequences. If less than + sequence_length, overlapping sequences are added. If equal to + sequence_length, sequences are exactly non-overlapping. + delta_encoded: If `True` (False by default) enables delta encoding, see + `Client` for more information. + priority_fns: See docstring for BaseAdder. + max_in_flight_items: The maximum number of items allowed to be "in flight" + at the same time. See `block_until_num_items` in + `reverb.TrajectoryWriter.flush` for more info. + end_of_episode_behavior: Determines how sequences at the end of the + episode are handled (default `EndOfEpisodeBehavior.ZERO_PAD`). See + the docstring for `EndOfEpisodeBehavior` for more information. + chunk_length: Deprecated and unused. + pad_end_of_episode: If True (default) then upon end of episode the current + sequence will be padded (with observations, actions, etc... whose values + are 0) until its length is `sequence_length`. If False then the last + sequence in the episode may have length less than `sequence_length`. + break_end_of_episode: If 'False' (True by default) does not break + sequences on env reset. In this case 'pad_end_of_episode' is not used. + validate_items: Whether to validate items against the table signature + before they are sent to the server. This requires table signature to be + fetched from the server and cached locally. + """ + del chunk_length + super().__init__( + client=client, + # We need an additional space in the buffer for the partial step the + # base.ReverbAdder will add with the next observation. + max_sequence_length=sequence_length+1, + delta_encoded=delta_encoded, + priority_fns=priority_fns, + max_in_flight_items=max_in_flight_items, + validate_items=validate_items) + + if pad_end_of_episode and not break_end_of_episode: + raise ValueError( + 'Can\'t set pad_end_of_episode=True and break_end_of_episode=False at' + ' the same time, since those behaviors are incompatible.') + + self._period = period + self._sequence_length = sequence_length + + if end_of_episode_behavior and (pad_end_of_episode is not None or + break_end_of_episode is not None): + raise ValueError( + 'Using end_of_episode_behavior and either ' + 'pad_end_of_episode or break_end_of_episode is not permitted. ' + 'Please use only end_of_episode_behavior instead.') + + # Set pad_end_of_episode and break_end_of_episode to default values. + if end_of_episode_behavior is None and pad_end_of_episode is None: + pad_end_of_episode = True + if end_of_episode_behavior is None and break_end_of_episode is None: + break_end_of_episode = True + + self._end_of_episode_behavior = EndBehavior.ZERO_PAD + if pad_end_of_episode is not None or break_end_of_episode is not None: + if not break_end_of_episode: + self._end_of_episode_behavior = EndBehavior.CONTINUE + elif break_end_of_episode and pad_end_of_episode: + self._end_of_episode_behavior = EndBehavior.ZERO_PAD + elif break_end_of_episode and not pad_end_of_episode: + self._end_of_episode_behavior = EndBehavior.TRUNCATE + else: + raise ValueError( + 'Reached an unexpected configuration of the SequenceAdder ' + f'with break_end_of_episode={break_end_of_episode} ' + f'and pad_end_of_episode={pad_end_of_episode}.') + elif isinstance(end_of_episode_behavior, EndBehavior): + self._end_of_episode_behavior = end_of_episode_behavior + else: + raise ValueError('end_of_episod_behavior must be an instance of ' + f'EndBehavior, received {end_of_episode_behavior}.') + + def reset(self): + """Resets the adder's buffer.""" + # If we do not write on end of episode, we should not reset the writer. + if self._end_of_episode_behavior is EndBehavior.CONTINUE: + return + + super().reset() + + def _write(self): + self._maybe_create_item(self._sequence_length) + + def _write_last(self): + # Maybe determine the delta to the next time we would write a sequence. + if self._end_of_episode_behavior in (EndBehavior.TRUNCATE, + EndBehavior.ZERO_PAD): + delta = self._sequence_length - self._writer.episode_steps + if delta < 0: + delta = (self._period + delta) % self._period + + # Handle various end-of-episode cases. + if self._end_of_episode_behavior is EndBehavior.CONTINUE: + self._maybe_create_item(self._sequence_length, end_of_episode=True) + + elif self._end_of_episode_behavior is EndBehavior.WRITE: + # Drop episodes that are too short. + if self._writer.episode_steps < self._sequence_length: + return + self._maybe_create_item( + self._sequence_length, end_of_episode=True, force=True) + + elif self._end_of_episode_behavior is EndBehavior.TRUNCATE: + self._maybe_create_item( + self._sequence_length - delta, + end_of_episode=True, + force=True) + + elif self._end_of_episode_behavior is EndBehavior.ZERO_PAD: + zero_step = tree.map_structure(lambda x: np.zeros_like(x[-2].numpy()), + self._writer.history) + for _ in range(delta): + self._writer.append(zero_step) + + self._maybe_create_item( + self._sequence_length, end_of_episode=True, force=True) + else: + raise ValueError( + f'Unhandled end of episode behavior: {self._end_of_episode_behavior}.' + ' This should never happen, please contact Acme dev team.') + + def _maybe_create_item(self, + sequence_length: int, + *, + end_of_episode: bool = False, + force: bool = False): + + # Check conditions under which a new item is created. + first_write = self._writer.episode_steps == sequence_length + # NOTE(bshahr): the following line assumes that the only way sequence_length + # is less than self._sequence_length, is if the episode is shorter than + # self._sequence_length. + period_reached = ( + self._writer.episode_steps > self._sequence_length and + ((self._writer.episode_steps - self._sequence_length) % self._period + == 0)) + + if not first_write and not period_reached and not force: + return + + # TODO(b/183945808): will need to change to adhere to the new protocol. + if not end_of_episode: + get_traj = operator.itemgetter(slice(-sequence_length - 1, -1)) + else: + get_traj = operator.itemgetter(slice(-sequence_length, None)) + + history = self._writer.history + trajectory = base.Trajectory(**tree.map_structure(get_traj, history)) + + # Compute priorities for the buffer. + table_priorities = utils.calculate_priorities(self._priority_fns, + trajectory) + + # Create a prioritized item for each table. + for table_name, priority in table_priorities.items(): + self._writer.create_item(table_name, priority, trajectory) + self._writer.flush(self._max_in_flight_items) + + # TODO(bshahr): make this into a standalone method. Class methods should be + # used as alternative constructors or when modifying some global state, + # neither of which is done here. + @classmethod + def signature(cls, environment_spec: specs.EnvironmentSpec, + extras_spec: types.NestedSpec = (), + sequence_length: Optional[int] = None): + """This is a helper method for generating signatures for Reverb tables. + + Signatures are useful for validating data types and shapes, see Reverb's + documentation for details on how they are used. + + Args: + environment_spec: A `specs.EnvironmentSpec` whose fields are nested + structures with leaf nodes that have `.shape` and `.dtype` attributes. + This should come from the environment that will be used to generate + the data inserted into the Reverb table. + extras_spec: A nested structure with leaf nodes that have `.shape` and + `.dtype` attributes. The structure (and shapes/dtypes) of this must + be the same as the `extras` passed into `ReverbAdder.add`. + sequence_length: An optional integer representing the expected length of + sequences that will be added to replay. + + Returns: + A `Trajectory` whose leaf nodes are `tf.TensorSpec` objects. + """ + + def add_time_dim(paths: Iterable[str], spec: tf.TensorSpec): + return tf.TensorSpec(shape=(sequence_length, *spec.shape), + dtype=spec.dtype, + name='/'.join(str(p) for p in paths)) + + trajectory_env_spec, trajectory_extras_spec = tree.map_structure_with_path( + add_time_dim, (environment_spec, extras_spec)) + + spec_step = base.Trajectory( + *trajectory_env_spec, + start_of_episode=tf.TensorSpec( + shape=(sequence_length,), dtype=tf.bool, name='start_of_episode'), + extras=trajectory_extras_spec) + + return spec_step diff --git a/acme/acme/adders/reverb/sequence_test.py b/acme/acme/adders/reverb/sequence_test.py new file mode 100644 index 00000000..d50f1250 --- /dev/null +++ b/acme/acme/adders/reverb/sequence_test.py @@ -0,0 +1,68 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for sequence adders.""" + +from acme.adders.reverb import sequence as adders +from acme.adders.reverb import test_cases +from acme.adders.reverb import test_utils + +from absl.testing import absltest +from absl.testing import parameterized + + +class SequenceAdderTest(test_utils.AdderTestMixin, parameterized.TestCase): + + @parameterized.named_parameters(*test_cases.TEST_CASES_FOR_SEQUENCE_ADDER) + def test_adder(self, + sequence_length: int, + period: int, + first, + steps, + expected_sequences, + end_behavior: adders.EndBehavior = adders.EndBehavior.ZERO_PAD, + repeat_episode_times: int = 1): + adder = adders.SequenceAdder( + self.client, + sequence_length=sequence_length, + period=period, + end_of_episode_behavior=end_behavior) + super().run_test_adder( + adder=adder, + first=first, + steps=steps, + expected_items=expected_sequences, + repeat_episode_times=repeat_episode_times, + end_behavior=end_behavior, + signature=adder.signature(*test_utils.get_specs(steps[0]))) + + @parameterized.parameters( + (True, True, adders.EndBehavior.ZERO_PAD), + (False, True, adders.EndBehavior.TRUNCATE), + (False, False, adders.EndBehavior.CONTINUE), + ) + def test_end_of_episode_behavior_set_correctly(self, pad_end_of_episode, + break_end_of_episode, + expected_behavior): + adder = adders.SequenceAdder( + self.client, + sequence_length=5, + period=3, + pad_end_of_episode=pad_end_of_episode, + break_end_of_episode=break_end_of_episode) + self.assertEqual(adder._end_of_episode_behavior, expected_behavior) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/adders/reverb/structured.py b/acme/acme/adders/reverb/structured.py new file mode 100644 index 00000000..b99f1715 --- /dev/null +++ b/acme/acme/adders/reverb/structured.py @@ -0,0 +1,424 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic adders that wraps Reverb's StructuredWriter.""" + +import itertools +import time + +from typing import Callable, List, Optional, Sequence, Sized + +from absl import logging +from acme import specs +from acme import types +from acme.adders import base as adders_base +from acme.adders.reverb import base as reverb_base +from acme.adders.reverb import sequence as sequence_adder +import dm_env +import numpy as np +import reverb +from reverb import structured_writer as sw +import tensorflow as tf +import tree + +Step = reverb_base.Step +Trajectory = reverb_base.Trajectory +EndBehavior = sequence_adder.EndBehavior + +_RESET_WRITER_EVERY_SECONDS = 60 + + +class StructuredAdder(adders_base.Adder): + """Generic Adder which writes to Reverb using Reverb's `StructuredWriter`. + + The StructuredAdder is a thin wrapper around Reverb's `StructuredWriter` and + its behaviour is determined by the configs to __init__. Much of the behaviour + provided by other Adders can be replicated using `StructuredAdder` but there + are a few noteworthy differences: + + * The behaviour of `StructuredAdder` can be thought of as the union of all + its configs. This means that a single adder is capable of inserting items + of different structures into any number of tables WITHOUT any data + duplication. Other adders are only capable of writing items of the same + structure into multiple tables. + * The complete structure of the step must be known at construction time when + using the StructuredAdder. This is not the case for other Adders as they + allow the structure of the step to become expanded over time. + * The `StructuredAdder` assigns all items the same priority (1.0) as it does + not currently support custom priority computations. + * The StructuredAdder is completely generic and thus does not perform any + preprocessing on the data (e.g. cumulative rewards as done by the + NStepTransitionAdder) before writing it to Reverb. The user is instead + expected to perform preprocessing in the dataset pipeline on the learner. + """ + + def __init__(self, client: reverb.Client, max_in_flight_items: int, + configs: Sequence[sw.Config], step_spec: Step): + """Initialize a StructuredAdder instance. + + Args: + client: A client to the Reverb backend. + max_in_flight_items: The maximum number of items allowed to be "in flight" + at the same time. See `block_until_num_items` in + `reverb.TrajectoryWriter.flush` for more info. + configs: Configurations defining the behaviour of the wrapped Reverb + writer. + step_spec: spec of the step that is going to be inserted in the Adder. It + can be created with `create_step_spec` using the environment spec and + and the extras spec. + """ + + # We validate the configs by attempting to infer the signatures of all + # targeted tables. + for table, table_configs in itertools.groupby(configs, lambda c: c.table): + try: + sw.infer_signature(list(table_configs), step_spec) + except ValueError as e: + raise ValueError( + f'Received invalid configs for table {table}: {str(e)}') from e + + self._client = client + self._configs = tuple(configs) + self._none_step: Step = tree.map_structure(lambda _: None, step_spec) + self._max_in_flight_items = max_in_flight_items + + self._writer = None + self._writer_created_at = None + + def __del__(self): + if self._writer is None: + return + + # Try flush all appended data before closing to avoid loss of experience. + try: + self._writer.flush(0, timeout_ms=10_000) + except reverb.DeadlineExceededError as e: + logging.error( + 'Timeout (10 s) exceeded when flushing the writer before ' + 'deleting it. Caught Reverb exception: %s', str(e)) + + def _make_step(self, **kwargs) -> Step: + """Complete the step with None in the missing positions.""" + return Step(**{**self._none_step._asdict(), **kwargs}) + + @property + def configs(self): + return self._configs + + def reset(self, timeout_ms: Optional[int] = None): + """Marks the active episode as completed and flushes pending items.""" + if self._writer is not None: + # Flush all pending items. + self._writer.end_episode(clear_buffers=True, timeout_ms=timeout_ms) + + # Create a new writer unless the current one is too young. + # This is to reduce the relative overhead of creating a new Reverb writer. + if time.time() - self._writer_created_at > _RESET_WRITER_EVERY_SECONDS: + self._writer = None + + def add_first(self, timestep: dm_env.TimeStep): + """Record the first observation of an episode.""" + if not timestep.first(): + raise ValueError( + 'adder.add_first called with a timestep that was not the first of its' + 'episode (i.e. one for which timestep.first() is not True)') + + if self._writer is None: + self._writer = self._client.structured_writer(self._configs) + self._writer_created_at = time.time() + + # Record the next observation but leave the history buffer row open by + # passing `partial_step=True`. + self._writer.append( + data=self._make_step( + observation=timestep.observation, + start_of_episode=timestep.first()), + partial_step=True) + self._writer.flush(self._max_in_flight_items) + + def add(self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + extras: types.NestedArray = ()): + """Record an action and the following timestep.""" + + if not self._writer.step_is_open: + raise ValueError('adder.add_first must be called before adder.add.') + + # Add the timestep to the buffer. + has_extras = ( + len(extras) > 0 if isinstance(extras, Sized) # pylint: disable=g-explicit-length-test + else extras is not None) + + current_step = self._make_step( + action=action, + reward=next_timestep.reward, + discount=next_timestep.discount, + extras=extras if has_extras else self._none_step.extras) + self._writer.append(current_step) + + # Record the next observation and write. + self._writer.append( + data=self._make_step( + observation=next_timestep.observation, + start_of_episode=next_timestep.first()), + partial_step=True) + self._writer.flush(self._max_in_flight_items) + + if next_timestep.last(): + # Complete the row by appending zeros to remaining open fields. + # TODO(b/183945808): remove this when fields are no longer expected to be + # of equal length on the learner side. + dummy_step = tree.map_structure( + lambda x: None if x is None else np.zeros_like(x), current_step) + self._writer.append(dummy_step) + self.reset() + + +def create_step_spec( + environment_spec: specs.EnvironmentSpec, extras_spec: types.NestedSpec = () +) -> Step: + return Step( + *environment_spec, + start_of_episode=tf.TensorSpec([], tf.bool, 'start_of_episode'), + extras=extras_spec) + + +def _last_n(n: int, step_spec: Step) -> Trajectory: + """Constructs a sequence with the last n elements of all the Step fields.""" + return Trajectory(*tree.map_structure(lambda x: x[-n:], step_spec)) + + +def create_sequence_config( + step_spec: Step, + sequence_length: int, + period: int, + table: str = reverb_base.DEFAULT_PRIORITY_TABLE, + end_of_episode_behavior: EndBehavior = EndBehavior.TRUNCATE, + sequence_pattern: Callable[[int, Step], Trajectory] = _last_n, +) -> List[sw.Config]: + """Generates configs that produces the same behaviour as `SequenceAdder`. + + NOTE! ZERO_PAD is not supported as the same behaviour can be achieved by + writing with TRUNCATE and then adding padding in the dataset pipeline on the + learner. + + Args: + step_spec: The full structure of the data which will be appended to the + Reverb `StructuredWriter` in each step. Please use `create_step_spec` to + create `step_spec`. + sequence_length: The number of steps that each trajectory should span. + period: The period with which we add sequences. If less than + sequence_length, overlapping sequences are added. If equal to + sequence_length, sequences are exactly non-overlapping. + table: Name of the Reverb table to write items to. Defaults to the default + Acme table. + end_of_episode_behavior: Determines how sequences at the end of the episode + are handled (default `EndOfEpisodeBehavior.TRUNCATE`). See the docstring + of `EndOfEpisodeBehavior` for more information. + sequence_pattern: Transformation to obtain a sequence given the length + and the shape of the step. + + Returns: + A list of configs for `StructuredAdder` to produce the described behaviour. + + Raises: + ValueError: If sequence_length is <= 0. + NotImplementedError: If `end_of_episod_behavior` is `ZERO_PAD`. + """ + if sequence_length <= 0: + raise ValueError(f'sequence_length must be > 0 but got {sequence_length}.') + + if end_of_episode_behavior == EndBehavior.ZERO_PAD: + raise NotImplementedError( + 'Zero-padding is not supported. Please use TRUNCATE instead.') + + if end_of_episode_behavior == EndBehavior.CONTINUE: + raise NotImplementedError('Merging episodes is not supported.') + + def _sequence_pattern(n: int) -> sw.Pattern: + return sw.pattern_from_transform(step_spec, + lambda step: sequence_pattern(n, step)) + + # The base config is considered for all but the last step in the episode. No + # trajectories are created for the first `sequence_step-1` steps and then a + # new trajectory is inserted every `period` steps. + base_config = sw.create_config( + pattern=_sequence_pattern(sequence_length), + table=table, + conditions=[ + sw.Condition.step_index() >= sequence_length - 1, + sw.Condition.step_index() % period == (sequence_length - 1) % period, + ]) + + end_of_episode_configs = [] + if end_of_episode_behavior == EndBehavior.WRITE: + # Simply write a trajectory in exactly the same way as the base config. The + # only difference here is that we ALWAYS create a trajectory even if it + # doesn't align with the `period`. The exceptions to the rule are episodes + # that are shorter than `sequence_length` steps which are completely + # ignored. + config = sw.create_config( + pattern=_sequence_pattern(sequence_length), + table=table, + conditions=[ + sw.Condition.is_end_episode(), + sw.Condition.step_index() >= sequence_length - 1, + ]) + end_of_episode_configs.append(config) + elif end_of_episode_behavior == EndBehavior.TRUNCATE: + # The first trajectory is written at step index `sequence_length - 1` and + # then written every `period` step. This means that the + # `step_index % period` will always be equal to the below value everytime a + # trajectory is written. + target = (sequence_length - 1) % period + + # When the episode ends we still want to capture the steps that has been + # appended since the last item was created. We do this by creating a config + # for all `step_index % period`, except `target`, and condition these + # configs so that they only are triggered when `end_episode` is called. + for x in range(period): + # When the last step is aligned with the period of the inserts then no + # action is required as the item was already generated by `base_config`. + if x == target: + continue + + # If we were to pad the trajectory then we'll need to continue adding + # padding until `step_index % period` is equal to `target` again. We can + # exploit this relation by conditioning the config to only be applied for + # a single value of `step_index % period`. This constraint means that we + # can infer the number of padding steps required until the next write + # would have occurred if the episode didn't end. + # + # Now if we assume that the padding instead is added on the dataset (or + # the trajectory is simply truncated) then we can infer from the above + # that the number of real steps in this padded trajectory will be the + # difference between `sequence_length` and number of pad steps. + num_pad_steps = (target - x) % period + unpadded_length = sequence_length - num_pad_steps + + config = sw.create_config( + pattern=_sequence_pattern(unpadded_length), + table=table, + conditions=[ + sw.Condition.is_end_episode(), + sw.Condition.step_index() % period == x, + sw.Condition.step_index() >= sequence_length, + ]) + end_of_episode_configs.append(config) + + # The above configs will capture the "remainder" of any episode that is at + # least `sequence_length` steps long. However, if the entire episode is + # shorter than `sequence_length` then data might still be lost. We avoid + # this by simply creating `sequence_length-1` configs that capture the last + # `x` steps iff the entire episode is `x` steps long. + for x in range(1, sequence_length): + config = sw.create_config( + pattern=_sequence_pattern(x), + table=table, + conditions=[ + sw.Condition.is_end_episode(), + sw.Condition.step_index() == x - 1, + ]) + end_of_episode_configs.append(config) + else: + raise ValueError( + f'Unexpected `end_of_episod_behavior`: {end_of_episode_behavior}') + + return [base_config] + end_of_episode_configs + + +def create_n_step_transition_config( + step_spec: Step, + n_step: int, + table: str = reverb_base.DEFAULT_PRIORITY_TABLE) -> List[sw.Config]: + """Generates configs that replicates the behaviour of NStepTransitionAdder. + + Please see the docstring of NStepTransitionAdder for more details. + + NOTE! In contrast to NStepTransitionAdder, the trajectories written by the + `StructuredWriter` does not include the precomputed cumulative reward and + discounts. Instead the trajectory includes the raw rewards and discounts + required to comptute these values. + + Args: + step_spec: The full structure of the data which will be appended to the + Reverb `StructuredWriter` in each step. Please use `create_step_spec` to + create `step_spec`. + n_step: The "N" in N-step transition. See the class docstring for the + precise definition of what an N-step transition is. `n_step` must be at + least 1, in which case we use the standard one-step transition, i.e. (s_t, + a_t, r_t, d_t, s_t+1, e_t). + table: Name of the Reverb table to write items to. Defaults to the default + Acme table. + + Returns: + A list of configs for `StructuredAdder` to produce the described behaviour. + """ + + def _make_pattern(n: int): + ref_step = sw.create_reference_step(step_spec) + + get_first = lambda x: x[-(n + 1):-n] + get_all = lambda x: x[-(n + 1):-1] + get_first_and_last = lambda x: x[-(n + 1)::n] + + tmap = tree.map_structure + + # We use the exact same structure as we done when writing sequences except + # we trim the number of steps in each sub tree. This has the benefit that + # the postprocessing used to transform these items into N-step transition + # structures (cumulative rewards and discounts etc.) can be applied on + # full sequence items as well. The only difference being that the latter is + # more wasteful than the trimmed down version we write here. + return Trajectory( + observation=tmap(get_first_and_last, ref_step.observation), + action=tmap(get_first, ref_step.action), + reward=tmap(get_all, ref_step.reward), + discount=tmap(get_all, ref_step.discount), + start_of_episode=tmap(get_first, ref_step.start_of_episode), + extras=tmap(get_first, ref_step.extras)) + + # At the start of the episodes we'll add shorter transitions. + start_of_episode_configs = [] + for n in range(1, n_step): + config = sw.create_config( + pattern=_make_pattern(n), + table=table, + conditions=[ + sw.Condition.step_index() == n, + ], + ) + start_of_episode_configs.append(config) + + # During all other steps we'll add a full N-step transition. + base_config = sw.create_config(pattern=_make_pattern(n_step), table=table) + + # When the episode ends we'll add shorter transitions. + end_of_episode_configs = [] + for n in range(n_step - 1, 0, -1): + config = sw.create_config( + pattern=_make_pattern(n), + table=table, + conditions=[ + sw.Condition.is_end_episode(), + # If the entire episode is shorter than n_step then the episode + # start configs will already create an item that covers all the + # steps so we add this filter here to avoid adding it again. + sw.Condition.step_index() != n, + ], + ) + end_of_episode_configs.append(config) + + return start_of_episode_configs + [base_config] + end_of_episode_configs diff --git a/acme/acme/adders/reverb/structured_test.py b/acme/acme/adders/reverb/structured_test.py new file mode 100644 index 00000000..761536e1 --- /dev/null +++ b/acme/acme/adders/reverb/structured_test.py @@ -0,0 +1,186 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for structured.""" + +from typing import Sequence + +from acme import types +from acme.adders.reverb import sequence as adders +from acme.adders.reverb import structured +from acme.adders.reverb import test_cases +from acme.adders.reverb import test_utils +from acme.utils import tree_utils +import dm_env +import numpy as np +from reverb import structured_writer as sw +import tree + +from absl.testing import absltest +from absl.testing import parameterized + + +class StructuredAdderTest(test_utils.AdderTestMixin, parameterized.TestCase): + + @parameterized.named_parameters(*test_cases.BASE_TEST_CASES_FOR_SEQUENCE_ADDER + ) + def test_sequence_adder(self, + sequence_length: int, + period: int, + first, + steps, + expected_sequences, + end_behavior: adders.EndBehavior, + repeat_episode_times: int = 1): + + env_spec, extras_spec = test_utils.get_specs(steps[0]) + step_spec = structured.create_step_spec(env_spec, extras_spec) + + should_pad_trajectory = end_behavior == adders.EndBehavior.ZERO_PAD + + def _maybe_zero_pad(flat_trajectory): + trajectory = tree.unflatten_as(step_spec, flat_trajectory) + + if not should_pad_trajectory: + return trajectory + + padding_length = sequence_length - flat_trajectory[0].shape[0] + if padding_length == 0: + return trajectory + + padding = tree.map_structure( + lambda x: np.zeros([padding_length, *x.shape[1:]], x.dtype), + trajectory) + + return tree.map_structure(lambda *x: np.concatenate(x), trajectory, + padding) + + # The StructuredAdder does not support adding padding steps as we assume + # that the padding will be added on the learner side. + if end_behavior == adders.EndBehavior.ZERO_PAD: + end_behavior = adders.EndBehavior.TRUNCATE + + configs = structured.create_sequence_config( + step_spec=step_spec, + sequence_length=sequence_length, + period=period, + end_of_episode_behavior=end_behavior) + adder = structured.StructuredAdder( + client=self.client, + max_in_flight_items=0, + configs=configs, + step_spec=step_spec) + + super().run_test_adder( + adder=adder, + first=first, + steps=steps, + expected_items=expected_sequences, + repeat_episode_times=repeat_episode_times, + end_behavior=end_behavior, + item_transform=_maybe_zero_pad, + signature=sw.infer_signature(configs, step_spec)) + + @parameterized.named_parameters(*test_cases.TEST_CASES_FOR_TRANSITION_ADDER) + def test_transition_adder(self, n_step: int, additional_discount: float, + first: dm_env.TimeStep, + steps: Sequence[dm_env.TimeStep], + expected_transitions: Sequence[types.Transition]): + + env_spec, extras_spec = test_utils.get_specs(steps[0]) + step_spec = structured.create_step_spec(env_spec, extras_spec) + + def _as_n_step_transition(flat_trajectory): + trajectory = tree.unflatten_as(step_spec, flat_trajectory) + + rewards, discount = _compute_cumulative_quantities( + rewards=trajectory.reward, + discounts=trajectory.discount, + additional_discount=additional_discount, + n_step=tree.flatten(trajectory.reward)[0].shape[0]) + + tmap = tree.map_structure + return types.Transition( + observation=tmap(lambda x: x[0], trajectory.observation), + action=tmap(lambda x: x[0], trajectory.action), + reward=rewards, + discount=discount, + next_observation=tmap(lambda x: x[-1], trajectory.observation), + extras=tmap(lambda x: x[0], trajectory.extras)) + + configs = structured.create_n_step_transition_config( + step_spec=step_spec, n_step=n_step) + + adder = structured.StructuredAdder( + client=self.client, + max_in_flight_items=0, + configs=configs, + step_spec=step_spec) + + super().run_test_adder( + adder=adder, + first=first, + steps=steps, + expected_items=expected_transitions, + stack_sequence_fields=False, + item_transform=_as_n_step_transition, + signature=sw.infer_signature(configs, step_spec)) + + +def _compute_cumulative_quantities(rewards: types.NestedArray, + discounts: types.NestedArray, + additional_discount: float, n_step: int): + """Stolen from TransitionAdder.""" + + # Give the same tree structure to the n-step return accumulator, + # n-step discount accumulator, and self.discount, so that they can be + # iterated in parallel using tree.map_structure. + rewards, discounts, self_discount = tree_utils.broadcast_structures( + rewards, discounts, additional_discount) + flat_rewards = tree.flatten(rewards) + flat_discounts = tree.flatten(discounts) + flat_self_discount = tree.flatten(self_discount) + + # Copy total_discount as it is otherwise read-only. + total_discount = [np.copy(a[0]) for a in flat_discounts] + + # Broadcast n_step_return to have the broadcasted shape of + # reward * discount. + n_step_return = [ + np.copy(np.broadcast_to(r[0], + np.broadcast(r[0], d).shape)) + for r, d in zip(flat_rewards, total_discount) + ] + + # NOTE: total_discount will have one less self_discount applied to it than + # the value of self._n_step. This is so that when the learner/update uses + # an additional discount we don't apply it twice. Inside the following loop + # we will apply this right before summing up the n_step_return. + for i in range(1, n_step): + for nsr, td, r, d, sd in zip(n_step_return, total_discount, flat_rewards, + flat_discounts, flat_self_discount): + # Equivalent to: `total_discount *= self._discount`. + td *= sd + # Equivalent to: `n_step_return += reward[i] * total_discount`. + nsr += r[i] * td + # Equivalent to: `total_discount *= discount[i]`. + td *= d[i] + + n_step_return = tree.unflatten_as(rewards, n_step_return) + total_discount = tree.unflatten_as(rewards, total_discount) + return n_step_return, total_discount + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/adders/reverb/test_cases.py b/acme/acme/adders/reverb/test_cases.py new file mode 100644 index 00000000..fac0ffb3 --- /dev/null +++ b/acme/acme/adders/reverb/test_cases.py @@ -0,0 +1,847 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test cases used by multiple test files.""" + +from acme import types +from acme.adders.reverb import sequence as sequence_adder +import dm_env +import numpy as np + +# Define the main set of test cases; these are given as parameterized tests to +# the test_adder method and describe a trajectory to add to replay and the +# expected transitions that should result from this trajectory. The expected +# transitions are of the form: (observation, action, reward, discount, +# next_observation, extras). +TEST_CASES_FOR_TRANSITION_ADDER = [ + dict( + testcase_name='OneStepFinalReward', + n_step=1, + additional_discount=1.0, + first=dm_env.restart(1), + steps=( + (0, dm_env.transition(reward=0.0, observation=2)), + (0, dm_env.transition(reward=0.0, observation=3)), + (0, dm_env.termination(reward=1.0, observation=4)), + ), + expected_transitions=( + types.Transition(1, 0, 0.0, 1.0, 2), + types.Transition(2, 0, 0.0, 1.0, 3), + types.Transition(3, 0, 1.0, 0.0, 4), + )), + dict( + testcase_name='OneStepDict', + n_step=1, + additional_discount=1.0, + first=dm_env.restart({'foo': 1}), + steps=( + (0, dm_env.transition(reward=0.0, observation={'foo': 2})), + (0, dm_env.transition(reward=0.0, observation={'foo': 3})), + (0, dm_env.termination(reward=1.0, observation={'foo': 4})), + ), + expected_transitions=( + types.Transition({'foo': 1}, 0, 0.0, 1.0, {'foo': 2}), + types.Transition({'foo': 2}, 0, 0.0, 1.0, {'foo': 3}), + types.Transition({'foo': 3}, 0, 1.0, 0.0, {'foo': 4}), + )), + dict( + testcase_name='OneStepExtras', + n_step=1, + additional_discount=1.0, + first=dm_env.restart(1), + steps=( + ( + 0, + dm_env.transition(reward=0.0, observation=2), + { + 'state': 0 + }, + ), + ( + 0, + dm_env.transition(reward=0.0, observation=3), + { + 'state': 1 + }, + ), + ( + 0, + dm_env.termination(reward=1.0, observation=4), + { + 'state': 2 + }, + ), + ), + expected_transitions=( + types.Transition(1, 0, 0.0, 1.0, 2, {'state': 0}), + types.Transition(2, 0, 0.0, 1.0, 3, {'state': 1}), + types.Transition(3, 0, 1.0, 0.0, 4, {'state': 2}), + )), + dict( + testcase_name='OneStepExtrasZeroes', + n_step=1, + additional_discount=1.0, + first=dm_env.restart(1), + steps=( + ( + 0, + dm_env.transition(reward=0.0, observation=2), + np.zeros(1), + ), + ( + 0, + dm_env.transition(reward=0.0, observation=3), + np.zeros(1), + ), + ( + 0, + dm_env.termination(reward=1.0, observation=4), + np.zeros(1), + ), + ), + expected_transitions=( + types.Transition(1, 0, 0.0, 1.0, 2, np.zeros(1)), + types.Transition(2, 0, 0.0, 1.0, 3, np.zeros(1)), + types.Transition(3, 0, 1.0, 0.0, 4, np.zeros(1)), + )), + dict( + testcase_name='TwoStep', + n_step=2, + additional_discount=1.0, + first=dm_env.restart(1), + steps=( + (0, dm_env.transition(reward=1.0, observation=2, discount=0.5)), + (0, dm_env.transition(reward=1.0, observation=3, discount=0.5)), + (0, dm_env.termination(reward=1.0, observation=4)), + ), + expected_transitions=( + types.Transition(1, 0, 1.0, 0.50, 2), + types.Transition(1, 0, 1.5, 0.25, 3), + types.Transition(2, 0, 1.5, 0.00, 4), + types.Transition(3, 0, 1.0, 0.00, 4), + )), + dict( + testcase_name='TwoStepStructuredReward', + n_step=2, + additional_discount=1.0, + first=dm_env.restart(1), + steps=( + (0, + dm_env.transition(reward=(1.0, 2.0), observation=2, discount=0.5)), + (0, + dm_env.transition(reward=(1.0, 2.0), observation=3, discount=0.5)), + (0, dm_env.termination(reward=(1.0, 2.0), observation=4)), + ), + expected_transitions=( + types.Transition(1, 0, (1.0, 2.0), (0.50, 0.50), 2), + types.Transition(1, 0, (1.5, 3.0), (0.25, 0.25), 3), + types.Transition(2, 0, (1.5, 3.0), (0.00, 0.00), 4), + types.Transition(3, 0, (1.0, 2.0), (0.00, 0.00), 4), + )), + dict( + testcase_name='TwoStepNDArrayReward', + n_step=2, + additional_discount=1.0, + first=dm_env.restart(1), + steps=( + (0, + dm_env.transition( + reward=np.array((1.0, 2.0)), observation=2, discount=0.5)), + (0, + dm_env.transition( + reward=np.array((1.0, 2.0)), observation=3, discount=0.5)), + (0, dm_env.termination(reward=np.array((1.0, 2.0)), observation=4)), + ), + expected_transitions=( + types.Transition(1, 0, np.array((1.0, 2.0)), np.array((0.50, 0.50)), + 2), + types.Transition(1, 0, np.array((1.5, 3.0)), np.array((0.25, 0.25)), + 3), + types.Transition(2, 0, np.array((1.5, 3.0)), np.array((0.00, 0.00)), + 4), + types.Transition(3, 0, np.array((1.0, 2.0)), np.array((0.00, 0.00)), + 4), + )), + dict( + testcase_name='TwoStepStructuredDiscount', + n_step=2, + additional_discount=1.0, + first=dm_env.restart(1), + steps=( + (0, + dm_env.transition( + reward=1.0, observation=2, discount={ + 'a': 0.5, + 'b': 0.1 + })), + (0, + dm_env.transition( + reward=1.0, observation=3, discount={ + 'a': 0.5, + 'b': 0.1 + })), + (0, dm_env.termination(reward=1.0, + observation=4)._replace(discount={ + 'a': 0.0, + 'b': 0.0 + })), + ), + expected_transitions=( + types.Transition(1, 0, { + 'a': 1.0, + 'b': 1.0 + }, { + 'a': 0.50, + 'b': 0.10 + }, 2), + types.Transition(1, 0, { + 'a': 1.5, + 'b': 1.1 + }, { + 'a': 0.25, + 'b': 0.01 + }, 3), + types.Transition(2, 0, { + 'a': 1.5, + 'b': 1.1 + }, { + 'a': 0.00, + 'b': 0.00 + }, 4), + types.Transition(3, 0, { + 'a': 1.0, + 'b': 1.0 + }, { + 'a': 0.00, + 'b': 0.00 + }, 4), + )), + dict( + testcase_name='TwoStepNDArrayDiscount', + n_step=2, + additional_discount=1.0, + first=dm_env.restart(1), + steps=( + (0, + dm_env.transition( + reward=1.0, observation=2, discount=np.array((0.5, 0.1)))), + (0, + dm_env.transition( + reward=1.0, observation=3, discount=np.array((0.5, 0.1)))), + (0, dm_env.termination( + reward=1.0, + observation=4)._replace(discount=np.array((0.0, 0.0)))), + ), + expected_transitions=( + types.Transition(1, 0, np.array((1.0, 1.0)), np.array((0.50, 0.10)), + 2), + types.Transition(1, 0, np.array((1.5, 1.1)), np.array((0.25, 0.01)), + 3), + types.Transition(2, 0, np.array((1.5, 1.1)), np.array((0.00, 0.00)), + 4), + types.Transition(3, 0, np.array((1.0, 1.0)), np.array((0.00, 0.00)), + 4), + )), + dict( + testcase_name='TwoStepBroadcastedNDArrays', + n_step=2, + additional_discount=1.0, + first=dm_env.restart(1), + steps=( + (0, + dm_env.transition( + reward=np.array([[1.0, 2.0]]), + observation=2, + discount=np.array([[0.5], [0.1]]))), + (0, + dm_env.transition( + reward=np.array([[1.0, 2.0]]), + observation=3, + discount=np.array([[0.5], [0.1]]))), + (0, dm_env.termination( + reward=np.array([[1.0, 2.0]]), + observation=4)._replace(discount=np.array([[0.0], [0.0]]))), + ), + expected_transitions=( + types.Transition(1, 0, np.array([[1.0, 2.0], [1.0, 2.0]]), + np.array([[0.50], [0.10]]), 2), + types.Transition(1, 0, np.array([[1.5, 3.0], [1.1, 2.2]]), + np.array([[0.25], [0.01]]), 3), + types.Transition(2, 0, np.array([[1.5, 3.0], [1.1, 2.2]]), + np.array([[0.00], [0.00]]), 4), + types.Transition(3, 0, np.array([[1.0, 2.0], [1.0, 2.0]]), + np.array([[0.00], [0.00]]), 4), + )), + dict( + testcase_name='TwoStepStructuredBroadcastedNDArrays', + n_step=2, + additional_discount=1.0, + first=dm_env.restart(1), + steps=( + (0, + dm_env.transition( + reward={'a': np.array([[1.0, 2.0]])}, + observation=2, + discount=np.array([[0.5], [0.1]]))), + (0, + dm_env.transition( + reward={'a': np.array([[1.0, 2.0]])}, + observation=3, + discount=np.array([[0.5], [0.1]]))), + (0, + dm_env.termination( + reward={ + 'a': np.array([[1.0, 2.0]]) + }, observation=4)._replace(discount=np.array([[0.0], [0.0]]))), + ), + expected_transitions=( + types.Transition(1, 0, {'a': np.array([[1.0, 2.0], [1.0, 2.0]])}, + {'a': np.array([[0.50], [0.10]])}, 2), + types.Transition(1, 0, {'a': np.array([[1.5, 3.0], [1.1, 2.2]])}, + {'a': np.array([[0.25], [0.01]])}, 3), + types.Transition(2, 0, {'a': np.array([[1.5, 3.0], [1.1, 2.2]])}, + {'a': np.array([[0.00], [0.00]])}, 4), + types.Transition(3, 0, {'a': np.array([[1.0, 2.0], [1.0, 2.0]])}, + {'a': np.array([[0.00], [0.00]])}, 4), + )), + dict( + testcase_name='TwoStepWithExtras', + n_step=2, + additional_discount=1.0, + first=dm_env.restart(1), + steps=( + ( + 0, + dm_env.transition(reward=1.0, observation=2, discount=0.5), + { + 'state': 0 + }, + ), + ( + 0, + dm_env.transition(reward=1.0, observation=3, discount=0.5), + { + 'state': 1 + }, + ), + ( + 0, + dm_env.termination(reward=1.0, observation=4), + { + 'state': 2 + }, + ), + ), + expected_transitions=( + types.Transition(1, 0, 1.0, 0.50, 2, {'state': 0}), + types.Transition(1, 0, 1.5, 0.25, 3, {'state': 0}), + types.Transition(2, 0, 1.5, 0.00, 4, {'state': 1}), + types.Transition(3, 0, 1.0, 0.00, 4, {'state': 2}), + )), + dict( + testcase_name='ThreeStepDiscounted', + n_step=3, + additional_discount=0.4, + first=dm_env.restart(1), + steps=( + (0, dm_env.transition(reward=1.0, observation=2, discount=0.5)), + (0, dm_env.transition(reward=1.0, observation=3, discount=0.5)), + (0, dm_env.termination(reward=1.0, observation=4)), + ), + expected_transitions=( + types.Transition(1, 0, 1.00, 0.5, 2), + types.Transition(1, 0, 1.20, 0.1, 3), + types.Transition(1, 0, 1.24, 0.0, 4), + types.Transition(2, 0, 1.20, 0.0, 4), + types.Transition(3, 0, 1.00, 0.0, 4), + )), + dict( + testcase_name='ThreeStepVaryingReward', + n_step=3, + additional_discount=0.5, + first=dm_env.restart(1), + steps=( + (0, dm_env.transition(reward=2.0, observation=2)), + (0, dm_env.transition(reward=3.0, observation=3)), + (0, dm_env.transition(reward=5.0, observation=4)), + (0, dm_env.termination(reward=7.0, observation=5)), + ), + expected_transitions=( + types.Transition(1, 0, 2, 1.00, 2), + types.Transition(1, 0, 2 + 0.5 * 3, 0.50, 3), + types.Transition(1, 0, 2 + 0.5 * 3 + 0.25 * 5, 0.25, 4), + types.Transition(2, 0, 3 + 0.5 * 5 + 0.25 * 7, 0.00, 5), + types.Transition(3, 0, 5 + 0.5 * 7, 0.00, 5), + types.Transition(4, 0, 7, 0.00, 5), + )), + dict( + testcase_name='SingleTransitionEpisode', + n_step=4, + additional_discount=1.0, + first=dm_env.restart(1), + steps=((0, dm_env.termination(reward=1.0, observation=2)),), + expected_transitions=(types.Transition(1, 0, 1.00, 0.0, 2),)), + dict( + testcase_name='EpisodeShorterThanN', + n_step=4, + additional_discount=1.0, + first=dm_env.restart(1), + steps=( + (0, dm_env.transition(reward=1.0, observation=2)), + (0, dm_env.termination(reward=1.0, observation=3)), + ), + expected_transitions=( + types.Transition(1, 0, 1.00, 1.0, 2), + types.Transition(1, 0, 2.00, 0.0, 3), + types.Transition(2, 0, 1.00, 0.0, 3), + )), + dict( + testcase_name='EpisodeEqualToN', + n_step=3, + additional_discount=1.0, + first=dm_env.restart(1), + steps=( + (0, dm_env.transition(reward=1.0, observation=2)), + (0, dm_env.termination(reward=1.0, observation=3)), + ), + expected_transitions=( + types.Transition(1, 0, 1.00, 1.0, 2), + types.Transition(1, 0, 2.00, 0.0, 3), + types.Transition(2, 0, 1.00, 0.0, 3), + )), +] + +BASE_TEST_CASES_FOR_SEQUENCE_ADDER = [ + dict( + testcase_name='PeriodOne', + sequence_length=3, + period=1, + first=dm_env.restart(1), + steps=( + (0, dm_env.transition(reward=2.0, observation=2)), + (0, dm_env.transition(reward=3.0, observation=3)), + (0, dm_env.transition(reward=5.0, observation=4)), + (0, dm_env.termination(reward=7.0, observation=5)), + ), + expected_sequences=( + # (observation, action, reward, discount, start_of_episode, extra) + [ + (1, 0, 2.0, 1.0, True, ()), + (2, 0, 3.0, 1.0, False, ()), + (3, 0, 5.0, 1.0, False, ()), + ], + [ + (2, 0, 3.0, 1.0, False, ()), + (3, 0, 5.0, 1.0, False, ()), + (4, 0, 7.0, 0.0, False, ()), + ], + [ + (3, 0, 5.0, 1.0, False, ()), + (4, 0, 7.0, 0.0, False, ()), + (5, 0, 0.0, 0.0, False, ()), + ], + ), + end_behavior=sequence_adder.EndBehavior.ZERO_PAD, + ), + dict( + testcase_name='PeriodTwo', + sequence_length=3, + period=2, + first=dm_env.restart(1), + steps=( + (0, dm_env.transition(reward=2.0, observation=2)), + (0, dm_env.transition(reward=3.0, observation=3)), + (0, dm_env.transition(reward=5.0, observation=4)), + (0, dm_env.termination(reward=7.0, observation=5)), + ), + expected_sequences=( + # (observation, action, reward, discount, start_of_episode, extra) + [ + (1, 0, 2.0, 1.0, True, ()), + (2, 0, 3.0, 1.0, False, ()), + (3, 0, 5.0, 1.0, False, ()), + ], + [ + (3, 0, 5.0, 1.0, False, ()), + (4, 0, 7.0, 0.0, False, ()), + (5, 0, 0.0, 0.0, False, ()), + ], + ), + end_behavior=sequence_adder.EndBehavior.ZERO_PAD, + ), + dict( + testcase_name='EarlyTerminationPeriodOne', + sequence_length=3, + period=1, + first=dm_env.restart(1), + steps=( + (0, dm_env.transition(reward=2.0, observation=2)), + (0, dm_env.termination(reward=3.0, observation=3)), + ), + expected_sequences=( + # (observation, action, reward, discount, start_of_episode, extra) + [ + (1, 0, 2.0, 1.0, True, ()), + (2, 0, 3.0, 0.0, False, ()), + (3, 0, 0.0, 0.0, False, ()), + ],), + end_behavior=sequence_adder.EndBehavior.ZERO_PAD, + ), + dict( + testcase_name='EarlyTerminationPeriodTwo', + sequence_length=3, + period=2, + first=dm_env.restart(1), + steps=( + (0, dm_env.transition(reward=2.0, observation=2)), + (0, dm_env.termination(reward=3.0, observation=3)), + ), + expected_sequences=( + # (observation, action, reward, discount, start_of_episode, extra) + [ + (1, 0, 2.0, 1.0, True, ()), + (2, 0, 3.0, 0.0, False, ()), + (3, 0, 0.0, 0.0, False, ()), + ],), + end_behavior=sequence_adder.EndBehavior.ZERO_PAD, + ), + dict( + testcase_name='EarlyTerminationPaddingPeriodOne', + sequence_length=4, + period=1, + first=dm_env.restart(1), + steps=( + (0, dm_env.transition(reward=2.0, observation=2)), + (0, dm_env.termination(reward=3.0, observation=3)), + ), + expected_sequences=( + # (observation, action, reward, discount, start_of_episode, extra) + [ + (1, 0, 2.0, 1.0, True, ()), + (2, 0, 3.0, 0.0, False, ()), + (3, 0, 0.0, 0.0, False, ()), + (0, 0, 0.0, 0.0, False, ()), + ],), + end_behavior=sequence_adder.EndBehavior.ZERO_PAD, + ), + dict( + testcase_name='EarlyTerminationPaddingPeriodTwo', + sequence_length=4, + period=2, + first=dm_env.restart(1), + steps=( + (0, dm_env.transition(reward=2.0, observation=2)), + (0, dm_env.termination(reward=3.0, observation=3)), + ), + expected_sequences=( + # (observation, action, reward, discount, start_of_episode, extra) + [ + (1, 0, 2.0, 1.0, True, ()), + (2, 0, 3.0, 0.0, False, ()), + (3, 0, 0.0, 0.0, False, ()), + (0, 0, 0.0, 0.0, False, ()), + ],), + end_behavior=sequence_adder.EndBehavior.ZERO_PAD, + ), + dict( + testcase_name='EarlyTerminationNoPadding', + sequence_length=4, + period=1, + first=dm_env.restart(1), + steps=( + (0, dm_env.transition(reward=2.0, observation=2)), + (0, dm_env.termination(reward=3.0, observation=3)), + ), + expected_sequences=( + # (observation, action, reward, discount, start_of_episode, extra) + [ + (1, 0, 2.0, 1.0, True, ()), + (2, 0, 3.0, 0.0, False, ()), + (3, 0, 0.0, 0.0, False, ()), + ],), + end_behavior=sequence_adder.EndBehavior.TRUNCATE, + ), + dict( + testcase_name='LongEpisodePadding', + sequence_length=3, + period=3, + first=dm_env.restart(1), + steps=( + (0, dm_env.transition(reward=2.0, observation=2)), + (0, dm_env.transition(reward=3.0, observation=3)), + (0, dm_env.transition(reward=5.0, observation=4)), + (0, dm_env.transition(reward=7.0, observation=5)), + (0, dm_env.transition(reward=9.0, observation=6)), + (0, dm_env.transition(reward=11.0, observation=7)), + (0, dm_env.termination(reward=13.0, observation=8)), + ), + expected_sequences=( + # (observation, action, reward, discount, start_of_episode, extra) + [ + (1, 0, 2.0, 1.0, True, ()), + (2, 0, 3.0, 1.0, False, ()), + (3, 0, 5.0, 1.0, False, ()), + ], + [ + (4, 0, 7.0, 1.0, False, ()), + (5, 0, 9.0, 1.0, False, ()), + (6, 0, 11.0, 1.0, False, ()), + ], + [ + (7, 0, 13.0, 0.0, False, ()), + (8, 0, 0.0, 0.0, False, ()), + (0, 0, 0.0, 0.0, False, ()), + ], + ), + end_behavior=sequence_adder.EndBehavior.ZERO_PAD, + ), + dict( + testcase_name='LongEpisodeNoPadding', + sequence_length=3, + period=3, + first=dm_env.restart(1), + steps=( + (0, dm_env.transition(reward=2.0, observation=2)), + (0, dm_env.transition(reward=3.0, observation=3)), + (0, dm_env.transition(reward=5.0, observation=4)), + (0, dm_env.transition(reward=7.0, observation=5)), + (0, dm_env.transition(reward=9.0, observation=6)), + (0, dm_env.transition(reward=11.0, observation=7)), + (0, dm_env.termination(reward=13.0, observation=8)), + ), + expected_sequences=( + # (observation, action, reward, discount, start_of_episode, extra) + [ + (1, 0, 2.0, 1.0, True, ()), + (2, 0, 3.0, 1.0, False, ()), + (3, 0, 5.0, 1.0, False, ()), + ], + [ + (4, 0, 7.0, 1.0, False, ()), + (5, 0, 9.0, 1.0, False, ()), + (6, 0, 11.0, 1.0, False, ()), + ], + [ + (7, 0, 13.0, 0.0, False, ()), + (8, 0, 0.0, 0.0, False, ()), + ], + ), + end_behavior=sequence_adder.EndBehavior.TRUNCATE, + ), + dict( + testcase_name='EndBehavior_WRITE', + sequence_length=3, + period=2, + first=dm_env.restart(1), + steps=( + (0, dm_env.transition(reward=2.0, observation=2)), + (0, dm_env.transition(reward=3.0, observation=3)), + (0, dm_env.transition(reward=5.0, observation=4)), + (0, dm_env.transition(reward=7.0, observation=5)), + (0, dm_env.termination(reward=8.0, observation=6)), + ), + expected_sequences=( + # (observation, action, reward, discount, start_of_episode, extra) + [ + (1, 0, 2.0, 1.0, True, ()), + (2, 0, 3.0, 1.0, False, ()), + (3, 0, 5.0, 1.0, False, ()), + ], + [ + (3, 0, 5.0, 1.0, False, ()), + (4, 0, 7.0, 1.0, False, ()), + (5, 0, 8.0, 0.0, False, ()), + ], + [ + (4, 0, 7.0, 1.0, False, ()), + (5, 0, 8.0, 0.0, False, ()), + (6, 0, 0.0, 0.0, False, ()), + ], + ), + end_behavior=sequence_adder.EndBehavior.WRITE, + ), +] + +TEST_CASES_FOR_SEQUENCE_ADDER = BASE_TEST_CASES_FOR_SEQUENCE_ADDER + [ + dict( + testcase_name='NonBreakingSequenceOnEpisodeReset', + sequence_length=3, + period=2, + first=dm_env.restart(1), + steps=( + (0, dm_env.transition(reward=2.0, observation=2)), + (0, dm_env.transition(reward=3.0, observation=3)), + (0, dm_env.transition(reward=5.0, observation=4)), + (0, dm_env.transition(reward=7.0, observation=5)), + (0, dm_env.transition(reward=9.0, observation=6)), + (0, dm_env.transition(reward=11.0, observation=7)), + (0, dm_env.termination(reward=13.0, observation=8)), + ), + expected_sequences=( + # (observation, action, reward, discount, start_of_episode, extra) + [ + (1, 0, 2.0, 1.0, True, ()), + (2, 0, 3.0, 1.0, False, ()), + (3, 0, 5.0, 1.0, False, ()), + ], + [ + (3, 0, 5.0, 1.0, False, ()), + (4, 0, 7.0, 1.0, False, ()), + (5, 0, 9.0, 1.0, False, ()), + ], + [ + (5, 0, 9.0, 1.0, False, ()), + (6, 0, 11.0, 1.0, False, ()), + (7, 0, 13.0, 0.0, False, ()), + ], + ), + end_behavior=sequence_adder.EndBehavior.CONTINUE, + repeat_episode_times=1), + dict( + testcase_name='NonBreakingSequenceMultipleTerminatedEpisodes', + sequence_length=3, + period=2, + first=dm_env.restart(1), + steps=( + (0, dm_env.transition(reward=2.0, observation=2)), + (0, dm_env.transition(reward=3.0, observation=3)), + (0, dm_env.transition(reward=5.0, observation=4)), + (0, dm_env.transition(reward=7.0, observation=5)), + (0, dm_env.transition(reward=9.0, observation=6)), + (0, dm_env.termination(reward=13.0, observation=7)), + ), + expected_sequences=( + # (observation, action, reward, discount, start_of_episode, extra) + [ + (1, 0, 2.0, 1.0, True, ()), + (2, 0, 3.0, 1.0, False, ()), + (3, 0, 5.0, 1.0, False, ()), + ], + [ + (3, 0, 5.0, 1.0, False, ()), + (4, 0, 7.0, 1.0, False, ()), + (5, 0, 9.0, 1.0, False, ()), + ], + [ + (5, 0, 9.0, 1.0, False, ()), + (6, 0, 13.0, 0.0, False, ()), + (7, 0, 0.0, 0.0, False, ()), + ], + [ + (7, 0, 0.0, 0.0, False, ()), + (1, 0, 2.0, 1.0, True, ()), + (2, 0, 3.0, 1.0, False, ()), + ], + [ + (2, 0, 3.0, 1.0, False, ()), + (3, 0, 5.0, 1.0, False, ()), + (4, 0, 7.0, 1.0, False, ()), + ], + [ + (4, 0, 7.0, 1.0, False, ()), + (5, 0, 9.0, 1.0, False, ()), + (6, 0, 13.0, 0.0, False, ()), + ], + [ + (6, 0, 13.0, 0.0, False, ()), + (7, 0, 0.0, 0.0, False, ()), + (1, 0, 2.0, 1.0, True, ()), + ], + [ + (1, 0, 2.0, 1.0, True, ()), + (2, 0, 3.0, 1.0, False, ()), + (3, 0, 5.0, 1.0, False, ()), + ], + [ + (3, 0, 5.0, 1.0, False, ()), + (4, 0, 7.0, 1.0, False, ()), + (5, 0, 9.0, 1.0, False, ()), + ], + [ + (5, 0, 9.0, 1.0, False, ()), + (6, 0, 13.0, 0.0, False, ()), + (7, 0, 0.0, 0.0, False, ()), + ], + ), + end_behavior=sequence_adder.EndBehavior.CONTINUE, + repeat_episode_times=3), + dict( + testcase_name='NonBreakingSequenceMultipleTruncatedEpisodes', + sequence_length=3, + period=2, + first=dm_env.restart(1), + steps=( + (0, dm_env.transition(reward=2.0, observation=2)), + (0, dm_env.transition(reward=3.0, observation=3)), + (0, dm_env.transition(reward=5.0, observation=4)), + (0, dm_env.transition(reward=7.0, observation=5)), + (0, dm_env.transition(reward=9.0, observation=6)), + (0, dm_env.truncation(reward=13.0, observation=7)), + ), + expected_sequences=( + # (observation, action, reward, discount, start_of_episode, extra) + [ + (1, 0, 2.0, 1.0, True, ()), + (2, 0, 3.0, 1.0, False, ()), + (3, 0, 5.0, 1.0, False, ()), + ], + [ + (3, 0, 5.0, 1.0, False, ()), + (4, 0, 7.0, 1.0, False, ()), + (5, 0, 9.0, 1.0, False, ()), + ], + [ + (5, 0, 9.0, 1.0, False, ()), + (6, 0, 13.0, 1.0, False, ()), + (7, 0, 0.0, 0.0, False, ()), + ], + [ + (7, 0, 0.0, 0.0, False, ()), + (1, 0, 2.0, 1.0, True, ()), + (2, 0, 3.0, 1.0, False, ()), + ], + [ + (2, 0, 3.0, 1.0, False, ()), + (3, 0, 5.0, 1.0, False, ()), + (4, 0, 7.0, 1.0, False, ()), + ], + [ + (4, 0, 7.0, 1.0, False, ()), + (5, 0, 9.0, 1.0, False, ()), + (6, 0, 13.0, 1.0, False, ()), + ], + [ + (6, 0, 13.0, 1.0, False, ()), + (7, 0, 0.0, 0.0, False, ()), + (1, 0, 2.0, 1.0, True, ()), + ], + [ + (1, 0, 2.0, 1.0, True, ()), + (2, 0, 3.0, 1.0, False, ()), + (3, 0, 5.0, 1.0, False, ()), + ], + [ + (3, 0, 5.0, 1.0, False, ()), + (4, 0, 7.0, 1.0, False, ()), + (5, 0, 9.0, 1.0, False, ()), + ], + [ + (5, 0, 9.0, 1.0, False, ()), + (6, 0, 13.0, 1.0, False, ()), + (7, 0, 0.0, 0.0, False, ()), + ], + ), + end_behavior=sequence_adder.EndBehavior.CONTINUE, + repeat_episode_times=3), +] diff --git a/acme/acme/adders/reverb/test_utils.py b/acme/acme/adders/reverb/test_utils.py new file mode 100644 index 00000000..6ed9a9ac --- /dev/null +++ b/acme/acme/adders/reverb/test_utils.py @@ -0,0 +1,233 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for testing Reverb adders.""" + +from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, Union + +from acme import specs +from acme import types +from acme.adders import base as adders_base +from acme.adders import reverb as adders +from acme.utils import tree_utils +import dm_env +import numpy as np +import reverb +import tensorflow as tf +import tree + +from absl.testing import absltest + +StepWithExtra = Tuple[Any, dm_env.TimeStep, Any] +StepWithoutExtra = Tuple[Any, dm_env.TimeStep] +Step = TypeVar('Step', StepWithExtra, StepWithoutExtra) + + +def make_trajectory(observations): + """Make a simple trajectory from a sequence of observations. + + Arguments: + observations: a sequence of observations. + + Returns: + a tuple (first, steps) where first contains the initial dm_env.TimeStep + object and steps contains a list of (action, step) tuples. The length of + steps is given by episode_length. + """ + first = dm_env.restart(observations[0]) + middle = [(0, dm_env.transition(reward=0.0, observation=observation)) + for observation in observations[1:-1]] + last = (0, dm_env.termination(reward=0.0, observation=observations[-1])) + return first, middle + [last] + + +def make_sequence(observations): + """Create a sequence of timesteps of the form `first, [second, ..., last]`.""" + first, steps = make_trajectory(observations) + observation = first.observation + sequence = [] + start_of_episode = True + for action, timestep in steps: + extras = () + sequence.append((observation, action, timestep.reward, timestep.discount, + start_of_episode, extras)) + observation = timestep.observation + start_of_episode = False + sequence.append((observation, 0, 0.0, 0.0, False, ())) + return sequence + + +def _numeric_to_spec(x: Union[float, int, np.ndarray]): + if isinstance(x, np.ndarray): + return specs.Array(shape=x.shape, dtype=x.dtype) + elif isinstance(x, (float, int)): + return specs.Array(shape=(), dtype=type(x)) + else: + raise ValueError(f'Unsupported numeric: {type(x)}') + + +def get_specs(step): + """Infer spec from an example step.""" + env_spec = tree.map_structure( + _numeric_to_spec, + specs.EnvironmentSpec( + observations=step[1].observation, + actions=step[0], + rewards=step[1].reward, + discounts=step[1].discount)) + + has_extras = len(step) == 3 + if has_extras: + extras_spec = tree.map_structure(_numeric_to_spec, step[2]) + else: + extras_spec = () + + return env_spec, extras_spec + + +class AdderTestMixin(absltest.TestCase): + """A helper mixin for testing Reverb adders. + + Note that any test inheriting from this mixin must also inherit from something + that provides the Python unittest assert methods. + """ + + server: reverb.Server + client: reverb.Client + + @classmethod + def setUpClass(cls): + super().setUpClass() + + replay_table = reverb.Table.queue(adders.DEFAULT_PRIORITY_TABLE, 1000) + cls.server = reverb.Server([replay_table]) + cls.client = reverb.Client(f'localhost:{cls.server.port}') + + def tearDown(self): + self.client.reset(adders.DEFAULT_PRIORITY_TABLE) + super().tearDown() + + @classmethod + def tearDownClass(cls): + cls.server.stop() + super().tearDownClass() + + def num_episodes(self): + info = self.client.server_info(1)[adders.DEFAULT_PRIORITY_TABLE] + return info.num_episodes + + def num_items(self): + info = self.client.server_info(1)[adders.DEFAULT_PRIORITY_TABLE] + return info.current_size + + def items(self): + sampler = self.client.sample( + table=adders.DEFAULT_PRIORITY_TABLE, + num_samples=self.num_items(), + emit_timesteps=False) + return [sample.data for sample in sampler] # pytype: disable=attribute-error + + def run_test_adder( + self, + adder: adders_base.Adder, + first: dm_env.TimeStep, + steps: Sequence[Step], + expected_items: Sequence[Any], + signature: types.NestedSpec, + pack_expected_items: bool = False, + stack_sequence_fields: bool = True, + repeat_episode_times: int = 1, + end_behavior: adders.EndBehavior = adders.EndBehavior.ZERO_PAD, + item_transform: Optional[Callable[[Sequence[np.ndarray]], Any]] = None): + """Runs a unit test case for the adder. + + Args: + adder: The instance of `Adder` that is being tested. + first: The first `dm_env.TimeStep` that is used to call + `Adder.add_first()`. + steps: A sequence of (action, timestep) tuples that are passed to + `Adder.add()`. + expected_items: The sequence of items that are expected to be created + by calling the adder's `add_first()` method on `first` and `add()` on + all of the elements in `steps`. + signature: Signature that written items must be compatible with. + pack_expected_items: Deprecated and not used. If true the expected items + are given unpacked and need to be packed in a list before comparison. + stack_sequence_fields: Whether to stack the sequence fields of the + expected items before comparing to the observed items. Usually False + for transition adders and True for both episode and sequence adders. + repeat_episode_times: How many times to run an episode. + end_behavior: How end of episode should be handled. + item_transform: Transformation of item simulating the work done by the + dataset pipeline on the learner in a real setup. + """ + + del pack_expected_items + + if not steps: + raise ValueError('At least one step must be given.') + + has_extras = len(steps[0]) == 3 + for _ in range(repeat_episode_times): + # Add all the data up to the final step. + adder.add_first(first) + for step in steps[:-1]: + action, ts = step[0], step[1] + + if has_extras: + extras = step[2] + else: + extras = () + + adder.add(action, next_timestep=ts, extras=extras) + + # Add the final step. + adder.add(*steps[-1]) + + # Force run the destructor to trigger the flushing of all pending items. + getattr(adder, '__del__', lambda: None)() + + # Ending the episode should close the writer. No new writer should yet have + # been created as it is constructed lazily. + if end_behavior is not adders.EndBehavior.CONTINUE: + self.assertEqual(self.num_episodes(), repeat_episode_times) + + # Make sure our expected and observed data match. + observed_items = self.items() + + # Check matching number of items. + self.assertEqual(len(expected_items), len(observed_items)) + + # Check items are matching according to numpy's almost_equal. + for expected_item, observed_item in zip(expected_items, observed_items): + if stack_sequence_fields: + expected_item = tree_utils.stack_sequence_fields(expected_item) + + # Apply the transformation which would be done by the dataset in a real + # setup. + if item_transform: + observed_item = item_transform(observed_item) + + tree.map_structure(np.testing.assert_array_almost_equal, + tree.flatten(expected_item), + tree.flatten(observed_item)) + + # Make sure the signature matches was is being written by Reverb. + def _check_signature(spec: tf.TensorSpec, value: np.ndarray): + self.assertTrue(spec.is_compatible_with(tf.convert_to_tensor(value))) + + # Check that it is possible to unpack observed using the signature. + for item in observed_items: + tree.map_structure(_check_signature, tree.flatten(signature), + tree.flatten(item)) diff --git a/acme/acme/adders/reverb/transition.py b/acme/acme/adders/reverb/transition.py new file mode 100644 index 00000000..fe3e16f7 --- /dev/null +++ b/acme/acme/adders/reverb/transition.py @@ -0,0 +1,307 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transition adders. + +This implements an N-step transition adder which collapses trajectory sequences +into a single transition, simplifying to a simple transition adder when N=1. +""" + +import copy +from typing import Optional, Tuple + +from acme import specs +from acme import types +from acme.adders.reverb import base +from acme.adders.reverb import utils +from acme.utils import tree_utils + +import numpy as np +import reverb +import tree + + +class NStepTransitionAdder(base.ReverbAdder): + """An N-step transition adder. + + This will buffer a sequence of N timesteps in order to form a single N-step + transition which is added to reverb for future retrieval. + + For N=1 the data added to replay will be a standard one-step transition which + takes the form: + + (s_t, a_t, r_t, d_t, s_{t+1}, e_t) + + where: + + s_t = state observation at time t + a_t = the action taken from s_t + r_t = reward ensuing from action a_t + d_t = environment discount ensuing from action a_t. This discount is + applied to future rewards after r_t. + e_t [Optional] = extra data that the agent persists in replay. + + For N greater than 1, transitions are of the form: + + (s_t, a_t, R_{t:t+n}, D_{t:t+n}, s_{t+N}, e_t), + + where: + + s_t = State (observation) at time t. + a_t = Action taken from state s_t. + g = the additional discount, used by the agent to discount future returns. + R_{t:t+n} = N-step discounted return, i.e. accumulated over N rewards: + R_{t:t+n} := r_t + g * d_t * r_{t+1} + ... + + g^{n-1} * d_t * ... * d_{t+n-2} * r_{t+n-1}. + D_{t:t+n}: N-step product of agent discounts g_i and environment + "discounts" d_i. + D_{t:t+n} := g^{n-1} * d_{t} * ... * d_{t+n-1}, + For most environments d_i is 1 for all steps except the last, + i.e. it is the episode termination signal. + s_{t+n}: The "arrival" state, i.e. the state at time t+n. + e_t [Optional]: A nested structure of any 'extras' the user wishes to add. + + Notes: + - At the beginning and end of episodes, shorter transitions are added. + That is, at the beginning of the episode, it will add: + (s_0 -> s_1), (s_0 -> s_2), ..., (s_0 -> s_n), (s_1 -> s_{n+1}) + + And at the end of the episode, it will add: + (s_{T-n+1} -> s_T), (s_{T-n+2} -> s_T), ... (s_{T-1} -> s_T). + - We add the *first* `extra` of each transition, not the *last*, i.e. + if extras are provided, we get e_t, not e_{t+n}. + """ + + def __init__( + self, + client: reverb.Client, + n_step: int, + discount: float, + *, + priority_fns: Optional[base.PriorityFnMapping] = None, + max_in_flight_items: int = 5, + ): + """Creates an N-step transition adder. + + Args: + client: A `reverb.Client` to send the data to replay through. + n_step: The "N" in N-step transition. See the class docstring for the + precise definition of what an N-step transition is. `n_step` must be at + least 1, in which case we use the standard one-step transition, i.e. + (s_t, a_t, r_t, d_t, s_t+1, e_t). + discount: Discount factor to apply. This corresponds to the agent's + discount in the class docstring. + priority_fns: See docstring for BaseAdder. + max_in_flight_items: The maximum number of items allowed to be "in flight" + at the same time. See `block_until_num_items` in + `reverb.TrajectoryWriter.flush` for more info. + + Raises: + ValueError: If n_step is less than 1. + """ + # Makes the additional discount a float32, which means that it will be + # upcast if rewards/discounts are float64 and left alone otherwise. + self.n_step = n_step + self._discount = tree.map_structure(np.float32, discount) + self._first_idx = 0 + self._last_idx = 0 + + super().__init__( + client=client, + max_sequence_length=n_step + 1, + priority_fns=priority_fns, + max_in_flight_items=max_in_flight_items) + + def add(self, *args, **kwargs): + # Increment the indices for the start and end of the window for computing + # n-step returns. + if self._writer.episode_steps >= self.n_step: + self._first_idx += 1 + self._last_idx += 1 + + super().add(*args, **kwargs) + + def reset(self): + super().reset() + self._first_idx = 0 + self._last_idx = 0 + + @property + def _n_step(self) -> int: + """Effective n-step, which may vary at starts and ends of episodes.""" + return self._last_idx - self._first_idx + + def _write(self): + # Convenient getters for use in tree operations. + get_first = lambda x: x[self._first_idx] + get_last = lambda x: x[self._last_idx] + # Note: this getter is meant to be used on a TrajectoryWriter.history to + # obtain its numpy values. + get_all_np = lambda x: x[self._first_idx:self._last_idx].numpy() + + # Get the state, action, next_state, as well as possibly extras for the + # transition that is about to be written. + history = self._writer.history + s, a = tree.map_structure(get_first, + (history['observation'], history['action'])) + s_ = tree.map_structure(get_last, history['observation']) + + # Maybe get extras to add to the transition later. + if 'extras' in history: + extras = tree.map_structure(get_first, history['extras']) + + # Note: at the beginning of an episode we will add the initial N-1 + # transitions (of size 1, 2, ...) and at the end of an episode (when + # called from write_last) we will write the final transitions of size (N, + # N-1, ...). See the Note in the docstring. + # Get numpy view of the steps to be fed into the priority functions. + reward, discount = tree.map_structure( + get_all_np, (history['reward'], history['discount'])) + + # Compute discounted return and geometric discount over n steps. + n_step_return, total_discount = self._compute_cumulative_quantities( + reward, discount) + + # Append the computed n-step return and total discount. + # Note: if this call to _write() is within a call to _write_last(), then + # this is the only data being appended and so it is not a partial append. + self._writer.append( + dict(n_step_return=n_step_return, total_discount=total_discount), + partial_step=self._writer.episode_steps <= self._last_idx) + # This should be done immediately after self._writer.append so the history + # includes the recently appended data. + history = self._writer.history + + # Form the n-step transition by using the following: + # the first observation and action in the buffer, along with the cumulative + # reward and discount computed above. + n_step_return, total_discount = tree.map_structure( + lambda x: x[-1], (history['n_step_return'], history['total_discount'])) + transition = types.Transition( + observation=s, + action=a, + reward=n_step_return, + discount=total_discount, + next_observation=s_, + extras=(extras if 'extras' in history else ())) + + # Calculate the priority for this transition. + table_priorities = utils.calculate_priorities(self._priority_fns, + transition) + + # Insert the transition into replay along with its priority. + for table, priority in table_priorities.items(): + self._writer.create_item( + table=table, priority=priority, trajectory=transition) + self._writer.flush(self._max_in_flight_items) + + def _write_last(self): + # Write the remaining shorter transitions by alternating writing and + # incrementingfirst_idx. Note that last_idx will no longer be incremented + # once we're in this method's scope. + self._first_idx += 1 + while self._first_idx < self._last_idx: + self._write() + self._first_idx += 1 + + def _compute_cumulative_quantities( + self, rewards: types.NestedArray, discounts: types.NestedArray + ) -> Tuple[types.NestedArray, types.NestedArray]: + + # Give the same tree structure to the n-step return accumulator, + # n-step discount accumulator, and self.discount, so that they can be + # iterated in parallel using tree.map_structure. + rewards, discounts, self_discount = tree_utils.broadcast_structures( + rewards, discounts, self._discount) + flat_rewards = tree.flatten(rewards) + flat_discounts = tree.flatten(discounts) + flat_self_discount = tree.flatten(self_discount) + + # Copy total_discount as it is otherwise read-only. + total_discount = [np.copy(a[0]) for a in flat_discounts] + + # Broadcast n_step_return to have the broadcasted shape of + # reward * discount. + n_step_return = [ + np.copy(np.broadcast_to(r[0], + np.broadcast(r[0], d).shape)) + for r, d in zip(flat_rewards, total_discount) + ] + + # NOTE: total_discount will have one less self_discount applied to it than + # the value of self._n_step. This is so that when the learner/update uses + # an additional discount we don't apply it twice. Inside the following loop + # we will apply this right before summing up the n_step_return. + for i in range(1, self._n_step): + for nsr, td, r, d, sd in zip(n_step_return, total_discount, flat_rewards, + flat_discounts, flat_self_discount): + # Equivalent to: `total_discount *= self._discount`. + td *= sd + # Equivalent to: `n_step_return += reward[i] * total_discount`. + nsr += r[i] * td + # Equivalent to: `total_discount *= discount[i]`. + td *= d[i] + + n_step_return = tree.unflatten_as(rewards, n_step_return) + total_discount = tree.unflatten_as(rewards, total_discount) + return n_step_return, total_discount + + # TODO(bshahr): make this into a standalone method. Class methods should be + # used as alternative constructors or when modifying some global state, + # neither of which is done here. + @classmethod + def signature(cls, + environment_spec: specs.EnvironmentSpec, + extras_spec: types.NestedSpec = ()): + + # This function currently assumes that self._discount is a scalar. + # If it ever becomes a nested structure and/or a np.ndarray, this method + # will need to know its structure / shape. This is because the signature + # discount shape is the environment's discount shape and this adder's + # discount shape broadcasted together. Also, the reward shape is this + # signature discount shape broadcasted together with the environment + # reward shape. As long as self._discount is a scalar, it will not affect + # either the signature discount shape nor the signature reward shape, so we + # can ignore it. + + rewards_spec, step_discounts_spec = tree_utils.broadcast_structures( + environment_spec.rewards, environment_spec.discounts) + rewards_spec = tree.map_structure(_broadcast_specs, rewards_spec, + step_discounts_spec) + step_discounts_spec = tree.map_structure(copy.deepcopy, step_discounts_spec) + + transition_spec = types.Transition( + environment_spec.observations, + environment_spec.actions, + rewards_spec, + step_discounts_spec, + environment_spec.observations, # next_observation + extras_spec) + + return tree.map_structure_with_path(base.spec_like_to_tensor_spec, + transition_spec) + + +def _broadcast_specs(*args: specs.Array) -> specs.Array: + """Like np.broadcast, but for specs.Array. + + Args: + *args: one or more specs.Array instances. + + Returns: + A specs.Array with the broadcasted shape and dtype of the specs in *args. + """ + bc_info = np.broadcast(*tuple(a.generate_value() for a in args)) + dtype = np.result_type(*tuple(a.dtype for a in args)) + return specs.Array(shape=bc_info.shape, dtype=dtype) diff --git a/acme/acme/adders/reverb/transition_test.py b/acme/acme/adders/reverb/transition_test.py new file mode 100644 index 00000000..0c668d70 --- /dev/null +++ b/acme/acme/adders/reverb/transition_test.py @@ -0,0 +1,43 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for NStepTransition adders.""" + +from acme.adders.reverb import test_cases +from acme.adders.reverb import test_utils +from acme.adders.reverb import transition as adders + +from absl.testing import absltest +from absl.testing import parameterized + + +class NStepTransitionAdderTest(test_utils.AdderTestMixin, + parameterized.TestCase): + + @parameterized.named_parameters(*test_cases.TEST_CASES_FOR_TRANSITION_ADDER) + def test_adder(self, n_step, additional_discount, first, steps, + expected_transitions): + adder = adders.NStepTransitionAdder(self.client, n_step, + additional_discount) + super().run_test_adder( + adder=adder, + first=first, + steps=steps, + expected_items=expected_transitions, + stack_sequence_fields=False, + signature=adder.signature(*test_utils.get_specs(steps[0]))) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/adders/reverb/utils.py b/acme/acme/adders/reverb/utils.py new file mode 100644 index 00000000..9c7974a2 --- /dev/null +++ b/acme/acme/adders/reverb/utils.py @@ -0,0 +1,96 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for Reverb-based adders.""" + +from typing import Dict, Union + +from acme import types +from acme.adders.reverb import base +import jax.numpy as jnp +import numpy as np +import tree + + +def zeros_like(x: Union[np.ndarray, int, float, np.number]): + """Returns a zero-filled object of the same (d)type and shape as the input. + + The difference between this and `np.zeros_like()` is that this works well + with `np.number`, `int`, `float`, and `jax.numpy.DeviceArray` objects without + converting them to `np.ndarray`s. + + Args: + x: The object to replace with 0s. + + Returns: + A zero-filed object of the same (d)type and shape as the input. + """ + if isinstance(x, (int, float, np.number)): + return type(x)(0) + elif isinstance(x, jnp.DeviceArray): + return jnp.zeros_like(x) + elif isinstance(x, np.ndarray): + return np.zeros_like(x) + else: + raise ValueError( + f'Input ({type(x)}) must be either a numpy array, an int, or a float.') + + +def final_step_like(step: base.Step, + next_observation: types.NestedArray) -> base.Step: + """Return a list of steps with the final step zero-filled.""" + # Make zero-filled components so we can fill out the last step. + zero_action, zero_reward, zero_discount, zero_extras = tree.map_structure( + zeros_like, (step.action, step.reward, step.discount, step.extras)) + + # Return a final step that only has next_observation. + return base.Step( + observation=next_observation, + action=zero_action, + reward=zero_reward, + discount=zero_discount, + start_of_episode=False, + extras=zero_extras) + + +def calculate_priorities( + priority_fns: base.PriorityFnMapping, + trajectory_or_transition: Union[base.Trajectory, types.Transition], +) -> Dict[str, float]: + """Helper used to calculate the priority of a Trajectory or Transition. + + This helper converts the leaves of the Trajectory or Transition from + `reverb.TrajectoryColumn` objects into numpy arrays. The converted Trajectory + or Transition is then passed into each of the functions in `priority_fns`. + + Args: + priority_fns: a mapping from table names to priority functions (i.e. a + callable of type PriorityFn). The given function will be used to generate + the priority (a float) for the given table. + trajectory_or_transition: the trajectory or transition used to compute + priorities. + + Returns: + A dictionary mapping from table names to the priority (a float) for the + given collection Trajectory or Transition. + """ + if any([priority_fn is not None for priority_fn in priority_fns.values()]): + + trajectory_or_transition = tree.map_structure(lambda col: col.numpy(), + trajectory_or_transition) + + return { + table: (priority_fn(trajectory_or_transition) if priority_fn else 1.0) + for table, priority_fn in priority_fns.items() + } diff --git a/acme/acme/adders/wrappers.py b/acme/acme/adders/wrappers.py new file mode 100644 index 00000000..9b26944e --- /dev/null +++ b/acme/acme/adders/wrappers.py @@ -0,0 +1,62 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A library of useful adder wrappers.""" + +from typing import Iterable + +from acme import types +from acme.adders import base +import dm_env + + +class ForkingAdder(base.Adder): + """An adder that forks data into several other adders.""" + + def __init__(self, adders: Iterable[base.Adder]): + self._adders = adders + + def reset(self): + for adder in self._adders: + adder.reset() + + def add_first(self, timestep: dm_env.TimeStep): + for adder in self._adders: + adder.add_first(timestep) + + def add(self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + extras: types.NestedArray = ()): + for adder in self._adders: + adder.add(action, next_timestep, extras) + + +class IgnoreExtrasAdder(base.Adder): + """An adder that ignores extras.""" + + def __init__(self, adder: base.Adder): + self._adder = adder + + def reset(self): + self._adder.reset() + + def add_first(self, timestep: dm_env.TimeStep): + self._adder.add_first(timestep) + + def add(self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + extras: types.NestedArray = ()): + self._adder.add(action, next_timestep) diff --git a/acme/acme/agents/__init__.py b/acme/acme/agents/__init__.py new file mode 100644 index 00000000..15d16f85 --- /dev/null +++ b/acme/acme/agents/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Agent implementations.""" diff --git a/acme/acme/agents/agent.py b/acme/acme/agents/agent.py new file mode 100644 index 00000000..678b4961 --- /dev/null +++ b/acme/acme/agents/agent.py @@ -0,0 +1,136 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The base agent interface.""" + +import math +from typing import List, Optional, Sequence + +from acme import core +from acme import types +import dm_env +import numpy as np +import reverb + + +def _calculate_num_learner_steps(num_observations: int, + min_observations: int, + observations_per_step: float) -> int: + """Calculates the number of learner steps to do at step=num_observations.""" + n = num_observations - min_observations + if n < 0: + # Do not do any learner steps until you have seen min_observations. + return 0 + if observations_per_step > 1: + # One batch every 1/obs_per_step observations, otherwise zero. + return int(n % int(observations_per_step) == 0) + else: + # Always return 1/obs_per_step batches every observation. + return int(1 / observations_per_step) + + +class Agent(core.Actor, core.VariableSource): + """Agent class which combines acting and learning. + + This provides an implementation of the `Actor` interface which acts and + learns. It takes as input instances of both `acme.Actor` and `acme.Learner` + classes, and implements the policy, observation, and update methods which + defer to the underlying actor and learner. + + The only real logic implemented by this class is that it controls the number + of observations to make before running a learner step. This is done by + passing the number of `min_observations` to use and a ratio of + `observations_per_step` := num_actor_actions / num_learner_steps. + + Note that the number of `observations_per_step` can also be in the range[0, 1] + in order to allow the agent to take more than 1 learner step per action. + """ + + def __init__(self, actor: core.Actor, learner: core.Learner, + min_observations: Optional[int] = None, + observations_per_step: Optional[float] = None, + iterator: Optional[core.PrefetchingIterator] = None, + replay_tables: Optional[List[reverb.Table]] = None): + self._actor = actor + self._learner = learner + self._min_observations = min_observations + self._observations_per_step = observations_per_step + self._num_observations = 0 + self._iterator = iterator + self._replay_tables = replay_tables + self._batch_size_upper_bounds = [1_000_000_000] * len( + replay_tables) if replay_tables else None + + def select_action(self, observation: types.NestedArray) -> types.NestedArray: + return self._actor.select_action(observation) + + def observe_first(self, timestep: dm_env.TimeStep): + self._actor.observe_first(timestep) + + def observe(self, action: types.NestedArray, next_timestep: dm_env.TimeStep): + self._num_observations += 1 + self._actor.observe(action, next_timestep) + + def _has_data_for_training(self): + if self._iterator.ready(): + return True + for (table, batch_size) in zip(self._replay_tables, + self._batch_size_upper_bounds): + if not table.can_sample(batch_size): + return False + return True + + def update(self): + if self._iterator: + # Perform learner steps as long as iterator has data. + update_actor = False + while self._has_data_for_training(): + # Run learner steps (usually means gradient steps). + total_batches = self._iterator.retrieved_elements() + self._learner.step() + current_batches = self._iterator.retrieved_elements() - total_batches + assert current_batches == 1, ( + 'Learner step must retrieve exactly one element from the iterator' + f' (retrieved {current_batches}). Otherwise agent can deadlock. ' + 'Example cause is that your chosen agent' + 's Builder has a ' + '`make_learner` factory that prefetches the data but it ' + 'shouldn' + 't.') + self._batch_size_upper_bounds = [ + math.ceil(t.info.rate_limiter_info.sample_stats.completed / + (total_batches + 1)) for t in self._replay_tables + ] + update_actor = True + if update_actor: + # Update the actor weights only when learner was updated. + self._actor.update() + return + + # If dataset is not provided, follback to the old logic. + # TODO(stanczyk): Remove when not used. + num_steps = _calculate_num_learner_steps( + num_observations=self._num_observations, + min_observations=self._min_observations, + observations_per_step=self._observations_per_step, + ) + for _ in range(num_steps): + # Run learner steps (usually means gradient steps). + self._learner.step() + if num_steps > 0: + # Update the actor weights when learner updates. + self._actor.update() + + def get_variables(self, names: Sequence[str]) -> List[List[np.ndarray]]: + return self._learner.get_variables(names) diff --git a/acme/acme/agents/jax/__init__.py b/acme/acme/agents/jax/__init__.py new file mode 100644 index 00000000..f06d7218 --- /dev/null +++ b/acme/acme/agents/jax/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""JAX agents.""" diff --git a/acme/acme/agents/jax/actor_core.py b/acme/acme/agents/jax/actor_core.py new file mode 100644 index 00000000..387b5ef6 --- /dev/null +++ b/acme/acme/agents/jax/actor_core.py @@ -0,0 +1,170 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ActorCore interface definition.""" + +import dataclasses +from typing import Callable, Generic, Mapping, Tuple, TypeVar, Union + +from acme import types +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.jax.types import PRNGKey +import chex +import jax +import jax.numpy as jnp + + +NoneType = type(None) +# The state of the actor. This could include recurrent network state or any +# other state which needs to be propagated through the select_action calls. +State = TypeVar('State') +# The extras to be passed to the observe method. +Extras = TypeVar('Extras') +RecurrentState = TypeVar('RecurrentState') + + +@dataclasses.dataclass +class ActorCore(Generic[State, Extras]): + """Pure functions that define the algorithm-specific actor functionality.""" + init: Callable[[PRNGKey], State] + select_action: Callable[[ + networks_lib.Params, networks_lib.Observation, State + ], Tuple[networks_lib.Action, State]] + get_extras: Callable[[State], Extras] + + +# A simple feed forward policy which produces no extras and takes only an RNGKey +# as a state. +FeedForwardPolicy = Callable[ + [networks_lib.Params, PRNGKey, networks_lib.Observation], + networks_lib.Action] + +FeedForwardPolicyWithExtra = Callable[ + [networks_lib.Params, PRNGKey, networks_lib.Observation], + Tuple[networks_lib.Action, types.NestedArray]] + +RecurrentPolicy = Callable[[ + networks_lib.Params, PRNGKey, networks_lib + .Observation, RecurrentState +], Tuple[networks_lib.Action, RecurrentState]] + +Policy = Union[FeedForwardPolicy, FeedForwardPolicyWithExtra, RecurrentPolicy] + + +def batched_feed_forward_to_actor_core( + policy: FeedForwardPolicy +) -> ActorCore[PRNGKey, NoneType]: + """A convenience adaptor from FeedForwardPolicy to ActorCore.""" + + def select_action(params: networks_lib.Params, + observation: networks_lib.Observation, + state: PRNGKey): + rng = state + rng1, rng2 = jax.random.split(rng) + observation = utils.add_batch_dim(observation) + action = utils.squeeze_batch_dim(policy(params, rng1, observation)) + return action, rng2 + + def init(rng: PRNGKey) -> PRNGKey: + return rng + + def get_extras(unused_rng: PRNGKey) -> NoneType: + return None + return ActorCore(init=init, select_action=select_action, + get_extras=get_extras) + + +@chex.dataclass(frozen=True, mappable_dataclass=False) +class SimpleActorCoreStateWithExtras: + rng: PRNGKey + extras: Mapping[str, jnp.ndarray] + + +def unvectorize_select_action(actor_core: ActorCore) -> ActorCore: + """Makes an actor core's select_action method expect unbatched arguments.""" + + def unvectorized_select_action( + params: networks_lib.Params, + observations: networks_lib.Observation, + state: State, + ) -> Tuple[networks_lib.Action, State]: + observations, state = utils.add_batch_dim((observations, state)) + actions, state = actor_core.select_action(params, observations, state) + return utils.squeeze_batch_dim((actions, state)) + + return ActorCore( + init=actor_core.init, + select_action=unvectorized_select_action, + get_extras=actor_core.get_extras) + + +def batched_feed_forward_with_extras_to_actor_core( + policy: FeedForwardPolicyWithExtra +) -> ActorCore[SimpleActorCoreStateWithExtras, Mapping[str, jnp.ndarray]]: + """A convenience adaptor from FeedForwardPolicy to ActorCore.""" + + def select_action(params: networks_lib.Params, + observation: networks_lib.Observation, + state: SimpleActorCoreStateWithExtras): + rng = state.rng + rng1, rng2 = jax.random.split(rng) + observation = utils.add_batch_dim(observation) + action, extras = utils.squeeze_batch_dim(policy(params, rng1, observation)) + return action, SimpleActorCoreStateWithExtras(rng2, extras) + + def init(rng: PRNGKey) -> SimpleActorCoreStateWithExtras: + return SimpleActorCoreStateWithExtras(rng, {}) + + def get_extras( + state: SimpleActorCoreStateWithExtras) -> Mapping[str, jnp.ndarray]: + return state.extras + return ActorCore(init=init, select_action=select_action, + get_extras=get_extras) + + +@chex.dataclass(frozen=True, mappable_dataclass=False) +class SimpleActorCoreRecurrentState(Generic[RecurrentState]): + rng: PRNGKey + recurrent_state: RecurrentState + + +def batched_recurrent_to_actor_core( + recurrent_policy: RecurrentPolicy, initial_core_state: RecurrentState +) -> ActorCore[SimpleActorCoreRecurrentState[RecurrentState], Mapping[ + str, jnp.ndarray]]: + """Returns ActorCore for a recurrent policy.""" + def select_action(params: networks_lib.Params, + observation: networks_lib.Observation, + state: SimpleActorCoreRecurrentState[RecurrentState]): + # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. + rng = state.rng + rng, policy_rng = jax.random.split(rng) + observation = utils.add_batch_dim(observation) + recurrent_state = utils.add_batch_dim(state.recurrent_state) + action, new_recurrent_state = utils.squeeze_batch_dim(recurrent_policy( + params, policy_rng, observation, recurrent_state)) + return action, SimpleActorCoreRecurrentState(rng, new_recurrent_state) + + initial_core_state = utils.squeeze_batch_dim(initial_core_state) + def init(rng: PRNGKey) -> SimpleActorCoreRecurrentState[RecurrentState]: + return SimpleActorCoreRecurrentState(rng, initial_core_state) + + def get_extras( + state: SimpleActorCoreRecurrentState[RecurrentState] + ) -> Mapping[str, jnp.ndarray]: + return {'core_state': state.recurrent_state} + + return ActorCore(init=init, select_action=select_action, + get_extras=get_extras) diff --git a/acme/acme/agents/jax/actors.py b/acme/acme/agents/jax/actors.py new file mode 100644 index 00000000..20426c2e --- /dev/null +++ b/acme/acme/agents/jax/actors.py @@ -0,0 +1,99 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simple JAX actors.""" + +from typing import Generic, Optional + +from acme import adders +from acme import core +from acme import types +from acme.agents.jax import actor_core +from acme.jax import networks as network_lib +from acme.jax import utils +from acme.jax import variable_utils +import dm_env +import jax + + +class GenericActor(core.Actor, Generic[actor_core.State, actor_core.Extras]): + """A generic actor implemented on top of ActorCore. + + An actor based on a policy which takes observations and outputs actions. It + also adds experiences to replay and updates the actor weights from the policy + on the learner. + """ + + def __init__( + self, + actor: actor_core.ActorCore[actor_core.State, actor_core.Extras], + random_key: network_lib.PRNGKey, + variable_client: Optional[variable_utils.VariableClient], + adder: Optional[adders.Adder] = None, + jit: bool = True, + backend: Optional[str] = 'cpu', + per_episode_update: bool = False + ): + """Initializes a feed forward actor. + + Args: + actor: actor core. + random_key: Random key. + variable_client: The variable client to get policy parameters from. + adder: An adder to add experiences to. + jit: Whether or not to jit the passed ActorCore's pure functions. + backend: Which backend to use when jitting the policy. + per_episode_update: if True, updates variable client params once at the + beginning of each episode + """ + self._random_key = random_key + self._variable_client = variable_client + self._adder = adder + self._state = None + + # Unpack ActorCore, jitting if requested. + if jit: + self._init = jax.jit(actor.init, backend=backend) + self._policy = jax.jit(actor.select_action, backend=backend) + else: + self._init = actor.init + self._policy = actor.select_action + self._get_extras = actor.get_extras + self._per_episode_update = per_episode_update + + @property + def _params(self): + return self._variable_client.params if self._variable_client else [] + + def select_action(self, + observation: network_lib.Observation) -> types.NestedArray: + action, self._state = self._policy(self._params, observation, self._state) + return utils.to_numpy(action) + + def observe_first(self, timestep: dm_env.TimeStep): + self._random_key, key = jax.random.split(self._random_key) + self._state = self._init(key) + if self._adder: + self._adder.add_first(timestep) + if self._variable_client and self._per_episode_update: + self._variable_client.update_and_wait() + + def observe(self, action: network_lib.Action, next_timestep: dm_env.TimeStep): + if self._adder: + self._adder.add( + action, next_timestep, extras=self._get_extras(self._state)) + + def update(self, wait: bool = False): + if self._variable_client and not self._per_episode_update: + self._variable_client.update(wait) diff --git a/acme/acme/agents/jax/actors_test.py b/acme/acme/agents/jax/actors_test.py new file mode 100644 index 00000000..941e7a20 --- /dev/null +++ b/acme/acme/agents/jax/actors_test.py @@ -0,0 +1,142 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for actors.""" +from typing import Optional, Tuple + +from acme import environment_loop +from acme import specs +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import actors +from acme.jax import utils +from acme.jax import variable_utils +from acme.testing import fakes +import dm_env +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np + +from absl.testing import absltest +from absl.testing import parameterized + + +def _make_fake_env() -> dm_env.Environment: + env_spec = specs.EnvironmentSpec( + observations=specs.Array(shape=(10, 5), dtype=np.float32), + actions=specs.DiscreteArray(num_values=3), + rewards=specs.Array(shape=(), dtype=np.float32), + discounts=specs.BoundedArray( + shape=(), dtype=np.float32, minimum=0., maximum=1.), + ) + return fakes.Environment(env_spec, episode_length=10) + + +class ActorTest(parameterized.TestCase): + + @parameterized.named_parameters( + ('policy', False), + ('policy_with_extras', True)) + def test_feedforward(self, has_extras): + environment = _make_fake_env() + env_spec = specs.make_environment_spec(environment) + + def policy(inputs: jnp.ndarray): + action_values = hk.Sequential([ + hk.Flatten(), + hk.Linear(env_spec.actions.num_values), + ])( + inputs) + action = jnp.argmax(action_values, axis=-1) + if has_extras: + return action, (action_values,) + else: + return action + + policy = hk.transform(policy) + + rng = hk.PRNGSequence(1) + dummy_obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) + params = policy.init(next(rng), dummy_obs) + + variable_source = fakes.VariableSource(params) + variable_client = variable_utils.VariableClient(variable_source, 'policy') + + if has_extras: + actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core( + policy.apply) + else: + actor_core = actor_core_lib.batched_feed_forward_to_actor_core( + policy.apply) + actor = actors.GenericActor( + actor_core, + random_key=jax.random.PRNGKey(1), + variable_client=variable_client) + + loop = environment_loop.EnvironmentLoop(environment, actor) + loop.run(20) + + +def _transform_without_rng(f): + return hk.without_apply_rng(hk.transform(f)) + + +class RecurrentActorTest(absltest.TestCase): + + def test_recurrent(self): + environment = _make_fake_env() + env_spec = specs.make_environment_spec(environment) + output_size = env_spec.actions.num_values + obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) + rng = hk.PRNGSequence(1) + + @_transform_without_rng + def network(inputs: jnp.ndarray, state: hk.LSTMState): + return hk.DeepRNN([hk.Reshape([-1], preserve_dims=1), + hk.LSTM(output_size)])(inputs, state) + + @_transform_without_rng + def initial_state(batch_size: Optional[int] = None): + network = hk.DeepRNN([hk.Reshape([-1], preserve_dims=1), + hk.LSTM(output_size)]) + return network.initial_state(batch_size) + + initial_state = initial_state.apply(initial_state.init(next(rng)), 1) + params = network.init(next(rng), obs, initial_state) + + def policy( + params: jnp.ndarray, + key: jnp.ndarray, + observation: jnp.ndarray, + core_state: hk.LSTMState + ) -> Tuple[jnp.ndarray, hk.LSTMState]: + del key # Unused for test-case deterministic policy. + action_values, core_state = network.apply(params, observation, core_state) + actions = jnp.argmax(action_values, axis=-1) + return actions, core_state + + variable_source = fakes.VariableSource(params) + variable_client = variable_utils.VariableClient(variable_source, 'policy') + + actor_core = actor_core_lib.batched_recurrent_to_actor_core( + policy, initial_state) + actor = actors.GenericActor(actor_core, jax.random.PRNGKey(1), + variable_client) + + loop = environment_loop.EnvironmentLoop(environment, actor) + loop.run(20) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/jax/ail/README.md b/acme/acme/agents/jax/ail/README.md new file mode 100644 index 00000000..5e1b546a --- /dev/null +++ b/acme/acme/agents/jax/ail/README.md @@ -0,0 +1,17 @@ +# Adversarial Imitation Learning (AIL) + +This folder contains a modular implementation of an Adversarial +Imitation Learning agent. +The initial algorithm is Generative Adversarial Imitation Learning +(GAIL - [Ho et al., 2016]), but many more tricks and variations are +available. +The corresponding paper ([Orsini et al., 2021]) explains and discusses +the utility of all those tricks. + +AIL requires an off-policy RL algorithm to work, passed in as an +`ActorLearnerBuilder`. + +If you use this code, please cite [Orsini et al., 2021]. + +[Ho et al., 2016]: https://arxiv.org/abs/1606.03476 +[Orsini et al., 2021]: https://arxiv.org/abs/2106.00672 diff --git a/acme/acme/agents/jax/ail/__init__.py b/acme/acme/agents/jax/ail/__init__.py new file mode 100644 index 00000000..df302c49 --- /dev/null +++ b/acme/acme/agents/jax/ail/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementations of a AIL agent.""" + +from acme.agents.jax.ail import losses +from acme.agents.jax.ail import rewards +from acme.agents.jax.ail.builder import AILBuilder +from acme.agents.jax.ail.config import AILConfig +from acme.agents.jax.ail.dac import DACBuilder +from acme.agents.jax.ail.dac import DACConfig +from acme.agents.jax.ail.gail import GAILBuilder +from acme.agents.jax.ail.gail import GAILConfig +from acme.agents.jax.ail.learning import AILLearner +from acme.agents.jax.ail.networks import AILNetworks +from acme.agents.jax.ail.networks import AIRLModule +from acme.agents.jax.ail.networks import compute_ail_reward +from acme.agents.jax.ail.networks import DiscriminatorMLP +from acme.agents.jax.ail.networks import DiscriminatorModule +from acme.agents.jax.ail.networks import make_discriminator diff --git a/acme/acme/agents/jax/ail/builder.py b/acme/acme/agents/jax/ail/builder.py new file mode 100644 index 00000000..b71350c9 --- /dev/null +++ b/acme/acme/agents/jax/ail/builder.py @@ -0,0 +1,331 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Adversarial Imitation Learning (AIL) Builder.""" + +import functools +import itertools +from typing import Callable, Generic, Iterator, List, Optional, Tuple + +from acme import adders +from acme import core +from acme import specs +from acme import types +from acme.adders import reverb as adders_reverb +from acme.agents.jax import builders +from acme.agents.jax.ail import config as ail_config +from acme.agents.jax.ail import learning +from acme.agents.jax.ail import losses +from acme.agents.jax.ail import networks as ail_networks +from acme.datasets import reverb as datasets +from acme.jax import types as jax_types +from acme.jax import utils +from acme.jax.imitation_learning_types import DirectPolicyNetwork +from acme.utils import counting +from acme.utils import loggers +from acme.utils import reverb_utils +import jax +import numpy as np +import optax +import reverb +from reverb import rate_limiters +import tree + + +def _split_transitions( + transitions: types.Transition, + index: int) -> Tuple[types.Transition, types.Transition]: + """Splits the given transition on the first axis at the given index. + + Args: + transitions: Transitions to split. + index: Spliting index. + + Returns: + A pair of transitions, the first containing elements before the index + (exclusive) and the second after the index (inclusive) + """ + return (tree.map_structure(lambda x: x[:index], transitions), + tree.map_structure(lambda x: x[index:], transitions)) + + +def _rebatch(iterator: Iterator[types.Transition], + batch_size: int) -> Iterator[types.Transition]: + """Rebatch the itererator with the given batch size. + + Args: + iterator: Iterator to rebatch. + batch_size: New batch size. + + Yields: + Transitions with the new batch size. + """ + data = next(iterator) + while True: + while len(data.reward) < batch_size: + # Ensure we can get enough demonstrations. + next_data = next(iterator) + data = tree.map_structure(lambda *args: np.concatenate(list(args)), data, + next_data) + output, data = _split_transitions(data, batch_size) + yield output + + +def _mix_arrays( + replay: np.ndarray, + demo: np.ndarray, + index: int, + seed: int) -> np.ndarray: + """Mixes `replay` and `demo`. + + Args: + replay: Replay data to mix. Only index element will be selected. + demo: Demonstration data to mix. + index: Amount of replay data we should include. + seed: RNG seed. + + Returns: + An array with replay elements up to 'index' and all the demos. + """ + # We're throwing away some replay data here. We have to if we want to make + # sure the output info field is correct. + output = np.concatenate((replay[:index], demo)) + return np.random.default_rng(seed=seed).permutation(output) + + +def _generate_samples_with_demonstrations( + demonstration_iterator: Iterator[types.Transition], + replay_iterator: Iterator[reverb.ReplaySample], + policy_to_expert_data_ratio: int, + batch_size) -> Iterator[reverb.ReplaySample]: + """Generator which creates the sample having demonstrations in them. + + It takes the demonstrations and replay iterators and generates batches with + same size as the replay iterator, such that each batches have the ratio of + policy and expert data specified in policy_to_expert_data_ratio on average. + There is no constraints on how the demonstrations and replay samples should be + batched. + + Args: + demonstration_iterator: Iterator of demonstrations. + replay_iterator: Replay buffer sample iterator. + policy_to_expert_data_ratio: Amount of policy transitions for 1 expert + transitions. + batch_size: Output batch size, which should match the replay batch size. + + Yields: + Samples having a mix of demonstrations and policy data. The info will match + the current replay sample info and the batch size will be the same as the + replay_iterator data batch size. + """ + count = 0 + if batch_size % (policy_to_expert_data_ratio + 1) != 0: + raise ValueError( + 'policy_to_expert_data_ratio + 1 must divide the batch size but ' + f'{batch_size} % {policy_to_expert_data_ratio+1} !=0') + demo_insertion_size = batch_size // (policy_to_expert_data_ratio + 1) + policy_insertion_size = batch_size - demo_insertion_size + + demonstration_iterator = _rebatch(demonstration_iterator, demo_insertion_size) + for sample, demos in zip(replay_iterator, demonstration_iterator): + output_transitions = tree.map_structure( + functools.partial(_mix_arrays, + index=policy_insertion_size, + seed=count), + sample.data, demos) + count += 1 + yield reverb.ReplaySample(info=sample.info, data=output_transitions) + + +class AILBuilder(builders.ActorLearnerBuilder[ail_networks.AILNetworks, + DirectPolicyNetwork, + learning.AILSample], + Generic[ail_networks.DirectRLNetworks, DirectPolicyNetwork]): + """AIL Builder.""" + + def __init__( + self, + rl_agent: builders.ActorLearnerBuilder[ail_networks.DirectRLNetworks, + DirectPolicyNetwork, + reverb.ReplaySample], + config: ail_config.AILConfig, discriminator_loss: losses.Loss, + make_demonstrations: Callable[[int], Iterator[types.Transition]]): + """Implements a builder for AIL using rl_agent as forward RL algorithm. + + Args: + rl_agent: The standard RL agent used by AIL to optimize the generator. + config: a AIL config + discriminator_loss: The loss function for the discriminator to minimize. + make_demonstrations: A function that returns an iterator with + demonstrations to be imitated. + """ + self._rl_agent = rl_agent + self._config = config + self._discriminator_loss = discriminator_loss + self._make_demonstrations = make_demonstrations + + def make_learner(self, + random_key: jax_types.PRNGKey, + networks: ail_networks.AILNetworks, + dataset: Iterator[learning.AILSample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None) -> core.Learner: + counter = counter or counting.Counter() + direct_rl_counter = counting.Counter(counter, 'direct_rl') + batch_size_per_learner_step = ail_config.get_per_learner_step_batch_size( + self._config) + + direct_rl_learner_key, discriminator_key = jax.random.split(random_key) + + direct_rl_learner = functools.partial( + self._rl_agent.make_learner, + direct_rl_learner_key, + networks.direct_rl_networks, + logger_fn=logger_fn, + environment_spec=environment_spec, + replay_client=replay_client, + counter=direct_rl_counter) + + discriminator_optimizer = ( + self._config.discriminator_optimizer or optax.adam(1e-5)) + + return learning.AILLearner( + counter, + direct_rl_learner_factory=direct_rl_learner, + loss_fn=self._discriminator_loss, + iterator=dataset, + discriminator_optimizer=discriminator_optimizer, + ail_network=networks, + discriminator_key=discriminator_key, + is_sequence_based=self._config.is_sequence_based, + num_sgd_steps_per_step=batch_size_per_learner_step // + self._config.discriminator_batch_size, + policy_variable_name=self._config.policy_variable_name, + logger=logger_fn('learner', steps_key=counter.get_steps_key())) + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: DirectPolicyNetwork, + ) -> List[reverb.Table]: + replay_tables = self._rl_agent.make_replay_tables(environment_spec, policy) + if self._config.share_iterator: + return replay_tables + replay_tables.append( + reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_replay_size, + rate_limiter=rate_limiters.MinSize(self._config.min_replay_size), + signature=adders_reverb.NStepTransitionAdder.signature( + environment_spec))) + return replay_tables + + # This function does not expose all the iterators used by the learner when + # share_iterator is False, making further wrapping impossible. + # TODO(eorsini): Expose all iterators. + # Currently GAIL uses 3 iterators, instead we can make it use a single + # iterator and return this one here. The way to achieve this would be: + # * Create the 3 iterators here. + # * zip them and return them here. + # * upzip them in the learner (this step will not be necessary once we move to + # stateless learners) + # This should work fine as the 3 iterators are always iterated in parallel + # (i.e. at every step we call next once on each of them). + def make_dataset_iterator( + self, replay_client: reverb.Client) -> Iterator[learning.AILSample]: + batch_size_per_learner_step = ail_config.get_per_learner_step_batch_size( + self._config) + + iterator_demonstration = self._make_demonstrations( + batch_size_per_learner_step) + + direct_iterator = self._rl_agent.make_dataset_iterator(replay_client) + + if self._config.share_iterator: + # In order to reuse the iterator return values and not lose a 2x factor on + # sample efficiency, we need to use itertools.tee(). + discriminator_iterator, direct_iterator = itertools.tee(direct_iterator) + else: + discriminator_iterator = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=replay_client.server_address, + batch_size=ail_config.get_per_learner_step_batch_size(self._config), + prefetch_size=self._config.prefetch_size).as_numpy_iterator() + + if self._config.policy_to_expert_data_ratio is not None: + iterator_demonstration, iterator_demonstration2 = itertools.tee( + iterator_demonstration) + direct_iterator = _generate_samples_with_demonstrations( + iterator_demonstration2, direct_iterator, + self._config.policy_to_expert_data_ratio, + self._config.direct_rl_batch_size) + + is_sequence_based = self._config.is_sequence_based + + # Don't flatten the discriminator batch if the iterator is not shared. + process_discriminator_sample = functools.partial( + reverb_utils.replay_sample_to_sars_transition, + is_sequence=is_sequence_based and self._config.share_iterator, + flatten_batch=is_sequence_based and self._config.share_iterator, + strip_last_transition=is_sequence_based and self._config.share_iterator) + + discriminator_iterator = ( + # Remove the extras to have the same nested structure as demonstrations. + process_discriminator_sample(sample)._replace(extras=()) + for sample in discriminator_iterator) + + return utils.device_put((learning.AILSample(*sample) for sample in zip( + discriminator_iterator, direct_iterator, iterator_demonstration)), + jax.devices()[0]) + + def make_adder( + self, replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[DirectPolicyNetwork]) -> Optional[adders.Adder]: + direct_rl_adder = self._rl_agent.make_adder(replay_client, environment_spec, + policy) + if self._config.share_iterator: + return direct_rl_adder + ail_adder = adders_reverb.NStepTransitionAdder( + priority_fns={self._config.replay_table_name: None}, + client=replay_client, + n_step=1, + discount=self._config.discount) + + # Some direct rl algorithms (such as PPO), might be passing extra data + # which we won't be able to process here properly, so we need to ignore them + return adders.ForkingAdder( + [adders.IgnoreExtrasAdder(ail_adder), direct_rl_adder]) + + def make_actor( + self, + random_key: jax_types.PRNGKey, + policy: DirectPolicyNetwork, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> core.Actor: + return self._rl_agent.make_actor(random_key, policy, environment_spec, + variable_source, adder) + + def make_policy(self, + networks: ail_networks.AILNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False) -> DirectPolicyNetwork: + return self._rl_agent.make_policy(networks.direct_rl_networks, + environment_spec, evaluation) diff --git a/acme/acme/agents/jax/ail/builder_test.py b/acme/acme/agents/jax/ail/builder_test.py new file mode 100644 index 00000000..800dbc3f --- /dev/null +++ b/acme/acme/agents/jax/ail/builder_test.py @@ -0,0 +1,56 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the builder generator.""" +from acme import types +from acme.agents.jax.ail import builder +import numpy as np +import reverb + +from absl.testing import absltest + +_REWARD = np.zeros((3,)) + + +class BuilderTest(absltest.TestCase): + + def test_weighted_generator(self): + data0 = types.Transition(np.array([[1], [2], [3]]), (), _REWARD, (), ()) + it0 = iter([data0]) + + data1 = types.Transition(np.array([[4], [5], [6]]), (), _REWARD, (), ()) + data2 = types.Transition(np.array([[7], [8], [9]]), (), _REWARD, (), ()) + it1 = iter([ + reverb.ReplaySample( + info=reverb.SampleInfo( + *[() for _ in reverb.SampleInfo.tf_dtypes()]), + data=data1), + reverb.ReplaySample( + info=reverb.SampleInfo( + *[() for _ in reverb.SampleInfo.tf_dtypes()]), + data=data2) + ]) + + weighted_it = builder._generate_samples_with_demonstrations( + it0, it1, policy_to_expert_data_ratio=2, batch_size=3) + + np.testing.assert_array_equal( + next(weighted_it).data.observation, np.array([[1], [4], [5]])) + np.testing.assert_array_equal( + next(weighted_it).data.observation, np.array([[7], [8], [2]])) + self.assertRaises(StopIteration, lambda: next(weighted_it)) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/jax/ail/config.py b/acme/acme/agents/jax/ail/config.py new file mode 100644 index 00000000..c5541de5 --- /dev/null +++ b/acme/acme/agents/jax/ail/config.py @@ -0,0 +1,78 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AIL config.""" +import dataclasses +from typing import Optional + +import optax + + +@dataclasses.dataclass +class AILConfig: + """Configuration options for AIL. + + Attributes: + direct_rl_batch_size: Batch size of a direct rl algorithm (measured in + transitions). + is_sequence_based: If True, a direct rl algorithm is using SequenceAdder + data format. Otherwise the learner assumes that the direct rl algorithm is + using NStepTransitionAdder. + share_iterator: If True, AIL will use the same iterator for the + discriminator network training as the direct rl algorithm. + num_sgd_steps_per_step: Only used if 'share_iterator' is False. Denotes how + many gradient updates perform per one learner step. + discriminator_batch_size: Batch size for training the discriminator. + policy_variable_name: The name of the policy variable to retrieve direct_rl + policy parameters. + discriminator_optimizer: Optimizer for the discriminator. If not specified + it is set to Adam with learning rate of 1e-5. + replay_table_name: The name of the reverb replay table to use. + prefetch_size: How many batches to prefetch + discount: Discount to use for TD updates + min_replay_size: Minimal size of replay buffer + max_replay_size: Maximal size of replay buffer + policy_to_expert_data_ratio: If not None, the direct RL learner will receive + expert transitions in the given proportions. + policy_to_expert_data_ratio + 1 must divide the direct RL batch size. + """ + direct_rl_batch_size: int + is_sequence_based: bool = False + share_iterator: bool = True + num_sgd_steps_per_step: int = 1 + discriminator_batch_size: int = 256 + policy_variable_name: Optional[str] = None + discriminator_optimizer: Optional[optax.GradientTransformation] = None + replay_table_name: str = 'ail_table' + prefetch_size: int = 4 + discount: float = 0.99 + min_replay_size: int = 1000 + max_replay_size: int = int(1e6) + policy_to_expert_data_ratio: Optional[int] = None + + def __post_init__(self): + assert self.direct_rl_batch_size % self.discriminator_batch_size == 0 + + +def get_per_learner_step_batch_size(config: AILConfig) -> int: + """Returns how many transitions should be sampled per direct learner step.""" + # If the iterators are tied, the discriminator learning batch size has to + # match the direct RL one. + if config.share_iterator: + assert (config.direct_rl_batch_size % config.discriminator_batch_size) == 0 + return config.direct_rl_batch_size + # Otherwise each iteration of the discriminator will sample a batch which will + # be split in num_sgd_steps_per_step batches, each of size + # discriminator_batch_size. + return config.discriminator_batch_size * config.num_sgd_steps_per_step diff --git a/acme/acme/agents/jax/ail/dac.py b/acme/acme/agents/jax/ail/dac.py new file mode 100644 index 00000000..d048f6e1 --- /dev/null +++ b/acme/acme/agents/jax/ail/dac.py @@ -0,0 +1,65 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Builder for DAC. + +https://arxiv.org/pdf/1809.02925.pdf +""" + +import dataclasses +from typing import Callable, Iterator + +from acme import types +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import td3 +from acme.agents.jax.ail import builder +from acme.agents.jax.ail import config as ail_config +from acme.agents.jax.ail import losses + + +@dataclasses.dataclass +class DACConfig: + """Configuration options specific to DAC. + + Attributes: + ail_config: AIL config. + td3_config: TD3 config. + entropy_coefficient: Entropy coefficient of the discriminator loss. + gradient_penalty_coefficient: Coefficient for the gradient penalty term in + the discriminator loss. + """ + ail_config: ail_config.AILConfig + td3_config: td3.TD3Config + entropy_coefficient: float = 1e-3 + gradient_penalty_coefficient: float = 10. + + +class DACBuilder(builder.AILBuilder[td3.TD3Networks, + actor_core_lib.FeedForwardPolicy]): + """DAC Builder.""" + + def __init__(self, config: DACConfig, + make_demonstrations: Callable[[int], + Iterator[types.Transition]]): + + td3_builder = td3.TD3Builder(config.td3_config) + dac_loss = losses.add_gradient_penalty( + losses.gail_loss(entropy_coefficient=config.entropy_coefficient), + gradient_penalty_coefficient=config.gradient_penalty_coefficient, + gradient_penalty_target=1.) + super().__init__( + td3_builder, + config=config.ail_config, + discriminator_loss=dac_loss, + make_demonstrations=make_demonstrations) diff --git a/acme/acme/agents/jax/ail/gail.py b/acme/acme/agents/jax/ail/gail.py new file mode 100644 index 00000000..c5ba2904 --- /dev/null +++ b/acme/acme/agents/jax/ail/gail.py @@ -0,0 +1,52 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Builder for GAIL. + +https://arxiv.org/pdf/1606.03476.pdf +""" + +import dataclasses +from typing import Callable, Iterator + +from acme import types +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import ppo +from acme.agents.jax.ail import builder +from acme.agents.jax.ail import config as ail_config +from acme.agents.jax.ail import losses + + +@dataclasses.dataclass +class GAILConfig: + """Configuration options specific to GAIL.""" + ail_config: ail_config.AILConfig + ppo_config: ppo.PPOConfig + + +class GAILBuilder(builder.AILBuilder[ppo.PPONetworks, + actor_core_lib.FeedForwardPolicyWithExtra] + ): + """GAIL Builder.""" + + def __init__(self, config: GAILConfig, + make_demonstrations: Callable[[int], + Iterator[types.Transition]]): + + ppo_builder = ppo.PPOBuilder(config.ppo_config) + super().__init__( + ppo_builder, + config=config.ail_config, + discriminator_loss=losses.gail_loss(), + make_demonstrations=make_demonstrations) diff --git a/acme/acme/agents/jax/ail/learning.py b/acme/acme/agents/jax/ail/learning.py new file mode 100644 index 00000000..64f625ff --- /dev/null +++ b/acme/acme/agents/jax/ail/learning.py @@ -0,0 +1,293 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AIL learner implementation.""" +import functools +import itertools +import time +from typing import Any, Callable, Iterator, List, NamedTuple, Optional, Tuple + +import acme +from acme import types +from acme.agents.jax.ail import losses +from acme.agents.jax.ail import networks as ail_networks +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.utils import counting +from acme.utils import loggers +from acme.utils import reverb_utils +import jax +import optax +import reverb + + +class DiscriminatorTrainingState(NamedTuple): + """Contains training state for the discriminator.""" + # State of the optimizer used to optimize the discriminator parameters. + optimizer_state: optax.OptState + + # Parameters of the discriminator. + discriminator_params: networks_lib.Params + + # State of the discriminator + discriminator_state: losses.State + + # For AIRL variants, we need the policy params to compute the loss. + policy_params: Optional[networks_lib.Params] + + # Key for random number generation. + key: networks_lib.PRNGKey + + # Training step of the discriminator. + steps: int + + +class TrainingState(NamedTuple): + """Contains training state of the AIL learner.""" + rewarder_state: DiscriminatorTrainingState + learner_state: Any + + +def ail_update_step( + state: DiscriminatorTrainingState, data: Tuple[types.Transition, + types.Transition], + optimizer: optax.GradientTransformation, + ail_network: ail_networks.AILNetworks, + loss_fn: losses.Loss) -> Tuple[DiscriminatorTrainingState, losses.Metrics]: + """Run an update steps on the given transitions. + + Args: + state: The learner state. + data: Demo and rb transitions. + optimizer: Discriminator optimizer. + ail_network: AIL networks. + loss_fn: Discriminator loss to minimize. + + Returns: + A new state and metrics. + """ + demo_transitions, rb_transitions = data + key, discriminator_key, loss_key = jax.random.split(state.key, 3) + + def compute_loss( + discriminator_params: networks_lib.Params) -> losses.LossOutput: + discriminator_fn = functools.partial( + ail_network.discriminator_network.apply, + discriminator_params, + state.policy_params, + is_training=True, + rng=discriminator_key) + return loss_fn(discriminator_fn, state.discriminator_state, + demo_transitions, rb_transitions, loss_key) + + loss_grad = jax.grad(compute_loss, has_aux=True) + + grads, (loss, new_discriminator_state) = loss_grad(state.discriminator_params) + + update, optimizer_state = optimizer.update( + grads, + state.optimizer_state, + params=state.discriminator_params) + discriminator_params = optax.apply_updates(state.discriminator_params, update) + + new_state = DiscriminatorTrainingState( + optimizer_state=optimizer_state, + discriminator_params=discriminator_params, + discriminator_state=new_discriminator_state, + policy_params=state.policy_params, # Not modified. + key=key, + steps=state.steps + 1, + ) + return new_state, loss + + +class AILSample(NamedTuple): + discriminator_sample: types.Transition + direct_sample: reverb.ReplaySample + demonstration_sample: types.Transition + + +class AILLearner(acme.Learner): + """AIL learner.""" + + def __init__( + self, + counter: counting.Counter, + direct_rl_learner_factory: Callable[[Iterator[reverb.ReplaySample]], + acme.Learner], + loss_fn: losses.Loss, + iterator: Iterator[AILSample], + discriminator_optimizer: optax.GradientTransformation, + ail_network: ail_networks.AILNetworks, + discriminator_key: networks_lib.PRNGKey, + is_sequence_based: bool, + num_sgd_steps_per_step: int = 1, + policy_variable_name: Optional[str] = None, + logger: Optional[loggers.Logger] = None): + """AIL Learner. + + Args: + counter: Counter. + direct_rl_learner_factory: Function that creates the direct RL learner + when passed a replay sample iterator. + loss_fn: Discriminator loss. + iterator: Iterator that returns AILSamples. + discriminator_optimizer: Discriminator optax optimizer. + ail_network: AIL networks. + discriminator_key: RNG key. + is_sequence_based: If True, a direct rl algorithm is using SequenceAdder + data format. Otherwise the learner assumes that the direct rl algorithm + is using NStepTransitionAdder. + num_sgd_steps_per_step: Number of discriminator gradient updates per step. + policy_variable_name: The name of the policy variable to retrieve + direct_rl policy parameters. + logger: Logger. + """ + self._is_sequence_based = is_sequence_based + + state_key, networks_key = jax.random.split(discriminator_key) + + # Generator expression that works the same as an iterator. + # https://pymbook.readthedocs.io/en/latest/igd.html#generator-expressions + iterator, direct_rl_iterator = itertools.tee(iterator) + direct_rl_iterator = ( + self._process_sample(sample.direct_sample) + for sample in direct_rl_iterator) + self._direct_rl_learner = direct_rl_learner_factory(direct_rl_iterator) + + self._iterator = iterator + + if policy_variable_name is not None: + + def get_policy_params(): + return self._direct_rl_learner.get_variables([policy_variable_name])[0] + + self._get_policy_params = get_policy_params + + else: + self._get_policy_params = lambda: None + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + 'learner', + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + steps_key=self._counter.get_steps_key()) + + # Use the JIT compiler. + self._update_step = functools.partial( + ail_update_step, + optimizer=discriminator_optimizer, + ail_network=ail_network, + loss_fn=loss_fn) + self._update_step = utils.process_multiple_batches(self._update_step, + num_sgd_steps_per_step) + self._update_step = jax.jit(self._update_step) + + discriminator_params, discriminator_state = ( + ail_network.discriminator_network.init(networks_key)) + self._state = DiscriminatorTrainingState( + optimizer_state=discriminator_optimizer.init(discriminator_params), + discriminator_params=discriminator_params, + discriminator_state=discriminator_state, + policy_params=self._get_policy_params(), + key=state_key, + steps=0, + ) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + self._get_reward = jax.jit( + functools.partial( + ail_networks.compute_ail_reward, networks=ail_network)) + + def _process_sample(self, sample: reverb.ReplaySample) -> reverb.ReplaySample: + """Updates the reward of the replay sample. + + Args: + sample: Replay sample to update the reward to. + + Returns: + The replay sample with an updated reward. + """ + transitions = reverb_utils.replay_sample_to_sars_transition( + sample, is_sequence=self._is_sequence_based) + rewards = self._get_reward(self._state.discriminator_params, + self._state.discriminator_state, + self._state.policy_params, transitions) + + return sample._replace(data=sample.data._replace(reward=rewards)) + + def step(self): + sample = next(self._iterator) + rb_transitions = sample.discriminator_sample + demo_transitions = sample.demonstration_sample + + if demo_transitions.reward.shape != rb_transitions.reward.shape: + raise ValueError( + 'Different shapes for demo transitions and rb_transitions: ' + f'{demo_transitions.reward.shape} != {rb_transitions.reward.shape}') + + # Update the parameters of the policy before doing a gradient step. + state = self._state._replace(policy_params=self._get_policy_params()) + self._state, metrics = self._update_step(state, + (demo_transitions, rb_transitions)) + + # The order is important for AIRL. + # In AIRL, the discriminator update depends on the logpi of the direct rl + # policy. + # When updating the discriminator, we want the logpi for which the + # transitions were made with and not an updated one. + # Get data from replay (dropping extras if any). Note there is no + # extra data here because we do not insert any into Reverb. + self._direct_rl_learner.step() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Increment counts and record the current time. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + + # Attempts to write the logs. + self._logger.write({**metrics, **counts}) + + def get_variables(self, names: List[str]) -> List[Any]: + rewarder_dict = {'discriminator': self._state.discriminator_params} + + learner_names = [name for name in names if name not in rewarder_dict] + learner_dict = {} + if learner_names: + learner_dict = dict( + zip(learner_names, + self._direct_rl_learner.get_variables(learner_names))) + + variables = [ + rewarder_dict.get(name, learner_dict.get(name, None)) for name in names + ] + return variables + + def save(self) -> TrainingState: + return TrainingState( + rewarder_state=self._state, + learner_state=self._direct_rl_learner.save()) + + def restore(self, state: TrainingState): + self._state = state.rewarder_state + self._direct_rl_learner.restore(state.learner_state) diff --git a/acme/acme/agents/jax/ail/learning_test.py b/acme/acme/agents/jax/ail/learning_test.py new file mode 100644 index 00000000..9d1eb76d --- /dev/null +++ b/acme/acme/agents/jax/ail/learning_test.py @@ -0,0 +1,98 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the AIL learner.""" +import functools + +from acme import specs +from acme import types +from acme.agents.jax.ail import learning as ail_learning +from acme.agents.jax.ail import losses +from acme.agents.jax.ail import networks as ail_networks +from acme.jax import networks as networks_lib +from acme.jax import utils +import haiku as hk +import jax +import numpy as np +import optax + +from absl.testing import absltest + + +def _make_discriminator(spec): + def discriminator(*args, **kwargs) -> networks_lib.Logits: + return ail_networks.DiscriminatorModule( + environment_spec=spec, + use_action=False, + use_next_obs=False, + network_core=ail_networks.DiscriminatorMLP([]))(*args, **kwargs) + + discriminator_transformed = hk.without_apply_rng( + hk.transform_with_state(discriminator)) + return ail_networks.make_discriminator( + environment_spec=spec, + discriminator_transformed=discriminator_transformed) + + +class AilLearnerTest(absltest.TestCase): + + def test_step(self): + simple_spec = specs.Array(shape=(), dtype=float) + + spec = specs.EnvironmentSpec(simple_spec, simple_spec, simple_spec, + simple_spec) + + discriminator = _make_discriminator(spec) + ail_network = ail_networks.AILNetworks( + discriminator, imitation_reward_fn=lambda x: x, direct_rl_networks=None) + + loss = losses.gail_loss() + + optimizer = optax.adam(.01) + + step = jax.jit(functools.partial( + ail_learning.ail_update_step, + optimizer=optimizer, + ail_network=ail_network, + loss_fn=loss)) + + zero_transition = types.Transition( + np.array([0.]), np.array([0.]), 0., 0., np.array([0.])) + zero_transition = utils.add_batch_dim(zero_transition) + + one_transition = types.Transition( + np.array([1.]), np.array([0.]), 0., 0., np.array([0.])) + one_transition = utils.add_batch_dim(one_transition) + + key = jax.random.PRNGKey(0) + discriminator_params, discriminator_state = discriminator.init(key) + + state = ail_learning.DiscriminatorTrainingState( + optimizer_state=optimizer.init(discriminator_params), + discriminator_params=discriminator_params, + discriminator_state=discriminator_state, + policy_params=None, + key=key, + steps=0, + ) + + expected_loss = [1.062, 1.057, 1.052] + + for i in range(3): + state, loss = step(state, (one_transition, zero_transition)) + self.assertAlmostEqual(loss['total_loss'], expected_loss[i], places=3) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/jax/ail/losses.py b/acme/acme/agents/jax/ail/losses.py new file mode 100644 index 00000000..eff34a7a --- /dev/null +++ b/acme/acme/agents/jax/ail/losses.py @@ -0,0 +1,236 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AIL discriminator losses.""" + +import functools +from typing import Callable, Dict, Optional, Tuple + +from acme import types +from acme.jax import networks as networks_lib +import jax +import jax.numpy as jnp +import tensorflow_probability as tfp +import tree + +tfp = tfp.experimental.substrates.jax +tfd = tfp.distributions + +# The loss is a function taking the discriminator, its state, the demo +# transition and the replay buffer transitions. +# It returns the loss as a float and a debug dictionary with the new state. +State = networks_lib.Params +DiscriminatorOutput = Tuple[networks_lib.Logits, State] +DiscriminatorFn = Callable[[State, types.Transition], DiscriminatorOutput] +Metrics = Dict[str, float] +LossOutput = Tuple[float, Tuple[Metrics, State]] +Loss = Callable[[ + DiscriminatorFn, State, types.Transition, types.Transition, networks_lib + .PRNGKey +], LossOutput] + + +def _binary_cross_entropy_loss(logit: jnp.ndarray, + label: jnp.ndarray) -> jnp.ndarray: + return label * jax.nn.softplus(-logit) + (1 - label) * jax.nn.softplus(logit) + + +@jax.vmap +def _weighted_average(x: jnp.ndarray, y: jnp.ndarray, + lambdas: jnp.ndarray) -> jnp.ndarray: + return lambdas * x + (1. - lambdas) * y + + +def _label_data( + rb_transitions: types.Transition, + demonstration_transitions: types.Transition, mixup_alpha: Optional[float], + key: networks_lib.PRNGKey) -> Tuple[types.Transition, jnp.ndarray]: + """Create a tuple data, labels by concatenating the rb and dem transitions.""" + data = tree.map_structure(lambda x, y: jnp.concatenate([x, y]), + rb_transitions, demonstration_transitions) + labels = jnp.concatenate([ + jnp.zeros(rb_transitions.reward.shape), + jnp.ones(demonstration_transitions.reward.shape) + ]) + + if mixup_alpha is not None: + lambda_key, mixup_key = jax.random.split(key) + + lambdas = tfd.Beta(mixup_alpha, mixup_alpha).sample( + len(labels), seed=lambda_key) + + shuffled_data = tree.map_structure( + lambda x: jax.random.permutation(key=mixup_key, x=x), data) + shuffled_labels = jax.random.permutation(key=mixup_key, x=labels) + + data = tree.map_structure(lambda x, y: _weighted_average(x, y, lambdas), + data, shuffled_data) + labels = _weighted_average(labels, shuffled_labels, lambdas) + + return data, labels + + +def _logit_bernoulli_entropy(logits: networks_lib.Logits) -> jnp.ndarray: + return (1. - jax.nn.sigmoid(logits)) * logits - jax.nn.log_sigmoid(logits) + + +def gail_loss(entropy_coefficient: float = 0., + mixup_alpha: Optional[float] = None) -> Loss: + """Computes the standard GAIL loss.""" + + def loss_fn( + discriminator_fn: DiscriminatorFn, + discriminator_state: State, + demo_transitions: types.Transition, rb_transitions: types.Transition, + rng_key: networks_lib.PRNGKey) -> LossOutput: + + data, labels = _label_data( + rb_transitions=rb_transitions, + demonstration_transitions=demo_transitions, + mixup_alpha=mixup_alpha, + key=rng_key) + logits, discriminator_state = discriminator_fn(discriminator_state, data) + + classification_loss = jnp.mean(_binary_cross_entropy_loss(logits, labels)) + + entropy = jnp.mean(_logit_bernoulli_entropy(logits)) + entropy_loss = -entropy_coefficient * entropy + + total_loss = classification_loss + entropy_loss + + metrics = { + 'total_loss': total_loss, + 'entropy_loss': entropy_loss, + 'classification_loss': classification_loss + } + return total_loss, (metrics, discriminator_state) + + return loss_fn + + +def pugail_loss(positive_class_prior: float, + entropy_coefficient: float, + pugail_beta: Optional[float] = None) -> Loss: + """Computes the PUGAIL loss (https://arxiv.org/pdf/1911.00459.pdf).""" + + def loss_fn( + discriminator_fn: DiscriminatorFn, + discriminator_state: State, + demo_transitions: types.Transition, rb_transitions: types.Transition, + rng_key: networks_lib.PRNGKey) -> LossOutput: + del rng_key + + demo_logits, discriminator_state = discriminator_fn(discriminator_state, + demo_transitions) + rb_logits, discriminator_state = discriminator_fn(discriminator_state, + rb_transitions) + + # Quick Maths: + # output = logit(D) = ln(D) - ln(1-D) + # -softplus(-output) = ln(D) + # softplus(output) = -ln(1-D) + + # prior * -ln(D(expert)) + positive_loss = positive_class_prior * -jax.nn.log_sigmoid(demo_logits) + # -ln(1 - D(policy)) - prior * -ln(1 - D(expert)) + negative_loss = jax.nn.softplus( + rb_logits) - positive_class_prior * jax.nn.softplus(demo_logits) + if pugail_beta is not None: + negative_loss = jnp.clip(negative_loss, a_min=-1. * pugail_beta) + + classification_loss = jnp.mean(positive_loss + negative_loss) + + entropy = jnp.mean( + _logit_bernoulli_entropy(jnp.concatenate([demo_logits, rb_logits]))) + entropy_loss = -entropy_coefficient * entropy + + total_loss = classification_loss + entropy_loss + + metrics = { + 'total_loss': total_loss, + 'positive_loss': jnp.mean(positive_loss), + 'negative_loss': jnp.mean(negative_loss), + 'demo_logits': jnp.mean(demo_logits), + 'rb_logits': jnp.mean(rb_logits), + 'entropy_loss': entropy_loss, + 'classification_loss': classification_loss + } + return total_loss, (metrics, discriminator_state) + + return loss_fn + + +def _make_gradient_penalty_data(rb_transitions: types.Transition, + demonstration_transitions: types.Transition, + key: networks_lib.PRNGKey) -> types.Transition: + lambdas = tfd.Uniform().sample(len(rb_transitions.reward), seed=key) + return tree.map_structure(lambda x, y: _weighted_average(x, y, lambdas), + rb_transitions, demonstration_transitions) + + +@functools.partial(jax.vmap, in_axes=(0, None, None)) +def _compute_gradient_penalty(gradient_penalty_data: types.Transition, + discriminator_fn: Callable[[types.Transition], + float], + gradient_penalty_target: float) -> float: + """Computes a penalty based on the gradient norm on the data.""" + # The input should not be batched. + assert not gradient_penalty_data.reward.shape + discriminator_gradient_fn = jax.grad(discriminator_fn) + gradients = discriminator_gradient_fn(gradient_penalty_data) + gradients = tree.map_structure(lambda x: x.flatten(), gradients) + gradients = jnp.concatenate([gradients.observation, gradients.action, + gradients.next_observation]) + gradient_norms = jnp.linalg.norm(gradients + 1e-8) + k = gradient_penalty_target * jnp.ones_like(gradient_norms) + return jnp.mean(jnp.square(gradient_norms - k)) + + +def add_gradient_penalty(base_loss: Loss, + gradient_penalty_coefficient: float, + gradient_penalty_target: float) -> Loss: + """Adds a gradient penalty to the base_loss.""" + + if not gradient_penalty_coefficient: + return base_loss + + def loss_fn(discriminator_fn: DiscriminatorFn, + discriminator_state: State, + demo_transitions: types.Transition, + rb_transitions: types.Transition, + rng_key: networks_lib.PRNGKey) -> LossOutput: + super_key, gradient_penalty_key = jax.random.split(rng_key) + + partial_loss, (losses, discriminator_state) = base_loss( + discriminator_fn, discriminator_state, demo_transitions, rb_transitions, + super_key) + + gradient_penalty_data = _make_gradient_penalty_data( + rb_transitions=rb_transitions, + demonstration_transitions=demo_transitions, + key=gradient_penalty_key) + def apply_discriminator_fn(transitions: types.Transition) -> float: + logits, _ = discriminator_fn(discriminator_state, transitions) + return logits + gradient_penalty = gradient_penalty_coefficient * jnp.mean( + _compute_gradient_penalty(gradient_penalty_data, apply_discriminator_fn, + gradient_penalty_target)) + + losses['gradient_penalty'] = gradient_penalty + total_loss = partial_loss + gradient_penalty + losses['total_loss'] = total_loss + + return total_loss, (losses, discriminator_state) + + return loss_fn diff --git a/acme/acme/agents/jax/ail/losses_test.py b/acme/acme/agents/jax/ail/losses_test.py new file mode 100644 index 00000000..e38943a8 --- /dev/null +++ b/acme/acme/agents/jax/ail/losses_test.py @@ -0,0 +1,79 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the AIL discriminator losses.""" + +from acme import types +from acme.agents.jax.ail import losses +from acme.jax import networks as networks_lib +import jax +import jax.numpy as jnp +import tree + +from absl.testing import absltest + + +class AilLossTest(absltest.TestCase): + + def test_gradient_penalty(self): + + def dummy_discriminator( + transition: types.Transition) -> networks_lib.Logits: + return transition.observation + jnp.square(transition.action) + + zero_transition = types.Transition(0., 0., 0., 0., 0.) + zero_transition = tree.map_structure(lambda x: jnp.expand_dims(x, axis=0), + zero_transition) + self.assertEqual( + losses._compute_gradient_penalty(zero_transition, dummy_discriminator, + 0.), 1**2 + 0**2) + + one_transition = types.Transition(1., 1., 0., 0., 0.) + one_transition = tree.map_structure(lambda x: jnp.expand_dims(x, axis=0), + one_transition) + self.assertEqual( + losses._compute_gradient_penalty(one_transition, dummy_discriminator, + 0.), 1**2 + 2**2) + + def test_pugail(self): + + def dummy_discriminator( + state: losses.State, + transition: types.Transition) -> losses.DiscriminatorOutput: + return transition.observation, state + + zero_transition = types.Transition(.1, 0., 0., 0., 0.) + zero_transition = tree.map_structure(lambda x: jnp.expand_dims(x, axis=0), + zero_transition) + + one_transition = types.Transition(1., 0., 0., 0., 0.) + one_transition = tree.map_structure(lambda x: jnp.expand_dims(x, axis=0), + one_transition) + + prior = .7 + loss_fn = losses.pugail_loss( + positive_class_prior=prior, entropy_coefficient=0.) + loss, _ = loss_fn(dummy_discriminator, {}, one_transition, + zero_transition, ()) + + d_one = jax.nn.sigmoid(dummy_discriminator({}, one_transition)[0]) + d_zero = jax.nn.sigmoid(dummy_discriminator({}, zero_transition)[0]) + expected_loss = -prior * jnp.log( + d_one) + -jnp.log(1. - d_zero) - prior * -jnp.log(1 - d_one) + + self.assertAlmostEqual(loss, expected_loss, places=6) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/jax/ail/networks.py b/acme/acme/agents/jax/ail/networks.py new file mode 100644 index 00000000..7928a62e --- /dev/null +++ b/acme/acme/agents/jax/ail/networks.py @@ -0,0 +1,399 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Networks definitions for the BC agent. + +AIRL network architecture follows https://arxiv.org/pdf/1710.11248.pdf. +""" +import dataclasses +import functools +from typing import Any, Callable, Generic, Iterable, Optional + +from acme import specs +from acme import types +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.jax.imitation_learning_types import DirectRLNetworks +import haiku as hk +import jax +from jax import numpy as jnp +import numpy as np + +# Function from discriminator logit to imitation reward. +ImitationRewardFn = Callable[[networks_lib.Logits], jnp.ndarray] +State = networks_lib.Params + + +@dataclasses.dataclass +class AILNetworks(Generic[DirectRLNetworks]): + """AIL networks data class. + + Attributes: + discriminator_network: Networks which takes as input: + (observations, actions, next_observations, direct_rl_params) + to return the logit of the discriminator. + If the discriminator does not need direct_rl_params you can pass (). + imitation_reward_fn: Function from logit of the discriminator to imitation + reward. + direct_rl_networks: Networks of the direct RL algorithm. + """ + discriminator_network: networks_lib.FeedForwardNetwork + imitation_reward_fn: ImitationRewardFn + direct_rl_networks: DirectRLNetworks + + +def compute_ail_reward(discriminator_params: networks_lib.Params, + discriminator_state: State, + policy_params: Optional[networks_lib.Params], + transitions: types.Transition, + networks: AILNetworks) -> jnp.ndarray: + """Computes the AIL reward for a given transition. + + Args: + discriminator_params: Parameters of the discriminator network. + discriminator_state: State of the discriminator network. + policy_params: Parameters of the direct RL policy. + transitions: Transitions to compute the reward for. + networks: AIL networks. + + Returns: + The rewards as an ndarray. + """ + logits, _ = networks.discriminator_network.apply( + discriminator_params, + policy_params, + discriminator_state, + transitions, + is_training=False, + rng=None) + return networks.imitation_reward_fn(logits) + + +class SpectralNormalizedLinear(hk.Module): + """SpectralNormalizedLinear module. + + This is a Linear layer with a upper-bounded Lipschitz. It is used in iResNet. + + Reference: + Behrmann et al. Invertible Residual Networks. ICML 2019. + https://arxiv.org/pdf/1811.00995.pdf + """ + + def __init__( + self, + output_size: int, + lipschitz_coeff: float, + with_bias: bool = True, + w_init: Optional[hk.initializers.Initializer] = None, + b_init: Optional[hk.initializers.Initializer] = None, + name: Optional[str] = None, + ): + """Constructs the SpectralNormalizedLinear module. + + Args: + output_size: Output dimensionality. + lipschitz_coeff: Spectral normalization coefficient. + with_bias: Whether to add a bias to the output. + w_init: Optional initializer for weights. By default, uses random values + from truncated normal, with stddev ``1 / sqrt(fan_in)``. See + https://arxiv.org/abs/1502.03167v3. + b_init: Optional initializer for bias. By default, zero. + name: Name of the module. + """ + super().__init__(name=name) + self.input_size = None + self.output_size = output_size + self.with_bias = with_bias + self.w_init = w_init + self.b_init = b_init or jnp.zeros + self.lipschitz_coeff = lipschitz_coeff + self.num_iterations = 100 + self.eps = 1e-6 + + def get_normalized_weights(self, + weights: jnp.ndarray, + renormalize: bool = False) -> jnp.ndarray: + + def _l2_normalize(x, axis=None, eps=1e-12): + return x * jax.lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps) + + output_size = self.output_size + dtype = weights.dtype + assert output_size == weights.shape[-1] + sigma = hk.get_state('sigma', (), init=jnp.ones) + if renormalize: + # Power iterations to compute spectral norm V*W*U^T. + u = hk.get_state( + 'u', (1, output_size), dtype, init=hk.initializers.RandomNormal()) + for _ in range(self.num_iterations): + v = _l2_normalize(jnp.matmul(u, weights.transpose()), eps=self.eps) + u = _l2_normalize(jnp.matmul(v, weights), eps=self.eps) + u = jax.lax.stop_gradient(u) + v = jax.lax.stop_gradient(v) + sigma = jnp.matmul(jnp.matmul(v, weights), jnp.transpose(u))[0, 0] + hk.set_state('u', u) + hk.set_state('v', v) + hk.set_state('sigma', sigma) + factor = jnp.maximum(1, sigma / self.lipschitz_coeff) + return weights / factor + + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + """Computes a linear transform of the input.""" + if not inputs.shape: + raise ValueError('Input must not be scalar.') + + input_size = self.input_size = inputs.shape[-1] + output_size = self.output_size + dtype = inputs.dtype + + w_init = self.w_init + if w_init is None: + stddev = 1. / np.sqrt(self.input_size) + w_init = hk.initializers.TruncatedNormal(stddev=stddev) + w = hk.get_parameter('w', [input_size, output_size], dtype, init=w_init) + w = self.get_normalized_weights(w, renormalize=True) + + out = jnp.dot(inputs, w) + + if self.with_bias: + b = hk.get_parameter('b', [self.output_size], dtype, init=self.b_init) + b = jnp.broadcast_to(b, out.shape) + out = out + b + + return out + + +class DiscriminatorMLP(hk.Module): + """A multi-layer perceptron module.""" + + def __init__( + self, + hidden_layer_sizes: Iterable[int], + w_init: Optional[hk.initializers.Initializer] = None, + b_init: Optional[hk.initializers.Initializer] = None, + with_bias: bool = True, + activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu, + input_dropout_rate: float = 0., + hidden_dropout_rate: float = 0., + spectral_normalization_lipschitz_coeff: Optional[float] = None, + name: Optional[str] = None + ): + """Constructs an MLP. + + Args: + hidden_layer_sizes: Hiddent layer sizes. + w_init: Initializer for :class:`~haiku.Linear` weights. + b_init: Initializer for :class:`~haiku.Linear` bias. Must be ``None`` if + ``with_bias=False``. + with_bias: Whether or not to apply a bias in each layer. + activation: Activation function to apply between :class:`~haiku.Linear` + layers. Defaults to ReLU. + input_dropout_rate: Dropout on the input. + hidden_dropout_rate: Dropout on the hidden layer outputs. + spectral_normalization_lipschitz_coeff: If not None, the network will have + spectral normalization with the given constant. + name: Optional name for this module. + + Raises: + ValueError: If ``with_bias`` is ``False`` and ``b_init`` is not ``None``. + """ + if not with_bias and b_init is not None: + raise ValueError('When with_bias=False b_init must not be set.') + + super().__init__(name=name) + self._activation = activation + self._input_dropout_rate = input_dropout_rate + self._hidden_dropout_rate = hidden_dropout_rate + layer_sizes = list(hidden_layer_sizes) + [1] + + if spectral_normalization_lipschitz_coeff is not None: + layer_lipschitz_coeff = np.power(spectral_normalization_lipschitz_coeff, + 1. / len(layer_sizes)) + layer_module = functools.partial( + SpectralNormalizedLinear, + lipschitz_coeff=layer_lipschitz_coeff, + w_init=w_init, + b_init=b_init, + with_bias=with_bias) + else: + layer_module = functools.partial( + hk.Linear, + w_init=w_init, + b_init=b_init, + with_bias=with_bias) + + layers = [] + for index, output_size in enumerate(layer_sizes): + layers.append( + layer_module(output_size=output_size, name=f'linear_{index}')) + self._layers = tuple(layers) + + def __call__( + self, + inputs: jnp.ndarray, + is_training: bool, + rng: Optional[networks_lib.PRNGKey], + ) -> networks_lib.Logits: + rng = hk.PRNGSequence(rng) if rng is not None else None + + out = inputs + for i, layer in enumerate(self._layers): + if is_training: + dropout_rate = ( + self._input_dropout_rate if i == 0 else self._hidden_dropout_rate) + out = hk.dropout(next(rng), dropout_rate, out) + out = layer(out) + if i < len(self._layers) - 1: + out = self._activation(out) + + return out + + +class DiscriminatorModule(hk.Module): + """Discriminator module that concatenates its inputs.""" + + def __init__(self, + environment_spec: specs.EnvironmentSpec, + use_action: bool, + use_next_obs: bool, + network_core: Callable[..., Any], + observation_embedding: Callable[[networks_lib.Observation], + jnp.ndarray] = lambda x: x, + name='discriminator'): + super().__init__(name=name) + self._use_action = use_action + self._environment_spec = environment_spec + self._use_next_obs = use_next_obs + self._network_core = network_core + self._observation_embedding = observation_embedding + + def __call__(self, observations: networks_lib.Observation, + actions: networks_lib.Action, + next_observations: networks_lib.Observation, is_training: bool, + rng: networks_lib.PRNGKey) -> networks_lib.Logits: + observations = self._observation_embedding(observations) + if self._use_next_obs: + next_observations = self._observation_embedding(next_observations) + data = jnp.concatenate([observations, next_observations], axis=-1) + else: + data = observations + if self._use_action: + action_spec = self._environment_spec.actions + if isinstance(action_spec, specs.DiscreteArray): + actions = jax.nn.one_hot(actions, + action_spec.num_values) + data = jnp.concatenate([data, actions], axis=-1) + output = self._network_core(data, is_training, rng) + output = jnp.squeeze(output, axis=-1) + return output + + +class AIRLModule(hk.Module): + """AIRL Module.""" + + def __init__(self, + environment_spec: specs.EnvironmentSpec, + use_action: bool, + use_next_obs: bool, + discount: float, + g_core: Callable[..., Any], + h_core: Callable[..., Any], + observation_embedding: Callable[[networks_lib.Observation], + jnp.ndarray] = lambda x: x, + name='airl'): + super().__init__(name=name) + self._environment_spec = environment_spec + self._use_action = use_action + self._use_next_obs = use_next_obs + self._discount = discount + self._g_core = g_core + self._h_core = h_core + self._observation_embedding = observation_embedding + + def __call__(self, observations: networks_lib.Observation, + actions: networks_lib.Action, + next_observations: networks_lib.Observation, + is_training: bool, + rng: networks_lib.PRNGKey) -> networks_lib.Logits: + g_output = DiscriminatorModule( + environment_spec=self._environment_spec, + use_action=self._use_action, + use_next_obs=self._use_next_obs, + network_core=self._g_core, + observation_embedding=self._observation_embedding, + name='airl_g')(observations, actions, next_observations, is_training, + rng) + h_module = DiscriminatorModule( + environment_spec=self._environment_spec, + use_action=False, + use_next_obs=False, + network_core=self._h_core, + observation_embedding=self._observation_embedding, + name='airl_h') + return (g_output + self._discount * h_module(next_observations, (), + (), is_training, rng) - + h_module(observations, (), (), is_training, rng)) + + +# TODO(eorsini): Manipulate FeedForwardNetworks instead of transforms to +# increase compatibility with Flax. +def make_discriminator( + environment_spec: specs.EnvironmentSpec, + discriminator_transformed: hk.TransformedWithState, + logpi_fn: Optional[Callable[ + [networks_lib.Params, networks_lib.Observation, networks_lib.Action], + jnp.ndarray]] = None +) -> networks_lib.FeedForwardNetwork: + """Creates the discriminator network. + + Args: + environment_spec: Environment spec + discriminator_transformed: Haiku transformed of the discriminator. + logpi_fn: If the policy logpi function is provided, its output will be + removed from the discriminator logit. + + Returns: + The network. + """ + + def apply_fn(params: hk.Params, + policy_params: networks_lib.Params, + state: hk.State, + transitions: types.Transition, + is_training: bool, + rng: networks_lib.PRNGKey) -> networks_lib.Logits: + output, state = discriminator_transformed.apply( + params, state, transitions.observation, transitions.action, + transitions.next_observation, is_training, rng) + if logpi_fn is not None: + logpi = logpi_fn(policy_params, transitions.observation, + transitions.action) + + # Quick Maths: + # D = exp(output)/(exp(output) + pi(a|s)) + # logit(D) = log(D/(1-D)) = log(exp(output)/pi(a|s)) + # logit(D) = output - logpi + return output - logpi, state + return output, state + + dummy_obs = utils.zeros_like(environment_spec.observations) + dummy_obs = utils.add_batch_dim(dummy_obs) + dummy_actions = utils.zeros_like(environment_spec.actions) + dummy_actions = utils.add_batch_dim(dummy_actions) + + return networks_lib.FeedForwardNetwork( + # pylint: disable=g-long-lambda + init=lambda rng: discriminator_transformed.init( + rng, dummy_obs, dummy_actions, dummy_obs, False, rng), + apply=apply_fn) diff --git a/acme/acme/agents/jax/ail/rewards.py b/acme/acme/agents/jax/ail/rewards.py new file mode 100644 index 00000000..d1a8fe61 --- /dev/null +++ b/acme/acme/agents/jax/ail/rewards.py @@ -0,0 +1,76 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AIL logits to AIL reward.""" +from typing import Optional + +from acme.agents.jax.ail import networks as ail_networks +from acme.jax import networks as networks_lib +import jax +import jax.numpy as jnp + + +def fairl_reward( + max_reward_magnitude: Optional[float] = None +) -> ail_networks.ImitationRewardFn: + """The FAIRL reward function (https://arxiv.org/pdf/1911.02256.pdf). + + Args: + max_reward_magnitude: Clipping value for the reward. + + Returns: + The function from logit to imitation reward. + """ + + def imitation_reward(logits: networks_lib.Logits) -> float: + rewards = jnp.exp(jnp.clip(logits, a_max=20.)) * -logits + if max_reward_magnitude is not None: + # pylint: disable=invalid-unary-operand-type + rewards = jnp.clip( + rewards, a_min=-max_reward_magnitude, a_max=max_reward_magnitude) + return rewards + + return imitation_reward + + +def gail_reward( + reward_balance: float = .5, + max_reward_magnitude: Optional[float] = None +) -> ail_networks.ImitationRewardFn: + """GAIL reward function (https://arxiv.org/pdf/1606.03476.pdf). + + Args: + reward_balance: 1 means log(D) reward, 0 means -log(1-D) and other values + mean an average of the two. + max_reward_magnitude: Clipping value for the reward. + + Returns: + The function from logit to imitation reward. + """ + + def imitation_reward(logits: networks_lib.Logits) -> float: + # Quick Maths: + # logits = ln(D) - ln(1-D) + # -softplus(-logits) = ln(D) + # softplus(logits) = -ln(1-D) + rewards = ( + reward_balance * -jax.nn.softplus(-logits) + + (1 - reward_balance) * jax.nn.softplus(logits)) + if max_reward_magnitude is not None: + # pylint: disable=invalid-unary-operand-type + rewards = jnp.clip( + rewards, a_min=-max_reward_magnitude, a_max=max_reward_magnitude) + return rewards + + return imitation_reward diff --git a/acme/acme/agents/jax/ars/README.md b/acme/acme/agents/jax/ars/README.md new file mode 100644 index 00000000..7ce802a9 --- /dev/null +++ b/acme/acme/agents/jax/ars/README.md @@ -0,0 +1,7 @@ +# Augmented Random Search (ARS) + +This folder contains an implementation of the ARS algorithm +([Mania et al., 2018]). + + +[Mania et al., 2018]: https://arxiv.org/pdf/1803.07055.pdf diff --git a/acme/acme/agents/jax/ars/__init__.py b/acme/acme/agents/jax/ars/__init__.py new file mode 100644 index 00000000..38a58510 --- /dev/null +++ b/acme/acme/agents/jax/ars/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ARS agent.""" + +from acme.agents.jax.ars.builder import ARSBuilder +from acme.agents.jax.ars.config import ARSConfig +from acme.agents.jax.ars.networks import make_networks +from acme.agents.jax.ars.networks import make_policy_network diff --git a/acme/acme/agents/jax/ars/builder.py b/acme/acme/agents/jax/ars/builder.py new file mode 100644 index 00000000..01fc8c3f --- /dev/null +++ b/acme/acme/agents/jax/ars/builder.py @@ -0,0 +1,160 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ARS Builder.""" +from typing import Dict, Iterator, List, Optional, Tuple + +import acme +from acme import adders +from acme import core +from acme import specs +from acme.adders import reverb as adders_reverb +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import actors +from acme.agents.jax import builders +from acme.agents.jax.ars import config as ars_config +from acme.agents.jax.ars import learning +from acme.jax import networks as networks_lib +from acme.jax import running_statistics +from acme.jax import utils +from acme.jax import variable_utils +from acme.utils import counting +from acme.utils import loggers +import jax +import jax.numpy as jnp +import numpy as np +import reverb + + +def get_policy(policy_network: networks_lib.FeedForwardNetwork, + normalization_apply_fn) -> actor_core_lib.FeedForwardPolicy: + """Returns a function that computes actions.""" + + def apply( + params: networks_lib.Params, key: networks_lib.PRNGKey, + obs: networks_lib.Observation + ) -> Tuple[networks_lib.Action, Dict[str, jnp.ndarray]]: + del key + params_key, policy_params, normalization_params = params + normalized_obs = normalization_apply_fn(obs, normalization_params) + action = policy_network.apply(policy_params, normalized_obs) + return action, { + 'params_key': + jax.tree_map(lambda x: jnp.expand_dims(x, axis=0), params_key) + } + + return apply + + +class ARSBuilder( + builders.ActorLearnerBuilder[networks_lib.FeedForwardNetwork, + Tuple[str, networks_lib.FeedForwardNetwork], + reverb.ReplaySample]): + """ARS Builder.""" + + def __init__( + self, + config: ars_config.ARSConfig, + spec: specs.EnvironmentSpec, + ): + self._config = config + self._spec = spec + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: networks_lib.FeedForwardNetwork, + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec, replay_client + return learning.ARSLearner(self._spec, networks, random_key, self._config, + dataset, counter, logger_fn('learner')) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: Tuple[str, networks_lib.FeedForwardNetwork], + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> acme.Actor: + del environment_spec + assert variable_source is not None + + kname, policy = policy + + normalization_apply_fn = ( + running_statistics.normalize if self._config.normalize_observations else + (lambda a, b: a)) + policy_to_run = get_policy(policy, normalization_apply_fn) + + actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core( + policy_to_run) + variable_client = variable_utils.VariableClient(variable_source, kname, + device='cpu') + return actors.GenericActor( + actor_core, + random_key, + variable_client, + adder, + backend='cpu', + per_episode_update=True) + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: Tuple[str, networks_lib.FeedForwardNetwork], + ) -> List[reverb.Table]: + """Create tables to insert data into.""" + del policy + extra_spec = { + 'params_key': (np.zeros(shape=(), dtype=np.int32), + np.zeros(shape=(), dtype=np.int32), + np.zeros(shape=(), dtype=np.bool_)), + } + signature = adders_reverb.EpisodeAdder.signature( + environment_spec, sequence_length=None, extras_spec=extra_spec) + return [ + reverb.Table.queue( + name=self._config.replay_table_name, + max_size=10000, # a big number + signature=signature) + ] + + def make_dataset_iterator( + self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]: + """Create a dataset iterator to use for learning/updating the agent.""" + dataset = reverb.TrajectoryDataset.from_table_signature( + server_address=replay_client.server_address, + table=self._config.replay_table_name, + max_in_flight_samples_per_worker=1) + return utils.device_put(dataset.as_numpy_iterator(), jax.devices()[0]) + + def make_adder( + self, replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[Tuple[str, networks_lib.FeedForwardNetwork]] + ) -> Optional[adders.Adder]: + """Create an adder which records data generated by the actor/environment.""" + del environment_spec, policy + + return adders_reverb.EpisodeAdder( + priority_fns={self._config.replay_table_name: None}, + client=replay_client, + max_sequence_length=2000, + ) diff --git a/acme/acme/agents/jax/ars/config.py b/acme/acme/agents/jax/ars/config.py new file mode 100644 index 00000000..7658714b --- /dev/null +++ b/acme/acme/agents/jax/ars/config.py @@ -0,0 +1,31 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ARS config.""" +import dataclasses + +from acme.adders import reverb as adders_reverb + + +@dataclasses.dataclass +class ARSConfig: + """Configuration options for ARS.""" + num_steps: int = 1000000 + normalize_observations: bool = True + step_size: float = 0.015 + num_directions: int = 60 + exploration_noise_std: float = 0.025 + top_directions: int = 20 + reward_shift: float = 1.0 + replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE diff --git a/acme/acme/agents/jax/ars/learning.py b/acme/acme/agents/jax/ars/learning.py new file mode 100644 index 00000000..7ad7f3cf --- /dev/null +++ b/acme/acme/agents/jax/ars/learning.py @@ -0,0 +1,282 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ARS learner implementation.""" + +import collections +import threading +import time +from typing import Any, Deque, Dict, Iterator, List, NamedTuple, Optional + +import acme +from acme import specs +from acme.adders import reverb as acme_reverb +from acme.agents.jax.ars import config as ars_config +from acme.agents.jax.ars import networks as ars_networks +from acme.jax import networks as networks_lib +from acme.jax import running_statistics +from acme.jax import utils +from acme.utils import counting +from acme.utils import loggers +import jax +import numpy as np +import reverb + + +class PerturbationKey(NamedTuple): + training_iteration: int + perturbation_id: int + is_opposite: bool + + +class EvaluationResult(NamedTuple): + total_reward: float + observation: networks_lib.Observation + + +class EvaluationRequest(NamedTuple): + key: PerturbationKey + policy_params: networks_lib.Params + normalization_params: networks_lib.Params + + +class TrainingState(NamedTuple): + """Contains training state for the learner.""" + key: networks_lib.PRNGKey + normalizer_params: networks_lib.Params + policy_params: networks_lib.Params + training_iteration: int + + +class EvaluationState(NamedTuple): + """Contains training state for the learner.""" + key: networks_lib.PRNGKey + evaluation_queue: Deque[EvaluationRequest] + received_results: Dict[PerturbationKey, EvaluationResult] + noises: List[networks_lib.Params] + + +class ARSLearner(acme.Learner): + """ARS learner.""" + + _state: TrainingState + + def __init__( + self, + spec: specs.EnvironmentSpec, + networks: networks_lib.FeedForwardNetwork, + rng: networks_lib.PRNGKey, + config: ars_config.ARSConfig, + iterator: Iterator[reverb.ReplaySample], + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None): + + self._config = config + self._lock = threading.Lock() + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + 'learner', + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + steps_key=self._counter.get_steps_key()) + + # Iterator on demonstration transitions. + self._iterator = iterator + + if self._config.normalize_observations: + normalizer_params = running_statistics.init_state(spec.observations) + self._normalizer_update_fn = running_statistics.update + else: + normalizer_params = () + self._normalizer_update_fn = lambda a, b: a + + rng1, rng2, tmp = jax.random.split(rng, 3) + # Create initial state. + self._training_state = TrainingState( + key=rng1, + policy_params=networks.init(tmp), + normalizer_params=normalizer_params, + training_iteration=0) + self._evaluation_state = EvaluationState( + key=rng2, + evaluation_queue=collections.deque(), + received_results={}, + noises=[]) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + def _generate_perturbations(self): + with self._lock: + rng, noise_key = jax.random.split(self._evaluation_state.key) + self._evaluation_state = EvaluationState( + key=rng, + evaluation_queue=collections.deque(), + received_results={}, + noises=[]) + + all_noise = jax.random.normal( + noise_key, + shape=(self._config.num_directions,) + + self._training_state.policy_params.shape, + dtype=self._training_state.policy_params.dtype) + for i in range(self._config.num_directions): + noise = all_noise[i] + self._evaluation_state.noises.append(noise) + for direction in (-1, 1): + self._evaluation_state.evaluation_queue.append( + EvaluationRequest( + PerturbationKey(self._training_state.training_iteration, i, + direction == -1), + self._training_state.policy_params + + direction * noise * self._config.exploration_noise_std, + self._training_state.normalizer_params)) + + def _read_results(self): + while len(self._evaluation_state.received_results + ) != self._config.num_directions * 2: + data = next(self._iterator).data + data = acme_reverb.Step(*data) + + # validation + params_key = data.extras['params_key'] + training_step, perturbation_id, is_opposite = params_key + # If the incoming data does not correspond to the current iteration, + # we simply ignore it. + if not np.all( + training_step[:-1] == self._training_state.training_iteration): + continue + + # The whole episode should be run with the same policy, so let's check + # for that. + assert np.all(perturbation_id[:-1] == perturbation_id[0]) + assert np.all(is_opposite[:-1] == is_opposite[0]) + + perturbation_id = perturbation_id[0].item() + is_opposite = is_opposite[0].item() + + total_reward = np.sum(data.reward - self._config.reward_shift) + k = PerturbationKey(self._training_state.training_iteration, + perturbation_id, is_opposite) + if k in self._evaluation_state.received_results: + continue + self._evaluation_state.received_results[k] = EvaluationResult( + total_reward, data.observation) + + def _update_model(self) -> int: + # Update normalization params. + real_actor_steps = 0 + normalizer_params = self._training_state.normalizer_params + for _, value in self._evaluation_state.received_results.items(): + real_actor_steps += value.observation.shape[0] - 1 + normalizer_params = self._normalizer_update_fn(normalizer_params, + value.observation) + + # Keep only top directions. + top_directions = [] + for i in range(self._config.num_directions): + reward_forward = self._evaluation_state.received_results[PerturbationKey( + self._training_state.training_iteration, i, False)].total_reward + reward_reverse = self._evaluation_state.received_results[PerturbationKey( + self._training_state.training_iteration, i, True)].total_reward + top_directions.append((max(reward_forward, reward_reverse), i)) + top_directions.sort() + top_directions = top_directions[-self._config.top_directions:] + + # Compute reward_std. + reward = [] + for _, i in top_directions: + reward.append(self._evaluation_state.received_results[PerturbationKey( + self._training_state.training_iteration, i, False)].total_reward) + reward.append(self._evaluation_state.received_results[PerturbationKey( + self._training_state.training_iteration, i, True)].total_reward) + reward_std = np.std(reward) + + # Compute new policy params. + policy_params = self._training_state.policy_params + curr_sum = np.zeros_like(policy_params) + for _, i in top_directions: + reward_forward = self._evaluation_state.received_results[PerturbationKey( + self._training_state.training_iteration, i, False)].total_reward + reward_reverse = self._evaluation_state.received_results[PerturbationKey( + self._training_state.training_iteration, i, True)].total_reward + curr_sum += self._evaluation_state.noises[i] * ( + reward_forward - reward_reverse) + + policy_params = policy_params + self._config.step_size / ( + self._config.top_directions * reward_std) * curr_sum + + self._training_state = TrainingState( + key=self._training_state.key, + normalizer_params=normalizer_params, + policy_params=policy_params, + training_iteration=self._training_state.training_iteration) + return real_actor_steps + + def step(self): + self._training_state = self._training_state._replace( + training_iteration=self._training_state.training_iteration + 1) + self._generate_perturbations() + self._read_results() + real_actor_steps = self._update_model() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Increment counts and record the current time + counts = self._counter.increment( + steps=1, + real_actor_steps=real_actor_steps, + learner_episodes=2 * self._config.num_directions, + walltime=elapsed_time) + + # Attempts to write the logs. + self._logger.write(counts) + + def get_variables(self, names: List[str]) -> List[Any]: + assert (names == [ars_networks.BEHAVIOR_PARAMS_NAME] or + names == [ars_networks.EVAL_PARAMS_NAME]) + if names == [ars_networks.EVAL_PARAMS_NAME]: + return [PerturbationKey(-1, -1, False), + self._training_state.policy_params, + self._training_state.normalizer_params] + should_sleep = False + while True: + if should_sleep: + time.sleep(0.1) + should_sleep = False + with self._lock: + if not self._evaluation_state.evaluation_queue: + should_sleep = True + continue + data = self._evaluation_state.evaluation_queue.pop() + # If this perturbation was already evaluated, we simply skip it. + if data.key in self._evaluation_state.received_results: + continue + # In case if an actor fails we still need to reevaluate the same + # perturbation, so we just add it to the end of the queue. + self._evaluation_state.evaluation_queue.append(data) + return [data] + + def save(self) -> TrainingState: + return self._training_state + + def restore(self, state: TrainingState): + self._training_state = state diff --git a/acme/acme/agents/jax/ars/networks.py b/acme/acme/agents/jax/ars/networks.py new file mode 100644 index 00000000..55c0a537 --- /dev/null +++ b/acme/acme/agents/jax/ars/networks.py @@ -0,0 +1,52 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ARS networks definition.""" + +from typing import Tuple + +from acme import specs +from acme.jax import networks as networks_lib +import jax.numpy as jnp + + +BEHAVIOR_PARAMS_NAME = 'policy' +EVAL_PARAMS_NAME = 'eval' + + +def make_networks( + spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork: + """Creates networks used by the agent. + + The model used by the ARS paper is a simple clipped linear model. + + Args: + spec: an environment spec + + Returns: + A FeedForwardNetwork network. + """ + + obs_size = spec.observations.shape[0] + act_size = spec.actions.shape[0] + return networks_lib.FeedForwardNetwork( + init=lambda _: jnp.zeros((obs_size, act_size)), + apply=lambda matrix, obs: jnp.clip(jnp.matmul(obs, matrix), -1, 1)) + + +def make_policy_network( + network: networks_lib.FeedForwardNetwork, + eval_mode: bool = True) -> Tuple[str, networks_lib.FeedForwardNetwork]: + params_name = EVAL_PARAMS_NAME if eval_mode else BEHAVIOR_PARAMS_NAME + return (params_name, network) diff --git a/acme/acme/agents/jax/bc/README.md b/acme/acme/agents/jax/bc/README.md new file mode 100644 index 00000000..3dc46931 --- /dev/null +++ b/acme/acme/agents/jax/bc/README.md @@ -0,0 +1,20 @@ +# Behavioral Cloning (BC) + +This folder contains an implementation for supervised learning of a policy from +a dataset of observations and target actions. This is an approach of Imitation +Learning known as Behavioral Cloning, introduced by [Pomerleau, 1989]. + +Several losses are implemented: + +* Mean squared error (mse) +* Cross entropy (logp) +* Peer Behavioral Cloning (peerbc), a regularization scheme from + [Wang et al., 2021] +* Reward-regularized Classification for Apprenticeship Learning (rcal), + another regularization scheme from [Piot et al., 2014], defined for discrete + action environments (or discretized action-spaces in case of continuous + control). + +[Pomerleau, 1989]: https://papers.nips.cc/paper/95-alvinn-an-autonomous-land-vehicle-in-a-neural-network.pdf +[Wang et al., 2021]: https://arxiv.org/pdf/2010.01748.pdf +[Piot et al., 2014]: https://www.cristal.univ-lille.fr/~pietquin/pdf/AAMAS_2014_BPMGOP.pdf diff --git a/acme/acme/agents/jax/bc/__init__.py b/acme/acme/agents/jax/bc/__init__.py new file mode 100644 index 00000000..dee94a5a --- /dev/null +++ b/acme/acme/agents/jax/bc/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementation of a behavior cloning (BC) agent.""" + +from acme.agents.jax.bc import pretraining +from acme.agents.jax.bc.builder import BCBuilder +from acme.agents.jax.bc.config import BCConfig +from acme.agents.jax.bc.learning import BCLearner +from acme.agents.jax.bc.losses import BCLoss +from acme.agents.jax.bc.losses import logp +from acme.agents.jax.bc.losses import mse +from acme.agents.jax.bc.losses import peerbc +from acme.agents.jax.bc.losses import rcal +from acme.agents.jax.bc.networks import BCNetworks +from acme.agents.jax.bc.networks import BCPolicyNetwork +from acme.agents.jax.bc.networks import convert_policy_value_to_bc_network +from acme.agents.jax.bc.networks import convert_to_bc_network diff --git a/acme/acme/agents/jax/bc/agent_test.py b/acme/acme/agents/jax/bc/agent_test.py new file mode 100644 index 00000000..e266b49a --- /dev/null +++ b/acme/acme/agents/jax/bc/agent_test.py @@ -0,0 +1,191 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the BC agent.""" + +from acme import specs +from acme import types +from acme.agents.jax import bc +from acme.jax import networks as networks_lib +from acme.jax import types as jax_types +from acme.jax import utils +from acme.testing import fakes +import chex +import haiku as hk +import jax +import jax.numpy as jnp +from jax.scipy import special +import numpy as np +import optax +import rlax + +from absl.testing import absltest +from absl.testing import parameterized + + +def make_networks(spec: specs.EnvironmentSpec, + discrete_actions: bool = False) -> bc.BCNetworks: + """Creates networks used by the agent.""" + + if discrete_actions: + final_layer_size = spec.actions.num_values + else: + final_layer_size = np.prod(spec.actions.shape, dtype=int) + + def _actor_fn(obs, is_training=False, key=None): + # is_training and key allows to defined train/test dependant modules + # like dropout. + del is_training + del key + if discrete_actions: + network = hk.nets.MLP([64, 64, final_layer_size]) + else: + network = hk.Sequential([ + networks_lib.LayerNormMLP([64, 64], activate_final=True), + networks_lib.NormalTanhDistribution(final_layer_size), + ]) + return network(obs) + + policy = hk.without_apply_rng(hk.transform(_actor_fn)) + + # Create dummy observations and actions to create network parameters. + dummy_obs = utils.zeros_like(spec.observations) + dummy_obs = utils.add_batch_dim(dummy_obs) + policy_network = networks_lib.FeedForwardNetwork( + lambda key: policy.init(key, dummy_obs), policy.apply) + bc_policy_network = bc.convert_to_bc_network(policy_network) + + if discrete_actions: + + def sample_fn(logits: networks_lib.NetworkOutput, + key: jax_types.PRNGKey) -> networks_lib.Action: + return rlax.epsilon_greedy(epsilon=0.0).sample(key, logits) + + def log_prob(logits: networks_lib.NetworkOutput, + actions: networks_lib.Action) -> networks_lib.LogProb: + max_logits = jnp.max(logits, axis=-1, keepdims=True) + logits = logits - max_logits + logits_actions = jnp.sum( + jax.nn.one_hot(actions, spec.actions.num_values) * logits, axis=-1) + + log_prob = logits_actions - special.logsumexp(logits, axis=-1) + return log_prob + + else: + + def sample_fn(distribution: networks_lib.NetworkOutput, + key: jax_types.PRNGKey) -> networks_lib.Action: + return distribution.sample(seed=key) + + def log_prob(distribuition: networks_lib.NetworkOutput, + actions: networks_lib.Action) -> networks_lib.LogProb: + return distribuition.log_prob(actions) + + return bc.BCNetworks(bc_policy_network, sample_fn, log_prob) + + +class BCTest(parameterized.TestCase): + + @parameterized.parameters( + ('logp',), + ('mse',), + ('peerbc',) + ) + def test_continuous_actions(self, loss_name): + with chex.fake_pmap_and_jit(): + num_sgd_steps_per_step = 1 + num_steps = 5 + + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment( + episode_length=10, bounded=True, action_dim=6) + + spec = specs.make_environment_spec(environment) + dataset_demonstration = fakes.transition_dataset(environment) + dataset_demonstration = dataset_demonstration.map( + lambda sample: types.Transition(*sample.data)) + dataset_demonstration = dataset_demonstration.batch(8).as_numpy_iterator() + + # Construct the agent. + networks = make_networks(spec) + + if loss_name == 'logp': + loss_fn = bc.logp() + elif loss_name == 'mse': + loss_fn = bc.mse() + elif loss_name == 'peerbc': + loss_fn = bc.peerbc(bc.logp(), zeta=0.1) + else: + raise ValueError + + learner = bc.BCLearner( + networks=networks, + random_key=jax.random.PRNGKey(0), + loss_fn=loss_fn, + optimizer=optax.adam(0.01), + prefetching_iterator=utils.sharded_prefetch(dataset_demonstration), + num_sgd_steps_per_step=num_sgd_steps_per_step) + + # Train the agent + for _ in range(num_steps): + learner.step() + + @parameterized.parameters( + ('logp',), + ('rcal',)) + def test_discrete_actions(self, loss_name): + with chex.fake_pmap_and_jit(): + + num_sgd_steps_per_step = 1 + num_steps = 5 + + # Create a fake environment to test with. + environment = fakes.DiscreteEnvironment( + num_actions=10, num_observations=100, obs_shape=(10,), + obs_dtype=np.float32) + + spec = specs.make_environment_spec(environment) + dataset_demonstration = fakes.transition_dataset(environment) + dataset_demonstration = dataset_demonstration.map( + lambda sample: types.Transition(*sample.data)) + dataset_demonstration = dataset_demonstration.batch(8).as_numpy_iterator() + + # Construct the agent. + networks = make_networks(spec, discrete_actions=True) + + if loss_name == 'logp': + loss_fn = bc.logp() + + elif loss_name == 'rcal': + base_loss_fn = bc.logp() + loss_fn = bc.rcal(base_loss_fn, discount=0.99, alpha=0.1) + + else: + raise ValueError + + learner = bc.BCLearner( + networks=networks, + random_key=jax.random.PRNGKey(0), + loss_fn=loss_fn, + optimizer=optax.adam(0.01), + prefetching_iterator=utils.sharded_prefetch(dataset_demonstration), + num_sgd_steps_per_step=num_sgd_steps_per_step) + + # Train the agent + for _ in range(num_steps): + learner.step() + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/jax/bc/builder.py b/acme/acme/agents/jax/bc/builder.py new file mode 100644 index 00000000..92476195 --- /dev/null +++ b/acme/acme/agents/jax/bc/builder.py @@ -0,0 +1,113 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BC Builder.""" +from typing import Iterator, Optional + +from acme import core +from acme import specs +from acme import types +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import actors +from acme.agents.jax import builders +from acme.agents.jax.bc import config as bc_config +from acme.agents.jax.bc import learning +from acme.agents.jax.bc import losses +from acme.agents.jax.bc import networks as bc_networks +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.jax import variable_utils +from acme.utils import counting +from acme.utils import loggers +import jax +import optax + + +class BCBuilder(builders.OfflineBuilder[bc_networks.BCNetworks, + actor_core_lib.FeedForwardPolicy, + types.Transition]): + """BC Builder.""" + + def __init__( + self, + config: bc_config.BCConfig, + loss_fn: losses.BCLoss, + loss_has_aux: bool = False, + ): + """Creates a BC learner, an evaluation policy and an eval actor. + + Args: + config: a config with BC hps. + loss_fn: BC loss to use. + loss_has_aux: Whether the loss function returns auxiliary metrics as a + second argument. + """ + self._config = config + self._loss_fn = loss_fn + self._loss_has_aux = loss_has_aux + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: bc_networks.BCNetworks, + dataset: Iterator[types.Transition], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + *, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec + + return learning.BCLearner( + networks=networks, + random_key=random_key, + loss_fn=self._loss_fn, + optimizer=optax.adam(learning_rate=self._config.learning_rate), + prefetching_iterator=utils.sharded_prefetch(dataset), + num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, + loss_has_aux=self._loss_has_aux, + logger=logger_fn('learner'), + counter=counter) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: actor_core_lib.FeedForwardPolicy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + ) -> core.Actor: + del environment_spec + assert variable_source is not None + actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) + variable_client = variable_utils.VariableClient( + variable_source, 'policy', device='cpu') + return actors.GenericActor( + actor_core, random_key, variable_client, backend='cpu') + + def make_policy(self, + networks: bc_networks.BCNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False) -> actor_core_lib.FeedForwardPolicy: + """Construct the policy.""" + del environment_spec, evaluation + + def evaluation_policy( + params: networks_lib.Params, key: networks_lib.PRNGKey, + observation: networks_lib.Observation) -> networks_lib.Action: + apply_key, sample_key = jax.random.split(key) + network_output = networks.policy_network.apply( + params, observation, is_training=False, key=apply_key) + return networks.sample_fn(network_output, sample_key) + + return evaluation_policy diff --git a/acme/acme/agents/jax/bc/config.py b/acme/acme/agents/jax/bc/config.py new file mode 100644 index 00000000..15fa1ff8 --- /dev/null +++ b/acme/acme/agents/jax/bc/config.py @@ -0,0 +1,28 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Config classes for BC.""" +import dataclasses + + +@dataclasses.dataclass +class BCConfig: + """Configuration options for BC. + + Attributes: + learning_rate: Learning rate. + num_sgd_steps_per_step: How many gradient updates to perform per step. + """ + learning_rate: float = 1e-4 + num_sgd_steps_per_step: int = 1 diff --git a/acme/acme/agents/jax/bc/learning.py b/acme/acme/agents/jax/bc/learning.py new file mode 100644 index 00000000..11406bda --- /dev/null +++ b/acme/acme/agents/jax/bc/learning.py @@ -0,0 +1,200 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BC learner implementation.""" + +import time +from typing import Dict, List, NamedTuple, Optional, Tuple, Union, Iterator + +import acme +from acme import types +from acme.agents.jax.bc import losses +from acme.agents.jax.bc import networks as bc_networks +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.utils import counting +from acme.utils import loggers +import jax +import jax.numpy as jnp +import optax + +_PMAP_AXIS_NAME = 'data' + + +class TrainingState(NamedTuple): + """Contains training state for the learner.""" + optimizer_state: optax.OptState + policy_params: networks_lib.Params + key: networks_lib.PRNGKey + steps: int + + +def _create_loss_metrics( + loss_has_aux: bool, + loss_result: Union[jnp.ndarray, Tuple[jnp.ndarray, loggers.LoggingData]], + gradients: jnp.ndarray, +): + """Creates loss metrics for logging.""" + # Validate input. + if loss_has_aux and not (len(loss_result) == 2 and isinstance( + loss_result[0], jnp.ndarray) and isinstance(loss_result[1], dict)): + raise ValueError('Could not parse loss value and metrics from loss_fn\'s ' + 'output. Since loss_has_aux is enabled, loss_fn must ' + 'return loss_value and auxiliary metrics.') + + if not loss_has_aux and not isinstance(loss_result, jnp.ndarray): + raise ValueError(f'Loss returns type {loss_result}. However, it should ' + 'return a jnp.ndarray, given that loss_has_aux = False.') + + # Maybe unpack loss result. + if loss_has_aux: + loss, metrics = loss_result + else: + loss = loss_result + metrics = {} + + # Complete metrics dict and return it. + metrics['loss'] = loss + metrics['gradient_norm'] = optax.global_norm(gradients) + return metrics + + +class BCLearner(acme.Learner): + """BC learner. + + This is the learning component of a BC agent. It takes a Transitions iterator + as input and implements update functionality to learn from this iterator. + """ + + _state: TrainingState + + def __init__(self, + networks: bc_networks.BCNetworks, + random_key: networks_lib.PRNGKey, + loss_fn: losses.BCLoss, + optimizer: optax.GradientTransformation, + prefetching_iterator: Iterator[types.Transition], + num_sgd_steps_per_step: int, + loss_has_aux: bool = False, + logger: Optional[loggers.Logger] = None, + counter: Optional[counting.Counter] = None): + """Behavior Cloning Learner. + + Args: + networks: BC networks + random_key: RNG key. + loss_fn: BC loss to use. + optimizer: Optax optimizer. + prefetching_iterator: A sharded prefetching iterator as outputted from + `acme.jax.utils.sharded_prefetch`. Please see the documentation for + `sharded_prefetch` for more details. + num_sgd_steps_per_step: Number of gradient updates per step. + loss_has_aux: Whether the loss function returns auxiliary metrics as a + second argument. + logger: Logger. + counter: Counter. + """ + def sgd_step( + state: TrainingState, + transitions: types.Transition, + ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: + + loss_and_grad = jax.value_and_grad( + loss_fn, argnums=1, has_aux=loss_has_aux) + + # Compute losses and their gradients. + key, key_input = jax.random.split(state.key) + loss_result, gradients = loss_and_grad(networks, state.policy_params, + key_input, transitions) + + # Combine the gradient across all devices (by taking their mean). + gradients = jax.lax.pmean(gradients, axis_name=_PMAP_AXIS_NAME) + + # Compute and combine metrics across all devices. + metrics = _create_loss_metrics(loss_has_aux, loss_result, gradients) + metrics = jax.lax.pmean(metrics, axis_name=_PMAP_AXIS_NAME) + + policy_update, optimizer_state = optimizer.update(gradients, + state.optimizer_state, + state.policy_params) + policy_params = optax.apply_updates(state.policy_params, policy_update) + + new_state = TrainingState( + optimizer_state=optimizer_state, + policy_params=policy_params, + key=key, + steps=state.steps + 1, + ) + + return new_state, metrics + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter(prefix='learner') + self._logger = logger or loggers.make_default_logger( + 'learner', + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + steps_key=self._counter.get_steps_key()) + + # Split the input batch to `num_sgd_steps_per_step` minibatches in order + # to achieve better performance on accelerators. + sgd_step = utils.process_multiple_batches(sgd_step, num_sgd_steps_per_step) + self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME) + + random_key, init_key = jax.random.split(random_key) + policy_params = networks.policy_network.init(init_key) + optimizer_state = optimizer.init(policy_params) + + # Create initial state. + state = TrainingState( + optimizer_state=optimizer_state, + policy_params=policy_params, + key=random_key, + steps=0, + ) + self._state = utils.replicate_in_all_devices(state) + + self._timestamp = None + + self._prefetching_iterator = prefetching_iterator + + def step(self): + # Get a batch of Transitions. + transitions = next(self._prefetching_iterator) + self._state, metrics = self._sgd_step(self._state, transitions) + metrics = utils.get_from_first_device(metrics) + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Increment counts and record the current time + counts = self._counter.increment(steps=1, walltime=elapsed_time) + + # Attempts to write the logs. + self._logger.write({**metrics, **counts}) + + def get_variables(self, names: List[str]) -> List[networks_lib.Params]: + variables = { + 'policy': utils.get_from_first_device(self._state.policy_params), + } + return [variables[name] for name in names] + + def save(self) -> TrainingState: + # Serialize only the first replica of parameters and optimizer state. + return jax.tree_map(utils.get_from_first_device, self._state) + + def restore(self, state: TrainingState): + self._state = utils.replicate_in_all_devices(state) diff --git a/acme/acme/agents/jax/bc/losses.py b/acme/acme/agents/jax/bc/losses.py new file mode 100644 index 00000000..de39fbda --- /dev/null +++ b/acme/acme/agents/jax/bc/losses.py @@ -0,0 +1,143 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Offline losses used in variants of BC.""" +from typing import Callable, Optional, Tuple, Union + +from acme import types +from acme.agents.jax.bc import networks as bc_networks +from acme.jax import networks as networks_lib +from acme.jax import types as jax_types +from acme.utils import loggers +import jax +import jax.numpy as jnp + + +loss_args = [ + bc_networks.BCNetworks, networks_lib.Params, networks_lib.PRNGKey, + types.Transition +] +BCLossWithoutAux = Callable[loss_args, jnp.ndarray] +BCLossWithAux = Callable[loss_args, Tuple[jnp.ndarray, loggers.LoggingData]] +BCLoss = Union[BCLossWithoutAux, BCLossWithAux] + + +def mse() -> BCLossWithoutAux: + """Mean Squared Error loss.""" + + def loss(networks: bc_networks.BCNetworks, params: networks_lib.Params, + key: jax_types.PRNGKey, + transitions: types.Transition) -> jnp.ndarray: + key, key_dropout = jax.random.split(key) + dist_params = networks.policy_network.apply( + params, transitions.observation, is_training=True, key=key_dropout) + action = networks.sample_fn(dist_params, key) + return jnp.mean(jnp.square(action - transitions.action)) + + return loss + + +def logp() -> BCLossWithoutAux: + """Log probability loss.""" + + def loss(networks: bc_networks.BCNetworks, params: networks_lib.Params, + key: jax_types.PRNGKey, + transitions: types.Transition) -> jnp.ndarray: + logits = networks.policy_network.apply( + params, transitions.observation, is_training=True, key=key) + logp_action = networks.log_prob(logits, transitions.action) + return -jnp.mean(logp_action) + + return loss + + +def peerbc(base_loss_fn: BCLossWithoutAux, zeta: float) -> BCLossWithoutAux: + """Peer-BC loss from https://arxiv.org/pdf/2010.01748.pdf. + + Args: + base_loss_fn: the base loss to add RCAL on top of. + zeta: the weight of the regularization. + Returns: + The loss. + """ + + def loss(networks: bc_networks.BCNetworks, params: networks_lib.Params, + key: jax_types.PRNGKey, + transitions: types.Transition) -> jnp.ndarray: + key_perm, key_bc_loss, key_permuted_loss = jax.random.split(key, 3) + + permutation_keys = jax.random.split(key_perm, transitions.action.shape[0]) + permuted_actions = jax.vmap( + jax.random.permutation, in_axes=(0, 0))(permutation_keys, + transitions.action) + permuted_transition = transitions._replace(action=permuted_actions) + bc_loss = base_loss_fn(networks, params, key_bc_loss, transitions) + permuted_loss = base_loss_fn(networks, params, key_permuted_loss, + permuted_transition) + return bc_loss - zeta * permuted_loss + + return loss + + +def rcal(base_loss_fn: BCLossWithoutAux, + discount: float, + alpha: float, + num_bins: Optional[int] = None) -> BCLossWithoutAux: + """https://www.cristal.univ-lille.fr/~pietquin/pdf/AAMAS_2014_BPMGOP.pdf. + + Args: + base_loss_fn: the base loss to add RCAL on top of. + discount: the gamma discount used in RCAL. + alpha: the regularization parameter. + num_bins: how many bins were used for discretization. If None the + environment was originally discrete already. + Returns: + The loss function. + """ + + def loss(networks: bc_networks.BCNetworks, params: networks_lib.Params, + key: jax_types.PRNGKey, + transitions: types.Transition) -> jnp.ndarray: + + def logits_fn(key: jax_types.PRNGKey, + observations: networks_lib.Observation, + actions: Optional[networks_lib.Action] = None): + logits = networks.policy_network.apply( + params, observations, key=key, is_training=True) + if num_bins: + logits = jnp.reshape(logits, list(logits.shape[:-1]) + [-1, num_bins]) + if actions is None: + actions = jnp.argmax(logits, axis=-1) + logits_actions = jnp.sum( + jax.nn.one_hot(actions, logits.shape[-1]) * logits, axis=-1) + return logits_actions + + key, key1, key2 = jax.random.split(key, 3) + + logits_a_tm1 = logits_fn(key1, transitions.observation, transitions.action) + logits_a_t = logits_fn(key2, transitions.next_observation) + + # RCAL, by making a parallel between the logits of BC and Q-values, + # defines a regularization loss that encourages the implicit reward + # (inferred by inversing the Bellman Equation) to be sparse. + # NOTE: In case of discretized envs jnp.mean goes over batch and num_bins + # dimensions. + regularization_loss = jnp.mean( + jnp.abs(logits_a_tm1 - discount * logits_a_t) + ) + + loss = base_loss_fn(networks, params, key, transitions) + return loss + alpha * regularization_loss + + return loss diff --git a/acme/acme/agents/jax/bc/networks.py b/acme/acme/agents/jax/bc/networks.py new file mode 100644 index 00000000..8a0830ac --- /dev/null +++ b/acme/acme/agents/jax/bc/networks.py @@ -0,0 +1,119 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Network definitions for BC.""" + +import dataclasses +from typing import Callable, Optional, Protocol + +from acme.jax import networks as networks_lib +from acme.jax import types + + +class ApplyFn(Protocol): + + def __call__(self, + params: networks_lib.Params, + observation: networks_lib.Observation, + *args, + is_training: bool, + key: Optional[types.PRNGKey] = None, + **kwargs) -> networks_lib.NetworkOutput: + ... + + +@dataclasses.dataclass +class BCPolicyNetwork: + """Holds a pair of pure functions defining a policy network for BC. + + This is a feed-forward network taking params, obs, is_training, key as input. + + Attributes: + init: A pure function. Initializes and returns the networks parameters. + apply: A pure function. Computes and returns the outputs of a forward pass. + """ + init: Callable[[types.PRNGKey], networks_lib.Params] + apply: ApplyFn + + +def identity_sample(output: networks_lib.NetworkOutput, + key: types.PRNGKey) -> networks_lib.Action: + """Placeholder sampling function for non-distributional networks.""" + del key + return output + + +@dataclasses.dataclass +class BCNetworks: + """The network and pure functions for the BC agent. + + Attributes: + policy_network: The policy network. + sample_fn: A pure function. Samples an action based on the network output. + Must be set for distributional networks. Otherwise identity. + log_prob: A pure function. Computes log-probability for an action. + Must be set for distributional networks. Otherwise None. + """ + policy_network: BCPolicyNetwork + sample_fn: networks_lib.SampleFn = identity_sample + log_prob: Optional[networks_lib.LogProbFn] = None + + +def convert_to_bc_network( + policy_network: networks_lib.FeedForwardNetwork) -> BCPolicyNetwork: + """Converts a policy network from SAC/TD3/D4PG/.. into a BC policy network. + + Args: + policy_network: FeedForwardNetwork taking the observation as input and + returning action representation compatible with one of the BC losses. + + Returns: + The BC policy network taking observation, is_training, key as input. + """ + + def apply(params: networks_lib.Params, + observation: networks_lib.Observation, + *args, + is_training: bool = False, + key: Optional[types.PRNGKey] = None, + **kwargs) -> networks_lib.NetworkOutput: + del is_training, key + return policy_network.apply(params, observation, *args, **kwargs) + + return BCPolicyNetwork(policy_network.init, apply) + + +def convert_policy_value_to_bc_network( + policy_value_network: networks_lib.FeedForwardNetwork) -> BCPolicyNetwork: + """Converts a policy-value network (e.g. from PPO) into a BC policy network. + + Args: + policy_value_network: FeedForwardNetwork taking the observation as input. + + Returns: + The BC policy network taking observation, is_training, key as input. + """ + + def apply(params: networks_lib.Params, + observation: networks_lib.Observation, + *args, + is_training: bool = False, + key: Optional[types.PRNGKey] = None, + **kwargs) -> networks_lib.NetworkOutput: + del is_training, key + actions, _ = policy_value_network.apply(params, observation, *args, + **kwargs) + return actions + + return BCPolicyNetwork(policy_value_network.init, apply) diff --git a/acme/acme/agents/jax/bc/pretraining.py b/acme/acme/agents/jax/bc/pretraining.py new file mode 100644 index 00000000..b761578b --- /dev/null +++ b/acme/acme/agents/jax/bc/pretraining.py @@ -0,0 +1,63 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools to train a policy network with BC.""" +from typing import Callable, Iterator + +from acme import types +from acme.agents.jax.bc import learning +from acme.agents.jax.bc import losses +from acme.agents.jax.bc import networks as bc_networks +from acme.jax import networks as networks_lib +from acme.jax import utils +import jax +import optax + + +def train_with_bc(make_demonstrations: Callable[[int], + Iterator[types.Transition]], + networks: bc_networks.BCNetworks, + loss: losses.BCLoss, + num_steps: int = 100000) -> networks_lib.Params: + """Trains the given network with BC and returns the params. + + Args: + make_demonstrations: A function (batch_size) -> iterator with demonstrations + to be imitated. + networks: Network taking (params, obs, is_training, key) as input + loss: BC loss to use. + num_steps: number of training steps + + Returns: + The trained network params. + """ + demonstration_iterator = make_demonstrations(256) + prefetching_iterator = utils.sharded_prefetch( + demonstration_iterator, + buffer_size=2, + num_threads=jax.local_device_count()) + + learner = learning.BCLearner( + networks=networks, + random_key=jax.random.PRNGKey(0), + loss_fn=loss, + prefetching_iterator=prefetching_iterator, + optimizer=optax.adam(1e-4), + num_sgd_steps_per_step=1) + + # Train the agent + for _ in range(num_steps): + learner.step() + + return learner.get_variables(['policy'])[0] diff --git a/acme/acme/agents/jax/bc/pretraining_test.py b/acme/acme/agents/jax/bc/pretraining_test.py new file mode 100644 index 00000000..b298cc8c --- /dev/null +++ b/acme/acme/agents/jax/bc/pretraining_test.py @@ -0,0 +1,94 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for bc_initialization.""" + +from acme import specs +from acme.agents.jax import bc +from acme.agents.jax import sac +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.testing import fakes +import haiku as hk +import jax +import numpy as np + +from absl.testing import absltest + + +def make_networks(spec: specs.EnvironmentSpec) -> bc.BCNetworks: + """Creates networks used by the agent.""" + + final_layer_size = np.prod(spec.actions.shape, dtype=int) + + def _actor_fn(obs, is_training=False, key=None): + # is_training and key allows to defined train/test dependant modules + # like dropout. + del is_training + del key + network = networks_lib.LayerNormMLP([64, 64, final_layer_size], + activate_final=False) + return jax.nn.tanh(network(obs)) + + policy = hk.without_apply_rng(hk.transform(_actor_fn)) + + # Create dummy observations and actions to create network parameters. + dummy_obs = utils.zeros_like(spec.observations) + dummy_obs = utils.add_batch_dim(dummy_obs) + policy_network = bc.BCPolicyNetwork(lambda key: policy.init(key, dummy_obs), + policy.apply) + + return bc.BCNetworks(policy_network) + + +class BcPretrainingTest(absltest.TestCase): + + def test_bc_initialization(self): + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment( + episode_length=10, bounded=True, action_dim=6) + spec = specs.make_environment_spec(environment) + + # Construct the agent. + nets = make_networks(spec) + + loss = bc.mse() + + bc.pretraining.train_with_bc( + fakes.transition_iterator(environment), nets, loss, num_steps=100) + + def test_sac_to_bc_networks(self): + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment( + episode_length=10, bounded=True, action_dim=6) + spec = specs.make_environment_spec(environment) + + sac_nets = sac.make_networks(spec, hidden_layer_sizes=(4, 4)) + bc_nets = bc.convert_to_bc_network(sac_nets.policy_network) + + rng = jax.random.PRNGKey(0) + dummy_obs = utils.zeros_like(spec.observations) + dummy_obs = utils.add_batch_dim(dummy_obs) + + sac_params = sac_nets.policy_network.init(rng) + sac_output = sac_nets.policy_network.apply(sac_params, dummy_obs) + + bc_params = bc_nets.init(rng) + bc_output = bc_nets.apply(bc_params, dummy_obs, is_training=False, key=None) + + np.testing.assert_array_equal(sac_output.mode(), bc_output.mode()) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/jax/builders.py b/acme/acme/agents/jax/builders.py new file mode 100644 index 00000000..7daabfd8 --- /dev/null +++ b/acme/acme/agents/jax/builders.py @@ -0,0 +1,226 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RL agent Builder interface.""" + +import abc +from typing import Generic, Iterator, List, Optional, TypeVar + +from acme import adders +from acme import core +from acme import specs +from acme.jax import networks as networks_lib +from acme.utils import counting +from acme.utils import loggers +import reverb + +Networks = TypeVar('Networks') # Container for all agent network components. +Policy = TypeVar('Policy') # Function or container for agent policy functions. +Sample = TypeVar('Sample') # Sample from the demonstrations or replay buffer. + + +class OfflineBuilder(abc.ABC, Generic[Networks, Policy, Sample]): + """Interface for defining the components of an offline RL agent. + + Implementations of this interface contain a complete specification of a + concrete offline RL agent. An instance of this class can be used to build an + offline RL agent that operates either locally or in a distributed setup. + """ + + @abc.abstractmethod + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: Networks, + dataset: Iterator[Sample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + *, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + """Creates an instance of the learner. + + Args: + random_key: A key for random number generation. + networks: struct describing the networks needed by the learner; this is + specific to the learner in question. + dataset: iterator over demonstration samples. + logger_fn: factory providing loggers used for logging progress. + environment_spec: A container for all relevant environment specs. + counter: a Counter which allows for recording of counts (learner steps, + evaluator steps, etc.) distributed throughout the agent. + """ + + @abc.abstractmethod + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: Policy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + ) -> core.Actor: + """Create an actor instance to be used for evaluation. + + Args: + random_key: A key for random number generation. + policy: Instance of a policy expected by the algorithm corresponding to + this builder. + environment_spec: A container for all relevant environment specs. + variable_source: A source providing the necessary actor parameters. + """ + + @abc.abstractmethod + def make_policy(self, networks: Networks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool) -> Policy: + """Creates the agent policy to be used for evaluation. + + Args: + networks: struct describing the networks needed to generate the policy. + environment_spec: struct describing the specs of the environment. + evaluation: This flag is present for consistency with the + ActorLearnerBuilder, in which case data-generating actors and evaluation + actors can behave differently. For OfflineBuilders, this should be set + to True. + + Returns: + Policy to be used for evaluation. The exact form of this object may differ + from one agent to the next; it could be a simple callable, a nest of + callables, or an ActorCore for instance. + """ + + +class ActorLearnerBuilder(OfflineBuilder[Networks, Policy, Sample], + Generic[Networks, Policy, Sample]): + """Defines an interface for defining the components of an RL agent. + + Implementations of this interface contain a complete specification of a + concrete RL agent. An instance of this class can be used to build an + RL agent which interacts with the environment either locally or in a + distributed setup. + """ + + @abc.abstractmethod + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: Policy, + ) -> List[reverb.Table]: + """Create tables to insert data into. + + Args: + environment_spec: A container for all relevant environment specs. + policy: Agent's policy which can be used to extract the extras_spec. + + Returns: + The replay tables used to store the experience the agent uses to train. + """ + + @abc.abstractmethod + def make_dataset_iterator( + self, + replay_client: reverb.Client, + ) -> Iterator[Sample]: + """Create a dataset iterator to use for learning/updating the agent.""" + + @abc.abstractmethod + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[Policy], + ) -> Optional[adders.Adder]: + """Create an adder which records data generated by the actor/environment. + + Args: + replay_client: Reverb Client which points to the replay server. + environment_spec: specs of the environment. + policy: Agent's policy which can be used to extract the extras_spec. + """ + # TODO(sabela): make the parameters non-optional. + + @abc.abstractmethod + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: Policy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> core.Actor: + """Create an actor instance. + + Args: + random_key: A key for random number generation. + policy: Instance of a policy expected by the algorithm corresponding to + this builder. + environment_spec: A container for all relevant environment specs. + variable_source: A source providing the necessary actor parameters. + adder: How data is recorded (e.g. added to replay). + """ + + @abc.abstractmethod + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: Networks, + dataset: Iterator[Sample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + """Creates an instance of the learner. + + Args: + random_key: A key for random number generation. + networks: struct describing the networks needed by the learner; this can + be specific to the learner in question. + dataset: iterator over samples from replay. + logger_fn: factory providing loggers used for logging progress. + environment_spec: A container for all relevant environment specs. + replay_client: client which allows communication with replay. Note that + this is only intended to be used for updating priorities. Samples should + be obtained from `dataset`. + counter: a Counter which allows for recording of counts (learner steps, + actor steps, etc.) distributed throughout the agent. + """ + + def make_policy(self, + networks: Networks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False) -> Policy: + """Creates the agent policy. + + Creates the agent policy given the collection of network components and + environment spec. An optional boolean can be given to indicate if the + policy will be used for evaluation. + + Args: + networks: struct describing the networks needed to generate the policy. + environment_spec: struct describing the specs of the environment. + evaluation: when true, a version of the policy to use for evaluation + should be returned. This is algorithm-specific so if an algorithm makes + no distinction between behavior and evaluation policies this boolean may + be ignored. + + Returns: + Behavior policy or evaluation policy for the agent. + """ + # TODO(sabela): make abstract once all agents implement it. + del networks, environment_spec, evaluation + raise NotImplementedError + +# TODO(sinopalnikov): deprecated, migrate all users and remove. +GenericActorLearnerBuilder = ActorLearnerBuilder diff --git a/acme/acme/agents/jax/cql/README.md b/acme/acme/agents/jax/cql/README.md new file mode 100644 index 00000000..ceee1aa2 --- /dev/null +++ b/acme/acme/agents/jax/cql/README.md @@ -0,0 +1,7 @@ +# Conservative Q-Learning (CQL) + +CQL (1) is an offline RL algorithm. It is based on an offline version of SAC +with an additional regularizing ("conservative") component in the critic loss. + +(1) [Kumar et al., *Conservative Q-Learning for Offline Reinforcement Learning*, +2020](https://arxiv.org/abs/2006.04779) diff --git a/acme/acme/agents/jax/cql/__init__.py b/acme/acme/agents/jax/cql/__init__.py new file mode 100644 index 00000000..238ed878 --- /dev/null +++ b/acme/acme/agents/jax/cql/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementation of the CQL agent.""" + +from acme.agents.jax.cql.builder import CQLBuilder +from acme.agents.jax.cql.config import CQLConfig +from acme.agents.jax.cql.learning import CQLLearner +from acme.agents.jax.cql.networks import CQLNetworks +from acme.agents.jax.cql.networks import make_networks diff --git a/acme/acme/agents/jax/cql/agent_test.py b/acme/acme/agents/jax/cql/agent_test.py new file mode 100644 index 00000000..2babe689 --- /dev/null +++ b/acme/acme/agents/jax/cql/agent_test.py @@ -0,0 +1,62 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the CQL agent.""" + +from acme import specs +from acme.agents.jax import cql +from acme.testing import fakes +import jax +import optax + +from absl.testing import absltest + + +class CQLTest(absltest.TestCase): + + def test_train(self): + seed = 0 + num_iterations = 6 + batch_size = 64 + + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment( + episode_length=10, bounded=True, action_dim=6) + spec = specs.make_environment_spec(environment) + + # Construct the agent. + networks = cql.make_networks( + spec, hidden_layer_sizes=(8, 8)) + dataset = fakes.transition_iterator(environment) + key = jax.random.PRNGKey(seed) + learner = cql.CQLLearner( + batch_size, + networks, + key, + demonstrations=dataset(batch_size), + policy_optimizer=optax.adam(3e-5), + critic_optimizer=optax.adam(3e-4), + fixed_cql_coefficient=5., + cql_lagrange_threshold=None, + target_entropy=0.1, + num_bc_iters=2, + num_sgd_steps_per_step=1) + + # Train the agent + for _ in range(num_iterations): + learner.step() + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/jax/cql/builder.py b/acme/acme/agents/jax/cql/builder.py new file mode 100644 index 00000000..8fe8173c --- /dev/null +++ b/acme/acme/agents/jax/cql/builder.py @@ -0,0 +1,109 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CQL Builder.""" +from typing import Iterator, Optional + +from acme import core +from acme import specs +from acme import types +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import actors +from acme.agents.jax import builders +from acme.agents.jax.cql import config as cql_config +from acme.agents.jax.cql import learning +from acme.agents.jax.cql import networks as cql_networks +from acme.jax import networks as networks_lib +from acme.jax import variable_utils +from acme.utils import counting +from acme.utils import loggers +import optax + + +class CQLBuilder(builders.OfflineBuilder[cql_networks.CQLNetworks, + actor_core_lib.FeedForwardPolicy, + types.Transition]): + """CQL Builder.""" + + def __init__( + self, + config: cql_config.CQLConfig, + ): + """Creates a CQL learner, an evaluation policy and an eval actor. + + Args: + config: a config with CQL hps. + """ + self._config = config + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: cql_networks.CQLNetworks, + dataset: Iterator[types.Transition], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + *, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec + + return learning.CQLLearner( + batch_size=self._config.batch_size, + networks=networks, + random_key=random_key, + demonstrations=dataset, + policy_optimizer=optax.adam(self._config.policy_learning_rate), + critic_optimizer=optax.adam(self._config.critic_learning_rate), + tau=self._config.tau, + fixed_cql_coefficient=self._config.fixed_cql_coefficient, + cql_lagrange_threshold=self._config.cql_lagrange_threshold, + cql_num_samples=self._config.cql_num_samples, + num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, + reward_scale=self._config.reward_scale, + discount=self._config.discount, + fixed_entropy_coefficient=self._config.fixed_entropy_coefficient, + target_entropy=self._config.target_entropy, + num_bc_iters=self._config.num_bc_iters, + logger=logger_fn('learner'), + counter=counter) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: actor_core_lib.FeedForwardPolicy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + ) -> core.Actor: + del environment_spec + assert variable_source is not None + actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) + variable_client = variable_utils.VariableClient( + variable_source, 'policy', device='cpu') + return actors.GenericActor( + actor_core, random_key, variable_client, backend='cpu') + + def make_policy(self, networks: cql_networks.CQLNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool) -> actor_core_lib.FeedForwardPolicy: + """Construct the policy.""" + del environment_spec, evaluation + + def evaluation_policy( + params: networks_lib.Params, key: networks_lib.PRNGKey, + observation: networks_lib.Observation) -> networks_lib.Action: + dist_params = networks.policy_network.apply(params, observation) + return networks.sample_eval(dist_params, key) + + return evaluation_policy diff --git a/acme/acme/agents/jax/cql/config.py b/acme/acme/agents/jax/cql/config.py new file mode 100644 index 00000000..44b2b26b --- /dev/null +++ b/acme/acme/agents/jax/cql/config.py @@ -0,0 +1,58 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Config classes for CQL.""" +import dataclasses +from typing import Optional + + +@dataclasses.dataclass +class CQLConfig: + """Configuration options for CQL. + + Attributes: + batch_size: batch size. + policy_learning_rate: learning rate for the policy optimizer. + critic_learning_rate: learning rate for the Q-function optimizer. + tau: Target smoothing coefficient. + fixed_cql_coefficient: the value for cql coefficient. If None an adaptive + coefficient will be used. + cql_lagrange_threshold: a threshold that controls the adaptive loss for the + cql coefficient. + cql_num_samples: number of samples used to compute logsumexp(Q) via + importance sampling. + num_sgd_steps_per_step: how many gradient updates to perform per batch. + Batch is split into this many smaller batches thus should be a multiple of + num_sgd_steps_per_step + reward_scale: reward scale. + discount: discount to use for TD updates. + fixed_entropy_coefficient: coefficient applied to the entropy bonus. If None + an adaptative coefficient will be used. + target_entropy: target entropy when using adapdative entropy bonus. + num_bc_iters: number of BC steps for actor initialization. + """ + batch_size: int = 256 + policy_learning_rate: float = 3e-5 + critic_learning_rate: float = 3e-4 + fixed_cql_coefficient: float = 5. + tau: float = 0.005 + fixed_cql_coefficient: Optional[float] = 5. + cql_lagrange_threshold: Optional[float] = None + cql_num_samples: int = 10 + num_sgd_steps_per_step: int = 1 + reward_scale: float = 1.0 + discount: float = 0.99 + fixed_entropy_coefficient: Optional[float] = 0. + target_entropy: Optional[float] = 0 + num_bc_iters: int = 50_000 diff --git a/acme/acme/agents/jax/cql/learning.py b/acme/acme/agents/jax/cql/learning.py new file mode 100644 index 00000000..82b14068 --- /dev/null +++ b/acme/acme/agents/jax/cql/learning.py @@ -0,0 +1,481 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CQL learner implementation.""" + +import time +from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple +import acme +from acme import types +from acme.agents.jax.cql.networks import apply_and_sample_n +from acme.agents.jax.cql.networks import CQLNetworks +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.utils import counting +from acme.utils import loggers +import jax +import jax.numpy as jnp +import optax + + +_CQL_COEFFICIENT_MAX_VALUE = 1E6 +_CQL_GRAD_CLIPPING_VALUE = 40 + + +class TrainingState(NamedTuple): + """Contains training state for the learner.""" + policy_optimizer_state: optax.OptState + critic_optimizer_state: optax.OptState + policy_params: networks_lib.Params + critic_params: networks_lib.Params + target_critic_params: networks_lib.Params + key: networks_lib.PRNGKey + + # Optimizer and value of the alpha parameter from SAC (entropy temperature). + # These fields are only used with an adaptive coefficient (when + # fixed_entropy_coefficeint is None in the CQLLearner) + alpha_optimizer_state: Optional[optax.OptState] = None + log_sac_alpha: Optional[networks_lib.Params] = None + + # Optimizer and value of the alpha parameter from CQL (regularization + # coefficient). + # These fields are only used with an adaptive coefficient (when + # fixed_cql_coefficiennt is None in the CQLLearner) + cql_optimizer_state: Optional[optax.OptState] = None + log_cql_alpha: Optional[networks_lib.Params] = None + + steps: int = 0 + + +class CQLLearner(acme.Learner): + """CQL learner. + + Learning component of the Conservative Q-Learning algorithm from + [Kumar et al., 2020] https://arxiv.org/abs/2006.04779. + """ + + _state: TrainingState + + def __init__(self, + batch_size: int, + networks: CQLNetworks, + random_key: networks_lib.PRNGKey, + demonstrations: Iterator[types.Transition], + policy_optimizer: optax.GradientTransformation, + critic_optimizer: optax.GradientTransformation, + tau: float = 0.005, + fixed_cql_coefficient: Optional[float] = None, + cql_lagrange_threshold: Optional[float] = None, + cql_num_samples: int = 10, + num_sgd_steps_per_step: int = 1, + reward_scale: float = 1.0, + discount: float = 0.99, + fixed_entropy_coefficient: Optional[float] = None, + target_entropy: Optional[float] = 0, + num_bc_iters: int = 50_000, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None): + """Initializes the CQL learner. + + Args: + batch_size: batch size. + networks: CQL networks. + random_key: a key for random number generation. + demonstrations: an iterator over training data. + policy_optimizer: the policy optimizer. + critic_optimizer: the Q-function optimizer. + tau: target smoothing coefficient. + fixed_cql_coefficient: the value for cql coefficient. If None, an adaptive + coefficient will be used. + cql_lagrange_threshold: a threshold that controls the adaptive loss for + the cql coefficient. + cql_num_samples: number of samples used to compute logsumexp(Q) via + importance sampling. + num_sgd_steps_per_step: how many gradient updated to perform per batch. + batch is split into this many smaller batches, thus should be a multiple + of num_sgd_steps_per_step + reward_scale: reward scale. + discount: discount to use for TD updates. + fixed_entropy_coefficient: coefficient applied to the entropy bonus. If + None, an adaptative coefficient will be used. + target_entropy: Target entropy when using adapdative entropy bonus. + num_bc_iters: Number of BC steps for actor initialization. + counter: counter object used to keep track of steps. + logger: logger object to be used by learner. + """ + self._num_bc_iters = num_bc_iters + adaptive_entropy_coefficient = fixed_entropy_coefficient is None + action_spec = networks.environment_specs.actions + if adaptive_entropy_coefficient: + # sac_alpha is the temperature parameter that determines the relative + # importance of the entropy term versus the reward. + log_sac_alpha = jnp.asarray(0., dtype=jnp.float32) + alpha_optimizer = optax.adam(learning_rate=3e-4) + alpha_optimizer_state = alpha_optimizer.init(log_sac_alpha) + else: + if target_entropy: + raise ValueError('target_entropy should not be set when ' + 'fixed_entropy_coefficient is provided') + + adaptive_cql_coefficient = fixed_cql_coefficient is None + if adaptive_cql_coefficient: + log_cql_alpha = jnp.asarray(0., dtype=jnp.float32) + cql_optimizer = optax.adam(learning_rate=3e-4) + cql_optimizer_state = cql_optimizer.init(log_cql_alpha) + else: + if cql_lagrange_threshold: + raise ValueError('cql_lagrange_threshold should not be set when ' + 'fixed_cql_coefficient is provided') + + def alpha_loss(log_sac_alpha: jnp.ndarray, + policy_params: networks_lib.Params, + transitions: types.Transition, + key: jnp.ndarray) -> jnp.ndarray: + """Eq 18 from https://arxiv.org/pdf/1812.05905.pdf.""" + dist_params = networks.policy_network.apply(policy_params, + transitions.observation) + action = networks.sample(dist_params, key) + log_prob = networks.log_prob(dist_params, action) + sac_alpha = jnp.exp(log_sac_alpha) + sac_alpha_loss = sac_alpha * jax.lax.stop_gradient(-log_prob - + target_entropy) + return jnp.mean(sac_alpha_loss) + + def sac_critic_loss(q_old_action: jnp.ndarray, + policy_params: networks_lib.Params, + target_critic_params: networks_lib.Params, + transitions: types.Transition, + key: networks_lib.PRNGKey) -> jnp.ndarray: + """Computes the SAC part of the loss.""" + next_dist_params = networks.policy_network.apply( + policy_params, transitions.next_observation) + next_action = networks.sample(next_dist_params, key) + next_q = networks.critic_network.apply(target_critic_params, + transitions.next_observation, + next_action) + next_v = jnp.min(next_q, axis=-1) + target_q = jax.lax.stop_gradient(transitions.reward * reward_scale + + transitions.discount * discount * next_v) + return jnp.mean(jnp.square(q_old_action - jnp.expand_dims(target_q, -1))) + + def batched_critic(actions: jnp.ndarray, critic_params: networks_lib.Params, + observation: jnp.ndarray) -> jnp.ndarray: + """Applies the critic network to a batch of sampled actions.""" + actions = jax.lax.stop_gradient(actions) + tiled_actions = jnp.reshape(actions, (batch_size * cql_num_samples, -1)) + tiled_states = jnp.tile(observation, [cql_num_samples, 1]) + tiled_q = networks.critic_network.apply(critic_params, tiled_states, + tiled_actions) + return jnp.reshape(tiled_q, (cql_num_samples, batch_size, -1)) + + def cql_critic_loss(q_old_action: jnp.ndarray, + critic_params: networks_lib.Params, + policy_params: networks_lib.Params, + transitions: types.Transition, + key: networks_lib.PRNGKey) -> jnp.ndarray: + """Computes the CQL part of the loss.""" + # The CQL part of the loss is + # logsumexp(Q(s,·)) - Q(s,a), + # where s is the currrent state, and a the action in the dataset (so + # Q(s,a) is simply q_old_action. + # We need to estimate logsumexp(Q). This is done with importance sampling + # (IS). This function implements the unlabeled equation page 29, Appx. F, + # in https://arxiv.org/abs/2006.04779. + # Here, IS is done with the uniform distribution and the policy in the + # current state s. In their implementation, the authors also add the + # policy in the transiting state s': + # https://github.com/aviralkumar2907/CQL/blob/master/d4rl/rlkit/torch/sac/cql.py, + # (l. 233-236). + + key_policy, key_policy_next, key_uniform = jax.random.split(key, 3) + + def sampled_q(obs, key): + actions, log_probs = apply_and_sample_n( + key, networks, policy_params, obs, cql_num_samples) + return batched_critic(actions, critic_params, + transitions.observation) - jax.lax.stop_gradient( + jnp.expand_dims(log_probs, -1)) + + # Sample wrt policy in s + sampled_q_from_policy = sampled_q(transitions.observation, key_policy) + + # Sample wrt policy in s' + sampled_q_from_policy_next = sampled_q(transitions.next_observation, + key_policy_next) + + # Sample wrt uniform + actions_uniform = jax.random.uniform( + key_uniform, (cql_num_samples, batch_size) + action_spec.shape, + minval=action_spec.minimum, maxval=action_spec.maximum) + log_prob_uniform = -jnp.sum( + jnp.log(action_spec.maximum - action_spec.minimum)) + sampled_q_from_uniform = ( + batched_critic(actions_uniform, critic_params, + transitions.observation) - log_prob_uniform) + + # Combine the samplings + combined = jnp.concatenate( + (sampled_q_from_uniform, sampled_q_from_policy, + sampled_q_from_policy_next), + axis=0) + lse_q = jax.nn.logsumexp(combined, axis=0, b=1. / (3 * cql_num_samples)) + + return jnp.mean(lse_q - q_old_action) + + def critic_loss(critic_params: networks_lib.Params, + policy_params: networks_lib.Params, + target_critic_params: networks_lib.Params, + cql_alpha: jnp.ndarray, transitions: types.Transition, + key: networks_lib.PRNGKey) -> jnp.ndarray: + """Computes the full critic loss.""" + key_cql, key_sac = jax.random.split(key, 2) + q_old_action = networks.critic_network.apply(critic_params, + transitions.observation, + transitions.action) + cql_loss = cql_critic_loss(q_old_action, critic_params, policy_params, + transitions, key_cql) + sac_loss = sac_critic_loss(q_old_action, policy_params, + target_critic_params, transitions, key_sac) + return cql_alpha * cql_loss + sac_loss + + def cql_lagrange_loss(log_cql_alpha: jnp.ndarray, + critic_params: networks_lib.Params, + policy_params: networks_lib.Params, + transitions: types.Transition, + key: jnp.ndarray) -> jnp.ndarray: + """Computes the loss that optimizes the cql coefficient.""" + cql_alpha = jnp.exp(log_cql_alpha) + q_old_action = networks.critic_network.apply(critic_params, + transitions.observation, + transitions.action) + return -cql_alpha * ( + cql_critic_loss(q_old_action, critic_params, policy_params, + transitions, key) - cql_lagrange_threshold) + + def actor_loss(policy_params: networks_lib.Params, + critic_params: networks_lib.Params, sac_alpha: jnp.ndarray, + transitions: types.Transition, key: jnp.ndarray, + in_initial_bc_iters: bool) -> jnp.ndarray: + """Computes the loss for the policy.""" + dist_params = networks.policy_network.apply(policy_params, + transitions.observation) + if in_initial_bc_iters: + log_prob = networks.log_prob(dist_params, transitions.action) + actor_loss = -jnp.mean(log_prob) + else: + action = networks.sample(dist_params, key) + log_prob = networks.log_prob(dist_params, action) + q_action = networks.critic_network.apply(critic_params, + transitions.observation, + action) + min_q = jnp.min(q_action, axis=-1) + actor_loss = jnp.mean(sac_alpha * log_prob - min_q) + return actor_loss + + alpha_grad = jax.value_and_grad(alpha_loss) + cql_lagrange_grad = jax.value_and_grad(cql_lagrange_loss) + critic_grad = jax.value_and_grad(critic_loss) + actor_grad = jax.value_and_grad(actor_loss) + + def update_step( + state: TrainingState, + rb_transitions: types.Transition, + in_initial_bc_iters: bool, + ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: + + key, key_alpha, key_critic, key_actor = jax.random.split(state.key, 4) + + if adaptive_entropy_coefficient: + alpha_loss, alpha_grads = alpha_grad(state.log_sac_alpha, + state.policy_params, + rb_transitions, key_alpha) + sac_alpha = jnp.exp(state.log_sac_alpha) + else: + sac_alpha = fixed_entropy_coefficient + + if adaptive_cql_coefficient: + cql_lagrange_loss, cql_lagrange_grads = cql_lagrange_grad( + state.log_cql_alpha, state.critic_params, state.policy_params, + rb_transitions, key_critic) + cql_lagrange_grads = jnp.clip(cql_lagrange_grads, + -_CQL_GRAD_CLIPPING_VALUE, + _CQL_GRAD_CLIPPING_VALUE) + cql_alpha = jnp.exp(state.log_cql_alpha) + cql_alpha = jnp.clip( + cql_alpha, a_min=0., a_max=_CQL_COEFFICIENT_MAX_VALUE) + else: + cql_alpha = fixed_cql_coefficient + + critic_loss, critic_grads = critic_grad(state.critic_params, + state.policy_params, + state.target_critic_params, + cql_alpha, rb_transitions, + key_critic) + actor_loss, actor_grads = actor_grad(state.policy_params, + state.critic_params, sac_alpha, + rb_transitions, key_actor, + in_initial_bc_iters) + + # Apply policy gradients + actor_update, policy_optimizer_state = policy_optimizer.update( + actor_grads, state.policy_optimizer_state) + policy_params = optax.apply_updates(state.policy_params, actor_update) + + # Apply critic gradients + critic_update, critic_optimizer_state = critic_optimizer.update( + critic_grads, state.critic_optimizer_state) + critic_params = optax.apply_updates(state.critic_params, critic_update) + + new_target_critic_params = jax.tree_map( + lambda x, y: x * (1 - tau) + y * tau, state.target_critic_params, + critic_params) + + metrics = { + 'critic_loss': critic_loss, + 'actor_loss': actor_loss, + } + + new_state = TrainingState( + policy_optimizer_state=policy_optimizer_state, + critic_optimizer_state=critic_optimizer_state, + policy_params=policy_params, + critic_params=critic_params, + target_critic_params=new_target_critic_params, + key=key, + alpha_optimizer_state=state.alpha_optimizer_state, + log_sac_alpha=state.log_sac_alpha, + steps=state.steps + 1, + ) + if adaptive_entropy_coefficient and (not in_initial_bc_iters): + # Apply sac_alpha gradients + alpha_update, alpha_optimizer_state = alpha_optimizer.update( + alpha_grads, state.alpha_optimizer_state) + log_sac_alpha = optax.apply_updates(state.log_sac_alpha, alpha_update) + metrics.update({ + 'alpha_loss': alpha_loss, + 'sac_alpha': jnp.exp(log_sac_alpha), + }) + new_state = new_state._replace( + alpha_optimizer_state=alpha_optimizer_state, + log_sac_alpha=log_sac_alpha) + else: + metrics['alpha_loss'] = 0. + metrics['sac_alpha'] = fixed_cql_coefficient + + if adaptive_cql_coefficient: + # Apply cql coeff gradients + cql_update, cql_optimizer_state = cql_optimizer.update( + cql_lagrange_grads, state.cql_optimizer_state) + log_cql_alpha = optax.apply_updates(state.log_cql_alpha, cql_update) + metrics.update({ + 'cql_lagrange_loss': cql_lagrange_loss, + 'cql_alpha': jnp.exp(log_cql_alpha), + }) + new_state = new_state._replace( + cql_optimizer_state=cql_optimizer_state, + log_cql_alpha=log_cql_alpha) + + return new_state, metrics + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + 'learner', + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + steps_key=self._counter.get_steps_key()) + + # Iterator on demonstration transitions. + self._demonstrations = demonstrations + + # Use the JIT compiler. + update_step_in_initial_bc_iters = utils.process_multiple_batches( + lambda x, y: update_step(x, y, True), num_sgd_steps_per_step) + update_step_rest = utils.process_multiple_batches( + lambda x, y: update_step(x, y, False), num_sgd_steps_per_step) + + self._update_step_in_initial_bc_iters = jax.jit( + update_step_in_initial_bc_iters) + self._update_step_rest = jax.jit(update_step_rest) + + # Create initial state. + key_policy, key_q, training_state_key = jax.random.split(random_key, 3) + del random_key + policy_params = networks.policy_network.init(key_policy) + policy_optimizer_state = policy_optimizer.init(policy_params) + critic_params = networks.critic_network.init(key_q) + critic_optimizer_state = critic_optimizer.init(critic_params) + + self._state = TrainingState( + policy_optimizer_state=policy_optimizer_state, + critic_optimizer_state=critic_optimizer_state, + policy_params=policy_params, + critic_params=critic_params, + target_critic_params=critic_params, + key=training_state_key, + steps=0) + + if adaptive_entropy_coefficient: + self._state = self._state._replace( + alpha_optimizer_state=alpha_optimizer_state, + log_sac_alpha=log_sac_alpha) + if adaptive_cql_coefficient: + self._state = self._state._replace( + cql_optimizer_state=cql_optimizer_state, log_cql_alpha=log_cql_alpha) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + def step(self): + # Get data from replay (dropping extras if any). Note there is no + # extra data here because we do not insert any into Reverb. + transitions = next(self._demonstrations) + + counts = self._counter.get_counts() + if 'learner_steps' not in counts: + cur_step = 0 + else: + cur_step = counts['learner_steps'] + in_initial_bc_iters = cur_step < self._num_bc_iters + + if in_initial_bc_iters: + self._state, metrics = self._update_step_in_initial_bc_iters( + self._state, transitions) + else: + self._state, metrics = self._update_step_rest(self._state, transitions) + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Increment counts and record the current time + counts = self._counter.increment(steps=1, walltime=elapsed_time) + + # Attempts to write the logs. + self._logger.write({**metrics, **counts}) + + def get_variables(self, names: List[str]) -> List[Any]: + variables = { + 'policy': self._state.policy_params, + } + return [variables[name] for name in names] + + def save(self) -> TrainingState: + return self._state + + def restore(self, state: TrainingState): + self._state = state diff --git a/acme/acme/agents/jax/cql/networks.py b/acme/acme/agents/jax/cql/networks.py new file mode 100644 index 00000000..0593489b --- /dev/null +++ b/acme/acme/agents/jax/cql/networks.py @@ -0,0 +1,60 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Networks definitions for the CQL agent.""" +import dataclasses +from typing import Optional, Tuple + +from acme import specs +from acme.agents.jax import sac +from acme.jax import networks as networks_lib +import jax +import jax.numpy as jnp + + +@dataclasses.dataclass +class CQLNetworks: + """Network and pure functions for the CQL agent.""" + policy_network: networks_lib.FeedForwardNetwork + critic_network: networks_lib.FeedForwardNetwork + log_prob: networks_lib.LogProbFn + sample: Optional[networks_lib.SampleFn] + sample_eval: Optional[networks_lib.SampleFn] + environment_specs: specs.EnvironmentSpec + + +def apply_and_sample_n(key: networks_lib.PRNGKey, + networks: CQLNetworks, + params: networks_lib.Params, obs: jnp.ndarray, + num_samples: int) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Applies the policy and samples num_samples actions.""" + dist_params = networks.policy_network.apply(params, obs) + sampled_actions = jnp.array([ + networks.sample(dist_params, key_n) + for key_n in jax.random.split(key, num_samples) + ]) + sampled_log_probs = networks.log_prob(dist_params, sampled_actions) + return sampled_actions, sampled_log_probs + + +def make_networks( + spec: specs.EnvironmentSpec, **kwargs) -> CQLNetworks: + sac_networks = sac.make_networks(spec, **kwargs) + return CQLNetworks( + policy_network=sac_networks.policy_network, + critic_network=sac_networks.q_network, + log_prob=sac_networks.log_prob, + sample=sac_networks.sample, + sample_eval=sac_networks.sample_eval, + environment_specs=spec) diff --git a/acme/acme/agents/jax/crr/README.md b/acme/acme/agents/jax/crr/README.md new file mode 100644 index 00000000..bf8bd6a8 --- /dev/null +++ b/acme/acme/agents/jax/crr/README.md @@ -0,0 +1,11 @@ +# Critic Regularized Regression (CRR) + +This folder contains an implementation of the CRR algorithm +([Wang et al., 2020]). It is an offline RL algorithm to learn policies from data +using a form of critic-regularized regression. + +For the advantage estimate, a sampled mean is used. See policy.py file for +possible weighting coefficients for the policy loss (including exponential +estimated advantage). The policy network assumes a continuous action space. + +[Wang et al., 2020]: https://arxiv.org/abs/2006.15134 diff --git a/acme/acme/agents/jax/crr/__init__.py b/acme/acme/agents/jax/crr/__init__.py new file mode 100644 index 00000000..d594cf23 --- /dev/null +++ b/acme/acme/agents/jax/crr/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementation of the Critic Regularized Regression (CRR) agent.""" + +from acme.agents.jax.crr.builder import CRRBuilder +from acme.agents.jax.crr.config import CRRConfig +from acme.agents.jax.crr.learning import CRRLearner +from acme.agents.jax.crr.losses import policy_loss_coeff_advantage_exp +from acme.agents.jax.crr.losses import policy_loss_coeff_advantage_indicator +from acme.agents.jax.crr.losses import policy_loss_coeff_constant +from acme.agents.jax.crr.losses import PolicyLossCoeff +from acme.agents.jax.crr.networks import CRRNetworks +from acme.agents.jax.crr.networks import make_networks diff --git a/acme/acme/agents/jax/crr/agent_test.py b/acme/acme/agents/jax/crr/agent_test.py new file mode 100644 index 00000000..2f925206 --- /dev/null +++ b/acme/acme/agents/jax/crr/agent_test.py @@ -0,0 +1,66 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the CRR agent.""" + +from acme import specs +from acme.agents.jax import crr +from acme.testing import fakes +import jax +import optax + +from absl.testing import absltest +from absl.testing import parameterized + + +class CRRTest(parameterized.TestCase): + + @parameterized.named_parameters( + ('exp', crr.policy_loss_coeff_advantage_exp), + ('indicator', crr.policy_loss_coeff_advantage_indicator), + ('all', crr.policy_loss_coeff_constant)) + def test_train(self, policy_loss_coeff_fn): + seed = 0 + num_iterations = 5 + batch_size = 64 + grad_updates_per_batch = 1 + + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment( + episode_length=10, bounded=True, action_dim=6) + spec = specs.make_environment_spec(environment) + + # Construct the learner. + networks = crr.make_networks( + spec, policy_layer_sizes=(8, 8), critic_layer_sizes=(8, 8)) + key = jax.random.PRNGKey(seed) + dataset = fakes.transition_iterator(environment) + learner = crr.CRRLearner( + networks, + key, + discount=0.95, + target_update_period=2, + policy_loss_coeff_fn=policy_loss_coeff_fn, + iterator=dataset(batch_size * grad_updates_per_batch), + policy_optimizer=optax.adam(1e-4), + critic_optimizer=optax.adam(1e-4), + grad_updates_per_batch=grad_updates_per_batch) + + # Train the learner. + for _ in range(num_iterations): + learner.step() + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/jax/crr/builder.py b/acme/acme/agents/jax/crr/builder.py new file mode 100644 index 00000000..7cdd9b53 --- /dev/null +++ b/acme/acme/agents/jax/crr/builder.py @@ -0,0 +1,106 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CRR Builder.""" +from typing import Iterator, Optional + +from acme import core +from acme import specs +from acme import types +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import actors +from acme.agents.jax import builders +from acme.agents.jax.crr import config as crr_config +from acme.agents.jax.crr import learning +from acme.agents.jax.crr import losses +from acme.agents.jax.crr import networks as crr_networks +from acme.jax import networks as networks_lib +from acme.jax import variable_utils +from acme.utils import counting +from acme.utils import loggers +import optax + + +class CRRBuilder(builders.OfflineBuilder[crr_networks.CRRNetworks, + actor_core_lib.FeedForwardPolicy, + types.Transition]): + """CRR Builder.""" + + def __init__( + self, + config: crr_config.CRRConfig, + policy_loss_coeff_fn: losses.PolicyLossCoeff, + ): + """Creates a CRR learner, an evaluation policy and an eval actor. + + Args: + config: a config with CRR hps. + policy_loss_coeff_fn: set the loss function for the policy. + """ + self._config = config + self._policy_loss_coeff_fn = policy_loss_coeff_fn + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: crr_networks.CRRNetworks, + dataset: Iterator[types.Transition], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + *, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec + + return learning.CRRLearner( + networks=networks, + random_key=random_key, + discount=self._config.discount, + target_update_period=self._config.target_update_period, + iterator=dataset, + policy_loss_coeff_fn=self._policy_loss_coeff_fn, + policy_optimizer=optax.adam(self._config.learning_rate), + critic_optimizer=optax.adam(self._config.learning_rate), + use_sarsa_target=self._config.use_sarsa_target, + logger=logger_fn('learner'), + counter=counter) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: actor_core_lib.FeedForwardPolicy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + ) -> core.Actor: + del environment_spec + assert variable_source is not None + actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) + variable_client = variable_utils.VariableClient( + variable_source, 'policy', device='cpu') + return actors.GenericActor( + actor_core, random_key, variable_client, backend='cpu') + + def make_policy(self, networks: crr_networks.CRRNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool) -> actor_core_lib.FeedForwardPolicy: + """Construct the policy.""" + del environment_spec, evaluation + + def evaluation_policy( + params: networks_lib.Params, key: networks_lib.PRNGKey, + observation: networks_lib.Observation) -> networks_lib.Action: + dist_params = networks.policy_network.apply(params, observation) + return networks.sample_eval(dist_params, key) + + return evaluation_policy diff --git a/acme/acme/agents/jax/crr/config.py b/acme/acme/agents/jax/crr/config.py new file mode 100644 index 00000000..9bfda702 --- /dev/null +++ b/acme/acme/agents/jax/crr/config.py @@ -0,0 +1,34 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Config classes for CRR.""" +import dataclasses + + +@dataclasses.dataclass +class CRRConfig: + """Configuration options for CRR. + + Attributes: + learning_rate: Learning rate. + discount: discount to use for TD updates. + target_update_period: period to update target's parameters. + use_sarsa_target: compute on-policy target using iterator's actions rather + than sampled actions. + Useful for 1-step offline RL (https://arxiv.org/pdf/2106.08909.pdf). + """ + learning_rate: float = 3e-4 + discount: float = 0.99 + target_update_period: int = 100 + use_sarsa_target: bool = False diff --git a/acme/acme/agents/jax/crr/learning.py b/acme/acme/agents/jax/crr/learning.py new file mode 100644 index 00000000..7d416241 --- /dev/null +++ b/acme/acme/agents/jax/crr/learning.py @@ -0,0 +1,260 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CRR learner implementation.""" + +import time +from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple + +import acme +from acme import types +from acme.agents.jax.crr.losses import PolicyLossCoeff +from acme.agents.jax.crr.networks import CRRNetworks +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.utils import counting +from acme.utils import loggers +import jax +import jax.numpy as jnp +import optax + + +class TrainingState(NamedTuple): + """Contains training state for the learner.""" + policy_params: networks_lib.Params + target_policy_params: networks_lib.Params + critic_params: networks_lib.Params + target_critic_params: networks_lib.Params + policy_opt_state: optax.OptState + critic_opt_state: optax.OptState + steps: int + key: networks_lib.PRNGKey + + +class CRRLearner(acme.Learner): + """Critic Regularized Regression (CRR) learner. + + This is the learning component of a CRR agent as described in + https://arxiv.org/abs/2006.15134. + """ + + _state: TrainingState + + def __init__(self, + networks: CRRNetworks, + random_key: networks_lib.PRNGKey, + discount: float, + target_update_period: int, + policy_loss_coeff_fn: PolicyLossCoeff, + iterator: Iterator[types.Transition], + policy_optimizer: optax.GradientTransformation, + critic_optimizer: optax.GradientTransformation, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + grad_updates_per_batch: int = 1, + use_sarsa_target: bool = False): + """Initializes the CRR learner. + + Args: + networks: CRR networks. + random_key: a key for random number generation. + discount: discount to use for TD updates. + target_update_period: period to update target's parameters. + policy_loss_coeff_fn: set the loss function for the policy. + iterator: an iterator over training data. + policy_optimizer: the policy optimizer. + critic_optimizer: the Q-function optimizer. + counter: counter object used to keep track of steps. + logger: logger object to be used by learner. + grad_updates_per_batch: how many gradient updates given a sampled batch. + use_sarsa_target: compute on-policy target using iterator's actions rather + than sampled actions. + Useful for 1-step offline RL (https://arxiv.org/pdf/2106.08909.pdf). + When set to `True`, `target_policy_params` are unused. + """ + + critic_network = networks.critic_network + policy_network = networks.policy_network + + def policy_loss( + policy_params: networks_lib.Params, + critic_params: networks_lib.Params, + transition: types.Transition, + key: networks_lib.PRNGKey, + ) -> jnp.ndarray: + # Compute the loss coefficients. + coeff = policy_loss_coeff_fn(networks, policy_params, critic_params, + transition, key) + coeff = jax.lax.stop_gradient(coeff) + # Return the weighted loss. + dist_params = policy_network.apply(policy_params, transition.observation) + logp_action = networks.log_prob(dist_params, transition.action) + return -jnp.mean(logp_action * coeff) + + def critic_loss( + critic_params: networks_lib.Params, + target_policy_params: networks_lib.Params, + target_critic_params: networks_lib.Params, + transition: types.Transition, + key: networks_lib.PRNGKey, + ): + # Sample the next action. + if use_sarsa_target: + # TODO(b/222674779): use N-steps Trajectories to get the next actions. + assert 'next_action' in transition.extras, ( + 'next actions should be given as extras for one step RL.') + next_action = transition.extras['next_action'] + else: + next_dist_params = policy_network.apply(target_policy_params, + transition.next_observation) + next_action = networks.sample(next_dist_params, key) + # Calculate the value of the next state and action. + next_q = critic_network.apply(target_critic_params, + transition.next_observation, next_action) + target_q = transition.reward + transition.discount * discount * next_q + target_q = jax.lax.stop_gradient(target_q) + + q = critic_network.apply(critic_params, transition.observation, + transition.action) + q_error = q - target_q + # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error. + # TODO(sertan): Replace with a distributional critic. CRR paper states + # that this may perform better. + return 0.5 * jnp.mean(jnp.square(q_error)) + + policy_loss_and_grad = jax.value_and_grad(policy_loss) + critic_loss_and_grad = jax.value_and_grad(critic_loss) + + def sgd_step( + state: TrainingState, + transitions: types.Transition, + ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: + + key, key_policy, key_critic = jax.random.split(state.key, 3) + + # Compute losses and their gradients. + policy_loss_value, policy_gradients = policy_loss_and_grad( + state.policy_params, state.critic_params, transitions, key_policy) + critic_loss_value, critic_gradients = critic_loss_and_grad( + state.critic_params, state.target_policy_params, + state.target_critic_params, transitions, key_critic) + + # Get optimizer updates and state. + policy_updates, policy_opt_state = policy_optimizer.update( + policy_gradients, state.policy_opt_state) + critic_updates, critic_opt_state = critic_optimizer.update( + critic_gradients, state.critic_opt_state) + + # Apply optimizer updates to parameters. + policy_params = optax.apply_updates(state.policy_params, policy_updates) + critic_params = optax.apply_updates(state.critic_params, critic_updates) + + steps = state.steps + 1 + + # Periodically update target networks. + target_policy_params, target_critic_params = optax.periodic_update( + (policy_params, critic_params), + (state.target_policy_params, state.target_critic_params), steps, + target_update_period) + + new_state = TrainingState( + policy_params=policy_params, + target_policy_params=target_policy_params, + critic_params=critic_params, + target_critic_params=target_critic_params, + policy_opt_state=policy_opt_state, + critic_opt_state=critic_opt_state, + steps=steps, + key=key, + ) + + metrics = { + 'policy_loss': policy_loss_value, + 'critic_loss': critic_loss_value, + } + + return new_state, metrics + + sgd_step = utils.process_multiple_batches(sgd_step, grad_updates_per_batch) + self._sgd_step = jax.jit(sgd_step) + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + 'learner', + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + steps_key=self._counter.get_steps_key()) + + # Create prefetching dataset iterator. + self._iterator = iterator + + # Create the network parameters and copy into the target network parameters. + key, key_policy, key_critic = jax.random.split(random_key, 3) + initial_policy_params = policy_network.init(key_policy) + initial_critic_params = critic_network.init(key_critic) + initial_target_policy_params = initial_policy_params + initial_target_critic_params = initial_critic_params + + # Initialize optimizers. + initial_policy_opt_state = policy_optimizer.init(initial_policy_params) + initial_critic_opt_state = critic_optimizer.init(initial_critic_params) + + # Create initial state. + self._state = TrainingState( + policy_params=initial_policy_params, + target_policy_params=initial_target_policy_params, + critic_params=initial_critic_params, + target_critic_params=initial_target_critic_params, + policy_opt_state=initial_policy_opt_state, + critic_opt_state=initial_critic_opt_state, + steps=0, + key=key, + ) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + def step(self): + transitions = next(self._iterator) + + self._state, metrics = self._sgd_step(self._state, transitions) + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Increment counts and record the current time + counts = self._counter.increment(steps=1, walltime=elapsed_time) + + # Attempts to write the logs. + self._logger.write({**metrics, **counts}) + + def get_variables(self, names: List[str]) -> List[networks_lib.Params]: + # We only expose the variables for the learned policy and critic. The target + # policy and critic are internal details. + variables = { + 'policy': self._state.target_policy_params, + 'critic': self._state.target_critic_params, + } + return [variables[name] for name in names] + + def save(self) -> TrainingState: + return self._state + + def restore(self, state: TrainingState): + self._state = state diff --git a/acme/acme/agents/jax/crr/losses.py b/acme/acme/agents/jax/crr/losses.py new file mode 100644 index 00000000..d42d0656 --- /dev/null +++ b/acme/acme/agents/jax/crr/losses.py @@ -0,0 +1,99 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Loss (weight) functions for CRR.""" + +from typing import Callable + +from acme import types +from acme.agents.jax.crr.networks import CRRNetworks +from acme.jax import networks as networks_lib +import jax.numpy as jnp + +PolicyLossCoeff = Callable[[ + CRRNetworks, + networks_lib.Params, + networks_lib.Params, + types.Transition, + networks_lib.PRNGKey, +], jnp.ndarray] + + +def _compute_advantage(networks: CRRNetworks, + policy_params: networks_lib.Params, + critic_params: networks_lib.Params, + transition: types.Transition, + key: networks_lib.PRNGKey, + num_action_samples: int = 4) -> jnp.ndarray: + """Returns the advantage for the transition.""" + # Sample count actions. + replicated_observation = jnp.broadcast_to(transition.observation, + (num_action_samples,) + + transition.observation.shape) + dist_params = networks.policy_network.apply(policy_params, + replicated_observation) + actions = networks.sample(dist_params, key) + # Compute the state-action values for the sampled actions. + q_actions = networks.critic_network.apply(critic_params, + replicated_observation, actions) + # Take the mean as the state-value estimate. It is also possible to take the + # maximum, aka CRR(max); see table 1 in CRR paper. + q_estimate = jnp.mean(q_actions, axis=0) + # Compute the advantage. + q = networks.critic_network.apply(critic_params, transition.observation, + transition.action) + return q - q_estimate + + +def policy_loss_coeff_advantage_exp( + networks: CRRNetworks, + policy_params: networks_lib.Params, + critic_params: networks_lib.Params, + transition: types.Transition, + key: networks_lib.PRNGKey, + num_action_samples: int = 4, + beta: float = 1.0, + ratio_upper_bound: float = 20.0) -> jnp.ndarray: + """Exponential advantage weigting; see equation (4) in CRR paper.""" + advantage = _compute_advantage(networks, policy_params, critic_params, + transition, key, num_action_samples) + return jnp.minimum(jnp.exp(advantage / beta), ratio_upper_bound) + + +def policy_loss_coeff_advantage_indicator( + networks: CRRNetworks, + policy_params: networks_lib.Params, + critic_params: networks_lib.Params, + transition: types.Transition, + key: networks_lib.PRNGKey, + num_action_samples: int = 4) -> jnp.ndarray: + """Indicator advantage weighting; see equation (3) in CRR paper.""" + advantage = _compute_advantage(networks, policy_params, critic_params, + transition, key, num_action_samples) + return jnp.heaviside(advantage, 0.) + + +def policy_loss_coeff_constant(networks: CRRNetworks, + policy_params: networks_lib.Params, + critic_params: networks_lib.Params, + transition: types.Transition, + key: networks_lib.PRNGKey, + value: float = 1.0) -> jnp.ndarray: + """Constant weights.""" + del networks + del policy_params + del critic_params + del transition + del key + return value diff --git a/acme/acme/agents/jax/crr/networks.py b/acme/acme/agents/jax/crr/networks.py new file mode 100644 index 00000000..69677912 --- /dev/null +++ b/acme/acme/agents/jax/crr/networks.py @@ -0,0 +1,86 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Networks definition for CRR.""" + +import dataclasses +from typing import Callable, Tuple + +from acme import specs +from acme.jax import networks as networks_lib +from acme.jax import utils +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np + + +@dataclasses.dataclass +class CRRNetworks: + """Network and pure functions for the CRR agent..""" + policy_network: networks_lib.FeedForwardNetwork + critic_network: networks_lib.FeedForwardNetwork + log_prob: networks_lib.LogProbFn + sample: networks_lib.SampleFn + sample_eval: networks_lib.SampleFn + + +def make_networks( + spec: specs.EnvironmentSpec, + policy_layer_sizes: Tuple[int, ...] = (256, 256), + critic_layer_sizes: Tuple[int, ...] = (256, 256), + activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu, +) -> CRRNetworks: + """Creates networks used by the agent.""" + num_actions = np.prod(spec.actions.shape, dtype=int) + + # Create dummy observations and actions to create network parameters. + dummy_action = utils.add_batch_dim(utils.zeros_like(spec.actions)) + dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations)) + + def _policy_fn(obs: jnp.ndarray) -> jnp.ndarray: + network = hk.Sequential([ + hk.nets.MLP( + list(policy_layer_sizes), + w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'), + activation=activation, + activate_final=True), + networks_lib.NormalTanhDistribution(num_actions), + ]) + return network(obs) + + policy = hk.without_apply_rng(hk.transform(_policy_fn)) + policy_network = networks_lib.FeedForwardNetwork( + lambda key: policy.init(key, dummy_obs), policy.apply) + + def _critic_fn(obs, action): + network = hk.Sequential([ + hk.nets.MLP( + list(critic_layer_sizes) + [1], + w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'), + activation=activation), + ]) + data = jnp.concatenate([obs, action], axis=-1) + return network(data) + + critic = hk.without_apply_rng(hk.transform(_critic_fn)) + critic_network = networks_lib.FeedForwardNetwork( + lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply) + + return CRRNetworks( + policy_network=policy_network, + critic_network=critic_network, + log_prob=lambda params, actions: params.log_prob(actions), + sample=lambda params, key: params.sample(seed=key), + sample_eval=lambda params, key: params.mode()) diff --git a/acme/acme/agents/jax/d4pg/README.md b/acme/acme/agents/jax/d4pg/README.md new file mode 100644 index 00000000..840b3357 --- /dev/null +++ b/acme/acme/agents/jax/d4pg/README.md @@ -0,0 +1,24 @@ +# Distributed Distributional Deep Deterministic Policy Gradient (D4PG) + +This folder contains an implementation of the D4PG agent introduced in +([Barth-Maron et al., 2018]), which extends previous Deterministic Policy +Gradient (DPG) algorithms ([Silver et al., 2014]; [Lillicrap et al., 2015]) by +using a distributional Q-network similar to C51 ([Bellemare et al., 2017]). + +Note that since the synchronous agent is not distributed (i.e. not using +multiple asynchronous actors), it is not precisely speaking D4PG; a more +accurate name would be Distributional DDPG. In this algorithm, the critic +outputs a distribution over state-action values; in this particular case this +discrete distribution is parametrized as in C51. + +Detailed notes: + +- The `vmin|vmax` hyperparameters of the distributional critic may need tuning + depending on your environment's rewards. A good rule of thumb is to set `vmax` + to the discounted sum of the maximum instantaneous rewards for the maximum + episode length; then set `vmin` to `-vmax`. + +[Barth-Maron et al., 2018]: https://arxiv.org/abs/1804.08617 +[Bellemare et al., 2017]: https://arxiv.org/abs/1707.06887 +[Lillicrap et al., 2015]: https://arxiv.org/abs/1509.02971 +[Silver et al., 2014]: http://proceedings.mlr.press/v32/silver14 diff --git a/acme/acme/agents/jax/d4pg/__init__.py b/acme/acme/agents/jax/d4pg/__init__.py new file mode 100644 index 00000000..a9ea271e --- /dev/null +++ b/acme/acme/agents/jax/d4pg/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementations of a D4PG agent.""" + +from acme.agents.jax.d4pg.builder import D4PGBuilder +from acme.agents.jax.d4pg.config import D4PGConfig +from acme.agents.jax.d4pg.learning import D4PGLearner +from acme.agents.jax.d4pg.networks import D4PGNetworks +from acme.agents.jax.d4pg.networks import get_default_behavior_policy +from acme.agents.jax.d4pg.networks import get_default_eval_policy +from acme.agents.jax.d4pg.networks import make_networks + diff --git a/acme/acme/agents/jax/d4pg/builder.py b/acme/acme/agents/jax/d4pg/builder.py new file mode 100644 index 00000000..38f20c5f --- /dev/null +++ b/acme/acme/agents/jax/d4pg/builder.py @@ -0,0 +1,173 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""D4PG Builder.""" +from typing import Iterator, List, Optional + +import acme +from acme import adders +from acme import core +from acme import specs +from acme.adders import reverb as adders_reverb +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import actors +from acme.agents.jax import builders +from acme.agents.jax.d4pg import config as d4pg_config +from acme.agents.jax.d4pg import learning +from acme.agents.jax.d4pg import networks as d4pg_networks +from acme.datasets import reverb as datasets +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.jax import variable_utils +from acme.utils import counting +from acme.utils import loggers +import jax +import optax +import reverb +from reverb import rate_limiters + + +class D4PGBuilder(builders.ActorLearnerBuilder[d4pg_networks.D4PGNetworks, + actor_core_lib.FeedForwardPolicy, + reverb.ReplaySample]): + """D4PG Builder.""" + + def __init__( + self, + config: d4pg_config.D4PGConfig, + ): + """Creates a D4PG learner, a behavior policy and an eval actor. + + Args: + config: a config with D4PG hps + """ + self._config = config + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: d4pg_networks.D4PGNetworks, + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec, replay_client + + policy_optimizer = optax.adam(self._config.learning_rate) + critic_optimizer = optax.adam(self._config.learning_rate) + + if self._config.clipping: + policy_optimizer = optax.chain( + optax.clip_by_global_norm(40.), policy_optimizer) + critic_optimizer = optax.chain( + optax.clip_by_global_norm(40.), critic_optimizer) + + # The learner updates the parameters (and initializes them). + return learning.D4PGLearner( + policy_network=networks.policy_network, + critic_network=networks.critic_network, + random_key=random_key, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + clipping=self._config.clipping, + discount=self._config.discount, + target_update_period=self._config.target_update_period, + iterator=dataset, + counter=counter, + logger=logger_fn('learner'), + num_sgd_steps_per_step=self._config.num_sgd_steps_per_step) + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: actor_core_lib.FeedForwardPolicy, + ) -> List[reverb.Table]: + """Create tables to insert data into.""" + del policy + # Create the rate limiter. + if self._config.samples_per_insert: + samples_per_insert_tolerance = ( + self._config.samples_per_insert_tolerance_rate * + self._config.samples_per_insert) + error_buffer = self._config.min_replay_size * samples_per_insert_tolerance + limiter = rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._config.min_replay_size, + samples_per_insert=self._config.samples_per_insert, + error_buffer=error_buffer) + else: + limiter = rate_limiters.MinSize(self._config.min_replay_size) + return [ + reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_replay_size, + rate_limiter=limiter, + signature=adders_reverb.NStepTransitionAdder.signature( + environment_spec)) + ] + + def make_dataset_iterator( + self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]: + """Create a dataset iterator to use for learning/updating the agent.""" + dataset = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=replay_client.server_address, + batch_size=self._config.batch_size * + self._config.num_sgd_steps_per_step, + prefetch_size=self._config.prefetch_size) + return utils.device_put(dataset.as_numpy_iterator(), jax.devices()[0]) + + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[actor_core_lib.FeedForwardPolicy], + ) -> Optional[adders.Adder]: + """Create an adder which records data generated by the actor/environment.""" + del environment_spec, policy + return adders_reverb.NStepTransitionAdder( + priority_fns={self._config.replay_table_name: None}, + client=replay_client, + n_step=self._config.n_step, + discount=self._config.discount) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: actor_core_lib.FeedForwardPolicy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> acme.Actor: + del environment_spec + assert variable_source is not None + actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) + # Inference happens on CPU, so it's better to move variables there too. + variable_client = variable_utils.VariableClient( + variable_source, 'policy', device='cpu') + return actors.GenericActor( + actor_core, random_key, variable_client, adder, backend='cpu') + + def make_policy(self, + networks: d4pg_networks.D4PGNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False) -> actor_core_lib.FeedForwardPolicy: + """Create the policy.""" + del environment_spec + if evaluation: + return d4pg_networks.get_default_eval_policy(networks) + return d4pg_networks.get_default_behavior_policy(networks, self._config) diff --git a/acme/acme/agents/jax/d4pg/config.py b/acme/acme/agents/jax/d4pg/config.py new file mode 100644 index 00000000..338a1abc --- /dev/null +++ b/acme/acme/agents/jax/d4pg/config.py @@ -0,0 +1,45 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Config classes for D4PG.""" +import dataclasses +from typing import Optional +from acme.adders import reverb as adders_reverb + + +@dataclasses.dataclass +class D4PGConfig: + """Configuration options for D4PG.""" + sigma: float = 0.3 + target_update_period: int = 100 + samples_per_insert: Optional[float] = 32.0 + + # Loss options + n_step: int = 5 + discount: float = 0.99 + batch_size: int = 256 + learning_rate: float = 1e-4 + clipping: bool = True + + # Replay options + min_replay_size: int = 1000 + max_replay_size: int = 1000000 + replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE + prefetch_size: int = 4 + # Rate to be used for the SampleToInsertRatio rate limitter tolerance. + # See a formula in make_replay_tables for more details. + samples_per_insert_tolerance_rate: float = 0.1 + + # How many gradient updates to perform per step. + num_sgd_steps_per_step: int = 1 diff --git a/acme/acme/agents/jax/d4pg/learning.py b/acme/acme/agents/jax/d4pg/learning.py new file mode 100644 index 00000000..ab04ad37 --- /dev/null +++ b/acme/acme/agents/jax/d4pg/learning.py @@ -0,0 +1,247 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""D4PG learner implementation.""" + +import time +from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple + +import acme +from acme import types +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.utils import counting +from acme.utils import loggers +import jax +import jax.numpy as jnp +import optax +import reverb +import rlax + + +class TrainingState(NamedTuple): + """Contains training state for the learner.""" + policy_params: networks_lib.Params + target_policy_params: networks_lib.Params + critic_params: networks_lib.Params + target_critic_params: networks_lib.Params + policy_opt_state: optax.OptState + critic_opt_state: optax.OptState + steps: int + + +class D4PGLearner(acme.Learner): + """D4PG learner. + + This is the learning component of a D4PG agent. IE it takes a dataset as input + and implements update functionality to learn from this dataset. + """ + + _state: TrainingState + + def __init__(self, + policy_network: networks_lib.FeedForwardNetwork, + critic_network: networks_lib.FeedForwardNetwork, + random_key: networks_lib.PRNGKey, + discount: float, + target_update_period: int, + iterator: Iterator[reverb.ReplaySample], + policy_optimizer: Optional[optax.GradientTransformation] = None, + critic_optimizer: Optional[optax.GradientTransformation] = None, + clipping: bool = True, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + jit: bool = True, + num_sgd_steps_per_step: int = 1): + + def critic_mean( + critic_params: networks_lib.Params, + observation: types.NestedArray, + action: types.NestedArray, + ) -> jnp.ndarray: + # We add batch dimension to make sure batch concat in critic_network + # works correctly. + observation = utils.add_batch_dim(observation) + action = utils.add_batch_dim(action) + # Computes the mean action-value estimate. + logits, atoms = critic_network.apply(critic_params, observation, action) + logits = utils.squeeze_batch_dim(logits) + probabilities = jax.nn.softmax(logits) + return jnp.sum(probabilities * atoms, axis=-1) + + def policy_loss( + policy_params: networks_lib.Params, + critic_params: networks_lib.Params, + o_t: types.NestedArray, + ) -> jnp.ndarray: + # Computes the discrete policy gradient loss. + dpg_a_t = policy_network.apply(policy_params, o_t) + grad_critic = jax.vmap( + jax.grad(critic_mean, argnums=2), in_axes=(None, 0, 0)) + dq_da = grad_critic(critic_params, o_t, dpg_a_t) + dqda_clipping = 1. if clipping else None + batch_dpg_learning = jax.vmap(rlax.dpg_loss, in_axes=(0, 0, None)) + loss = batch_dpg_learning(dpg_a_t, dq_da, dqda_clipping) + return jnp.mean(loss) + + def critic_loss( + critic_params: networks_lib.Params, + state: TrainingState, + transition: types.Transition, + ): + # Computes the distributional critic loss. + q_tm1, atoms_tm1 = critic_network.apply(critic_params, + transition.observation, + transition.action) + a = policy_network.apply(state.target_policy_params, + transition.next_observation) + q_t, atoms_t = critic_network.apply(state.target_critic_params, + transition.next_observation, a) + batch_td_learning = jax.vmap( + rlax.categorical_td_learning, in_axes=(None, 0, 0, 0, None, 0)) + loss = batch_td_learning(atoms_tm1, q_tm1, transition.reward, + discount * transition.discount, atoms_t, q_t) + return jnp.mean(loss) + + def sgd_step( + state: TrainingState, + transitions: types.Transition, + ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: + + # TODO(jaslanides): Use a shared forward pass for efficiency. + policy_loss_and_grad = jax.value_and_grad(policy_loss) + critic_loss_and_grad = jax.value_and_grad(critic_loss) + + # Compute losses and their gradients. + policy_loss_value, policy_gradients = policy_loss_and_grad( + state.policy_params, state.critic_params, + transitions.next_observation) + critic_loss_value, critic_gradients = critic_loss_and_grad( + state.critic_params, state, transitions) + + # Get optimizer updates and state. + policy_updates, policy_opt_state = policy_optimizer.update( # pytype: disable=attribute-error + policy_gradients, state.policy_opt_state) + critic_updates, critic_opt_state = critic_optimizer.update( # pytype: disable=attribute-error + critic_gradients, state.critic_opt_state) + + # Apply optimizer updates to parameters. + policy_params = optax.apply_updates(state.policy_params, policy_updates) + critic_params = optax.apply_updates(state.critic_params, critic_updates) + + steps = state.steps + 1 + + # Periodically update target networks. + target_policy_params, target_critic_params = optax.periodic_update( + (policy_params, critic_params), + (state.target_policy_params, state.target_critic_params), steps, + self._target_update_period) + + new_state = TrainingState( + policy_params=policy_params, + critic_params=critic_params, + target_policy_params=target_policy_params, + target_critic_params=target_critic_params, + policy_opt_state=policy_opt_state, + critic_opt_state=critic_opt_state, + steps=steps, + ) + + metrics = { + 'policy_loss': policy_loss_value, + 'critic_loss': critic_loss_value, + } + + return new_state, metrics + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + 'learner', + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + steps_key=self._counter.get_steps_key()) + + # Necessary to track when to update target networks. + self._target_update_period = target_update_period + + # Create prefetching dataset iterator. + self._iterator = iterator + + # Maybe use the JIT compiler. + sgd_step = utils.process_multiple_batches(sgd_step, num_sgd_steps_per_step) + self._sgd_step = jax.jit(sgd_step) if jit else sgd_step + + # Create the network parameters and copy into the target network parameters. + key_policy, key_critic = jax.random.split(random_key) + initial_policy_params = policy_network.init(key_policy) + initial_critic_params = critic_network.init(key_critic) + initial_target_policy_params = initial_policy_params + initial_target_critic_params = initial_critic_params + + # Create optimizers if they aren't given. + critic_optimizer = critic_optimizer or optax.adam(1e-4) + policy_optimizer = policy_optimizer or optax.adam(1e-4) + + # Initialize optimizers. + initial_policy_opt_state = policy_optimizer.init(initial_policy_params) # pytype: disable=attribute-error + initial_critic_opt_state = critic_optimizer.init(initial_critic_params) # pytype: disable=attribute-error + + # Create initial state. + self._state = TrainingState( + policy_params=initial_policy_params, + target_policy_params=initial_target_policy_params, + critic_params=initial_critic_params, + target_critic_params=initial_target_critic_params, + policy_opt_state=initial_policy_opt_state, + critic_opt_state=initial_critic_opt_state, + steps=0, + ) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + def step(self): + # Get data from replay (dropping extras if any). Note there is no + # extra data here because we do not insert any into Reverb. + sample = next(self._iterator) + transitions = types.Transition(*sample.data) + + self._state, metrics = self._sgd_step(self._state, transitions) + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Increment counts and record the current time + counts = self._counter.increment(steps=1, walltime=elapsed_time) + + # Attempts to write the logs. + self._logger.write({**metrics, **counts}) + + def get_variables(self, names: List[str]) -> List[networks_lib.Params]: + variables = { + 'policy': self._state.target_policy_params, + 'critic': self._state.target_critic_params, + } + return [variables[name] for name in names] + + def save(self) -> TrainingState: + return self._state + + def restore(self, state: TrainingState): + self._state = state diff --git a/acme/acme/agents/jax/d4pg/networks.py b/acme/acme/agents/jax/d4pg/networks.py new file mode 100644 index 00000000..d598035e --- /dev/null +++ b/acme/acme/agents/jax/d4pg/networks.py @@ -0,0 +1,109 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""D4PG networks definition.""" + +import dataclasses +from typing import Sequence + +from acme import specs +from acme import types +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax.d4pg import config as d4pg_config +from acme.jax import networks as networks_lib +from acme.jax import utils +import haiku as hk +import jax.numpy as jnp +import numpy as np +import rlax + + +@dataclasses.dataclass +class D4PGNetworks: + """Network and pure functions for the D4PG agent..""" + policy_network: networks_lib.FeedForwardNetwork + critic_network: networks_lib.FeedForwardNetwork + + +def get_default_behavior_policy( + networks: D4PGNetworks, + config: d4pg_config.D4PGConfig) -> actor_core_lib.FeedForwardPolicy: + """Selects action according to the training policy.""" + def behavior_policy(params: networks_lib.Params, key: networks_lib.PRNGKey, + observation: types.NestedArray): + action = networks.policy_network.apply(params, observation) + if config.sigma != 0: + action = rlax.add_gaussian_noise(key, action, config.sigma) + return action + + return behavior_policy + + +def get_default_eval_policy( + networks: D4PGNetworks) -> actor_core_lib.FeedForwardPolicy: + """Selects action according to the training policy.""" + def behavior_policy(params: networks_lib.Params, key: networks_lib.PRNGKey, + observation: types.NestedArray): + del key + action = networks.policy_network.apply(params, observation) + return action + return behavior_policy + + +def make_networks( + spec: specs.EnvironmentSpec, + policy_layer_sizes: Sequence[int] = (300, 200), + critic_layer_sizes: Sequence[int] = (400, 300), + vmin: float = -150., + vmax: float = 150., + num_atoms: int = 51, +) -> D4PGNetworks: + """Creates networks used by the agent.""" + + action_spec = spec.actions + + num_dimensions = np.prod(action_spec.shape, dtype=int) + critic_atoms = jnp.linspace(vmin, vmax, num_atoms) + + def _actor_fn(obs): + network = hk.Sequential([ + utils.batch_concat, + networks_lib.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks_lib.NearZeroInitializedLinear(num_dimensions), + networks_lib.TanhToSpec(action_spec), + ]) + return network(obs) + + def _critic_fn(obs, action): + network = hk.Sequential([ + utils.batch_concat, + networks_lib.LayerNormMLP(layer_sizes=[*critic_layer_sizes, num_atoms]), + ]) + value = network([obs, action]) + return value, critic_atoms + + policy = hk.without_apply_rng(hk.transform(_actor_fn)) + critic = hk.without_apply_rng(hk.transform(_critic_fn)) + + # Create dummy observations and actions to create network parameters. + dummy_action = utils.zeros_like(spec.actions) + dummy_obs = utils.zeros_like(spec.observations) + dummy_action = utils.add_batch_dim(dummy_action) + dummy_obs = utils.add_batch_dim(dummy_obs) + + return D4PGNetworks( + policy_network=networks_lib.FeedForwardNetwork( + lambda rng: policy.init(rng, dummy_obs), policy.apply), + critic_network=networks_lib.FeedForwardNetwork( + lambda rng: critic.init(rng, dummy_obs, dummy_action), critic.apply)) diff --git a/acme/acme/agents/jax/dqn/README.md b/acme/acme/agents/jax/dqn/README.md new file mode 100644 index 00000000..37f9c55a --- /dev/null +++ b/acme/acme/agents/jax/dqn/README.md @@ -0,0 +1,37 @@ +# Deep Q-Networks (DQN) + +This folder contains an implementation of the DQN algorithm +([Mnih et al., 2013], [Mnih et al., 2015]), with extras bells & whistles, +similar to Rainbow DQN ([Hessel et al., 2017]). + +* Q-learning with neural network function approximation. The loss is given by + the Huber loss applied to the temporal difference error. +* Target Q' network updated periodically ([Mnih et al., 2015]). +* N-step bootstrapping ([Sutton & Barto, 2018]). +* Double Q-learning ([van Hasselt et al., 2015]). +* Prioritized experience replay ([Schaul et al., 2015]). + +This DQN implementation has a configurable loss. In losses.py, you can find +ready-to-use implementations of other methods related to DQN. + +* Vanilla Deep Q-learning [Mnih et al., 2013], with two optimization tweaks + (Adam intead of RMSProp, square instead of Huber, as suggested e.g. by + [Obando-Ceron et al., 2020]). +* Quantile regression DQN (QrDQN) [Dabney et al., 2017] +* Categorical DQN (C51) [Bellemare et al., 2017] +* Munchausen DQN [Vieillard et al., 2020] +* Regularized DQN (DQNReg) [Co-Reyes et al., 2021] + + +[Mnih et al., 2013]: https://arxiv.org/abs/1312.5602 +[Mnih et al., 2015]: https://www.nature.com/articles/nature14236 +[van Hasselt et al., 2015]: https://arxiv.org/abs/1509.06461 +[Schaul et al., 2015]: https://arxiv.org/abs/1511.05952 +[Bellemare et al., 2017]: https://arxiv.org/abs/1707.06887 +[Dabney et al., 2017]: https://arxiv.org/abs/1710.10044 +[Hessel et al., 2017]: https://arxiv.org/abs/1710.02298 +[Horgan et al., 2018]: https://arxiv.org/abs/1803.00933 +[Sutton & Barto, 2018]: http://incompleteideas.net/book/the-book.html +[Obando-Ceron et al., 2020]: https://arxiv.org/abs/2011.14826 +[Vieillard et al., 2020]: https://arxiv.org/abs/2007.14430 +[Co-Reyes et al., 2021]: https://arxiv.org/abs/2101.03958 diff --git a/acme/acme/agents/jax/dqn/__init__.py b/acme/acme/agents/jax/dqn/__init__.py new file mode 100644 index 00000000..d06a4bbe --- /dev/null +++ b/acme/acme/agents/jax/dqn/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementation of a deep Q-networks (DQN) agent.""" + +from acme.agents.jax.dqn.actor import behavior_policy +from acme.agents.jax.dqn.actor import default_behavior_policy +from acme.agents.jax.dqn.actor import Epsilon +from acme.agents.jax.dqn.actor import EpsilonPolicy +from acme.agents.jax.dqn.builder import DQNBuilder +from acme.agents.jax.dqn.config import DQNConfig +from acme.agents.jax.dqn.learning import DQNLearner +from acme.agents.jax.dqn.learning_lib import LossExtra +from acme.agents.jax.dqn.learning_lib import LossFn +from acme.agents.jax.dqn.learning_lib import ReverbUpdate +from acme.agents.jax.dqn.learning_lib import SGDLearner +from acme.agents.jax.dqn.losses import PrioritizedDoubleQLearning +from acme.agents.jax.dqn.losses import QrDqn diff --git a/acme/acme/agents/jax/dqn/actor.py b/acme/acme/agents/jax/dqn/actor.py new file mode 100644 index 00000000..c62c2162 --- /dev/null +++ b/acme/acme/agents/jax/dqn/actor.py @@ -0,0 +1,111 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DQN actor helpers.""" + +from typing import Callable, Sequence + +from acme.agents.jax import actor_core as actor_core_lib +from acme.jax import networks as networks_lib +from acme.jax import utils +import chex +import jax +import jax.numpy as jnp +import rlax + + +Epsilon = float +EpsilonPolicy = Callable[[ + networks_lib.Params, networks_lib.PRNGKey, networks_lib + .Observation, Epsilon +], networks_lib.Action] + + +@chex.dataclass(frozen=True, mappable_dataclass=False) +class EpsilonActorState: + rng: networks_lib.PRNGKey + epsilon: jnp.ndarray + + +def alternating_epsilons_actor_core( + policy_network: EpsilonPolicy, epsilons: Sequence[float], +) -> actor_core_lib.ActorCore[EpsilonActorState, None]: + """Returns actor components for alternating epsilon exploration. + + Args: + policy_network: A feedforward action selecting function. + epsilons: epsilons to alternate per-episode for epsilon-greedy exploration. + + Returns: + A feedforward policy. + """ + epsilons = jnp.array(epsilons) + + def apply_and_sample(params: networks_lib.Params, + observation: networks_lib.Observation, + state: EpsilonActorState): + random_key, key = jax.random.split(state.rng) + actions = policy_network(params, key, observation, state.epsilon) + return (actions.astype(jnp.int32), + EpsilonActorState(rng=random_key, epsilon=state.epsilon)) + + def policy_init(random_key: networks_lib.PRNGKey): + random_key, key = jax.random.split(random_key) + epsilon = jax.random.choice(key, epsilons) + return EpsilonActorState(rng=random_key, epsilon=epsilon) + + return actor_core_lib.ActorCore( + init=policy_init, select_action=apply_and_sample, + get_extras=lambda _: None) + + +def behavior_policy(network: networks_lib.FeedForwardNetwork + ) -> EpsilonPolicy: + """A policy with parameterized epsilon-greedy exploration.""" + + def apply_and_sample(params: networks_lib.Params, key: networks_lib.PRNGKey, + observation: networks_lib.Observation, epsilon: Epsilon + ) -> networks_lib.Action: + # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. + observation = utils.add_batch_dim(observation) + action_values = network.apply(params, observation) + action_values = utils.squeeze_batch_dim(action_values) + return rlax.epsilon_greedy(epsilon).sample(key, action_values) + + return apply_and_sample + + +def default_behavior_policy(network: networks_lib.FeedForwardNetwork, + epsilon: Epsilon) -> EpsilonPolicy: + """A policy with a fixed-epsilon epsilon-greedy exploration. + + DEPRECATED: use behavior_policy instead. + Args: + network: network producing observation -> action values or logits + epsilon: sampling parameter that overrides the one in EpsilonPolicy + Returns: + epsilon-greedy behavior policy with fixed epsilon + """ + # TODO(lukstafi): remove this function and migrate its users. + + def apply_and_sample(params: networks_lib.Params, key: networks_lib.PRNGKey, + observation: networks_lib.Observation, _: Epsilon + ) -> networks_lib.Action: + # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. + observation = utils.add_batch_dim(observation) + action_values = network.apply(params, observation) + action_values = utils.squeeze_batch_dim(action_values) + return rlax.epsilon_greedy(epsilon).sample(key, action_values) + + return apply_and_sample diff --git a/acme/acme/agents/jax/dqn/builder.py b/acme/acme/agents/jax/dqn/builder.py new file mode 100644 index 00000000..a31d6427 --- /dev/null +++ b/acme/acme/agents/jax/dqn/builder.py @@ -0,0 +1,182 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DQN Builder.""" +from typing import Iterator, List, Optional, Sequence + +from acme import adders +from acme import core +from acme import specs +from acme.adders import reverb as adders_reverb +from acme.agents.jax import actors +from acme.agents.jax import builders +from acme.agents.jax.dqn import actor as dqn_actor +from acme.agents.jax.dqn import config as dqn_config +from acme.agents.jax.dqn import learning_lib +from acme.datasets import reverb as datasets +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.jax import variable_utils +from acme.utils import counting +from acme.utils import loggers +import jax +import optax +import reverb +from reverb import rate_limiters + + +class DQNBuilder(builders.ActorLearnerBuilder[networks_lib.FeedForwardNetwork, + dqn_actor.EpsilonPolicy, + reverb.ReplaySample]): + """DQN Builder.""" + + def __init__(self, + config: dqn_config.DQNConfig, + loss_fn: learning_lib.LossFn, + actor_backend: Optional[str] = 'cpu'): + """Creates DQN learner and the behavior policies. + + Args: + config: DQN config. + loss_fn: A loss function. + actor_backend: Which backend to use when jitting the policy. + """ + self._config = config + self._loss_fn = loss_fn + self._actor_backend = actor_backend + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: networks_lib.FeedForwardNetwork, + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: Optional[specs.EnvironmentSpec], + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec + + return learning_lib.SGDLearner( + network=networks, + random_key=random_key, + optimizer=optax.adam( + self._config.learning_rate, eps=self._config.adam_eps), + target_update_period=self._config.target_update_period, + data_iterator=dataset, + loss_fn=self._loss_fn, + replay_client=replay_client, + replay_table_name=self._config.replay_table_name, + counter=counter, + num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, + logger=logger_fn('learner')) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: dqn_actor.EpsilonPolicy, + environment_spec: Optional[specs.EnvironmentSpec], + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> core.Actor: + del environment_spec + assert variable_source is not None + # Inference happens on CPU, so it's better to move variables there too. + variable_client = variable_utils.VariableClient( + variable_source, '', device='cpu') + epsilon = self._config.epsilon + epsilons = epsilon if epsilon is Sequence else (epsilon,) + actor_core = dqn_actor.alternating_epsilons_actor_core( + policy, epsilons=epsilons) + return actors.GenericActor( + actor=actor_core, + random_key=random_key, + variable_client=variable_client, + adder=adder, + backend=self._actor_backend) + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: dqn_actor.EpsilonPolicy, + ) -> List[reverb.Table]: + """Creates reverb tables for the algorithm.""" + del policy + samples_per_insert_tolerance = ( + self._config.samples_per_insert_tolerance_rate * + self._config.samples_per_insert) + error_buffer = self._config.min_replay_size * samples_per_insert_tolerance + limiter = rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._config.min_replay_size, + samples_per_insert=self._config.samples_per_insert, + error_buffer=error_buffer) + return [ + reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Prioritized( + self._config.priority_exponent), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_replay_size, + rate_limiter=limiter, + signature=adders_reverb.NStepTransitionAdder.signature( + environment_spec)) + ] + + @property + def batch_size_per_device(self) -> int: + """Splits the batch size across local devices.""" + # Account for the number of SGD steps per step. + batch_size = self._config.batch_size * self._config.num_sgd_steps_per_step + batch_size = self._config.batch_size + num_devices = jax.local_device_count() + if batch_size % num_devices != 0: + raise ValueError( + 'The DQN learner received a batch size that is not divisible by the ' + f'number of available learner devices. Got: batch_size={batch_size}, ' + f'num_devices={num_devices}.') + batch_size_per_device = batch_size // num_devices + return batch_size_per_device + + def make_dataset_iterator( + self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]: + """Creates a dataset iterator to use for learning.""" + dataset = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=replay_client.server_address, + batch_size=self.batch_size_per_device, + prefetch_size=self._config.prefetch_size) + return utils.multi_device_put(dataset.as_numpy_iterator(), + jax.local_devices()) + + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[dqn_actor.EpsilonPolicy], + ) -> Optional[adders.Adder]: + """Creates an adder which handles observations.""" + del environment_spec, policy + return adders_reverb.NStepTransitionAdder( + priority_fns={self._config.replay_table_name: None}, + client=replay_client, + n_step=self._config.n_step, + discount=self._config.discount) + + def make_policy(self, + networks: networks_lib.FeedForwardNetwork, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False) -> dqn_actor.EpsilonPolicy: + """Creates the policy.""" + del environment_spec, evaluation + return dqn_actor.behavior_policy(networks) diff --git a/acme/acme/agents/jax/dqn/config.py b/acme/acme/agents/jax/dqn/config.py new file mode 100644 index 00000000..ab90368a --- /dev/null +++ b/acme/acme/agents/jax/dqn/config.py @@ -0,0 +1,85 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DQN config.""" + +import dataclasses +from typing import Callable, Sequence, Union + +from acme.adders import reverb as adders_reverb +import jax.numpy as jnp +import numpy as np + + +@dataclasses.dataclass +class DQNConfig: + """Configuration options for DQN agent. + + Attributes: + epsilon: for use by epsilon-greedy policies. If multiple, the epsilons are + alternated randomly per-episode. + seed: Random seed. + learning_rate: Learning rate for Adam optimizer. Could be a number or a + function defining a schedule. + adam_eps: Epsilon for Adam optimizer. + discount: Discount rate applied to value per timestep. + n_step: N-step TD learning. + target_update_period: Update target network every period. + max_gradient_norm: For gradient clipping. + batch_size: Number of transitions per batch. + min_replay_size: Minimum replay size. + max_replay_size: Maximum replay size. + replay_table_name: Reverb table, defaults to DEFAULT_PRIORITY_TABLE. + importance_sampling_exponent: Importance sampling for replay. + priority_exponent: Priority exponent for replay. + prefetch_size: Prefetch size for reverb replay performance. + samples_per_insert: Ratio of learning samples to insert. + samples_per_insert_tolerance_rate: Rate to be used for + the SampleToInsertRatio rate limitter tolerance. + See a formula in make_replay_tables for more details. + num_sgd_steps_per_step: How many gradient updates to perform per learner + step. + """ + epsilon: Union[float, Sequence[float]] = 0.05 + # TODO(b/191706065): update all clients and remove this field. + seed: int = 1 + + # Learning rule + learning_rate: Union[float, Callable[[int], float]] = 1e-3 + adam_eps: float = 1e-8 # Eps for Adam optimizer. + discount: float = 0.99 # Discount rate applied to value per timestep. + n_step: int = 5 # N-step TD learning. + target_update_period: int = 100 # Update target network every period. + max_gradient_norm: float = np.inf # For gradient clipping. + + # Replay options + batch_size: int = 256 + min_replay_size: int = 1_000 + max_replay_size: int = 1_000_000 + replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE + importance_sampling_exponent: float = 0.2 + priority_exponent: float = 0.6 + prefetch_size: int = 4 + samples_per_insert: float = 0.5 + samples_per_insert_tolerance_rate: float = 0.1 + + num_sgd_steps_per_step: int = 1 + + +def logspace_epsilons(num_epsilons: int, epsilon: float = 0.017 + ) -> Sequence[float]: + """`num_epsilons` of logspace-distributed values, with median `epsilon`.""" + if num_epsilons <= 1: + return (epsilon,) + return jnp.logspace(1, 8, num_epsilons, base=epsilon ** (2./9.)) diff --git a/acme/acme/agents/jax/dqn/learning.py b/acme/acme/agents/jax/dqn/learning.py new file mode 100644 index 00000000..da4fa6f4 --- /dev/null +++ b/acme/acme/agents/jax/dqn/learning.py @@ -0,0 +1,72 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DQN learner implementation.""" + +from typing import Iterator, Optional + +from acme.adders import reverb as adders +from acme.agents.jax.dqn import learning_lib +from acme.agents.jax.dqn import losses +from acme.jax import networks as networks_lib +from acme.utils import counting +from acme.utils import loggers +import optax +import reverb + + +class DQNLearner(learning_lib.SGDLearner): + """DQN learner. + + We are in the process of migrating towards a more general SGDLearner to allow + for easy configuration of the loss. This is maintained now for compatibility. + """ + + def __init__(self, + network: networks_lib.FeedForwardNetwork, + discount: float, + importance_sampling_exponent: float, + target_update_period: int, + iterator: Iterator[reverb.ReplaySample], + optimizer: optax.GradientTransformation, + random_key: networks_lib.PRNGKey, + stochastic_network: bool = False, + max_abs_reward: float = 1., + huber_loss_parameter: float = 1., + replay_client: Optional[reverb.Client] = None, + replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + num_sgd_steps_per_step: int = 1): + """Initializes the learner.""" + loss_fn = losses.PrioritizedDoubleQLearning( + discount=discount, + importance_sampling_exponent=importance_sampling_exponent, + max_abs_reward=max_abs_reward, + huber_loss_parameter=huber_loss_parameter, + stochastic_network=stochastic_network, + ) + super().__init__( + network=network, + loss_fn=loss_fn, + optimizer=optimizer, + data_iterator=iterator, + target_update_period=target_update_period, + random_key=random_key, + replay_client=replay_client, + replay_table_name=replay_table_name, + counter=counter, + logger=logger, + num_sgd_steps_per_step=num_sgd_steps_per_step, + ) diff --git a/acme/acme/agents/jax/dqn/learning_lib.py b/acme/acme/agents/jax/dqn/learning_lib.py new file mode 100644 index 00000000..13eda1f2 --- /dev/null +++ b/acme/acme/agents/jax/dqn/learning_lib.py @@ -0,0 +1,211 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SgdLearner takes steps of SGD on a LossFn.""" + +import functools +import time +from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple + +import acme +from acme.adders import reverb as adders +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.utils import async_utils +from acme.utils import counting +from acme.utils import loggers +import jax +import jax.numpy as jnp +import optax +import reverb +import tree +import typing_extensions + + +# The pmap axis name. Data means data parallelization. +PMAP_AXIS_NAME = 'data' + + +class ReverbUpdate(NamedTuple): + """Tuple for updating reverb priority information.""" + keys: jnp.ndarray + priorities: jnp.ndarray + + +class LossExtra(NamedTuple): + """Extra information that is returned along with loss value.""" + metrics: Dict[str, jnp.DeviceArray] + reverb_update: Optional[ReverbUpdate] = None + + +class LossFn(typing_extensions.Protocol): + """A LossFn calculates a loss on a single batch of data.""" + + def __call__(self, + network: networks_lib.FeedForwardNetwork, + params: networks_lib.Params, + target_params: networks_lib.Params, + batch: reverb.ReplaySample, + key: networks_lib.PRNGKey) -> Tuple[jnp.DeviceArray, LossExtra]: + """Calculates a loss on a single batch of data.""" + + +class TrainingState(NamedTuple): + """Holds the agent's training state.""" + params: networks_lib.Params + target_params: networks_lib.Params + opt_state: optax.OptState + steps: int + rng_key: networks_lib.PRNGKey + + +class SGDLearner(acme.Learner): + """An Acme learner based around SGD on batches. + + This learner currently supports optional prioritized replay and assumes a + TrainingState as described above. + """ + + def __init__(self, + network: networks_lib.FeedForwardNetwork, + loss_fn: LossFn, + optimizer: optax.GradientTransformation, + data_iterator: Iterator[reverb.ReplaySample], + target_update_period: int, + random_key: networks_lib.PRNGKey, + replay_client: Optional[reverb.Client] = None, + replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + num_sgd_steps_per_step: int = 1): + """Initialize the SGD learner.""" + self.network = network + + # Internalize the loss_fn with network. + self._loss = jax.jit(functools.partial(loss_fn, self.network)) + + # SGD performs the loss, optimizer update and periodic target net update. + def sgd_step(state: TrainingState, + batch: reverb.ReplaySample) -> Tuple[TrainingState, LossExtra]: + next_rng_key, rng_key = jax.random.split(state.rng_key) + # Implements one SGD step of the loss and updates training state + (loss, extra), grads = jax.value_and_grad( + self._loss, has_aux=True)(state.params, state.target_params, batch, + rng_key) + + loss = jax.lax.pmean(loss, axis_name=PMAP_AXIS_NAME) + # Average gradients over pmap replicas before optimizer update. + grads = jax.lax.pmean(grads, axis_name=PMAP_AXIS_NAME) + # Apply the optimizer updates + updates, new_opt_state = optimizer.update(grads, state.opt_state) + new_params = optax.apply_updates(state.params, updates) + + extra.metrics.update({'total_loss': loss}) + + # Periodically update target networks. + steps = state.steps + 1 + target_params = optax.periodic_update(new_params, state.target_params, + steps, target_update_period) + + new_training_state = TrainingState( + new_params, target_params, new_opt_state, steps, next_rng_key) + return new_training_state, extra + + def postprocess_aux(extra: LossExtra) -> LossExtra: + reverb_update = jax.tree_util.tree_map( + lambda a: jnp.reshape(a, (-1, *a.shape[2:])), extra.reverb_update) + return extra._replace( + metrics=jax.tree_util.tree_map(jnp.mean, extra.metrics), + reverb_update=reverb_update) + + self._num_sgd_steps_per_step = num_sgd_steps_per_step + sgd_step = utils.process_multiple_batches(sgd_step, num_sgd_steps_per_step, + postprocess_aux) + self._sgd_step = jax.pmap( + sgd_step, axis_name=PMAP_AXIS_NAME, devices=jax.local_devices()) + + # Internalise agent components + self._data_iterator = data_iterator + self._target_update_period = target_update_period + self._counter = counter or counting.Counter() + self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + # Initialize the network parameters + key_params, key_target, key_state = jax.random.split(random_key, 3) + initial_params = self.network.init(key_params) + initial_target_params = self.network.init(key_target) + state = TrainingState( + params=initial_params, + target_params=initial_target_params, + opt_state=optimizer.init(initial_params), + steps=0, + rng_key=key_state, + ) + self._state = utils.replicate_in_all_devices(state, jax.local_devices()) + + # Update replay priorities + def update_priorities(reverb_update: ReverbUpdate) -> None: + if replay_client is None: + return + keys, priorities = tree.map_structure( + # Fetch array and combine device and batch dimensions. + lambda x: utils.fetch_devicearray(x).reshape((-1,) + x.shape[2:]), + (reverb_update.keys, reverb_update.priorities)) + replay_client.mutate_priorities( + table=replay_table_name, + updates=dict(zip(keys, priorities))) + self._replay_client = replay_client + self._async_priority_updater = async_utils.AsyncExecutor(update_priorities) + + def step(self): + """Takes one SGD step on the learner.""" + with jax.profiler.StepTraceAnnotation('step', + step_num=self._state.steps): + batch = next(self._data_iterator) + self._state, extra = self._sgd_step(self._state, batch) + + # Compute elapsed time. + timestamp = time.time() + elapsed = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + if self._replay_client and extra.reverb_update: + reverb_update = extra.reverb_update._replace(keys=batch.info.key) + self._async_priority_updater.put(reverb_update) + + steps_per_sec = (self._num_sgd_steps_per_step / elapsed) if elapsed else 0 + metrics = utils.get_from_first_device(extra.metrics) + metrics['steps_per_second'] = steps_per_sec + + # Update our counts and record it. + result = self._counter.increment( + steps=self._num_sgd_steps_per_step, walltime=elapsed) + result.update(metrics) + self._logger.write(result) + + def get_variables(self, names: List[str]) -> List[networks_lib.Params]: + # Return first replica of parameters. + return utils.get_from_first_device([self._state.params]) + + def save(self) -> TrainingState: + # Serialize only the first replica of parameters and optimizer state. + return utils.get_from_first_device(self._state) + + def restore(self, state: TrainingState): + self._state = utils.replicate_in_all_devices(state, jax.local_devices()) diff --git a/acme/acme/agents/jax/dqn/losses.py b/acme/acme/agents/jax/dqn/losses.py new file mode 100644 index 00000000..abb181d0 --- /dev/null +++ b/acme/acme/agents/jax/dqn/losses.py @@ -0,0 +1,325 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DQN losses.""" +import dataclasses +from typing import Tuple + +from acme import types +from acme.agents.jax.dqn import learning_lib +from acme.jax import networks as networks_lib +import jax +import jax.numpy as jnp +import reverb +import rlax + + +@dataclasses.dataclass +class PrioritizedDoubleQLearning(learning_lib.LossFn): + """Clipped double q learning with prioritization on TD error.""" + discount: float = 0.99 + importance_sampling_exponent: float = 0.2 + max_abs_reward: float = 1. + huber_loss_parameter: float = 1. + stochastic_network: bool = False + + def __call__( + self, + network: networks_lib.FeedForwardNetwork, + params: networks_lib.Params, + target_params: networks_lib.Params, + batch: reverb.ReplaySample, + key: networks_lib.PRNGKey, + ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: + """Calculate a loss on a single batch of data.""" + transitions: types.Transition = batch.data + keys, probs, *_ = batch.info + + # Forward pass. + if self.stochastic_network: + q_tm1 = network.apply(params, key, transitions.observation) + q_t_value = network.apply(target_params, key, + transitions.next_observation) + q_t_selector = network.apply(params, key, transitions.next_observation) + else: + q_tm1 = network.apply(params, transitions.observation) + q_t_value = network.apply(target_params, transitions.next_observation) + q_t_selector = network.apply(params, transitions.next_observation) + + # Cast and clip rewards. + d_t = (transitions.discount * self.discount).astype(jnp.float32) + r_t = jnp.clip(transitions.reward, -self.max_abs_reward, + self.max_abs_reward).astype(jnp.float32) + + # Compute double Q-learning n-step TD-error. + batch_error = jax.vmap(rlax.double_q_learning) + td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t_value, + q_t_selector) + batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter) + + # Importance weighting. + importance_weights = (1. / probs).astype(jnp.float32) + importance_weights **= self.importance_sampling_exponent + importance_weights /= jnp.max(importance_weights) + + # Reweight. + loss = jnp.mean(importance_weights * batch_loss) # [] + reverb_update = learning_lib.ReverbUpdate( + keys=keys, priorities=jnp.abs(td_error).astype(jnp.float64)) + extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update) + return loss, extra + + +@dataclasses.dataclass +class QrDqn(learning_lib.LossFn): + """Quantile Regression DQN. + + https://arxiv.org/abs/1710.10044 + """ + num_atoms: int = 51 + huber_param: float = 1.0 + + def __call__( + self, + network: networks_lib.FeedForwardNetwork, + params: networks_lib.Params, + target_params: networks_lib.Params, + batch: reverb.ReplaySample, + key: networks_lib.PRNGKey, + ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: + """Calculate a loss on a single batch of data.""" + del key + transitions: types.Transition = batch.data + dist_q_tm1 = network.apply(params, + transitions.observation)['q_dist'] + dist_q_target_t = network.apply(target_params, + transitions.next_observation)['q_dist'] + # Swap distribution and action dimension, since + # rlax.quantile_q_learning expects it that way. + dist_q_tm1 = jnp.swapaxes(dist_q_tm1, 1, 2) + dist_q_target_t = jnp.swapaxes(dist_q_target_t, 1, 2) + quantiles = ( + (jnp.arange(self.num_atoms, dtype=jnp.float32) + 0.5) / self.num_atoms) + batch_quantile_q_learning = jax.vmap( + rlax.quantile_q_learning, in_axes=(0, None, 0, 0, 0, 0, 0, None)) + losses = batch_quantile_q_learning( + dist_q_tm1, + quantiles, + transitions.action, + transitions.reward, + transitions.discount, + dist_q_target_t, # No double Q-learning here. + dist_q_target_t, + self.huber_param, + ) + loss = jnp.mean(losses) + extra = learning_lib.LossExtra(metrics={'mean_loss': loss}) + return loss, extra + + +@dataclasses.dataclass +class PrioritizedCategoricalDoubleQLearning(learning_lib.LossFn): + """Categorical double q learning with prioritization on TD error.""" + discount: float = 0.99 + importance_sampling_exponent: float = 0.2 + max_abs_reward: float = 1. + + def __call__( + self, + network: networks_lib.FeedForwardNetwork, + params: networks_lib.Params, + target_params: networks_lib.Params, + batch: reverb.ReplaySample, + key: networks_lib.PRNGKey, + ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: + """Calculate a loss on a single batch of data.""" + del key + transitions: types.Transition = batch.data + keys, probs, *_ = batch.info + + # Forward pass. + _, logits_tm1, atoms_tm1 = network.apply(params, transitions.observation) + _, logits_t, atoms_t = network.apply(target_params, + transitions.next_observation) + q_t_selector, _, _ = network.apply(params, transitions.next_observation) + + # Cast and clip rewards. + d_t = (transitions.discount * self.discount).astype(jnp.float32) + r_t = jnp.clip(transitions.reward, -self.max_abs_reward, + self.max_abs_reward).astype(jnp.float32) + + # Compute categorical double Q-learning loss. + batch_loss_fn = jax.vmap( + rlax.categorical_double_q_learning, + in_axes=(None, 0, 0, 0, 0, None, 0, 0)) + batch_loss = batch_loss_fn(atoms_tm1, logits_tm1, transitions.action, r_t, + d_t, atoms_t, logits_t, q_t_selector) + + # Importance weighting. + importance_weights = (1. / probs).astype(jnp.float32) + importance_weights **= self.importance_sampling_exponent + importance_weights /= jnp.max(importance_weights) + + # Reweight. + loss = jnp.mean(importance_weights * batch_loss) # [] + reverb_update = learning_lib.ReverbUpdate( + keys=keys, priorities=jnp.abs(batch_loss).astype(jnp.float64)) + extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update) + return loss, extra + + +@dataclasses.dataclass +class QLearning(learning_lib.LossFn): + """Deep q learning. + + This matches the original DQN loss: https://arxiv.org/abs/1312.5602. + It differs by two aspects that improve it on the optimization side + - it uses Adam intead of RMSProp as an optimizer + - it uses a square loss instead of the Huber one. + """ + discount: float = 0.99 + max_abs_reward: float = 1. + + def __call__( + self, + network: networks_lib.FeedForwardNetwork, + params: networks_lib.Params, + target_params: networks_lib.Params, + batch: reverb.ReplaySample, + key: networks_lib.PRNGKey, + ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: + """Calculate a loss on a single batch of data.""" + del key + transitions: types.Transition = batch.data + + # Forward pass. + q_tm1 = network.apply(params, transitions.observation) + q_t = network.apply(target_params, transitions.next_observation) + + # Cast and clip rewards. + d_t = (transitions.discount * self.discount).astype(jnp.float32) + r_t = jnp.clip(transitions.reward, -self.max_abs_reward, + self.max_abs_reward).astype(jnp.float32) + + # Compute Q-learning TD-error. + batch_error = jax.vmap(rlax.q_learning) + td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t) + batch_loss = jnp.square(td_error) + + loss = jnp.mean(batch_loss) + extra = learning_lib.LossExtra(metrics={}) + return loss, extra + + +@dataclasses.dataclass +class RegularizedQLearning(learning_lib.LossFn): + """Regularized Q-learning. + + Implements DQNReg loss function: https://arxiv.org/abs/2101.03958. + This is almost identical to QLearning except: 1) Adds a regularization term; + 2) Uses vanilla TD error without huber loss. 3) No reward clipping. + """ + discount: float = 0.99 + regularizer_coeff = 0.1 + + def __call__( + self, + network: networks_lib.FeedForwardNetwork, + params: networks_lib.Params, + target_params: networks_lib.Params, + batch: reverb.ReplaySample, + key: networks_lib.PRNGKey, + ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: + """Calculate a loss on a single batch of data.""" + del key + transitions: types.Transition = batch.data + + # Forward pass. + q_tm1 = network.apply(params, transitions.observation) + q_t = network.apply(target_params, transitions.next_observation) + + d_t = (transitions.discount * self.discount).astype(jnp.float32) + + # Compute Q-learning TD-error. + batch_error = jax.vmap(rlax.q_learning) + td_error = batch_error( + q_tm1, transitions.action, transitions.reward, d_t, q_t) + td_error = 0.5 * jnp.square(td_error) + + def select(qtm1, action): + return qtm1[action] + q_regularizer = jax.vmap(select)(q_tm1, transitions.action) + + loss = self.regularizer_coeff * jnp.mean(q_regularizer) + jnp.mean(td_error) + extra = learning_lib.LossExtra(metrics={}) + return loss, extra + + +@dataclasses.dataclass +class MunchausenQLearning(learning_lib.LossFn): + """Munchausen q learning. + + Implements M-DQN: https://arxiv.org/abs/2007.14430. + """ + entropy_temperature: float = 0.03 # tau parameter + munchausen_coefficient: float = 0.9 # alpha parameter + clip_value_min: float = -1e3 + discount: float = 0.99 + max_abs_reward: float = 1. + huber_loss_parameter: float = 1. + + def __call__( + self, + network: networks_lib.FeedForwardNetwork, + params: networks_lib.Params, + target_params: networks_lib.Params, + batch: reverb.ReplaySample, + key: networks_lib.PRNGKey, + ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: + """Calculate a loss on a single batch of data.""" + del key + transitions: types.Transition = batch.data + + # Forward pass. + q_online_s = network.apply(params, transitions.observation) + action_one_hot = jax.nn.one_hot(transitions.action, q_online_s.shape[-1]) + q_online_sa = jnp.sum(action_one_hot * q_online_s, axis=-1) + q_target_s = network.apply(target_params, transitions.observation) + q_target_next = network.apply(target_params, transitions.next_observation) + + # Cast and clip rewards. + d_t = (transitions.discount * self.discount).astype(jnp.float32) + r_t = jnp.clip(transitions.reward, -self.max_abs_reward, + self.max_abs_reward).astype(jnp.float32) + + # Munchausen term : tau * log_pi(a|s) + munchausen_term = self.entropy_temperature * jax.nn.log_softmax( + q_target_s / self.entropy_temperature, axis=-1) + munchausen_term_a = jnp.sum(action_one_hot * munchausen_term, axis=-1) + munchausen_term_a = jnp.clip(munchausen_term_a, + a_min=self.clip_value_min, + a_max=0.) + + # Soft Bellman operator applied to q + next_v = self.entropy_temperature * jax.nn.logsumexp( + q_target_next / self.entropy_temperature, axis=-1) + target_q = jax.lax.stop_gradient(r_t + self.munchausen_coefficient * + munchausen_term_a + d_t * next_v) + + batch_loss = rlax.huber_loss(target_q - q_online_sa, + self.huber_loss_parameter) + loss = jnp.mean(batch_loss) + + extra = learning_lib.LossExtra(metrics={}) + return loss, extra diff --git a/acme/acme/agents/jax/dqn/rainbow.py b/acme/acme/agents/jax/dqn/rainbow.py new file mode 100644 index 00000000..c5773ea2 --- /dev/null +++ b/acme/acme/agents/jax/dqn/rainbow.py @@ -0,0 +1,95 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines Rainbow DQN, using JAX.""" + +import dataclasses +from typing import Callable + + +from acme import specs +from acme.agents.jax.dqn import actor as dqn_actor +from acme.agents.jax.dqn import builder +from acme.agents.jax.dqn import config as dqn_config +from acme.agents.jax.dqn import losses +from acme.jax import networks as networks_lib +from acme.jax import utils +import rlax + +NetworkFactory = Callable[[specs.EnvironmentSpec], + networks_lib.FeedForwardNetwork] + + +@dataclasses.dataclass +class RainbowConfig(dqn_config.DQNConfig): + """(Additional) configuration options for RainbowDQN.""" + max_abs_reward: float = 1.0 # For clipping reward + + +def apply_policy_and_sample( + network: networks_lib.FeedForwardNetwork,) -> dqn_actor.EpsilonPolicy: + """Returns a function that computes actions. + + Note that this differs from default_behavior_policy with that it + expects c51-style network head which returns a tuple with the first entry + representing q-values. + + Args: + network: A c51-style feedforward network. + + Returns: + A feedforward policy. + """ + + def apply_and_sample(params, key, obs, epsilon): + # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. + obs = utils.add_batch_dim(obs) + action_values = network.apply(params, obs)[0] + action_values = utils.squeeze_batch_dim(action_values) + return rlax.epsilon_greedy(epsilon).sample(key, action_values) + + return apply_and_sample + + +def eval_policy(network: networks_lib.FeedForwardNetwork, + eval_epsilon: float) -> dqn_actor.EpsilonPolicy: + """Returns a function that computes actions. + + Note that this differs from default_behavior_policy with that it + expects c51-style network head which returns a tuple with the first entry + representing q-values. + + Args: + network: A c51-style feedforward network. + eval_epsilon: for epsilon-greedy exploration. + + Returns: + A feedforward policy. + """ + policy = apply_policy_and_sample(network) + + def apply_and_sample(params, key, obs, _): + return policy(params, key, obs, eval_epsilon) + + return apply_and_sample + + +def make_builder(config: RainbowConfig): + """Returns a DQNBuilder with a pre-built loss function.""" + loss_fn = losses.PrioritizedCategoricalDoubleQLearning( + discount=config.discount, + importance_sampling_exponent=config.importance_sampling_exponent, + max_abs_reward=config.max_abs_reward, + ) + return builder.DQNBuilder(config, loss_fn=loss_fn) diff --git a/acme/acme/agents/jax/impala/__init__.py b/acme/acme/agents/jax/impala/__init__.py new file mode 100644 index 00000000..4fb9b755 --- /dev/null +++ b/acme/acme/agents/jax/impala/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Importance-weighted actor-learner architecture (IMPALA) agent.""" + +from acme.agents.jax.impala.builder import IMPALABuilder +from acme.agents.jax.impala.config import IMPALAConfig +from acme.agents.jax.impala.learning import IMPALALearner +from acme.agents.jax.impala.networks import IMPALANetworks +from acme.agents.jax.impala.networks import make_atari_networks +from acme.agents.jax.impala.networks import make_haiku_networks diff --git a/acme/acme/agents/jax/impala/acting.py b/acme/acme/agents/jax/impala/acting.py new file mode 100644 index 00000000..2fefb53c --- /dev/null +++ b/acme/acme/agents/jax/impala/acting.py @@ -0,0 +1,104 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""IMPALA actor implementation.""" + +from typing import Optional + +from acme import adders +from acme import core +from acme.agents.jax.impala import types +from acme.jax import variable_utils +import dm_env +import haiku as hk +import jax +import jax.numpy as jnp + + +class IMPALAActor(core.Actor): + """A recurrent actor.""" + + _state: hk.LSTMState + _prev_state: hk.LSTMState + _prev_logits: jnp.ndarray + + def __init__( + self, + forward_fn: types.PolicyValueFn, + initial_state_fn: types.RecurrentStateFn, + rng: hk.PRNGSequence, + variable_client: Optional[variable_utils.VariableClient] = None, + adder: Optional[adders.Adder] = None, + ): + + # Store these for later use. + self._adder = adder + self._variable_client = variable_client + self._forward = forward_fn + self._reset_fn_or_none = getattr(forward_fn, 'reset', None) + self._rng = rng + + self._initial_state = initial_state_fn(next(self._rng)) + + def select_action(self, observation: types.Observation) -> types.Action: + + if self._state is None: + self._state = self._initial_state + + # Forward. + (logits, _), new_state = self._forward(self._params, observation, + self._state) + + self._prev_logits = logits + self._prev_state = self._state + self._state = new_state + + action = jax.random.categorical(next(self._rng), logits) + + return action + + def observe_first(self, timestep: dm_env.TimeStep): + if self._adder: + self._adder.add_first(timestep) + + # Set the state to None so that we re-initialize at the next policy call. + self._state = None + + # Reset state of inference functions that employ stateful wrappers (eg. BIT) + # at the start of the episode. + if self._reset_fn_or_none is not None: + self._reset_fn_or_none() + + def observe( + self, + action: types.Action, + next_timestep: dm_env.TimeStep, + ): + if not self._adder: + return + + extras = {'logits': self._prev_logits, 'core_state': self._prev_state} + self._adder.add(action, next_timestep, extras) + + def update(self, wait: bool = False): + if self._variable_client is not None: + self._variable_client.update(wait) + + @property + def _params(self) -> Optional[hk.Params]: + if self._variable_client is None: + # If self._variable_client is None then we assume self._forward does not + # use the parameters it is passed and just return None. + return None + return self._variable_client.params diff --git a/acme/acme/agents/jax/impala/builder.py b/acme/acme/agents/jax/impala/builder.py new file mode 100644 index 00000000..c5552c0d --- /dev/null +++ b/acme/acme/agents/jax/impala/builder.py @@ -0,0 +1,207 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""IMPALA Builder.""" + +from typing import Any, Callable, Iterator, List, Optional + +import acme +from acme import adders +from acme import core +from acme import specs +from acme.adders import reverb as reverb_adders +from acme.agents.jax import builders +from acme.agents.jax.impala import acting +from acme.agents.jax.impala import config as impala_config +from acme.agents.jax.impala import learning +from acme.agents.jax.impala import networks as impala_networks +from acme.datasets import reverb as datasets +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.jax import variable_utils +from acme.utils import counting +from acme.utils import loggers +import haiku as hk +import jax +import jax.numpy as jnp +import optax +import reverb + + +class IMPALABuilder(builders.ActorLearnerBuilder[impala_networks.IMPALANetworks, + impala_networks.IMPALANetworks, + reverb.ReplaySample]): + """IMPALA Builder.""" + + def __init__( + self, + config: impala_config.IMPALAConfig, + core_state_spec: hk.LSTMState, + table_extension: Optional[Callable[[], Any]] = None, + ): + """Creates an IMPALA learner.""" + self._config = config + self._core_state_spec = core_state_spec + self._sequence_length = self._config.sequence_length + self._table_extension = table_extension + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: impala_networks.IMPALANetworks, + ) -> List[reverb.Table]: + """The queue; use XData or INFO log.""" + del policy + num_actions = environment_spec.actions.num_values + extra_spec = { + 'core_state': self._core_state_spec, + 'logits': jnp.ones(shape=(num_actions,), dtype=jnp.float32) + } + signature = reverb_adders.SequenceAdder.signature( + environment_spec, + extra_spec, + sequence_length=self._config.sequence_length) + + # Maybe create rate limiter. + # Setting the samples_per_insert ratio less than the default of 1.0, allows + # the agent to drop data for the benefit of using data from most up-to-date + # policies to compute its learner updates. + samples_per_insert = self._config.samples_per_insert + if samples_per_insert: + if samples_per_insert > 1.0 or samples_per_insert <= 0.0: + raise ValueError( + 'Impala requires a samples_per_insert ratio in the range (0, 1],' + f' but received {samples_per_insert}.') + limiter = reverb.rate_limiters.SampleToInsertRatio( + samples_per_insert=samples_per_insert, + min_size_to_sample=1, + error_buffer=self._config.batch_size) + else: + limiter = reverb.rate_limiters.MinSize(1) + + table_extensions = [] + if self._table_extension is not None: + table_extensions = [self._table_extension()] + queue = reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_queue_size, + max_times_sampled=1, + rate_limiter=limiter, + extensions=table_extensions, + signature=signature) + return [queue] + + def make_dataset_iterator( + self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]: + """Creates a dataset.""" + batch_size_per_learner = self._config.batch_size // jax.process_count() + batch_size_per_device, ragged = divmod(self._config.batch_size, + jax.device_count()) + if ragged: + raise ValueError( + 'Learner batch size must be divisible by total number of devices!') + + dataset = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=replay_client.server_address, + batch_size=batch_size_per_device, + num_parallel_calls=None, + max_in_flight_samples_per_worker=2 * batch_size_per_learner) + + return utils.multi_device_put(dataset.as_numpy_iterator(), + jax.local_devices()) + + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[impala_networks.IMPALANetworks], + ) -> Optional[adders.Adder]: + """Creates an adder which handles observations.""" + del environment_spec, policy + # Note that the last transition in the sequence is used for bootstrapping + # only and is ignored otherwise. So we need to make sure that sequences + # overlap on one transition, thus "-1" in the period length computation. + return reverb_adders.SequenceAdder( + client=replay_client, + priority_fns={self._config.replay_table_name: None}, + period=self._config.sequence_period or (self._sequence_length - 1), + sequence_length=self._sequence_length, + ) + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: impala_networks.IMPALANetworks, + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec, replay_client + + optimizer = optax.chain( + optax.clip_by_global_norm(self._config.max_gradient_norm), + optax.adam( + self._config.learning_rate, + b1=self._config.adam_momentum_decay, + b2=self._config.adam_variance_decay, + eps=self._config.adam_eps, + eps_root=self._config.adam_eps_root)) + + return learning.IMPALALearner( + networks=networks, + iterator=dataset, + optimizer=optimizer, + random_key=random_key, + discount=self._config.discount, + entropy_cost=self._config.entropy_cost, + baseline_cost=self._config.baseline_cost, + max_abs_reward=self._config.max_abs_reward, + counter=counter, + logger=logger_fn('learner'), + ) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: impala_networks.IMPALANetworks, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> acme.Actor: + del environment_spec + variable_client = variable_utils.VariableClient( + client=variable_source, + key='network', + update_period=self._config.variable_update_period, + device='cpu') + return acting.IMPALAActor( + forward_fn=policy.forward_fn, + initial_state_fn=policy.initial_state_fn, + variable_client=variable_client, + adder=adder, + rng=hk.PRNGSequence(random_key), + ) + + def make_policy( + self, + networks: impala_networks.IMPALANetworks[Any], + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False) -> impala_networks.IMPALANetworks[Any]: + del environment_spec, evaluation + return networks diff --git a/acme/acme/agents/jax/impala/config.py b/acme/acme/agents/jax/impala/config.py new file mode 100644 index 00000000..161dd9c5 --- /dev/null +++ b/acme/acme/agents/jax/impala/config.py @@ -0,0 +1,63 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""IMPALA config.""" +import dataclasses +from typing import Optional, Union + +from acme import types +from acme.adders import reverb as adders_reverb +import numpy as np +import optax + + +@dataclasses.dataclass +class IMPALAConfig: + """Configuration options for IMPALA.""" + seed: int = 0 + discount: float = 0.99 + sequence_length: int = 20 + sequence_period: Optional[int] = None + variable_update_period: int = 1000 + + # Optimizer configuration. + batch_size: int = 32 + learning_rate: Union[float, optax.Schedule] = 2e-4 + adam_momentum_decay: float = 0.0 + adam_variance_decay: float = 0.99 + adam_eps: float = 1e-8 + adam_eps_root: float = 0.0 + max_gradient_norm: float = 40.0 + + # Loss configuration. + baseline_cost: float = 0.5 + entropy_cost: float = 0.01 + max_abs_reward: float = np.inf + + # Replay options + replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE + num_prefetch_threads: Optional[int] = None + samples_per_insert: Optional[float] = 1.0 + max_queue_size: Union[int, types.Batches] = types.Batches(10) + + def __post_init__(self): + if isinstance(self.max_queue_size, types.Batches): + self.max_queue_size *= self.batch_size + assert self.max_queue_size > self.batch_size + 1, (""" + max_queue_size must be strictly larger than the batch size: + - during the last step in an episode we might write 2 sequences to + Reverb at once (that's how SequenceAdder works) + - Reverb does insertion/sampling in multiple threads, so data is + added asynchronously at unpredictable times. Therefore we need + additional buffer size in order to avoid deadlocks.""") diff --git a/acme/acme/agents/jax/impala/learning.py b/acme/acme/agents/jax/impala/learning.py new file mode 100644 index 00000000..8f068aea --- /dev/null +++ b/acme/acme/agents/jax/impala/learning.py @@ -0,0 +1,160 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Learner for the IMPALA actor-critic agent.""" + +import time +from typing import Dict, Iterator, List, NamedTuple, Optional, Sequence, Tuple + +from absl import logging +import acme +from acme.agents.jax.impala import networks as impala_networks +from acme.jax import losses +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.utils import counting +from acme.utils import loggers +import jax +import jax.numpy as jnp +import numpy as np +import optax +import reverb + +_PMAP_AXIS_NAME = 'data' + + +class TrainingState(NamedTuple): + """Training state consists of network parameters and optimiser state.""" + params: networks_lib.Params + opt_state: optax.OptState + + +class IMPALALearner(acme.Learner): + """Learner for an importanced-weighted advantage actor-critic.""" + + def __init__( + self, + networks: impala_networks.IMPALANetworks, + iterator: Iterator[reverb.ReplaySample], + optimizer: optax.GradientTransformation, + random_key: networks_lib.PRNGKey, + discount: float = 0.99, + entropy_cost: float = 0., + baseline_cost: float = 1., + max_abs_reward: float = np.inf, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + devices: Optional[Sequence[jax.xla.Device]] = None, + prefetch_size: int = 2, + ): + local_devices = jax.local_devices() + process_id = jax.process_index() + logging.info('Learner process id: %s. Devices passed: %s', process_id, + devices) + logging.info('Learner process id: %s. Local devices from JAX API: %s', + process_id, local_devices) + self._devices = devices or local_devices + self._local_devices = [d for d in self._devices if d in local_devices] + + self._iterator = iterator + + loss_fn = losses.impala_loss( + networks.unroll_fn, + discount=discount, + max_abs_reward=max_abs_reward, + baseline_cost=baseline_cost, + entropy_cost=entropy_cost) + + @jax.jit + def sgd_step( + state: TrainingState, sample: reverb.ReplaySample + ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: + """Computes an SGD step, returning new state and metrics for logging.""" + + # Compute gradients. + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (loss_value, metrics), gradients = grad_fn(state.params, sample) + + # Average gradients over pmap replicas before optimizer update. + gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME) + + # Apply updates. + updates, new_opt_state = optimizer.update(gradients, state.opt_state) + new_params = optax.apply_updates(state.params, updates) + + metrics.update({ + 'loss': loss_value, + 'param_norm': optax.global_norm(new_params), + 'param_updates_norm': optax.global_norm(updates), + }) + + new_state = TrainingState(params=new_params, opt_state=new_opt_state) + + return new_state, metrics + + def make_initial_state(key: jnp.ndarray) -> TrainingState: + """Initialises the training state (parameters and optimiser state).""" + key, key_initial_state = jax.random.split(key) + # Note: parameters do not depend on the batch size, so initial_state below + # does not need a batch dimension. + # TODO(jferret): as it stands, we do not yet support + # training the initial state params. + initial_state = networks.initial_state_fn(key_initial_state) + + initial_params = networks.unroll_init_fn(key, initial_state) + initial_opt_state = optimizer.init(initial_params) + return TrainingState( + params=initial_params, opt_state=initial_opt_state) + + # Initialise training state (parameters and optimiser state). + state = make_initial_state(random_key) + self._state = utils.replicate_in_all_devices(state, self._local_devices) + + self._sgd_step = jax.pmap( + sgd_step, axis_name=_PMAP_AXIS_NAME, devices=self._devices) + + # Set up logging/counting. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + 'learner', steps_key=self._counter.get_steps_key()) + + def step(self): + """Does a step of SGD and logs the results.""" + samples = next(self._iterator) + + # Do a batch of SGD. + start = time.time() + self._state, results = self._sgd_step(self._state, samples) + + # Take results from first replica. + # NOTE: This measure will be a noisy estimate for the purposes of the logs + # as it does not pmean over all devices. + results = utils.get_from_first_device(results) + + # Update our counts and record them. + counts = self._counter.increment(steps=1, time_elapsed=time.time() - start) + + # Maybe write logs. + self._logger.write({**results, **counts}) + + def get_variables(self, names: Sequence[str]) -> List[networks_lib.Params]: + # Return first replica of parameters. + return [utils.get_from_first_device(self._state.params, as_numpy=False)] + + def save(self) -> TrainingState: + # Serialize only the first replica of parameters and optimizer state. + return jax.tree_map(utils.get_from_first_device, self._state) + + def restore(self, state: TrainingState): + self._state = utils.replicate_in_all_devices(state, self._local_devices) diff --git a/acme/acme/agents/jax/impala/networks.py b/acme/acme/agents/jax/impala/networks.py new file mode 100644 index 00000000..a86bb66e --- /dev/null +++ b/acme/acme/agents/jax/impala/networks.py @@ -0,0 +1,99 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""IMPALA networks definition.""" + +import dataclasses +from typing import Any, Generic, Optional, Tuple + +from acme import specs +from acme.agents.jax.impala import types +from acme.jax import networks as networks_lib +from acme.jax import utils +import haiku as hk + + +@dataclasses.dataclass +class IMPALANetworks(Generic[types.RecurrentState]): + + """Pure functions representing IMPALA's recurrent network components. + + Attributes: + forward_fn: Selects next action using the network at the given recurrent + state. + + unroll_init_fn: Initializes params for forward_fn and unroll_fn. + + unroll_fn: Applies the unrolled network to a sequence of observations, for + learning. + initial_state_fn: Recurrent state at the beginning of an episode. + """ + forward_fn: types.PolicyValueFn + unroll_init_fn: types.PolicyValueInitFn + unroll_fn: types.PolicyValueFn + initial_state_fn: types.RecurrentStateFn + + +def make_haiku_networks( + env_spec: specs.EnvironmentSpec, + forward_fn: Any, + initial_state_fn: Any, + unroll_fn: Any) -> IMPALANetworks[types.RecurrentState]: + """Builds functional impala network from recurrent model definitions.""" + # Make networks purely functional. + forward_hk = hk.without_apply_rng(hk.transform(forward_fn)) + initial_state_hk = hk.without_apply_rng(hk.transform(initial_state_fn)) + unroll_hk = hk.without_apply_rng(hk.transform(unroll_fn)) + + # Note: batch axis is not needed for the actors. + dummy_obs = utils.zeros_like(env_spec.observations) + dummy_obs_sequence = utils.add_batch_dim(dummy_obs) + def unroll_init_fn( + rng: networks_lib.PRNGKey, + initial_state: types.RecurrentState) -> hk.Params: + return unroll_hk.init(rng, dummy_obs_sequence, initial_state) + + return IMPALANetworks( + forward_fn=forward_hk.apply, + unroll_init_fn=unroll_init_fn, + unroll_fn=unroll_hk.apply, + initial_state_fn=( + lambda rng: initial_state_hk.apply(initial_state_hk.init(rng)))) + + +HaikuLSTMOutputs = Tuple[Tuple[networks_lib.Logits, networks_lib.Value], + hk.LSTMState] + + +def make_atari_networks(env_spec: specs.EnvironmentSpec + ) -> IMPALANetworks[hk.LSTMState]: + """Builds default IMPALA networks for Atari games.""" + + def forward_fn(inputs: types.Observation, state: hk.LSTMState + ) -> HaikuLSTMOutputs: + model = networks_lib.DeepIMPALAAtariNetwork(env_spec.actions.num_values) + return model(inputs, state) + + def initial_state_fn(batch_size: Optional[int] = None) -> hk.LSTMState: + model = networks_lib.DeepIMPALAAtariNetwork(env_spec.actions.num_values) + return model.initial_state(batch_size) + + def unroll_fn(inputs: types.Observation, state: hk.LSTMState + ) -> HaikuLSTMOutputs: + model = networks_lib.DeepIMPALAAtariNetwork(env_spec.actions.num_values) + return model.unroll(inputs, state) + + return make_haiku_networks( + env_spec=env_spec, forward_fn=forward_fn, + initial_state_fn=initial_state_fn, unroll_fn=unroll_fn) diff --git a/acme/acme/agents/jax/impala/types.py b/acme/acme/agents/jax/impala/types.py new file mode 100644 index 00000000..6763a71d --- /dev/null +++ b/acme/acme/agents/jax/impala/types.py @@ -0,0 +1,31 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Some types/assumptions used in the IMPALA agent.""" +from typing import Callable, Tuple + +from acme.agents.jax.actor_core import RecurrentState +from acme.jax import networks +from acme.jax import types as jax_types +import jax.numpy as jnp + +# Only simple observations & discrete action spaces for now. +Observation = jnp.ndarray +Action = int +Outputs = Tuple[Tuple[networks.Logits, networks.Value], RecurrentState] +PolicyValueInitFn = Callable[[networks.PRNGKey, RecurrentState], + networks.Params] +PolicyValueFn = Callable[[networks.Params, Observation, RecurrentState], + Outputs] +RecurrentStateFn = Callable[[jax_types.PRNGKey], RecurrentState] diff --git a/acme/acme/agents/jax/lfd/__init__.py b/acme/acme/agents/jax/lfd/__init__.py new file mode 100644 index 00000000..873ce23a --- /dev/null +++ b/acme/acme/agents/jax/lfd/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lfd agents.""" + +from acme.agents.jax.lfd.builder import LfdBuilder +from acme.agents.jax.lfd.builder import LfdStep +from acme.agents.jax.lfd.config import LfdConfig +from acme.agents.jax.lfd.sacfd import SACfDBuilder +from acme.agents.jax.lfd.sacfd import SACfDConfig +from acme.agents.jax.lfd.td3fd import TD3fDBuilder +from acme.agents.jax.lfd.td3fd import TD3fDConfig diff --git a/acme/acme/agents/jax/lfd/builder.py b/acme/acme/agents/jax/lfd/builder.py new file mode 100644 index 00000000..2f544f41 --- /dev/null +++ b/acme/acme/agents/jax/lfd/builder.py @@ -0,0 +1,80 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Builder enabling off-policy algorithms to learn from demonstrations.""" + +from typing import Any, Callable, Generic, Iterator, Tuple + +from acme.agents.jax import builders +from acme.agents.jax.lfd import config as lfd_config +from acme.agents.jax.lfd import lfd_adder +import dm_env + + +LfdStep = Tuple[Any, dm_env.TimeStep] + + +class LfdBuilder(builders.ActorLearnerBuilder[builders.Networks, + builders.Policy, + builders.Sample,], + Generic[builders.Networks, builders.Policy, builders.Sample]): + """Builder that enables Learning From demonstrations. + + This builder is not self contained and requires an underlying builder + implementing an off-policy algorithm. + """ + + def __init__(self, builder: builders.ActorLearnerBuilder[builders.Networks, + builders.Policy, + builders.Sample], + demonstrations_factory: Callable[[], Iterator[LfdStep]], + config: lfd_config.LfdConfig): + """LfdBuilder constructor. + + Args: + builder: The underlying builder implementing the off-policy algorithm. + demonstrations_factory: Factory returning an infinite stream (as an + iterator) of (action, next_timesteps). Episode boundaries in this stream + are given by timestep.first() and timestep.last(). Note that in the + distributed version of this algorithm, each actor is mixing the same + demonstrations with its online experience. This effectively results in + the demonstrations being replicated in the replay buffer as many times + as the number of actors being used. + config: LfD configuration. + """ + self._builder = builder + self._demonstrations_factory = demonstrations_factory + self._config = config + + def make_replay_tables(self, *args, **kwargs): + return self._builder.make_replay_tables(*args, **kwargs) + + def make_dataset_iterator(self, *args, **kwargs): + return self._builder.make_dataset_iterator(*args, **kwargs) + + def make_adder(self, *args, **kwargs): + demonstrations = self._demonstrations_factory() + return lfd_adder.LfdAdder(self._builder.make_adder(*args, **kwargs), + demonstrations, + self._config.initial_insert_count, + self._config.demonstration_ratio) + + def make_actor(self, *args, **kwargs): + return self._builder.make_actor(*args, **kwargs) + + def make_learner(self, *args, **kwargs): + return self._builder.make_learner(*args, **kwargs) + + def make_policy(self, *args, **kwargs): + return self._builder.make_policy(*args, **kwargs) diff --git a/acme/acme/agents/jax/lfd/config.py b/acme/acme/agents/jax/lfd/config.py new file mode 100644 index 00000000..2d6caf30 --- /dev/null +++ b/acme/acme/agents/jax/lfd/config.py @@ -0,0 +1,37 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LfD config.""" + +import dataclasses + + +@dataclasses.dataclass +class LfdConfig: + """Configuration options for LfD. + + Attributes: + initial_insert_count: Number of steps of demonstrations to add to the replay + buffer before adding any step of the collected episodes. Note that since + only full episodes can be added, this number of steps is only a target. + demonstration_ratio: Ratio of demonstration steps to add to the replay + buffer. ratio = num_demonstration_steps_added / total_num_steps_added. + The ratio must be in [0, 1). + Note that this ratio is the desired ratio in the steady behavior and does + not account for the initial demonstrations inserts. + Note also that this ratio is only a target ratio since the granularity + is the episode. + """ + initial_insert_count: int = 0 + demonstration_ratio: float = 0.01 diff --git a/acme/acme/agents/jax/lfd/lfd_adder.py b/acme/acme/agents/jax/lfd/lfd_adder.py new file mode 100644 index 00000000..9d53c96e --- /dev/null +++ b/acme/acme/agents/jax/lfd/lfd_adder.py @@ -0,0 +1,117 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An adder useful in the context of Learning From Demonstrations. + +This adder is mixing the collected episodes with some demonstrations +coming from an offline dataset. + +TODO(damienv): Mixing demonstrations and collected episodes could also be + done when reading from the replay buffer. In that case, all the processing + applied by reverb should also be applied on the demonstrations. + Design wise, both solutions make equally sense. The alternative solution + could then be later implemented as well. +""" + +from typing import Any, Iterator, Tuple + +from acme import adders +from acme import types +import dm_env + + +class LfdAdder(adders.Adder): + """Adder which adds from time to time some demonstrations. + + Lfd stands for Learning From Demonstrations and is the same technique + as the one used in R2D3. + """ + + def __init__(self, + adder: adders.Adder, + demonstrations: Iterator[Tuple[Any, dm_env.TimeStep]], + initial_insert_count: int, + demonstration_ratio: float): + """LfdAdder constructor. + + Args: + adder: The underlying adder used to add mixed episodes. + demonstrations: An iterator on infinite stream of (action, next_timestep) + pairs. Episode boundaries are defined by TimeStep.FIRST and + timestep.LAST markers. Note that the first action of an episode is + ignored. Note also that proper uniform sampling of demonstrations is the + responsibility of the iterator. + initial_insert_count: Number of steps of demonstrations to add before + adding any step of the collected episodes. Note that since only full + episodes can be added, this number of steps is only a target. + demonstration_ratio: Ratio of demonstration steps to add to the underlying + adder. ratio = num_demonstration_steps_added / total_num_steps_added + and must be in [0, 1). + Note that this ratio is the desired ratio in the steady behavior + and does not account for the initial inserts of demonstrations. + Note also that this ratio is only a target ratio since the granularity + is the episode. + """ + self._adder = adder + self._demonstrations = demonstrations + self._demonstration_ratio = demonstration_ratio + if demonstration_ratio < 0 or demonstration_ratio >= 1.: + raise ValueError('Invalid demonstration ratio.') + + # Number of demonstration steps that should have been added to the replay + # buffer to meet the target demonstration ratio minus what has been really + # added. + # As a consequence: + # - when this delta is zero, the effective ratio exactly matches the desired + # ratio + # - when it is positive, more demonstrations need to be added to + # reestablish the balance + # The initial value is set so that after exactly initial_insert_count + # inserts of demonstration steps, _delta_demonstration_step_count will be + # zero. + self._delta_demonstration_step_count = ( + (1. - self._demonstration_ratio) * initial_insert_count) + + def reset(self): + self._adder.reset() + + def _add_demonstration_episode(self): + _, timestep = next(self._demonstrations) + if not timestep.first(): + raise ValueError('Expecting the start of an episode.') + self._adder.add_first(timestep) + self._delta_demonstration_step_count -= (1. - self._demonstration_ratio) + while not timestep.last(): + action, timestep = next(self._demonstrations) + self._adder.add(action, timestep) + self._delta_demonstration_step_count -= (1. - self._demonstration_ratio) + + # Reset is being called periodically to reset the connection to reverb. + # TODO(damienv, bshahr): Make the reset an internal detail of the reverb + # adder and remove it from the adder API. + self._adder.reset() + + def add_first(self, timestep: dm_env.TimeStep): + while self._delta_demonstration_step_count > 0.: + self._add_demonstration_episode() + + self._adder.add_first(timestep) + self._delta_demonstration_step_count += self._demonstration_ratio + + def add(self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + extras: types.NestedArray = ()): + self._adder.add(action, next_timestep) + self._delta_demonstration_step_count += self._demonstration_ratio diff --git a/acme/acme/agents/jax/lfd/lfd_adder_test.py b/acme/acme/agents/jax/lfd/lfd_adder_test.py new file mode 100644 index 00000000..1e8f926e --- /dev/null +++ b/acme/acme/agents/jax/lfd/lfd_adder_test.py @@ -0,0 +1,143 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests of the LfD adder.""" + +import collections + +from acme import adders +from acme import types +from acme.agents.jax.lfd import lfd_adder +import dm_env +import numpy as np + +from absl.testing import absltest + + +class TestStatisticsAdder(adders.Adder): + + def __init__(self): + self.counts = collections.defaultdict(int) + + def reset(self): + pass + + def add_first(self, timestep: dm_env.TimeStep): + self.counts[int(timestep.observation[0])] += 1 + + def add(self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + extras: types.NestedArray = ()): + del action + del extras + self.counts[int(next_timestep.observation[0])] += 1 + + +class LfdAdderTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self._demonstration_episode_type = 1 + self._demonstration_episode_length = 10 + self._collected_episode_type = 2 + self._collected_episode_length = 5 + + def generate_episode(self, episode_type, episode_index, length): + episode = [] + action_dim = 8 + obs_dim = 16 + for k in range(length): + if k == 0: + action = None + else: + action = np.concatenate([ + np.asarray([episode_type, episode_index], dtype=np.float32), + np.random.uniform(0., 1., (action_dim - 2,))]) + observation = np.concatenate([ + np.asarray([episode_type, episode_index], dtype=np.float32), + np.random.uniform(0., 1., (obs_dim - 2,))]) + if k == 0: + timestep = dm_env.restart(observation) + elif k == length - 1: + timestep = dm_env.termination(0., observation) + else: + timestep = dm_env.transition(0., observation, 1.) + episode.append((action, timestep)) + return episode + + def generate_demonstration(self): + episode_index = 0 + while True: + episode = self.generate_episode(self._demonstration_episode_type, + episode_index, + self._demonstration_episode_length) + for x in episode: + yield x + episode_index += 1 + + def test_adder(self): + stats_adder = TestStatisticsAdder() + demonstration_ratio = 0.2 + initial_insert_count = 50 + adder = lfd_adder.LfdAdder( + stats_adder, + self.generate_demonstration(), + initial_insert_count=initial_insert_count, + demonstration_ratio=demonstration_ratio) + + num_episodes = 100 + for episode_index in range(num_episodes): + episode = self.generate_episode(self._collected_episode_type, + episode_index, + self._collected_episode_length) + for k, (action, timestep) in enumerate(episode): + if k == 0: + adder.add_first(timestep) + if episode_index == 0: + self.assertGreaterEqual( + stats_adder.counts[self._demonstration_episode_type], + initial_insert_count - self._demonstration_episode_length) + self.assertLessEqual( + stats_adder.counts[self._demonstration_episode_type], + initial_insert_count + self._demonstration_episode_length) + else: + adder.add(action, timestep) + + # Only 2 types of episodes. + self.assertLen(stats_adder.counts, 2) + + total_count = (stats_adder.counts[self._demonstration_episode_type] + + stats_adder.counts[self._collected_episode_type]) + # The demonstration ratio does not account for the initial demonstration + # insertion. Computes a ratio that takes it into account. + target_ratio = ( + demonstration_ratio * (float)(total_count - initial_insert_count) + + initial_insert_count) / (float)(total_count) + # Effective ratio of demonstrations. + effective_ratio = ( + float(stats_adder.counts[self._demonstration_episode_type]) / + float(total_count)) + # Only full episodes can be fed to the adder so the effective ratio + # might be slightly different from the requested demonstration ratio. + min_ratio = (target_ratio - + self._demonstration_episode_length / float(total_count)) + max_ratio = (target_ratio + + self._demonstration_episode_length / float(total_count)) + self.assertGreaterEqual(effective_ratio, min_ratio) + self.assertLessEqual(effective_ratio, max_ratio) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/jax/lfd/sacfd.py b/acme/acme/agents/jax/lfd/sacfd.py new file mode 100644 index 00000000..bd4c2d49 --- /dev/null +++ b/acme/acme/agents/jax/lfd/sacfd.py @@ -0,0 +1,47 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SAC agent learning from demonstrations.""" + +import dataclasses +from typing import Callable, Iterator + +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import sac +from acme.agents.jax.lfd import builder +from acme.agents.jax.lfd import config +import reverb + + +@dataclasses.dataclass +class SACfDConfig: + """Configuration options specific to SAC with demonstrations. + + Attributes: + lfd_config: LfD config. + sac_config: SAC config. + """ + lfd_config: config.LfdConfig + sac_config: sac.SACConfig + + +class SACfDBuilder(builder.LfdBuilder[sac.SACNetworks, + actor_core_lib.FeedForwardPolicy, + reverb.ReplaySample]): + """Builder for SAC agent learning from demonstrations.""" + + def __init__(self, sac_fd_config: SACfDConfig, + lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]]): + sac_builder = sac.SACBuilder(sac_fd_config.sac_config) + super().__init__(sac_builder, lfd_iterator_fn, sac_fd_config.lfd_config) diff --git a/acme/acme/agents/jax/lfd/td3fd.py b/acme/acme/agents/jax/lfd/td3fd.py new file mode 100644 index 00000000..531bfe35 --- /dev/null +++ b/acme/acme/agents/jax/lfd/td3fd.py @@ -0,0 +1,47 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TD3 agent learning from demonstrations.""" + +import dataclasses +from typing import Callable, Iterator + +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import td3 +from acme.agents.jax.lfd import builder +from acme.agents.jax.lfd import config +import reverb + + +@dataclasses.dataclass +class TD3fDConfig: + """Configuration options specific to TD3 with demonstrations. + + Attributes: + lfd_config: LfD config. + td3_config: TD3 config. + """ + lfd_config: config.LfdConfig + td3_config: td3.TD3Config + + +class TD3fDBuilder(builder.LfdBuilder[td3.TD3Networks, + actor_core_lib.FeedForwardPolicy, + reverb.ReplaySample]): + """Builder for TD3 agent learning from demonstrations.""" + + def __init__(self, td3_fd_config: TD3fDConfig, + lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]]): + td3_builder = td3.TD3Builder(td3_fd_config.td3_config) + super().__init__(td3_builder, lfd_iterator_fn, td3_fd_config.lfd_config) diff --git a/acme/acme/agents/jax/mbop/README.md b/acme/acme/agents/jax/mbop/README.md new file mode 100644 index 00000000..36f003c7 --- /dev/null +++ b/acme/acme/agents/jax/mbop/README.md @@ -0,0 +1,16 @@ +# Model-Based Offline Planning (MBOP) + +This folder contains an implementation of the MBOP algorithm ([Argenson and +Dulac-Arnold, 2021]). It is an offline RL algorithm that generates a model that +can be used to control the system directly through planning. The learning +components, i.e. the world model, policy prior and the n-step return, are simple +supervised ensemble learners. It uses the Model-Predictive Path Integral control +planner. + +The networks assume continuous and flattened observation and action spaces. The +dataset, i.e. demonstrations, should be in timestep-batched format (i.e. triple +transitions of the previous, current and next timesteps) and normalized. See +dataset.py file for helper functions for loading RLDS datasets and +normalization. + +[Argenson and Dulac-Arnold, 2021]: https://arxiv.org/abs/2008.05556 diff --git a/acme/acme/agents/jax/mbop/__init__.py b/acme/acme/agents/jax/mbop/__init__.py new file mode 100644 index 00000000..f0746a4d --- /dev/null +++ b/acme/acme/agents/jax/mbop/__init__.py @@ -0,0 +1,49 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementation of the Model-Based Offline Planning (MBOP) agent.""" + +from acme.agents.jax.mbop.acting import ActorCore +from acme.agents.jax.mbop.acting import make_actor +from acme.agents.jax.mbop.acting import make_actor_core +from acme.agents.jax.mbop.acting import make_ensemble_actor_core +from acme.agents.jax.mbop.dataset import EPISODE_RETURN +from acme.agents.jax.mbop.dataset import episodes_to_timestep_batched_transitions +from acme.agents.jax.mbop.dataset import get_normalization_stats +from acme.agents.jax.mbop.dataset import N_STEP_RETURN +from acme.agents.jax.mbop.learning import LoggerFn +from acme.agents.jax.mbop.learning import make_ensemble_regressor_learner +from acme.agents.jax.mbop.learning import MakeNStepReturnLearner +from acme.agents.jax.mbop.learning import MakePolicyPriorLearner +from acme.agents.jax.mbop.learning import MakeWorldModelLearner +from acme.agents.jax.mbop.learning import MBOPLearner +from acme.agents.jax.mbop.learning import TrainingState +from acme.agents.jax.mbop.losses import MBOPLosses +from acme.agents.jax.mbop.losses import policy_prior_loss +from acme.agents.jax.mbop.losses import TransitionLoss +from acme.agents.jax.mbop.losses import world_model_loss +from acme.agents.jax.mbop.models import make_ensemble_n_step_return +from acme.agents.jax.mbop.models import make_ensemble_policy_prior +from acme.agents.jax.mbop.models import make_ensemble_world_model +from acme.agents.jax.mbop.models import MakeNStepReturn +from acme.agents.jax.mbop.models import MakePolicyPrior +from acme.agents.jax.mbop.models import MakeWorldModel +from acme.agents.jax.mbop.mppi import mppi_planner +from acme.agents.jax.mbop.mppi import MPPIConfig +from acme.agents.jax.mbop.mppi import return_top_k_average +from acme.agents.jax.mbop.mppi import return_weighted_average +from acme.agents.jax.mbop.networks import make_networks +from acme.agents.jax.mbop.networks import make_policy_prior_network +from acme.agents.jax.mbop.networks import make_world_model_network +from acme.agents.jax.mbop.networks import MBOPNetworks diff --git a/acme/acme/agents/jax/mbop/acting.py b/acme/acme/agents/jax/mbop/acting.py new file mode 100644 index 00000000..8d06b66f --- /dev/null +++ b/acme/acme/agents/jax/mbop/acting.py @@ -0,0 +1,193 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The MPPI-inspired JAX actor.""" + +from typing import List, Mapping, Optional, Tuple + +from acme import adders +from acme import core +from acme import specs +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import actors +from acme.agents.jax.mbop import models +from acme.agents.jax.mbop import mppi +from acme.agents.jax.mbop import networks as mbop_networks +from acme.jax import networks as networks_lib +from acme.jax import running_statistics +from acme.jax import variable_utils +import jax +from jax import numpy as jnp + +# Recurrent state is the trajectory. +Trajectory = jnp.ndarray + +ActorCore = actor_core_lib.ActorCore[ + actor_core_lib.SimpleActorCoreRecurrentState[Trajectory], + Mapping[str, jnp.ndarray]] + + +def make_actor_core( + mppi_config: mppi.MPPIConfig, + world_model: models.WorldModel, + policy_prior: models.PolicyPrior, + n_step_return: models.NStepReturn, + environment_spec: specs.EnvironmentSpec, + mean_std: Optional[running_statistics.NestedMeanStd] = None, +) -> ActorCore: + """Creates an actor core wrapping the MBOP-configured MPPI planner. + + Args: + mppi_config: Planner hyperparameters. + world_model: A world model. + policy_prior: A policy prior. + n_step_return: An n-step return. + environment_spec: Used to initialize the initial trajectory data structure. + mean_std: Used to undo normalization if the networks trained normalized. + + Returns: + A recurrent actor core. + """ + + if mean_std is not None: + mean_std_observation = running_statistics.NestedMeanStd( + mean=mean_std.mean.observation, std=mean_std.std.observation) + mean_std_action = running_statistics.NestedMeanStd( + mean=mean_std.mean.action, std=mean_std.std.action) + mean_std_reward = running_statistics.NestedMeanStd( + mean=mean_std.mean.reward, std=mean_std.std.reward) + mean_std_n_step_return = running_statistics.NestedMeanStd( + mean=mean_std.mean.extras['n_step_return'], + std=mean_std.std.extras['n_step_return']) + + def denormalized_world_model( + params: networks_lib.Params, observation_t: networks_lib.Observation, + action_t: networks_lib.Action + ) -> Tuple[networks_lib.Observation, networks_lib.Value]: + """Denormalizes the reward for proper weighting in the planner.""" + observation_tp1, normalized_reward_t = world_model( + params, observation_t, action_t) + reward_t = running_statistics.denormalize(normalized_reward_t, + mean_std_reward) + return observation_tp1, reward_t + + planner_world_model = denormalized_world_model + + def denormalized_n_step_return( + params: networks_lib.Params, observation_t: networks_lib.Observation, + action_t: networks_lib.Action) -> networks_lib.Value: + """Denormalize the n-step return for proper weighting in the planner.""" + normalized_n_step_return_t = n_step_return(params, observation_t, + action_t) + return running_statistics.denormalize(normalized_n_step_return_t, + mean_std_n_step_return) + + planner_n_step_return = denormalized_n_step_return + else: + planner_world_model = world_model + planner_n_step_return = n_step_return + + def recurrent_policy( + params_list: List[networks_lib.Params], + random_key: networks_lib.PRNGKey, + observation: networks_lib.Observation, + previous_trajectory: Trajectory, + ) -> Tuple[networks_lib.Action, Trajectory]: + # Note that splitting the random key is handled by GenericActor. + if mean_std is not None: + observation = running_statistics.normalize( + observation, mean_std=mean_std_observation) + trajectory = mppi.mppi_planner( + config=mppi_config, + world_model=planner_world_model, + policy_prior=policy_prior, + n_step_return=planner_n_step_return, + world_model_params=params_list[0], + policy_prior_params=params_list[1], + n_step_return_params=params_list[2], + random_key=random_key, + observation=observation, + previous_trajectory=previous_trajectory) + action = trajectory[0, ...] + if mean_std is not None: + action = running_statistics.denormalize(action, mean_std=mean_std_action) + return (action, trajectory) + + batched_policy = jax.vmap(recurrent_policy, in_axes=(None, None, 0, 0)) + batched_policy = jax.jit(batched_policy) + + initial_trajectory = mppi.get_initial_trajectory( + config=mppi_config, env_spec=environment_spec) + initial_trajectory = jnp.expand_dims(initial_trajectory, axis=0) + + return actor_core_lib.batched_recurrent_to_actor_core(batched_policy, + initial_trajectory) + + +def make_ensemble_actor_core( + networks: mbop_networks.MBOPNetworks, + mppi_config: mppi.MPPIConfig, + environment_spec: specs.EnvironmentSpec, + mean_std: Optional[running_statistics.NestedMeanStd] = None, + use_round_robin: bool = True, +) -> ActorCore: + """Creates an actor core that uses ensemble models. + + Args: + networks: MBOP networks. + mppi_config: Planner hyperparameters. + environment_spec: Used to initialize the initial trajectory data structure. + mean_std: Used to undo normalization if the networks trained normalized. + use_round_robin: Whether to use round robin or mean to calculate the policy + prior over the ensemble members. + + Returns: + A recurrent actor core. + """ + world_model = models.make_ensemble_world_model(networks.world_model_network) + policy_prior = models.make_ensemble_policy_prior( + networks.policy_prior_network, + environment_spec, + use_round_robin=use_round_robin) + n_step_return = models.make_ensemble_n_step_return( + networks.n_step_return_network) + + return make_actor_core(mppi_config, world_model, policy_prior, n_step_return, + environment_spec, mean_std) + + +def make_actor(actor_core: ActorCore, + random_key: networks_lib.PRNGKey, + variable_source: core.VariableSource, + adder: Optional[adders.Adder] = None) -> core.Actor: + """Creates an MBOP actor from an actor core. + + Args: + actor_core: An MBOP actor core. + random_key: JAX Random key. + variable_source: The source to get networks parameters from. + adder: An adder to add experiences to. The `extras` of the adder holds the + state of the recurrent policy. If `has_extras=True` then the `extras` part + returned from the recurrent policy is appended to the state before added + to the adder. + + Returns: + A recurrent actor. + """ + variable_client = variable_utils.VariableClient( + client=variable_source, + key=['world_model-policy', 'policy_prior-policy', 'n_step_return-policy']) + + return actors.GenericActor( + actor_core, random_key, variable_client, adder, backend=None) diff --git a/acme/acme/agents/jax/mbop/agent_test.py b/acme/acme/agents/jax/mbop/agent_test.py new file mode 100644 index 00000000..db0fcadc --- /dev/null +++ b/acme/acme/agents/jax/mbop/agent_test.py @@ -0,0 +1,92 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the MBOP agent.""" + +import functools + +from acme import specs +from acme import types +from acme.agents.jax.mbop import learning +from acme.agents.jax.mbop import losses as mbop_losses +from acme.agents.jax.mbop import networks as mbop_networks +from acme.testing import fakes +from acme.utils import loggers +import chex +import jax +import optax +import rlds + +from absl.testing import absltest + + +class MBOPTest(absltest.TestCase): + + def test_learner(self): + with chex.fake_pmap_and_jit(): + num_sgd_steps_per_step = 1 + num_steps = 5 + num_networks = 7 + + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment( + episode_length=10, bounded=True, observation_dim=3, action_dim=2) + + spec = specs.make_environment_spec(environment) + dataset = fakes.transition_dataset(environment) + + # Add dummy n-step return to the transitions. + def _add_dummy_n_step_return(sample): + return types.Transition(*sample.data)._replace( + extras={'n_step_return': 1.0}) + + dataset = dataset.map(_add_dummy_n_step_return) + # Convert into time-batched format with previous, current and next + # transitions. + dataset = rlds.transformations.batch(dataset, 3) + dataset = dataset.batch(8).as_numpy_iterator() + + # Use the default networks and losses. + networks = mbop_networks.make_networks(spec) + losses = mbop_losses.MBOPLosses() + + def logger_fn(label: str, steps_key: str): + return loggers.make_default_logger(label, steps_key=steps_key) + + def make_learner_fn(name, logger_fn, counter, rng_key, dataset, network, + loss): + return learning.make_ensemble_regressor_learner(name, num_networks, + logger_fn, counter, + rng_key, dataset, + network, loss, + optax.adam(0.01), + num_sgd_steps_per_step) + + learner = learning.MBOPLearner( + networks, losses, dataset, jax.random.PRNGKey(0), logger_fn, + functools.partial(make_learner_fn, 'world_model'), + functools.partial(make_learner_fn, 'policy_prior'), + functools.partial(make_learner_fn, 'n_step_return')) + + # Train the agent + for _ in range(num_steps): + learner.step() + + # Save and restore. + learner_state = learner.save() + learner.restore(learner_state) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/jax/mbop/dataset.py b/acme/acme/agents/jax/mbop/dataset.py new file mode 100644 index 00000000..22bfdf20 --- /dev/null +++ b/acme/acme/agents/jax/mbop/dataset.py @@ -0,0 +1,220 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset related definitions and methods.""" + +import functools +import itertools +from typing import Iterator, Optional + +from acme import types +from acme.jax import running_statistics +import jax +import jax.numpy as jnp +import rlds +import tensorflow as tf +import tree + +# Keys in extras dictionary of the transitions. +# Total return over n-steps. +N_STEP_RETURN: str = 'n_step_return' +# Total return of the episode that the transition belongs to. +EPISODE_RETURN: str = 'episode_return' + +# Indices of the time-batched transitions. +PREVIOUS: int = 0 +CURRENT: int = 1 +NEXT: int = 2 + + +def _append_n_step_return(output, n_step_return): + """Append n-step return to an output step.""" + output[N_STEP_RETURN] = n_step_return + return output + + +def _append_episode_return(output, episode_return): + """Append episode return to an output step.""" + output[EPISODE_RETURN] = episode_return + return output + + +def _expand_scalars(output): + """If rewards are scalar, expand them.""" + return tree.map_structure(tf.experimental.numpy.atleast_1d, output) + + +def episode_to_timestep_batch( + episode: rlds.BatchedStep, + return_horizon: int = 0, + drop_return_horizon: bool = False, + calculate_episode_return: bool = False) -> tf.data.Dataset: + """Converts an episode into multi-timestep batches. + + Args: + episode: Batched steps as provided directly by RLDS. + return_horizon: int describing the horizon to which we should accumulate the + return. + drop_return_horizon: bool whether we should drop the last `return_horizon` + steps to avoid mis-calculated returns near the end of the episode. + calculate_episode_return: Whether to calculate episode return. Can be an + expensive operation on datasets with many episodes. + + Returns: + rl_dataset.DatasetType of 3-batched transitions, with scalar rewards + expanded to 1D rewards + + This means that for every step, the corresponding elements will be a batch of + size 3, with the first batched element corresponding to *_t-1, the second to + *_t and the third to *_t+1, e.g. you can access the previous observation as: + ``` + o_tm1 = el[types.OBSERVATION][0] + ``` + Two additional keys can be added: 'R_t' which corresponds to the undiscounted + return for horizon `return_horizon` from time t (always present), and + 'R_total' which corresponds to the total return of the associated episode (if + `calculate_episode_return` is True). Rewards are converted to be (at least) + one-dimensional, prior to batching (to avoid ()-shaped elements). + + In this example, 0-valued observations correspond to o_{t-1}, 1-valued + observations correspond to o_t, and 2-valued observations correspond to + s_{t+1}. This same structure is true for all keys, except 'R_t' and 'R_total' + which are both scalars. + ``` + ipdb> el[types.OBSERVATION] + + ``` + """ + steps = episode[rlds.STEPS] + + if drop_return_horizon: + episode_length = steps.cardinality() + steps = steps.take(episode_length - return_horizon) + + # Calculate n-step return: + rewards = steps.map(lambda step: step[rlds.REWARD]) + batched_rewards = rlds.transformations.batch( + rewards, size=return_horizon, shift=1, stride=1, drop_remainder=True) + returns = batched_rewards.map(tf.math.reduce_sum) + output = tf.data.Dataset.zip((steps, returns)).map(_append_n_step_return) + + # Calculate total episode return for potential filtering, use total # of steps + # to calculate return. + if calculate_episode_return: + dtype = jnp.float64 if jax.config.jax_enable_x64 else jnp.float32 + # Need to redefine this here to avoid a tf.data crash. + rewards = steps.map(lambda step: step[rlds.REWARD]) + episode_return = rewards.reduce(dtype(0), lambda x, y: x + y) + output = output.map( + functools.partial( + _append_episode_return, episode_return=episode_return)) + + output = output.map(_expand_scalars) + + output = rlds.transformations.batch( + output, size=3, shift=1, drop_remainder=True) + return output + + +def _step_to_transition(rlds_step: rlds.BatchedStep) -> types.Transition: + """Converts batched RLDS steps to batched transitions.""" + return types.Transition( + observation=rlds_step[rlds.OBSERVATION], + action=rlds_step[rlds.ACTION], + reward=rlds_step[rlds.REWARD], + discount=rlds_step[rlds.DISCOUNT], + # We provide next_observation if an algorithm needs it, however note that + # it will only contain s_t and s_t+1, so will be one element short of all + # other attributes (which contain s_t-1, s_t, s_t+1). + next_observation=tree.map_structure(lambda x: x[1:], + rlds_step[rlds.OBSERVATION]), + extras={ + N_STEP_RETURN: rlds_step[N_STEP_RETURN], + }) + + +def episodes_to_timestep_batched_transitions( + episode_dataset: tf.data.Dataset, + return_horizon: int = 10, + drop_return_horizon: bool = False, + min_return_filter: Optional[float] = None) -> tf.data.Dataset: + """Process an existing dataset converting it to episode to 3-transitions. + + A 3-transition is an Transition with each attribute having an extra dimension + of size 3, representing 3 consecutive timesteps. Each 3-step object will be + in random order relative to each other. See `episode_to_timestep_batch` for + more information. + + Args: + episode_dataset: An RLDS dataset to process. + return_horizon: The horizon we want calculate Monte-Carlo returns to. + drop_return_horizon: Whether we should drop the last `return_horizon` steps. + min_return_filter: Minimum episode return below which we drop an episode. + + Returns: + A tf.data.Dataset of 3-transitions. + """ + dataset = episode_dataset.interleave( + functools.partial( + episode_to_timestep_batch, + return_horizon=return_horizon, + drop_return_horizon=drop_return_horizon, + calculate_episode_return=min_return_filter is not None), + num_parallel_calls=tf.data.experimental.AUTOTUNE, + deterministic=False) + + if min_return_filter is not None: + + def filter_on_return(step): + return step[EPISODE_RETURN][0][0] > min_return_filter + + dataset = dataset.filter(filter_on_return) + + dataset = dataset.map( + _step_to_transition, num_parallel_calls=tf.data.experimental.AUTOTUNE) + + return dataset + + +def get_normalization_stats( + iterator: Iterator[types.Transition], + num_normalization_batches: int = 50 +) -> running_statistics.RunningStatisticsState: + """Precomputes normalization statistics over a fixed number of batches. + + The iterator should contain batches of 3-transitions, i.e. with two leading + dimensions, the first one denoting the batch dimension and the second one the + previous, current and next timesteps. The statistics are calculated using the + data of the previous timestep. + + Args: + iterator: Iterator of batchs of 3-transitions. + num_normalization_batches: Number of batches to calculate the statistics. + + Returns: + RunningStatisticsState containing the normalization statistics. + """ + # Set up normalization: + example = next(iterator) + unbatched_single_example = jax.tree_map(lambda x: x[0, PREVIOUS, :], example) + mean_std = running_statistics.init_state(unbatched_single_example) + + for batch in itertools.islice(iterator, num_normalization_batches - 1): + example = jax.tree_map(lambda x: x[:, PREVIOUS, :], batch) + mean_std = running_statistics.update(mean_std, example) + + return mean_std diff --git a/acme/acme/agents/jax/mbop/dataset_test.py b/acme/acme/agents/jax/mbop/dataset_test.py new file mode 100644 index 00000000..e0f1a938 --- /dev/null +++ b/acme/acme/agents/jax/mbop/dataset_test.py @@ -0,0 +1,194 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for dataset.""" + +from acme.agents.jax.mbop import dataset as dataset_lib +import rlds +from rlds.transformations import transformations_testlib +import tensorflow as tf + +from absl.testing import absltest + + +def sample_episode() -> rlds.Episode: + """Returns a sample episode.""" + steps = { + rlds.OBSERVATION: [ + [1, 1], + [2, 2], + [3, 3], + [4, 4], + [5, 5], + ], + rlds.ACTION: [[1], [2], [3], [4], [5]], + rlds.REWARD: [1.0, 2.0, 3.0, 4.0, 5.0], + rlds.DISCOUNT: [1, 1, 1, 1, 1], + rlds.IS_FIRST: [True, False, False, False, False], + rlds.IS_LAST: [False, False, False, False, True], + rlds.IS_TERMINAL: [False, False, False, False, True], + } + return {rlds.STEPS: tf.data.Dataset.from_tensor_slices(steps)} + + +class DatasetTest(transformations_testlib.TransformationsTest): + + def test_episode_to_timestep_batch(self): + batched = dataset_lib.episode_to_timestep_batch( + sample_episode(), return_horizon=2) + + # Scalars should be expanded and the n-step return should be present. Each + # element of a step should be a triplet containing the previous, current and + # next values of the corresponding fields. Since the return horizon is 2 and + # the number of steps in the episode is 5, there can be only 2 triplets for + # time steps 1 and 2. + expected_steps = { + rlds.OBSERVATION: [ + [[1, 1], [2, 2], [3, 3]], + [[2, 2], [3, 3], [4, 4]], + ], + rlds.ACTION: [ + [[1], [2], [3]], + [[2], [3], [4]], + ], + rlds.REWARD: [ + [[1.0], [2.0], [3.0]], + [[2.0], [3.0], [4.0]], + ], + rlds.DISCOUNT: [ + [[1], [1], [1]], + [[1], [1], [1]], + ], + rlds.IS_FIRST: [ + [[True], [False], [False]], + [[False], [False], [False]], + ], + rlds.IS_LAST: [ + [[False], [False], [False]], + [[False], [False], [False]], + ], + rlds.IS_TERMINAL: [ + [[False], [False], [False]], + [[False], [False], [False]], + ], + dataset_lib.N_STEP_RETURN: [ + [[3.0], [5.0], [7.0]], + [[5.0], [7.0], [9.0]], + ], + } + + self.expect_equal_datasets( + batched, tf.data.Dataset.from_tensor_slices(expected_steps)) + + def test_episode_to_timestep_batch_episode_return(self): + batched = dataset_lib.episode_to_timestep_batch( + sample_episode(), return_horizon=3, calculate_episode_return=True) + + expected_steps = { + rlds.OBSERVATION: [[[1, 1], [2, 2], [3, 3]]], + rlds.ACTION: [[[1], [2], [3]]], + rlds.REWARD: [[[1.0], [2.0], [3.0]]], + rlds.DISCOUNT: [[[1], [1], [1]]], + rlds.IS_FIRST: [[[True], [False], [False]]], + rlds.IS_LAST: [[[False], [False], [False]]], + rlds.IS_TERMINAL: [[[False], [False], [False]]], + dataset_lib.N_STEP_RETURN: [[[6.0], [9.0], [12.0]]], + # This should match to the sum of the rewards in the input. + dataset_lib.EPISODE_RETURN: [[[15.0], [15.0], [15.0]]], + } + + self.expect_equal_datasets( + batched, tf.data.Dataset.from_tensor_slices(expected_steps)) + + def test_episode_to_timestep_batch_no_return_horizon(self): + batched = dataset_lib.episode_to_timestep_batch( + sample_episode(), return_horizon=1) + + expected_steps = { + rlds.OBSERVATION: [ + [[1, 1], [2, 2], [3, 3]], + [[2, 2], [3, 3], [4, 4]], + [[3, 3], [4, 4], [5, 5]], + ], + rlds.ACTION: [ + [[1], [2], [3]], + [[2], [3], [4]], + [[3], [4], [5]], + ], + rlds.REWARD: [ + [[1.0], [2.0], [3.0]], + [[2.0], [3.0], [4.0]], + [[3.0], [4.0], [5.0]], + ], + rlds.DISCOUNT: [ + [[1], [1], [1]], + [[1], [1], [1]], + [[1], [1], [1]], + ], + rlds.IS_FIRST: [ + [[True], [False], [False]], + [[False], [False], [False]], + [[False], [False], [False]], + ], + rlds.IS_LAST: [ + [[False], [False], [False]], + [[False], [False], [False]], + [[False], [False], [True]], + ], + rlds.IS_TERMINAL: [ + [[False], [False], [False]], + [[False], [False], [False]], + [[False], [False], [True]], + ], + # n-step return should be equal to the rewards. + dataset_lib.N_STEP_RETURN: [ + [[1.0], [2.0], [3.0]], + [[2.0], [3.0], [4.0]], + [[3.0], [4.0], [5.0]], + ], + } + + self.expect_equal_datasets( + batched, tf.data.Dataset.from_tensor_slices(expected_steps)) + + def test_episode_to_timestep_batch_drop_return_horizon(self): + steps = { + rlds.OBSERVATION: [[1], [2], [3], [4], [5], [6]], + rlds.REWARD: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + } + episode = {rlds.STEPS: tf.data.Dataset.from_tensor_slices(steps)} + + batched = dataset_lib.episode_to_timestep_batch( + episode, + return_horizon=2, + calculate_episode_return=True, + drop_return_horizon=True) + + # The two steps of the episode should be dropped. There will be 4 steps left + # and since the return horizon is 2, only a single 3-batched step should be + # emitted. The episode return should be the sum of the rewards of the first + # 4 steps. + expected_steps = { + rlds.OBSERVATION: [[[1], [2], [3]]], + rlds.REWARD: [[[1.0], [2.0], [3.0]]], + dataset_lib.N_STEP_RETURN: [[[3.0], [5.0], [7.0]]], + dataset_lib.EPISODE_RETURN: [[[10.0], [10.0], [10.0]]], + } + + self.expect_equal_datasets( + batched, tf.data.Dataset.from_tensor_slices(expected_steps)) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/jax/mbop/ensemble.py b/acme/acme/agents/jax/mbop/ensemble.py new file mode 100644 index 00000000..c7ccc412 --- /dev/null +++ b/acme/acme/agents/jax/mbop/ensemble.py @@ -0,0 +1,166 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module to provide ensembling support on top of a base network.""" +import functools +from typing import (Any, Callable) + +from acme.jax import networks +import jax +import jax.numpy as jnp + + +def _split_batch_dimension(new_batch: int, data: jnp.ndarray) -> jnp.ndarray: + """Splits the batch dimension and introduces new one with size `new_batch`. + + The result has two batch dimensions, first one of size `new_batch`, second one + of size `data.shape[0]/new_batch`. It expects that `data.shape[0]` is + divisible by `new_batch`. + + Args: + new_batch: Dimension of outer batch dimension. + data: jnp.ndarray to be reshaped. + + Returns: + jnp.ndarray with extra batch dimension at start and updated second + dimension. + """ + # The first dimension will be used for allocating to a specific ensemble + # member, and the second dimension is the parallelized batch dimension, and + # the remaining dimensions are passed as-is to the wrapped network. + # We use Fortan (F) order so that each input batch i is allocated to + # ensemble member k = i % new_batch. + return jnp.reshape(data, (new_batch, -1) + data.shape[1:], order='F') + + +def _repeat_n(new_batch: int, data: jnp.ndarray) -> jnp.ndarray: + """Create new batch dimension of size `new_batch` by repeating `data`.""" + return jnp.broadcast_to(data, (new_batch,) + data.shape) + + +def ensemble_init(base_init: Callable[[networks.PRNGKey], networks.Params], + num_networks: int, rnd: jnp.ndarray): + """Initializes the ensemble parameters. + + Args: + base_init: An init function that takes only a PRNGKey, if a network's init + function requires other parameters such as example inputs they need to + have been previously wrapped i.e. with functool.partial using kwargs. + num_networks: Number of networks to generate parameters for. + rnd: PRNGKey to split from when generating parameters. + + Returns: + `params` for the set of ensemble networks. + """ + rnds = jax.random.split(rnd, num_networks) + return jax.vmap(base_init)(rnds) + + +def apply_round_robin(base_apply: Callable[[networks.Params, Any], Any], + params: networks.Params, *args, **kwargs) -> Any: + """Passes the input in a round-robin manner. + + The round-robin application means that each element of the input batch will + be passed through a single ensemble member in a deterministic round-robin + manner, i.e. element_i -> member_k where k = i % num_networks. + + It expects that: + * `base_apply(member_params, *member_args, **member_kwargs)` is a valid call, + where: + * `member_params.shape = params.shape[1:]` + * `member_args` and `member_kwargs` have the same structure as `args` and + `kwargs`. + * `params[k]` contains the params of the k-th member of the ensemble. + * All jax arrays in `args` and `kwargs` have a batch dimension at axis 0 of + the same size, which is divisible by `params.shape[0]`. + + Args: + base_apply: Base network `apply` function that will be used for round-robin. + NOTE -- This will not work with mutable/stateful apply functions. -- + params: Model parameters. Number of networks is deduced from this. + *args: Allows for arbitrary call signatures for `base_apply`. + **kwargs: Allows for arbitrary call signatures for `base_apply`. + + Returns: + pytree of the round-robin application. + Output shape will be [initial_batch_size, ]. + """ + # `num_networks` is the size of the batch dimension in `params`. + num_networks = jax.tree_util.tree_leaves(params)[0].shape[0] + + # Reshape args and kwargs for the round-robin: + args = jax.tree_map( + functools.partial(_split_batch_dimension, num_networks), args) + kwargs = jax.tree_map( + functools.partial(_split_batch_dimension, num_networks), kwargs) + # `out.shape` is `(num_networks, initial_batch_size/num_networks, ...) + out = jax.vmap(base_apply)(params, *args, **kwargs) + # Reshape to [initial_batch_size, ]. Using the 'F' order + # forces the original values to the last dimension. + return jax.tree_map(lambda x: x.reshape((-1,) + x.shape[2:], order='F'), out) + + +def apply_all(base_apply: Callable[[networks.Params, Any], Any], + params: networks.Params, *args, **kwargs) -> Any: + """Pass the input to all ensemble members. + + Inputs can either have a batch dimension which will get implicitly vmapped + over, or can be a single vector which will get sent to all ensemble members. + e.g. [] or [batch_size, ]. + + Args: + base_apply: Base network `apply` function that will be used for averaging. + NOTE -- This will not work with mutable/stateful apply functions. -- + params: Model parameters. Number of networks is deduced from this. + *args: Allows for arbitrary call signatures for `base_apply`. + **kwargs: Allows for arbitrary call signatures for `base_apply`. + + Returns: + pytree of the resulting output of passing input to all ensemble members. + Output shape will be [num_members, batch_size, ]. + """ + # `num_networks` is the size of the batch dimension in `params`. + num_networks = jax.tree_util.tree_leaves(params)[0].shape[0] + + args = jax.tree_map(functools.partial(_repeat_n, num_networks), args) + kwargs = jax.tree_map(functools.partial(_repeat_n, num_networks), kwargs) + # `out` is of shape `(num_networks, batch_size, )`. + return jax.vmap(base_apply)(params, *args, **kwargs) + + +def apply_mean(base_apply: Callable[[networks.Params, Any], Any], + params: networks.Params, *args, **kwargs) -> Any: + """Calculates the mean over all ensemble members for each batch element. + + Args: + base_apply: Base network `apply` function that will be used for averaging. + NOTE -- This will not work with mutable/stateful apply functions. -- + params: Model parameters. Number of networks is deduced from this. + *args: Allows for arbitrary call signatures for `base_apply`. + **kwargs: Allows for arbitrary call signatures for `base_apply`. + + Returns: + pytree of the average over all ensembles for each element. + Output shape will be [batch_size, ] + """ + out = apply_all(base_apply, params, *args, **kwargs) + return jax.tree_map(functools.partial(jnp.mean, axis=0), out) + + +def make_ensemble(base_network: networks.FeedForwardNetwork, + ensemble_apply: Callable[..., Any], + num_networks: int) -> networks.FeedForwardNetwork: + return networks.FeedForwardNetwork( + init=functools.partial(ensemble_init, base_network.init, num_networks), + apply=functools.partial(ensemble_apply, base_network.apply)) diff --git a/acme/acme/agents/jax/mbop/ensemble_test.py b/acme/acme/agents/jax/mbop/ensemble_test.py new file mode 100644 index 00000000..9890a781 --- /dev/null +++ b/acme/acme/agents/jax/mbop/ensemble_test.py @@ -0,0 +1,329 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ensemble.""" + +import functools +from typing import Any + +from acme.agents.jax.mbop import ensemble +from acme.jax import networks +from flax import linen as nn +import jax +import jax.numpy as jnp +import numpy as np + +from absl.testing import absltest + + +class RandomFFN(nn.Module): + + @nn.compact + def __call__(self, x): + return nn.Dense(15)(x) + + +def params_adding_ffn(x: jnp.ndarray) -> networks.FeedForwardNetwork: + """Apply adds the parameters to the inputs.""" + return networks.FeedForwardNetwork( + init=lambda key, x=x: jax.random.uniform(key, x.shape), + apply=lambda params, x: params + x) + + +def funny_args_ffn(x: jnp.ndarray) -> networks.FeedForwardNetwork: + """Apply takes additional parameters, returns `params + x + foo - bar`.""" + return networks.FeedForwardNetwork( + init=lambda key, x=x: jax.random.uniform(key, x.shape), + apply=lambda params, x, foo, bar: params + x + foo - bar) + + +def struct_params_adding_ffn(sx: Any) -> networks.FeedForwardNetwork: + """Like params_adding_ffn, but with pytree inputs, preserves structure.""" + + def init_fn(key, sx=sx): + return jax.tree_map(lambda x: jax.random.uniform(key, x.shape), sx) + + def apply_fn(params, x): + return jax.tree_map(lambda p, v: p + v, params, x) + + return networks.FeedForwardNetwork(init=init_fn, apply=apply_fn) + + +class EnsembleTest(absltest.TestCase): + + def test_ensemble_init(self): + x = jnp.ones(10) # Base input + + wrapped_ffn = params_adding_ffn(x) + + rr_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_round_robin, num_networks=3) + key = jax.random.PRNGKey(0) + params = rr_ensemble.init(key) + + self.assertTupleEqual(params.shape, (3,) + x.shape) + + # The ensemble dimension is the lead dimension. + self.assertFalse((params[0, ...] == params[1, ...]).all()) + + def test_apply_all(self): + x = jnp.ones(10) # Base input + bx = jnp.ones((7, 10)) # Batched input + + wrapped_ffn = params_adding_ffn(x) + + rr_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_all, num_networks=3) + key = jax.random.PRNGKey(0) + params = rr_ensemble.init(key) + self.assertTupleEqual(params.shape, (3,) + x.shape) + + y = rr_ensemble.apply(params, x) + self.assertTupleEqual(y.shape, (3,) + x.shape) + np.testing.assert_allclose(params, y - jnp.broadcast_to(x, (3,) + x.shape)) + + by = rr_ensemble.apply(params, bx) + # Note: the batch dimension is no longer the leading dimension. + self.assertTupleEqual(by.shape, (3,) + bx.shape) + + def test_apply_round_robin(self): + x = jnp.ones(10) # Base input + bx = jnp.ones((7, 10)) # Batched input + + wrapped_ffn = params_adding_ffn(x) + + rr_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_round_robin, num_networks=3) + key = jax.random.PRNGKey(0) + params = rr_ensemble.init(key) + self.assertTupleEqual(params.shape, (3,) + x.shape) + + y = rr_ensemble.apply(params, jnp.broadcast_to(x, (3,) + x.shape)) + self.assertTupleEqual(y.shape, (3,) + x.shape) + np.testing.assert_allclose(params, y - x) + + # Note: the ensemble dimension must lead, the batch dimension is no longer + # the leading dimension. + by = rr_ensemble.apply( + params, jnp.broadcast_to(jnp.expand_dims(bx, axis=0), (3,) + bx.shape)) + self.assertTupleEqual(by.shape, (3,) + bx.shape) + + # If num_networks=3, then `round_robin(params, input)[4]` should be equal + # to `apply(params[1], input[4])`, etc. + yy = rr_ensemble.apply(params, jnp.broadcast_to(x, (6,) + x.shape)) + self.assertTupleEqual(yy.shape, (6,) + x.shape) + np.testing.assert_allclose( + jnp.concatenate([params, params], axis=0), + yy - jnp.expand_dims(x, axis=0)) + + def test_apply_mean(self): + x = jnp.ones(10) # Base input + bx = jnp.ones((7, 10)) # Batched input + + wrapped_ffn = params_adding_ffn(x) + + rr_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_mean, num_networks=3) + key = jax.random.PRNGKey(0) + params = rr_ensemble.init(key) + self.assertTupleEqual(params.shape, (3,) + x.shape) + self.assertFalse((params[0, ...] == params[1, ...]).all()) + + y = rr_ensemble.apply(params, x) + self.assertTupleEqual(y.shape, x.shape) + np.testing.assert_allclose( + jnp.mean(params, axis=0), y - x, atol=1E-5, rtol=1E-5) + + by = rr_ensemble.apply(params, bx) + self.assertTupleEqual(by.shape, bx.shape) + + def test_apply_all_multiargs(self): + x = jnp.ones(10) # Base input + + wrapped_ffn = funny_args_ffn(x) + + rr_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_all, num_networks=3) + key = jax.random.PRNGKey(0) + params = rr_ensemble.init(key) + self.assertTupleEqual(params.shape, (3,) + x.shape) + + y = rr_ensemble.apply(params, x, 2 * x, x) + self.assertTupleEqual(y.shape, (3,) + x.shape) + np.testing.assert_allclose( + params, + y - jnp.broadcast_to(2 * x, (3,) + x.shape), + atol=1E-5, + rtol=1E-5) + + y = rr_ensemble.apply(params, x, bar=x, foo=2 * x) + self.assertTupleEqual(y.shape, (3,) + x.shape) + np.testing.assert_allclose( + params, + y - jnp.broadcast_to(2 * x, (3,) + x.shape), + atol=1E-5, + rtol=1E-5) + + def test_apply_all_structured(self): + x = jnp.ones(10) + sx = [(3 * x, 2 * x), 5 * x] # Base input + + wrapped_ffn = struct_params_adding_ffn(sx) + + rr_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_all, num_networks=3) + key = jax.random.PRNGKey(0) + params = rr_ensemble.init(key) + + y = rr_ensemble.apply(params, sx) + ex = jnp.broadcast_to(x, (3,) + x.shape) + np.testing.assert_allclose(y[0][0], params[0][0] + 3 * ex) + + def test_apply_round_robin_multiargs(self): + x = jnp.ones(10) # Base input + + wrapped_ffn = funny_args_ffn(x) + + rr_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_round_robin, num_networks=3) + key = jax.random.PRNGKey(0) + params = rr_ensemble.init(key) + self.assertTupleEqual(params.shape, (3,) + x.shape) + + ex = jnp.broadcast_to(x, (3,) + x.shape) + y = rr_ensemble.apply(params, ex, 2 * ex, ex) + self.assertTupleEqual(y.shape, (3,) + x.shape) + np.testing.assert_allclose( + params, + y - jnp.broadcast_to(2 * x, (3,) + x.shape), + atol=1E-5, + rtol=1E-5) + + y = rr_ensemble.apply(params, ex, bar=ex, foo=2 * ex) + self.assertTupleEqual(y.shape, (3,) + x.shape) + np.testing.assert_allclose( + params, + y - jnp.broadcast_to(2 * x, (3,) + x.shape), + atol=1E-5, + rtol=1E-5) + + def test_apply_round_robin_structured(self): + x = jnp.ones(10) + sx = [(3 * x, 2 * x), 5 * x] # Base input + + wrapped_ffn = struct_params_adding_ffn(sx) + + rr_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_round_robin, num_networks=3) + key = jax.random.PRNGKey(0) + params = rr_ensemble.init(key) + + ex = jnp.broadcast_to(x, (3,) + x.shape) + esx = [(3 * ex, 2 * ex), 5 * ex] + y = rr_ensemble.apply(params, esx) + np.testing.assert_allclose(y[0][0], params[0][0] + 3 * ex) + + def test_apply_mean_multiargs(self): + x = jnp.ones(10) # Base input + + wrapped_ffn = funny_args_ffn(x) + + rr_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_mean, num_networks=3) + key = jax.random.PRNGKey(0) + params = rr_ensemble.init(key) + self.assertTupleEqual(params.shape, (3,) + x.shape) + + y = rr_ensemble.apply(params, x, 2 * x, x) + self.assertTupleEqual(y.shape, x.shape) + np.testing.assert_allclose( + jnp.mean(params, axis=0), y - 2 * x, atol=1E-5, rtol=1E-5) + + y = rr_ensemble.apply(params, x, bar=x, foo=2 * x) + self.assertTupleEqual(y.shape, x.shape) + np.testing.assert_allclose( + jnp.mean(params, axis=0), y - 2 * x, atol=1E-5, rtol=1E-5) + + def test_apply_mean_structured(self): + x = jnp.ones(10) + sx = [(3 * x, 2 * x), 5 * x] # Base input + + wrapped_ffn = struct_params_adding_ffn(sx) + + rr_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_mean, num_networks=3) + key = jax.random.PRNGKey(0) + params = rr_ensemble.init(key) + + y = rr_ensemble.apply(params, sx) + np.testing.assert_allclose( + y[0][0], jnp.mean(params[0][0], axis=0) + 3 * x, atol=1E-5, rtol=1E-5) + + def test_round_robin_random(self): + x = jnp.ones(10) # Base input + bx = jnp.ones((9, 10)) # Batched input + ffn = RandomFFN() + wrapped_ffn = networks.FeedForwardNetwork( + init=functools.partial(ffn.init, x=x), apply=ffn.apply) + rr_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_round_robin, num_networks=3) + + key = jax.random.PRNGKey(0) + params = rr_ensemble.init(key) + out = rr_ensemble.apply(params, bx) + # The output should be the same every 3 rows: + blocks = jnp.split(out, 3, axis=0) + np.testing.assert_array_equal(blocks[0], blocks[1]) + np.testing.assert_array_equal(blocks[0], blocks[2]) + self.assertTrue((out[0] != out[1]).any()) + + for i in range(9): + np.testing.assert_allclose( + out[i], + ffn.apply(jax.tree_map(lambda p, i=i: p[i % 3], params), bx[i]), + atol=1E-5, + rtol=1E-5) + + def test_mean_random(self): + x = jnp.ones(10) + bx = jnp.ones((9, 10)) + ffn = RandomFFN() + wrapped_ffn = networks.FeedForwardNetwork( + init=functools.partial(ffn.init, x=x), apply=ffn.apply) + mean_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_mean, num_networks=3) + key = jax.random.PRNGKey(0) + params = mean_ensemble.init(key) + single_output = mean_ensemble.apply(params, x) + self.assertEqual(single_output.shape, (15,)) + batch_output = mean_ensemble.apply(params, bx) + # Make sure all rows are equal: + np.testing.assert_allclose( + jnp.broadcast_to(batch_output[0], batch_output.shape), + batch_output, + atol=1E-5, + rtol=1E-5) + + # Check results explicitly: + all_members = jnp.concatenate([ + jnp.expand_dims( + ffn.apply(jax.tree_map(lambda p, i=i: p[i], params), bx), axis=0) + for i in range(3) + ]) + batch_means = jnp.mean(all_members, axis=0) + np.testing.assert_allclose(batch_output, batch_means, atol=1E-5, rtol=1E-5) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/jax/mbop/learning.py b/acme/acme/agents/jax/mbop/learning.py new file mode 100644 index 00000000..6dfc0e9f --- /dev/null +++ b/acme/acme/agents/jax/mbop/learning.py @@ -0,0 +1,235 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Combined offline learning of world model, policy and N-step return.""" + +import dataclasses +import functools +import itertools +import time +from typing import Any, Callable, Iterator, List, Optional + +from acme import core +from acme import types +from acme.agents.jax import bc +from acme.agents.jax.mbop import ensemble +from acme.agents.jax.mbop import losses as mbop_losses +from acme.agents.jax.mbop import networks as mbop_networks +from acme.jax import networks as networks_lib +from acme.jax import types as jax_types +from acme.jax import utils +from acme.utils import counting +from acme.utils import loggers +import jax +import jax.numpy as jnp +import optax + + +@dataclasses.dataclass +class TrainingState: + """States of the world model, policy prior and n-step return learners.""" + world_model: Any + policy_prior: Any + n_step_return: Any + + +LoggerFn = Callable[[str, str], loggers.Logger] + +# Creates a world model learner. +MakeWorldModelLearner = Callable[[ + LoggerFn, + counting.Counter, + jax_types.PRNGKey, + Iterator[types.Transition], + mbop_networks.WorldModelNetwork, + mbop_losses.TransitionLoss, +], core.Learner] + +# Creates a policy prior learner. +MakePolicyPriorLearner = Callable[[ + LoggerFn, + counting.Counter, + jax_types.PRNGKey, + Iterator[types.Transition], + mbop_networks.PolicyPriorNetwork, + mbop_losses.TransitionLoss, +], core.Learner] + +# Creates an n-step return model learner. +MakeNStepReturnLearner = Callable[[ + LoggerFn, + counting.Counter, + jax_types.PRNGKey, + Iterator[types.Transition], + mbop_networks.NStepReturnNetwork, + mbop_losses.TransitionLoss, +], core.Learner] + + +def make_ensemble_regressor_learner( + name: str, + num_networks: int, + logger_fn: loggers.LoggerFactory, + counter: counting.Counter, + rng_key: jnp.ndarray, + iterator: Iterator[types.Transition], + base_network: networks_lib.FeedForwardNetwork, + loss: mbop_losses.TransitionLoss, + optimizer: optax.GradientTransformation, + num_sgd_steps_per_step: int, +): + """Creates an ensemble regressor learner from the base network. + + Args: + name: Name of the learner used for logging and counters. + num_networks: Number of networks in the ensemble. + logger_fn: Constructs a logger for a label. + counter: Parent counter object. + rng_key: Random key. + iterator: An iterator of time-batched transitions used to train the + networks. + base_network: Base network for the ensemble. + loss: Training loss to use. + optimizer: Optax optimizer. + num_sgd_steps_per_step: Number of gradient updates per step. + + Returns: + An ensemble regressor learner. + """ + mbop_ensemble = ensemble.make_ensemble(base_network, ensemble.apply_all, + num_networks) + local_counter = counting.Counter(parent=counter, prefix=name) + local_logger = logger_fn(name, + local_counter.get_steps_key()) if logger_fn else None + + def loss_fn(networks: bc.BCNetworks, params: networks_lib.Params, + key: jax_types.PRNGKey, + transitions: types.Transition) -> jnp.ndarray: + del key + return loss( + functools.partial(networks.policy_network.apply, params), transitions) + + bc_policy_network = bc.convert_to_bc_network(mbop_ensemble) + bc_networks = bc.BCNetworks(bc_policy_network) + + # This is effectively a regressor learner. + return bc.BCLearner( + bc_networks, + rng_key, + loss_fn, + optimizer, + iterator, + num_sgd_steps_per_step, + logger=local_logger, + counter=local_counter) + + +class MBOPLearner(core.Learner): + """Model-Based Offline Planning (MBOP) learner. + + See https://arxiv.org/abs/2008.05556 for more information. + """ + + def __init__(self, + networks: mbop_networks.MBOPNetworks, + losses: mbop_losses.MBOPLosses, + iterator: Iterator[types.Transition], + rng_key: jax_types.PRNGKey, + logger_fn: LoggerFn, + make_world_model_learner: MakeWorldModelLearner, + make_policy_prior_learner: MakePolicyPriorLearner, + make_n_step_return_learner: MakeNStepReturnLearner, + counter: Optional[counting.Counter] = None): + """Creates an MBOP learner. + + Args: + networks: One network per model. + losses: One loss per model. + iterator: An iterator of time-batched transitions used to train the + networks. + rng_key: Random key. + logger_fn: Constructs a logger for a label. + make_world_model_learner: Function to create the world model learner. + make_policy_prior_learner: Function to create the policy prior learner. + make_n_step_return_learner: Function to create the n-step return learner. + counter: Parent counter object. + """ + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger_fn('', 'steps') + + # Prepare iterators for the learners, to not split the data (preserve sample + # efficiency). + sharded_prefetching_dataset = utils.sharded_prefetch(iterator) + world_model_iterator, policy_prior_iterator, n_step_return_iterator = ( + itertools.tee(sharded_prefetching_dataset, 3)) + + world_model_key, policy_prior_key, n_step_return_key = jax.random.split( + rng_key, 3) + + self._world_model = make_world_model_learner(logger_fn, self._counter, + world_model_key, + world_model_iterator, + networks.world_model_network, + losses.world_model_loss) + self._policy_prior = make_policy_prior_learner( + logger_fn, self._counter, policy_prior_key, policy_prior_iterator, + networks.policy_prior_network, losses.policy_prior_loss) + self._n_step_return = make_n_step_return_learner( + logger_fn, self._counter, n_step_return_key, n_step_return_iterator, + networks.n_step_return_network, losses.n_step_return_loss) + # Start recording timestamps after the first learning step to not report + # "warmup" time. + self._timestamp = None + self._learners = { + 'world_model': self._world_model, + 'policy_prior': self._policy_prior, + 'n_step_return': self._n_step_return + } + + def step(self): + # Step the world model, policy learner and n-step return learners. + self._world_model.step() + self._policy_prior.step() + self._n_step_return.step() + + # Compute the elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + # Increment counts and record the current time. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + # Attempt to write the logs. + self._logger.write({**counts}) + + def get_variables(self, names: List[str]) -> List[types.NestedArray]: + variables = [] + for name in names: + # Variables will be prefixed by the learner names. If separator is not + # found, learner_name=name, which is OK. + learner_name, _, variable_name = name.partition('-') + learner = self._learners[learner_name] + variables.extend(learner.get_variables([variable_name])) + return variables + + def save(self) -> TrainingState: + return TrainingState( + world_model=self._world_model.save(), + policy_prior=self._policy_prior.save(), + n_step_return=self._n_step_return.save()) + + def restore(self, state: TrainingState): + self._world_model.restore(state.world_model) + self._policy_prior.restore(state.policy_prior) + self._n_step_return.restore(state.n_step_return) diff --git a/acme/acme/agents/jax/mbop/losses.py b/acme/acme/agents/jax/mbop/losses.py new file mode 100644 index 00000000..4ec911f4 --- /dev/null +++ b/acme/acme/agents/jax/mbop/losses.py @@ -0,0 +1,126 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Loss function wrappers, assuming a leading batch axis.""" + +import dataclasses +from typing import Any, Callable, Optional, Tuple, Union + +from acme import types +from acme.agents.jax.mbop import dataset +from acme.jax import networks +import jax +import jax.numpy as jnp + +# The apply function takes an observation (and an action) as arguments, and is +# usually a network with bound parameters. +TransitionApplyFn = Callable[[networks.Observation, networks.Action], Any] +ObservationOnlyTransitionApplyFn = Callable[[networks.Observation], Any] +TransitionLoss = Callable[[ + Union[TransitionApplyFn, ObservationOnlyTransitionApplyFn], types.Transition +], jnp.ndarray] + + +def mse(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + """MSE distance.""" + return jnp.mean(jnp.square(a - b)) + + +def world_model_loss(apply_fn: Callable[[networks.Observation, networks.Action], + Tuple[networks.Observation, + jnp.ndarray]], + steps: types.Transition) -> jnp.ndarray: + """Returns the loss for the world model. + + Args: + apply_fn: applies a transition model (o_t, a_t) -> (o_t+1, r), expects the + leading axis to index the batch and the second axis to index the + transition triplet (t-1, t, t+1). + steps: RLDS dictionary of transition triplets as prepared by + `rlds_loader.episode_to_timestep_batch`. + + Returns: + A scalar loss value as jnp.ndarray. + """ + observation_t = jax.tree_map(lambda obs: obs[:, dataset.CURRENT, ...], + steps.observation) + action_t = steps.action[:, dataset.CURRENT, ...] + observation_tp1 = jax.tree_map(lambda obs: obs[:, dataset.NEXT, ...], + steps.observation) + reward_t = steps.reward[:, dataset.CURRENT, ...] + (predicted_observation_tp1, + predicted_reward_t) = apply_fn(observation_t, action_t) + # predicted_* variables may have an extra outer dimension due to ensembling, + # the mse loss still works due to broadcasting however. + if len(observation_tp1.shape) != len(reward_t.shape): + # The rewards from the transitions may not have the last singular dimension. + reward_t = jnp.expand_dims(reward_t, axis=-1) + return mse( + jnp.concatenate([predicted_observation_tp1, predicted_reward_t], axis=-1), + jnp.concatenate([observation_tp1, reward_t], axis=-1)) + + +def policy_prior_loss( + apply_fn: Callable[[networks.Observation, networks.Action], + networks.Action], steps: types.Transition): + """Returns the loss for the policy prior. + + Args: + apply_fn: applies a policy prior (o_t, a_t) -> a_t+1, expects the leading + axis to index the batch and the second axis to index the transition + triplet (t-1, t, t+1). + steps: RLDS dictionary of transition triplets as prepared by + `rlds_loader.episode_to_timestep_batch`. + + Returns: + A scalar loss value as jnp.ndarray. + """ + observation_t = jax.tree_map(lambda obs: obs[:, dataset.CURRENT, ...], + steps.observation) + action_tm1 = steps.action[:, dataset.PREVIOUS, ...] + action_t = steps.action[:, dataset.CURRENT, ...] + + predicted_action_t = apply_fn(observation_t, action_tm1) + return mse(predicted_action_t, action_t) + + +def return_loss(apply_fn: Callable[[networks.Observation, networks.Action], + jnp.ndarray], steps: types.Transition): + """Returns the loss for the n-step return model. + + Args: + apply_fn: applies an n-step return model (o_t, a_t) -> r, expects the + leading axis to index the batch and the second axis to index the + transition triplet (t-1, t, t+1). + steps: RLDS dictionary of transition triplets as prepared by + `rlds_loader.episode_to_timestep_batch`. + + Returns: + A scalar loss value as jnp.ndarray. + """ + observation_t = jax.tree_map(lambda obs: obs[:, dataset.CURRENT, ...], + steps.observation) + action_t = steps.action[:, dataset.CURRENT, ...] + n_step_return_t = steps.extras[dataset.N_STEP_RETURN][:, dataset.CURRENT, ...] + + predicted_n_step_return_t = apply_fn(observation_t, action_t) + return mse(predicted_n_step_return_t, n_step_return_t) + + +@dataclasses.dataclass +class MBOPLosses: + """Losses for the world model, policy prior and the n-step return.""" + world_model_loss: Optional[TransitionLoss] = world_model_loss + policy_prior_loss: Optional[TransitionLoss] = policy_prior_loss + n_step_return_loss: Optional[TransitionLoss] = return_loss diff --git a/acme/acme/agents/jax/mbop/models.py b/acme/acme/agents/jax/mbop/models.py new file mode 100644 index 00000000..87489574 --- /dev/null +++ b/acme/acme/agents/jax/mbop/models.py @@ -0,0 +1,141 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MBOP models.""" + +import functools +from typing import Callable, Generic, Optional, Tuple + +from acme import specs +from acme.agents.jax import actor_core +from acme.agents.jax.mbop import ensemble +from acme.agents.jax.mbop import networks as mbop_networks +from acme.jax import networks +from acme.jax import utils +import chex +import jax + +# World, policy prior and n-step return models. These are backed by the +# corresponding networks. +WorldModel = Callable[[networks.Params, networks.Observation, networks.Action], + Tuple[networks.Observation, networks.Value]] +MakeWorldModel = Callable[[mbop_networks.WorldModelNetwork], WorldModel] + +PolicyPrior = actor_core.ActorCore +MakePolicyPrior = Callable[ + [mbop_networks.PolicyPriorNetwork, specs.EnvironmentSpec], PolicyPrior] + +NStepReturn = Callable[[networks.Params, networks.Observation, networks.Action], + networks.Value] +MakeNStepReturn = Callable[[mbop_networks.NStepReturnNetwork], NStepReturn] + + +@chex.dataclass(frozen=True, mappable_dataclass=False) +class PolicyPriorState(Generic[actor_core.RecurrentState]): + """State of a policy prior. + + Attributes: + rng: Random key. + action_tm1: Previous action. + recurrent_state: Recurrent state. It will be none for non-recurrent, e.g. + feed forward, policies. + """ + rng: networks.PRNGKey + action_tm1: networks.Action + recurrent_state: Optional[actor_core.RecurrentState] = None + + +FeedForwardPolicyState = PolicyPriorState[actor_core.NoneType] + + +def feed_forward_policy_prior_to_actor_core( + policy: actor_core.RecurrentPolicy, initial_action_tm1: networks.Action +) -> actor_core.ActorCore[PolicyPriorState, actor_core.NoneType]: + """A convenience adaptor from a feed forward policy prior to ActorCore. + + Args: + policy: A feed forward policy prior. In the planner and other components, + the previous action is explicitly passed as an argument to the policy + prior together with the observation to infer the next action. Therefore, + we model feed forward policy priors as recurrent ActorCore policies with + previous action being the recurrent state. + initial_action_tm1: Initial previous action. This will usually be a zero + tensor. + + Returns: + an ActorCore representing the feed forward policy prior. + """ + + def select_action(params: networks.Params, observation: networks.Observation, + state: FeedForwardPolicyState): + rng, policy_rng = jax.random.split(state.rng) + action = policy(params, policy_rng, observation, state.action_tm1) + return action, PolicyPriorState(rng, action) + + def init(rng: networks.PRNGKey) -> FeedForwardPolicyState: + return PolicyPriorState(rng, initial_action_tm1) + + def get_extras(unused_state: FeedForwardPolicyState) -> actor_core.NoneType: + return None + + return actor_core.ActorCore( + init=init, select_action=select_action, get_extras=get_extras) + + +def make_ensemble_world_model( + world_model_network: mbop_networks.WorldModelNetwork) -> WorldModel: + """Creates an ensemble world model from its network.""" + return functools.partial(ensemble.apply_round_robin, + world_model_network.apply) + + +def make_ensemble_policy_prior( + policy_prior_network: mbop_networks.PolicyPriorNetwork, + spec: specs.EnvironmentSpec, + use_round_robin: bool = True) -> PolicyPrior: + """Creates an ensemble policy prior from its network. + + Args: + policy_prior_network: The policy prior network. + spec: Environment specification. + use_round_robin: Whether to use round robin or mean to calculate the policy + prior over the ensemble members. + + Returns: + A policy prior. + """ + + def _policy_prior(params: networks.Params, key: networks.PRNGKey, + observation_t: networks.Observation, + action_tm1: networks.Action) -> networks.Action: + # Regressor policies are deterministic. + del key + apply_fn = ( + ensemble.apply_round_robin if use_round_robin else ensemble.apply_mean) + return apply_fn( + policy_prior_network.apply, + params, + observation_t=observation_t, + action_tm1=action_tm1) + + dummy_action = utils.zeros_like(spec.actions) + dummy_action = utils.add_batch_dim(dummy_action) + + return feed_forward_policy_prior_to_actor_core(_policy_prior, dummy_action) + + +def make_ensemble_n_step_return( + n_step_return_network: mbop_networks.NStepReturnNetwork) -> NStepReturn: + """Creates an ensemble n-step return model from its network.""" + return functools.partial(ensemble.apply_mean, n_step_return_network.apply) diff --git a/acme/acme/agents/jax/mbop/mppi.py b/acme/acme/agents/jax/mbop/mppi.py new file mode 100644 index 00000000..1731a994 --- /dev/null +++ b/acme/acme/agents/jax/mbop/mppi.py @@ -0,0 +1,251 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Provides the extended MPPI planner used in MBOP [https://arxiv.org/abs/2008.05556]. + +In this context, MPPI refers to Model-Predictive Path Integral control, +originally introduced in "Model Predictive Path Integral Control: From Theory to +Parallel Computation" Grady Williams, Andrew Aldrich and Evangelos A. Theodorou. + +This is a modified implementation of MPPI that adds a policy prior and n-step +return extension as described in Algorithm 2 of "Model-Based Offline Planning" +[https://arxiv.org/abs/2008.05556]. Notation is taken from the paper. This +planner can be 'degraded' to provide both 'basic' MPPI or PDDM-style +[https://arxiv.org/abs/1909.11652] planning by removing the n-step return, +providing a Gaussian policy prior, or single-head ensembles. +""" +import dataclasses +import functools +from typing import Callable, Optional + +from acme import specs +from acme.agents.jax.mbop import models +from acme.jax import networks +import jax +from jax import random +import jax.numpy as jnp + +# Function that takes (n_trajectories, horizon, action_dim) tensor of action +# trajectories and (n_trajectories) vector of corresponding cumulative rewards, +# i.e. returns, for each trajectory as input and returns a single action +# trajectory. +ActionAggregationFn = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] + + +def return_weighted_average(action_trajectories: jnp.ndarray, + cum_reward: jnp.ndarray, + kappa: float) -> jnp.ndarray: + r"""Calculates return-weighted average over all trajectories. + + This will calculate the return-weighted average over a set of trajectories as + defined on l.17 of Alg. 2 in the MBOP paper: + [https://arxiv.org/abs/2008.05556]. + + Note: Clipping will be performed for `cum_reward` values > 80 to avoid NaNs. + + Args: + action_trajectories: (n_trajectories, horizon, action_dim) tensor of action + trajectories, corresponds to `A` in Alg. 2. + cum_reward: (n_trajectories) vector of corresponding cumulative rewards + (returns) for each trajectory. Corresponds to `\mathcal{R}` in Alg. 2. + kappa: `\kappa` constant, changes the 'peakiness' of the exponential + averaging. + + Returns: + Single action trajectory corresponding to the return-weighted average of the + trajectories. + """ + # Substract maximum reward to avoid NaNs: + cum_reward = cum_reward - cum_reward.max() + # Remove the batch dimension of cum_reward allows for an implicit broadcast in + # jnp.average: + exp_cum_reward = jnp.exp(kappa * jnp.squeeze(cum_reward)) + return jnp.average(action_trajectories, weights=exp_cum_reward, axis=0) + + +def return_top_k_average(action_trajectories: jnp.ndarray, + cum_reward: jnp.ndarray, + k: int = 10) -> jnp.ndarray: + r"""Calculates the top-k average over all trajectories. + + This will calculate the top-k average over a set of trajectories as + defined in the POIR Paper: + + Note: top-k average is more numerically stable than the weighted average. + + Args: + action_trajectories: (n_trajectories, horizon, action_dim) tensor of action + trajectories. + cum_reward: (n_trajectories) vector of corresponding cumulative rewards + (returns) for each trajectory. + k: the number of trajectories to average. + + Returns: + Single action trajectory corresponding to the average of the k best + trajectories. + """ + top_k_trajectories = action_trajectories[jnp.argsort( + jnp.squeeze(cum_reward))[-int(k):]] + return jnp.mean(top_k_trajectories, axis=0) + + +@dataclasses.dataclass +class MPPIConfig: + """Config dataclass for MPPI-style planning, used in mppi.py. + + These variables correspond to different parameters of `MBOP-Trajopt` as + defined in MBOP [https://arxiv.org/abs/2008.05556] (Alg. 2). + + Attributes: + sigma: Variance of action-additive noise. + beta: Mixture parameter between old trajectory and new action. + horizon: Planning horizon, corresponds to H in Alg. 2 line 8. + n_trajectories: Number of trajectories used in `mppi_planner` to sample the + best action. Corresponds to `N` in Alg. 2 line. 5. + previous_trajectory_clip: Value to clip the previous_trajectory's actions + to. Disabled if None. + action_aggregation_fn: Function that aggregates action trajectories and + returns a single action trajectory. + """ + sigma: float = 0.8 + beta: float = 0.2 + horizon: int = 15 + n_trajectories: int = 1000 + previous_trajectory_clip: Optional[float] = None + action_aggregation_fn: ActionAggregationFn = ( + functools.partial(return_weighted_average, kappa=0.5)) + + +def get_initial_trajectory(config: MPPIConfig, env_spec: specs.EnvironmentSpec): + """Returns the initial empty trajectory `T_0`.""" + return jnp.zeros((max(1, config.horizon),) + env_spec.actions.shape) + + +def _repeat_n(new_batch: int, data: jnp.ndarray) -> jnp.ndarray: + """Create new batch dimension of size `new_batch` by repeating `data`.""" + return jnp.broadcast_to(data, (new_batch,) + data.shape) + + +def mppi_planner( + config: MPPIConfig, + world_model: models.WorldModel, + policy_prior: models.PolicyPrior, + n_step_return: models.NStepReturn, + world_model_params: networks.Params, + policy_prior_params: networks.Params, + n_step_return_params: networks.Params, + random_key: networks.PRNGKey, + observation: networks.Observation, + previous_trajectory: jnp.ndarray, +) -> jnp.ndarray: + """MPPI-extended trajectory optimizer. + + This implements the trajectory optimizer described in MBOP + [https://arxiv.org/abs/2008.05556] (Alg. 2) which is an extended version of + MPPI that adds support for arbitrary sampling distributions and extends the + return horizon using an approximate model of returns. There are a couple + notation changes for readability: + A -> action_trajectories + T -> action_trajectory + + If the horizon is set to 0, the planner will simply call the policy_prior + and average the action over the ensemble heads. + + Args: + config: Base configuration parameters of MPPI. + world_model: Corresponds to `f_m(s_t, a_t)_s` in Alg. 2. + policy_prior: Corresponds to `f_b(s_t, a_tm1)` in Alg. 2. + n_step_return: Corresponds to `f_R(s_t, a_t)` in Alg. 2. + world_model_params: Parameters for world model. + policy_prior_params: Parameters for policy prior. + n_step_return_params: Parameters for n_step return. + random_key: JAX random key seed. + observation: Normalized current observation from the environment, `s` in + Alg. 2. + previous_trajectory: Normalized previous action trajectory. `T` in Alg 2. + Shape is [n_trajectories, horizon, action_dims]. + + Returns: + jnp.ndarray: Average action trajectory of shape [horizon, action_dims]. + """ + action_trajectory_tm1 = previous_trajectory + policy_prior_state = policy_prior.init(random_key) + + # Broadcast so that we have n_trajectories copies of each: + observation_t = jax.tree_map( + functools.partial(_repeat_n, config.n_trajectories), observation) + action_tm1 = jnp.broadcast_to(action_trajectory_tm1[0], + (config.n_trajectories,) + + action_trajectory_tm1[0].shape) + + if config.previous_trajectory_clip is not None: + action_tm1 = jnp.clip( + action_tm1, + a_min=-config.previous_trajectory_clip, + a_max=config.previous_trajectory_clip) + + # First check if planning is unnecessary: + if config.horizon == 0: + if hasattr(policy_prior_state, 'action_tm1'): + policy_prior_state = policy_prior_state.replace(action_tm1=action_tm1) + action_set, _ = policy_prior.select_action(policy_prior_params, + observation_t, + policy_prior_state) + # Need to re-create an action trajectory from a single action. + return jnp.broadcast_to( + jnp.mean(action_set, axis=0), (1, action_set.shape[-1])) + + # Accumulators for returns and trajectories: + cum_reward = jnp.zeros((config.n_trajectories, 1)) + + # Generate noise once: + random_key, noise_key = random.split(random_key) + action_noise = config.sigma * random.normal(noise_key, ( + (config.horizon,) + action_tm1.shape)) + + # Initialize empty set of action trajectories for concatenation in loop: + action_trajectories = jnp.zeros((config.n_trajectories, 0) + + action_trajectory_tm1[0].shape) + + for t in range(config.horizon): + # Query policy prior for proposed action: + if hasattr(policy_prior_state, 'action_tm1'): + policy_prior_state = policy_prior_state.replace(action_tm1=action_tm1) + action_t, policy_prior_state = policy_prior.select_action( + policy_prior_params, observation_t, policy_prior_state) + # Add action noise: + action_t = action_t + action_noise[t] + # Mix action with previous trajectory's corresponding action: + action_t = (1 - + config.beta) * action_t + config.beta * action_trajectory_tm1[t] + + # Query world model to get next observation and reward: + observation_tp1, reward_t = world_model(world_model_params, observation_t, + action_t) + cum_reward += reward_t + + # Insert actions into trajectory matrix: + action_trajectories = jnp.concatenate( + [action_trajectories, + jnp.expand_dims(action_t, axis=1)], axis=1) + # Bump variable timesteps for next loop: + observation_t = observation_tp1 + action_tm1 = action_t + + # De-normalize and append the final n_step return prediction: + n_step_return_t = n_step_return(n_step_return_params, observation_t, action_t) + cum_reward += n_step_return_t + + # Average the set of `n_trajectories` trajectories into a single trajectory. + return config.action_aggregation_fn(action_trajectories, cum_reward) diff --git a/acme/acme/agents/jax/mbop/mppi_test.py b/acme/acme/agents/jax/mbop/mppi_test.py new file mode 100644 index 00000000..e0d80a8c --- /dev/null +++ b/acme/acme/agents/jax/mbop/mppi_test.py @@ -0,0 +1,155 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for mppi.""" +import functools +from typing import Any + +from acme import specs +from acme.agents.jax.mbop import ensemble +from acme.agents.jax.mbop import models +from acme.agents.jax.mbop import mppi +from acme.jax import networks as networks_lib +import jax +import jax.numpy as jnp +import numpy as np + +from absl.testing import absltest +from absl.testing import parameterized + + +def get_fake_world_model() -> networks_lib.FeedForwardNetwork: + + def apply(params: Any, observation_t: jnp.ndarray, action_t: jnp.ndarray): + del params + return observation_t, jnp.ones(( + action_t.shape[0], + 1, + )) + + return networks_lib.FeedForwardNetwork(init=lambda: None, apply=apply) + + +def get_fake_policy_prior() -> networks_lib.FeedForwardNetwork: + return networks_lib.FeedForwardNetwork( + init=lambda: None, + apply=lambda params, observation_t, action_tm1: action_tm1) + + +def get_fake_n_step_return() -> networks_lib.FeedForwardNetwork: + + def apply(params, observation_t, action_t): + del params, action_t + return jnp.ones((observation_t.shape[0], 1)) + + return networks_lib.FeedForwardNetwork(init=lambda: None, apply=apply) + + +class WeightedAverageTests(parameterized.TestCase): + + @parameterized.parameters((np.array([1, 1, 1]), 1), (np.array([0, 1, 0]), 10), + (np.array([-1, 1, -1]), 4), + (np.array([-10, 30, 0]), -0.5)) + def test_weighted_averages(self, cum_reward, kappa): + """Compares method with a local version of the exp-weighted averaging.""" + action_trajectories = jnp.reshape( + jnp.arange(3 * 10 * 4), (3, 10, 4), order='F') + averaged_trajectory = mppi.return_weighted_average( + action_trajectories=action_trajectories, + cum_reward=cum_reward, + kappa=kappa) + exp_weights = jnp.exp(kappa * cum_reward) + # Verify single-value averaging lines up with the global averaging call: + for i in range(10): + for j in range(4): + np.testing.assert_allclose( + averaged_trajectory[i, j], + jnp.sum(exp_weights * action_trajectories[:, i, j]) / + jnp.sum(exp_weights), + atol=1E-5, + rtol=1E-5) + + +class MPPITest(parameterized.TestCase): + """This tests the MPPI planner to make sure it is correctly rolling out. + + It does not check the actual performance of the planner, as this would be a + bit more complicated to set up. + """ + + # TODO(dulacarnold): Look at how we can check this is actually finding an + # optimal path through the model. + + def setUp(self): + super().setUp() + self.state_dims = 8 + self.action_dims = 4 + self.params = { + 'world': jnp.ones((3,)), + 'policy': jnp.ones((3,)), + 'value': jnp.ones((3,)) + } + self.env_spec = specs.EnvironmentSpec( + observations=specs.Array(shape=(self.state_dims,), dtype=float), + actions=specs.Array(shape=(self.action_dims,), dtype=float), + rewards=specs.Array(shape=(1,), dtype=float, name='reward'), + discounts=specs.BoundedArray( + shape=(), dtype=float, minimum=0., maximum=1., name='discount')) + + @parameterized.named_parameters(('NO-PLAN', 0), ('NORMAL', 10)) + def test_planner_init(self, horizon: int): + world_model = get_fake_world_model() + rr_world_model = functools.partial(ensemble.apply_round_robin, + world_model.apply) + policy_prior = get_fake_policy_prior() + + def _rr_policy_prior(params, key, observation_t, action_tm1): + del key + return ensemble.apply_round_robin( + policy_prior.apply, + params, + observation_t=observation_t, + action_tm1=action_tm1) + + rr_policy_prior = models.feed_forward_policy_prior_to_actor_core( + _rr_policy_prior, jnp.zeros((1, self.action_dims))) + + n_step_return = get_fake_n_step_return() + n_step_return = functools.partial(ensemble.apply_mean, n_step_return.apply) + + config = mppi.MPPIConfig( + sigma=1, + beta=0.2, + horizon=horizon, + n_trajectories=9, + action_aggregation_fn=functools.partial( + mppi.return_weighted_average, kappa=1)) + previous_trajectory = mppi.get_initial_trajectory(config, self.env_spec) + key = jax.random.PRNGKey(0) + for _ in range(5): + previous_trajectory = mppi.mppi_planner( + config, + world_model=rr_world_model, + policy_prior=rr_policy_prior, + n_step_return=n_step_return, + world_model_params=self.params, + policy_prior_params=self.params, + n_step_return_params=self.params, + random_key=key, + observation=jnp.ones(self.state_dims), + previous_trajectory=previous_trajectory) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/jax/mbop/networks.py b/acme/acme/agents/jax/mbop/networks.py new file mode 100644 index 00000000..76967d62 --- /dev/null +++ b/acme/acme/agents/jax/mbop/networks.py @@ -0,0 +1,138 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MBOP networks definitions.""" + +import dataclasses +from typing import Any, Tuple + +from acme import specs +from acme.jax import networks +from acme.jax import utils +import haiku as hk +import jax.numpy as jnp +import numpy as np + +# The term network is used in a general sense, e.g. for the CRR policy prior, it +# will be a dataclass that encapsulates the networks used by the CRR (learner). +WorldModelNetwork = Any +PolicyPriorNetwork = Any +NStepReturnNetwork = Any + + +@dataclasses.dataclass +class MBOPNetworks: + """Container class to hold MBOP networks.""" + world_model_network: WorldModelNetwork + policy_prior_network: PolicyPriorNetwork + n_step_return_network: NStepReturnNetwork + + +def make_network_from_module( + module: hk.Transformed, + spec: specs.EnvironmentSpec) -> networks.FeedForwardNetwork: + """Creates a network with dummy init arguments using the specified module. + + Args: + module: Module that expects one batch axis and one features axis for its + inputs. + spec: EnvironmentSpec shapes to derive dummy inputs. + + Returns: + FeedForwardNetwork whose `init` method only takes a random key, and `apply` + takes an observation and action and produces an output. + """ + dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations)) + dummy_action = utils.add_batch_dim(utils.zeros_like(spec.actions)) + return networks.FeedForwardNetwork( + lambda key: module.init(key, dummy_obs, dummy_action), module.apply) + + +def make_world_model_network( + spec: specs.EnvironmentSpec, hidden_layer_sizes: Tuple[int, ...] = (64, 64) +) -> networks.FeedForwardNetwork: + """Creates a world model network used by the agent.""" + + observation_size = np.prod(spec.observations.shape, dtype=int) + + def _world_model_fn(observation_t, action_t, is_training=False, key=None): + # is_training and key allows to defined train/test dependant modules + # like dropout. + del is_training + del key + network = hk.nets.MLP(hidden_layer_sizes + (observation_size + 1,)) + # World model returns both an observation and a reward. + observation_tp1, reward_t = jnp.split( + network(jnp.concatenate([observation_t, action_t], axis=-1)), + [observation_size], + axis=-1) + return observation_tp1, reward_t + + world_model = hk.without_apply_rng(hk.transform(_world_model_fn)) + return make_network_from_module(world_model, spec) + + +def make_policy_prior_network( + spec: specs.EnvironmentSpec, hidden_layer_sizes: Tuple[int, ...] = (64, 64) +) -> networks.FeedForwardNetwork: + """Creates a policy prior network used by the agent.""" + + action_size = np.prod(spec.actions.shape, dtype=int) + + def _policy_prior_fn(observation_t, action_tm1, is_training=False, key=None): + # is_training and key allows to defined train/test dependant modules + # like dropout. + del is_training + del key + network = hk.nets.MLP(hidden_layer_sizes + (action_size,)) + # Policy prior returns an action. + return network(jnp.concatenate([observation_t, action_tm1], axis=-1)) + + policy_prior = hk.without_apply_rng(hk.transform(_policy_prior_fn)) + return make_network_from_module(policy_prior, spec) + + +def make_n_step_return_network( + spec: specs.EnvironmentSpec, hidden_layer_sizes: Tuple[int, ...] = (64, 64) +) -> networks.FeedForwardNetwork: + """Creates an N-step return network used by the agent.""" + + def _n_step_return_fn(observation_t, action_t, is_training=False, key=None): + # is_training and key allows to defined train/test dependant modules + # like dropout. + del is_training + del key + network = hk.nets.MLP(hidden_layer_sizes + (1,)) + return network(jnp.concatenate([observation_t, action_t], axis=-1)) + + n_step_return = hk.without_apply_rng(hk.transform(_n_step_return_fn)) + return make_network_from_module(n_step_return, spec) + + +def make_networks( + spec: specs.EnvironmentSpec, + hidden_layer_sizes: Tuple[int, ...] = (64, 64), +) -> MBOPNetworks: + """Creates networks used by the agent.""" + world_model_network = make_world_model_network( + spec, hidden_layer_sizes=hidden_layer_sizes) + policy_prior_network = make_policy_prior_network( + spec, hidden_layer_sizes=hidden_layer_sizes) + n_step_return_network = make_n_step_return_network( + spec, hidden_layer_sizes=hidden_layer_sizes) + + return MBOPNetworks( + world_model_network=world_model_network, + policy_prior_network=policy_prior_network, + n_step_return_network=n_step_return_network) diff --git a/acme/acme/agents/jax/multiagent/__init__.py b/acme/acme/agents/jax/multiagent/__init__.py new file mode 100644 index 00000000..e5ed9e6d --- /dev/null +++ b/acme/acme/agents/jax/multiagent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multiagent implementations.""" diff --git a/acme/acme/agents/jax/multiagent/decentralized/README.md b/acme/acme/agents/jax/multiagent/decentralized/README.md new file mode 100644 index 00000000..c560a319 --- /dev/null +++ b/acme/acme/agents/jax/multiagent/decentralized/README.md @@ -0,0 +1,15 @@ +# Decentralized Multiagent Learning + +This folder contains an implementation of decentralized multiagent learning. +The current implementation supports homogeneous sub-agents (i.e., all agents +running identical sub-algorithms). + +The underlying multiagent environment should produce observations and rewards +that are each a dict, with keys corresponding to string IDs for the agents that +map to their respective local observation and rewards. Rewards can be +heterogeneous (e.g., for non-cooperative environments). + +The environment step() should consume dict-style actions, with key:value pairs +corresponding to agent:action, as above. + +Discounts are assumed shared between agents (i.e., should be a single scalar). diff --git a/acme/acme/agents/jax/multiagent/decentralized/__init__.py b/acme/acme/agents/jax/multiagent/decentralized/__init__.py new file mode 100644 index 00000000..b4e10eb4 --- /dev/null +++ b/acme/acme/agents/jax/multiagent/decentralized/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Decentralized multiagent configuration.""" + +from acme.agents.jax.multiagent.decentralized.agents import DecentralizedMultiAgent +from acme.agents.jax.multiagent.decentralized.agents import DistributedDecentralizedMultiAgent +from acme.agents.jax.multiagent.decentralized.builder import DecentralizedMultiAgentBuilder +from acme.agents.jax.multiagent.decentralized.config import DecentralizedMultiagentConfig +from acme.agents.jax.multiagent.decentralized.factories import builder_factory +from acme.agents.jax.multiagent.decentralized.factories import default_config_factory +from acme.agents.jax.multiagent.decentralized.factories import default_logger_factory +from acme.agents.jax.multiagent.decentralized.factories import DefaultSupportedAgent +from acme.agents.jax.multiagent.decentralized.factories import network_factory +from acme.agents.jax.multiagent.decentralized.factories import policy_network_factory diff --git a/acme/acme/agents/jax/multiagent/decentralized/actor.py b/acme/acme/agents/jax/multiagent/decentralized/actor.py new file mode 100644 index 00000000..6f5df02c --- /dev/null +++ b/acme/acme/agents/jax/multiagent/decentralized/actor.py @@ -0,0 +1,58 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Decentralized multiagent actor.""" + +from typing import Dict + +from acme import core +from acme.jax import networks +from acme.multiagent import types as ma_types +from acme.multiagent import utils as ma_utils +import dm_env + + +class SimultaneousActingMultiAgentActor(core.Actor): + """Simultaneous-move actor (see README.md for expected environment interface).""" + + def __init__(self, actors: Dict[ma_types.AgentID, core.Actor]): + """Initializer. + + Args: + actors: a dict specifying sub-actors. + """ + self._actors = actors + + def select_action( + self, observation: Dict[ma_types.AgentID, networks.Observation] + ) -> Dict[ma_types.AgentID, networks.Action]: + return { + actor_id: actor.select_action(observation[actor_id]) + for actor_id, actor in self._actors.items() + } + + def observe_first(self, timestep: dm_env.TimeStep): + for actor_id, actor in self._actors.items(): + sub_timestep = ma_utils.get_agent_timestep(timestep, actor_id) + actor.observe_first(sub_timestep) + + def observe(self, actions: Dict[ma_types.AgentID, networks.Action], + next_timestep: dm_env.TimeStep): + for actor_id, actor in self._actors.items(): + sub_next_timestep = ma_utils.get_agent_timestep(next_timestep, actor_id) + actor.observe(actions[actor_id], sub_next_timestep) + + def update(self, wait: bool = False): + for actor in self._actors.values(): + actor.update(wait=wait) diff --git a/acme/acme/agents/jax/multiagent/decentralized/agents.py b/acme/acme/agents/jax/multiagent/decentralized/agents.py new file mode 100644 index 00000000..c3022a6f --- /dev/null +++ b/acme/acme/agents/jax/multiagent/decentralized/agents.py @@ -0,0 +1,199 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines distributed and local multiagent decentralized agents.""" + +import functools +from typing import Any, Dict, Optional, Sequence, Tuple + +from acme import specs +from acme.agents.jax.multiagent.decentralized import builder as decentralized_builders +from acme.agents.jax.multiagent.decentralized import config as decentralized_config +from acme.agents.jax.multiagent.decentralized import factories as decentralized_factories +from acme.jax import types +from acme.jax import utils +from acme.jax.layouts import distributed_layout +from acme.jax.layouts import local_layout +from acme.multiagent import types as ma_types +from acme.utils import counting + + +class DistributedDecentralizedMultiAgent(distributed_layout.DistributedLayout): + """Distributed program definition for decentralized multiagent learning.""" + + def __init__( + self, + agent_types: Dict[ma_types.AgentID, ma_types.GenericAgent], + environment_factory: types.EnvironmentFactory, + network_factory: ma_types.NetworkFactory, + policy_factory: ma_types.PolicyFactory, + builder_factory: ma_types.BuilderFactory, + config: decentralized_config.DecentralizedMultiagentConfig, + seed: int, + num_parallel_actors_per_agent: int, + environment_spec: Optional[specs.EnvironmentSpec] = None, + max_number_of_steps: Optional[int] = None, + log_to_bigtable: bool = False, + log_every: float = 10.0, + evaluator_factories: Optional[Sequence[ + distributed_layout.EvaluatorFactory]] = None, + ): + assert len(set(agent_types.values())) == 1, ( + f'Sub-agent types must be identical, but are {agent_types}.') + + learner_logger_fns = decentralized_factories.default_logger_factory( + agent_types=agent_types, + base_label='learner', + save_data=log_to_bigtable, + time_delta=log_every, + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + steps_key='learner_steps') + builders = builder_factory(agent_types, config.sub_agent_configs) + + train_policy_factory = functools.partial(policy_factory, eval_mode=False) + if evaluator_factories is None: + eval_network_fn = functools.partial(policy_factory, eval_mode=True) + evaluator_factories = [ + distributed_layout.default_evaluator_factory( + environment_factory=environment_factory, + network_factory=network_factory, + policy_factory=eval_network_fn, + save_logs=log_to_bigtable) + ] + self.builder = decentralized_builders.DecentralizedMultiAgentBuilder( + builders) + # pytype: disable=wrong-arg-types + super().__init__( + seed=seed, + environment_factory=environment_factory, + network_factory=network_factory, + builder=self.builder, + policy_network=train_policy_factory, + evaluator_factories=evaluator_factories, + num_actors=num_parallel_actors_per_agent, + environment_spec=environment_spec, + max_number_of_steps=max_number_of_steps, + prefetch_size=config.prefetch_size, + save_logs=log_to_bigtable, + actor_logger_fn=distributed_layout.get_default_logger_fn( + log_to_bigtable, log_every), + learner_logger_fn=learner_logger_fns + ) + # pytype: enable=wrong-arg-types + + +class DecentralizedMultiAgent(local_layout.LocalLayout): + """Local definition for decentralized multiagent learning.""" + + def __init__( + self, + agent_types: Dict[ma_types.AgentID, ma_types.GenericAgent], + spec: specs.EnvironmentSpec, + builder_factory: ma_types.BuilderFactory, + networks: ma_types.MultiAgentNetworks, + policy_networks: ma_types.MultiAgentPolicyNetworks, + config: decentralized_config.DecentralizedMultiagentConfig, + seed: int, + workdir: Optional[str] = '~/acme', + counter: Optional[counting.Counter] = None, + save_data: bool = True + ): + assert len(set(agent_types.values())) == 1, ( + f'Sub-agent types must be identical, but are {agent_types}.') + # TODO(somidshafiei): add input normalizer. However, this may require + # adding some helper utilities for each single-agent algorithms, as + # batch_dims for NormalizationBuilder are algorithm-dependent (e.g., see + # PPO vs. SAC JAX agents) + + learner_logger_fns = decentralized_factories.default_logger_factory( + agent_types=agent_types, + base_label='learner', + save_data=save_data, + steps_key='learner_steps') + learner_loggers = {agent_id: learner_logger_fns[agent_id]() + for agent_id in agent_types.keys()} + builders = builder_factory(agent_types, config.sub_agent_configs) + self.builder = decentralized_builders.DecentralizedMultiAgentBuilder( + builders) + # pytype: disable=wrong-arg-types + super().__init__( + seed=seed, + environment_spec=spec, + builder=self.builder, + networks=networks, + policy_network=policy_networks, + prefetch_size=config.prefetch_size, + learner_logger=learner_loggers, + batch_size=config.batch_size, + workdir=workdir, + counter=counter, + ) + # pytype: enable=wrong-arg-types + + +def init_decentralized_multiagent( + agent_types: Dict[ma_types.AgentID, ma_types.GenericAgent], + environment_spec: specs.EnvironmentSpec, + seed: int, + batch_size: int, + workdir: Optional[str] = '~/acme', + init_network_fn: Optional[ma_types.InitNetworkFn] = None, + init_policy_network_fn: Optional[ma_types.InitPolicyNetworkFn] = None, + save_data: bool = True, + config_overrides: Optional[Dict[ma_types.AgentID, Dict[str, Any]]] = None + ) -> Tuple[DecentralizedMultiAgent, ma_types.MultiAgentPolicyNetworks]: + """Returns decentralized multiagent LocalLayout instance. + + Intended to be used as a helper function to more readily instantiate and + experiment with multiagent setups. For full functionality, use + DecentralizedMultiAgent or DistributedDecentralizedMultiAgent directly. + + Args: + agent_types: a dict specifying the agent identifier and their types + (e.g., {'0': factories.DefaultSupportedAgents.PPO, '1': ...}). + environment_spec: environment spec. + seed: seed. + batch_size: the batch size (used for each sub-agent). + workdir: working directory (e.g., used for checkpointing). + init_network_fn: optional custom network initializer function. + init_policy_network_fn: optional custom policy network initializer function. + save_data: whether to save data throughout training. + config_overrides: a dict specifying agent-specific configuration overrides. + """ + configs = decentralized_factories.default_config_factory( + agent_types, batch_size, config_overrides) + networks = decentralized_factories.network_factory(environment_spec, + agent_types, + init_network_fn) + policy_networks = decentralized_factories.policy_network_factory( + networks, environment_spec, agent_types, configs, eval_mode=False, + init_policy_network_fn=init_policy_network_fn) + eval_policy_networks = decentralized_factories.policy_network_factory( + networks, environment_spec, agent_types, configs, eval_mode=True, + init_policy_network_fn=init_policy_network_fn) + config = decentralized_config.DecentralizedMultiagentConfig( + batch_size=batch_size, sub_agent_configs=configs) + decentralized_multi_agent = DecentralizedMultiAgent( + agent_types=agent_types, + spec=environment_spec, + builder_factory=decentralized_factories.builder_factory, + networks=networks, + policy_networks=policy_networks, + seed=seed, + config=config, + workdir=workdir, + save_data=save_data + ) + return decentralized_multi_agent, eval_policy_networks diff --git a/acme/acme/agents/jax/multiagent/decentralized/agents_test.py b/acme/acme/agents/jax/multiagent/decentralized/agents_test.py new file mode 100644 index 00000000..84d8b6a2 --- /dev/null +++ b/acme/acme/agents/jax/multiagent/decentralized/agents_test.py @@ -0,0 +1,60 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for agents.""" + +import acme +from acme.agents.jax.multiagent.decentralized import agents +from acme.agents.jax.multiagent.decentralized import factories +from acme.testing import fakes +from acme.testing import multiagent_fakes +from absl.testing import absltest + + +class AgentsTest(absltest.TestCase): + + def test_init_decentralized_multiagent(self): + batch_size = 5 + agent_indices = ['a', '99', 'Z'] + environment_spec = multiagent_fakes.make_multiagent_environment_spec( + agent_indices) + env = fakes.Environment(environment_spec, episode_length=4) + agent_types = { + agent_id: factories.DefaultSupportedAgent.TD3 + for agent_id in agent_indices + } + agt_configs = {'sigma': 0.3, 'target_sigma': 0.3} + config_overrides = { + k: agt_configs for k, v in agent_types.items() + if v == factories.DefaultSupportedAgent.TD3 + } + + agent, _ = agents.init_decentralized_multiagent( + agent_types=agent_types, + environment_spec=environment_spec, + seed=1, + batch_size=batch_size, + workdir=None, + init_network_fn=None, + save_data=False, + config_overrides=config_overrides) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(env, agent) + loop.run(num_episodes=10) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/jax/multiagent/decentralized/builder.py b/acme/acme/agents/jax/multiagent/decentralized/builder.py new file mode 100644 index 00000000..8db6f661 --- /dev/null +++ b/acme/acme/agents/jax/multiagent/decentralized/builder.py @@ -0,0 +1,196 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""JAX multiagent builders.""" + +from typing import Dict, Iterator, List, Optional, Sequence + +from acme import adders +from acme import core +from acme import specs +from acme import types +from acme.agents.jax import builders as acme_builders +from acme.agents.jax.multiagent.decentralized import actor +from acme.agents.jax.multiagent.decentralized import learner_set +from acme.jax import networks as networks_lib +from acme.multiagent import types as ma_types +from acme.multiagent import utils as ma_utils +from acme.utils import counting +from acme.utils import iterator_utils +from acme.utils import loggers as acme_loggers +import jax +import reverb + + +VARIABLE_SEPARATOR = '-' + + +class PrefixedVariableSource(core.VariableSource): + """Wraps a variable source to add a pre-defined prefix to all names.""" + + def __init__(self, source: core.VariableSource, prefix: str): + self._source = source + self._prefix = prefix + + def get_variables(self, names: Sequence[str]) -> List[types.NestedArray]: + return self._source.get_variables([self._prefix + name for name in names]) + + +class DecentralizedMultiAgentBuilder( + acme_builders.GenericActorLearnerBuilder[ + ma_types.MultiAgentNetworks, + ma_types.MultiAgentPolicyNetworks, + ma_types.MultiAgentSample]): + """Builder for decentralized multiagent setup.""" + + def __init__(self, builders: Dict[ma_types.AgentID, + acme_builders.GenericActorLearnerBuilder]): + """Initializer. + + Args: + builders: a dict specifying the builders for all sub-agents. + """ + + self._builders = builders + self._num_agents = len(self._builders) + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: ma_types.MultiAgentPolicyNetworks, + ) -> List[reverb.Table]: + """Returns replay tables for all agents. + + Args: + environment_spec: the (multiagent) environment spec, which will be + factorized into single-agent specs for replay table initialization. + policy: the (multiagent) mapping from agent ID to the corresponding + agent's policy, used to get the correct extras_spec. + """ + replay_tables = [] + for agent_id, builder in self._builders.items(): + single_agent_spec = ma_utils.get_agent_spec(environment_spec, agent_id) + replay_tables += builder.make_replay_tables(single_agent_spec, + policy[agent_id]) + return replay_tables + + def make_dataset_iterator( + self, + replay_client: reverb.Client) -> Iterator[ma_types.MultiAgentSample]: + # Zipping stores sub-iterators in the order dictated by + # self._builders.values(), which are insertion-ordered in Python3.7+. + # Hence, later unzipping (in make_learner()) and accessing the iterators + # via the same self._builders.items() dict ordering should be safe. + return zip(*[ + b.make_dataset_iterator(replay_client) for b in self._builders.values() + ]) + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: ma_types.MultiAgentNetworks, + dataset: Iterator[ma_types.MultiAgentSample], + logger_fn: acme_loggers.LoggerFactory, + environment_spec: Optional[specs.EnvironmentSpec] = None, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None + ) -> learner_set.SynchronousDecentralizedLearnerSet: + """Returns multiagent learner set. + + Args: + random_key: random key. + networks: dict of networks, one per learner. Networks can be heterogeneous + (i.e., distinct in architecture) across learners. + dataset: list of iterators over samples from replay, one per learner. + logger_fn: factory providing loggers used for logging progress. + environment_spec: the (multiagent) environment spec, which will be + factorized into single-agent specs for replay table initialization. + replay_client: replay client that is shared amongst the sub-learners. + counter: a Counter which allows for recording of counts (learner steps, + actor steps, etc.) distributed throughout the agent. + """ + def _make_logger_fn(agent_id: ma_types.AgentID): + """Returns agent logger function while avoiding cell-var-from-loop bugs.""" + return ( + lambda label, steps_key=None, task_instance=None: loggers[agent_id]) + + loggers = logger_fn(label='') # label is unused at the parent level + if loggers is None: + loggers = {k: None for k in self._builders.keys()} + sub_learners = {} + unzipped_dataset = iterator_utils.unzip_iterators( + dataset, num_sub_iterators=self._num_agents) + for i_dataset, (agent_id, builder) in enumerate(self._builders.items()): + single_agent_spec = ma_utils.get_agent_spec(environment_spec, agent_id) + random_key, learner_key = jax.random.split(random_key) + sub_learners[agent_id] = builder.make_learner( + learner_key, + networks[agent_id], + unzipped_dataset[i_dataset], + logger_fn=_make_logger_fn(agent_id), + environment_spec=single_agent_spec, + replay_client=replay_client, + counter=counter) + return learner_set.SynchronousDecentralizedLearnerSet( + sub_learners, separator=VARIABLE_SEPARATOR) + + def make_adder( # Internal pytype check. + self, replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec] = None, + policy: Optional[ma_types.MultiAgentPolicyNetworks] = None, + ) -> Dict[ma_types.AgentID, Optional[adders.Adder]]: + del environment_spec, policy # Unused. + return { + agent_id: + b.make_adder(replay_client, environment_spec=None, policy=None) + for agent_id, b in self._builders.items() + } + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy_networks: ma_types.MultiAgentPolicyNetworks, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[Dict[ma_types.AgentID, adders.Adder]] = None, + ) -> core.Actor: + """Returns simultaneous-acting multiagent actor instance. + + Args: + random_key: random key. + policy_networks: dict of policy networks, one for each actor. Networks can + be heterogeneous (i.e., distinct in architecture) across actors. + environment_spec: the (multiagent) environment spec, which will be + factorized into single-agent specs for replay table initialization. + variable_source: an optional LearnerSet. Each sub_actor pulls its local + variables from variable_source. + adder: how data is recorded (e.g., added to replay) for each actor. + """ + if adder is None: + adder = {agent_id: None for agent_id in policy_networks.keys()} + + sub_actors = {} + for agent_id, builder in self._builders.items(): + single_agent_spec = ma_utils.get_agent_spec(environment_spec, agent_id) + random_key, actor_key = jax.random.split(random_key) + # Adds a prefix to each sub-actor's variable names to ensure the correct + # sub-learner is queried for variables. + sub_variable_source = PrefixedVariableSource( + variable_source, f'{agent_id}{VARIABLE_SEPARATOR}') + sub_actors[agent_id] = builder.make_actor(actor_key, + policy_networks[agent_id], + single_agent_spec, + sub_variable_source, + adder[agent_id]) + return actor.SimultaneousActingMultiAgentActor(sub_actors) diff --git a/acme/acme/agents/jax/multiagent/decentralized/config.py b/acme/acme/agents/jax/multiagent/decentralized/config.py new file mode 100644 index 00000000..9c0cc569 --- /dev/null +++ b/acme/acme/agents/jax/multiagent/decentralized/config.py @@ -0,0 +1,28 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Decentralized multiagent config.""" + +import dataclasses +from typing import Dict + +from acme.multiagent import types + + +@dataclasses.dataclass +class DecentralizedMultiagentConfig: + """Configuration options for decentralized multiagent.""" + sub_agent_configs: Dict[types.AgentID, types.AgentConfig] + batch_size: int = 256 + prefetch_size: int = 2 diff --git a/acme/acme/agents/jax/multiagent/decentralized/factories.py b/acme/acme/agents/jax/multiagent/decentralized/factories.py new file mode 100644 index 00000000..23e6dbcf --- /dev/null +++ b/acme/acme/agents/jax/multiagent/decentralized/factories.py @@ -0,0 +1,229 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Decentralized multiagent factories. + +Used to unify agent initialization for both local and distributed layouts. +""" + +import enum +import functools +from typing import Any, Callable, Dict, Mapping, Optional + +from acme import specs +from acme.adders import reverb as adders_reverb +from acme.agents.jax import builders as jax_builders +from acme.agents.jax import ppo +from acme.agents.jax import sac +from acme.agents.jax import td3 +from acme.multiagent import types as ma_types +from acme.multiagent import utils as ma_utils +from acme.utils import loggers + + +class DefaultSupportedAgent(enum.Enum): + """Agents which have default initializers supported below.""" + TD3 = 'TD3' + SAC = 'SAC' + PPO = 'PPO' + + +def init_default_network( + agent_type: DefaultSupportedAgent, + agent_spec: specs.EnvironmentSpec) -> ma_types.Networks: + """Returns default networks for a single agent.""" + if agent_type == DefaultSupportedAgent.TD3: + return td3.make_networks(agent_spec) + elif agent_type == DefaultSupportedAgent.SAC: + return sac.make_networks(agent_spec) + elif agent_type == DefaultSupportedAgent.PPO: + return ppo.make_networks(agent_spec) + else: + raise ValueError(f'Unsupported agent type: {agent_type}.') + + +def init_default_policy_network( + agent_type: DefaultSupportedAgent, + network: ma_types.Networks, + agent_spec: specs.EnvironmentSpec, + config: ma_types.AgentConfig, + eval_mode: ma_types.EvalMode = False) -> ma_types.PolicyNetwork: + """Returns default policy network for a single agent.""" + if agent_type == DefaultSupportedAgent.TD3: + sigma = 0. if eval_mode else config.sigma + return td3.get_default_behavior_policy( + network, agent_spec.actions, sigma=sigma) + elif agent_type == DefaultSupportedAgent.SAC: + return sac.apply_policy_and_sample(network, eval_mode=eval_mode) + elif agent_type == DefaultSupportedAgent.PPO: + return ppo.make_inference_fn(network, evaluation=eval_mode) + else: + raise ValueError(f'Unsupported agent type: {agent_type}.') + + +def init_default_builder( + agent_type: DefaultSupportedAgent, + agent_config: ma_types.AgentConfig, +) -> jax_builders.GenericActorLearnerBuilder: + """Returns default builder for a single agent.""" + if agent_type == DefaultSupportedAgent.TD3: + assert isinstance(agent_config, td3.TD3Config) + return td3.TD3Builder(agent_config) + elif agent_type == DefaultSupportedAgent.SAC: + assert isinstance(agent_config, sac.SACConfig) + return sac.SACBuilder(agent_config) + elif agent_type == DefaultSupportedAgent.PPO: + assert isinstance(agent_config, ppo.PPOConfig) + return ppo.PPOBuilder(agent_config) + else: + raise ValueError(f'Unsupported agent type: {agent_type}.') + + +def init_default_config( + agent_type: DefaultSupportedAgent, + config_overrides: Dict[str, Any]) -> ma_types.AgentConfig: + """Returns default config for a single agent.""" + if agent_type == DefaultSupportedAgent.TD3: + return td3.TD3Config(**config_overrides) + elif agent_type == DefaultSupportedAgent.SAC: + return sac.SACConfig(**config_overrides) + elif agent_type == DefaultSupportedAgent.PPO: + return ppo.PPOConfig(**config_overrides) + else: + raise ValueError(f'Unsupported agent type: {agent_type}.') + + +def default_logger_factory( + agent_types: Dict[ma_types.AgentID, DefaultSupportedAgent], + base_label: str, + save_data: bool, + time_delta: float = 1.0, + asynchronous: bool = False, + print_fn: Optional[Callable[[str], None]] = None, + serialize_fn: Optional[Callable[[Mapping[str, Any]], str]] = None, + steps_key: str = 'steps', +) -> ma_types.MultiAgentLoggerFn: + """Returns callable that constructs default logger for all agents.""" + logger_fns = {} + for agent_id in agent_types.keys(): + logger_fns[agent_id] = functools.partial( + loggers.make_default_logger, + f'{base_label}{agent_id}', + save_data=save_data, + time_delta=time_delta, + asynchronous=asynchronous, + print_fn=print_fn, + serialize_fn=serialize_fn, + steps_key=steps_key, + ) + return logger_fns + + +def default_config_factory( + agent_types: Dict[ma_types.AgentID, DefaultSupportedAgent], + batch_size: int, + config_overrides: Optional[Dict[ma_types.AgentID, Dict[str, Any]]] = None +) -> Dict[ma_types.AgentID, ma_types.AgentConfig]: + """Returns default configs for all agents. + + Args: + agent_types: dict mapping agent IDs to their type. + batch_size: shared batch size for all agents. + config_overrides: dict mapping (potentially a subset of) agent IDs to their + config overrides. This should include any mandatory config parameters for + the agents that do not have default values. + """ + configs = {} + for agent_id, agent_type in agent_types.items(): + agent_config_overrides = dict( + # batch_size is required by LocalLayout, which is shared amongst + # the agents. Hence, we enforce a shared batch_size in builders. + batch_size=batch_size, + # Unique replay_table_name per agent. + replay_table_name=f'{adders_reverb.DEFAULT_PRIORITY_TABLE}_agent{agent_id}' + ) + if config_overrides is not None and agent_id in config_overrides: + agent_config_overrides = { + **config_overrides[agent_id], + **agent_config_overrides # Comes second to ensure batch_size override + } + configs[agent_id] = init_default_config(agent_type, agent_config_overrides) + return configs + + +def network_factory( + environment_spec: specs.EnvironmentSpec, + agent_types: Dict[ma_types.AgentID, ma_types.GenericAgent], + init_network_fn: Optional[ma_types.InitNetworkFn] = None +) -> ma_types.MultiAgentNetworks: + """Returns networks for all agents. + + Args: + environment_spec: environment spec. + agent_types: dict mapping agent IDs to their type. + init_network_fn: optional callable that handles the network initialization + for all sub-agents. If this is not supplied, a default network initializer + is used (if it is supported for the designated agent type). + """ + init_fn = init_network_fn or init_default_network + networks = {} + for agent_id, agent_type in agent_types.items(): + single_agent_spec = ma_utils.get_agent_spec(environment_spec, agent_id) + networks[agent_id] = init_fn(agent_type, single_agent_spec) + return networks + + +def policy_network_factory( + networks: ma_types.MultiAgentNetworks, + environment_spec: specs.EnvironmentSpec, + agent_types: Dict[ma_types.AgentID, ma_types.GenericAgent], + agent_configs: Dict[ma_types.AgentID, ma_types.AgentConfig], + eval_mode: ma_types.EvalMode, + init_policy_network_fn: Optional[ma_types.InitPolicyNetworkFn] = None +) -> ma_types.MultiAgentPolicyNetworks: + """Returns default policy networks for all agents. + + Args: + networks: dict mapping agent IDs to their networks. + environment_spec: environment spec. + agent_types: dict mapping agent IDs to their type. + agent_configs: dict mapping agent IDs to their config. + eval_mode: whether the policy should be initialized in evaluation mode (only + used if an init_policy_network_fn is not explicitly supplied). + init_policy_network_fn: optional callable that handles the policy network + initialization for all sub-agents. If this is not supplied, a default + policy network initializer is used (if it is supported for the designated + agent type). + """ + init_fn = init_policy_network_fn or init_default_policy_network + policy_networks = {} + for agent_id, agent_type in agent_types.items(): + single_agent_spec = ma_utils.get_agent_spec(environment_spec, agent_id) + policy_networks[agent_id] = init_fn(agent_type, networks[agent_id], + single_agent_spec, + agent_configs[agent_id], eval_mode) + return policy_networks + + +def builder_factory( + agent_types: Dict[ma_types.AgentID, ma_types.GenericAgent], + agent_configs: Dict[ma_types.AgentID, ma_types.AgentConfig], + init_builder_fn: Optional[ma_types.InitBuilderFn] = None +) -> Dict[ma_types.AgentID, jax_builders.GenericActorLearnerBuilder]: + """Returns default policy networks for all agents.""" + init_fn = init_builder_fn or init_default_builder + builders = {} + for agent_id, agent_type in agent_types.items(): + builders[agent_id] = init_fn(agent_type, agent_configs[agent_id]) + return builders diff --git a/acme/acme/agents/jax/multiagent/decentralized/learner_set.py b/acme/acme/agents/jax/multiagent/decentralized/learner_set.py new file mode 100644 index 00000000..110bb3b9 --- /dev/null +++ b/acme/acme/agents/jax/multiagent/decentralized/learner_set.py @@ -0,0 +1,85 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Decentralized multiagent learnerset.""" + +import dataclasses +from typing import Any, Dict, List + +from acme import core +from acme import types + +from acme.multiagent import types as ma_types + +LearnerState = Any + + +@dataclasses.dataclass +class SynchronousDecentralizedLearnerSetState: + """State of a SynchronousDecentralizedLearnerSet.""" + # States of the learners keyed by their names. + learner_states: Dict[ma_types.AgentID, LearnerState] + + +class SynchronousDecentralizedLearnerSet(core.Learner): + """Creates a composed learner which wraps a set of local agent learners.""" + + def __init__(self, + learners: Dict[ma_types.AgentID, core.Learner], + separator: str = '-'): + """Initializer. + + Args: + learners: a dict specifying the learners for all sub-agents. + separator: separator character used to disambiguate sub-learner variables. + """ + self._learners = learners + self._separator = separator + + def step(self): + for learner in self._learners.values(): + learner.step() + + def get_variables(self, names: List[str]) -> List[types.NestedArray]: + """Return the named variables as a collection of (nested) numpy arrays. + + The variable names should be prefixed with the name of the child learners + using the separator specified in the constructor, e.g. learner1/var. + + Args: + names: args where each name is a string identifying a predefined subset of + the variables. The variables names should be prefixed with the name of + the learners using the separator specified in the constructor, e.g. + learner-var if the separator is -. + + Returns: + A list of (nested) numpy arrays `variables` such that `variables[i]` + corresponds to the collection named by `names[i]`. + """ + variables = [] + for name in names: + # Note: if separator is not found, learner_name=name, which is OK. + learner_id, _, variable_name = name.partition(self._separator) + learner = self._learners[learner_id] + variables.extend(learner.get_variables([variable_name])) + return variables + + def save(self) -> SynchronousDecentralizedLearnerSetState: + return SynchronousDecentralizedLearnerSetState(learner_states={ + name: learner.save() for name, learner in self._learners.items() + }) + + def restore(self, state: SynchronousDecentralizedLearnerSetState): + for name, learner in self._learners.items(): + learner.restore(state.learner_states[name]) diff --git a/acme/acme/agents/jax/normalization.py b/acme/acme/agents/jax/normalization.py new file mode 100644 index 00000000..1cd85e16 --- /dev/null +++ b/acme/acme/agents/jax/normalization.py @@ -0,0 +1,251 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility classes for input normalization.""" + +import dataclasses +import functools +from typing import Any, Callable, Generic, Iterator, List, Optional, Tuple + +import acme +from acme import adders +from acme import core +from acme import specs +from acme import types +from acme.agents.jax import builders +from acme.jax import networks as networks_lib +from acme.jax import running_statistics +from acme.jax import variable_utils +from acme.jax.types import Networks, PolicyNetwork # pylint: disable=g-multiple-import +from acme.utils import counting +from acme.utils import loggers +import dm_env +import jax +import reverb + +_NORMALIZATION_VARIABLES = 'normalization_variables' + + +# Wrapping the network instead might look more straightforward, but then +# different implementations would be needed for feed-forward and +# recurrent networks. +class NormalizationActorWrapper(core.Actor): + """An actor wrapper that normalizes observations before applying policy.""" + + def __init__(self, + wrapped_actor: core.Actor, + variable_source: core.VariableSource, + max_abs_observation: Optional[float], + update_period: int = 1, + backend: Optional[str] = None): + self._wrapped_actor = wrapped_actor + self._variable_client = variable_utils.VariableClient( + variable_source, + key=_NORMALIZATION_VARIABLES, + update_period=update_period, + device=backend) + self._apply_normalization = jax.jit( + functools.partial( + running_statistics.normalize, max_abs_value=max_abs_observation), + backend=backend) + + def select_action(self, observation: types.NestedArray) -> types.NestedArray: + self._variable_client.update() + observation_stats = self._variable_client.params + observation = self._apply_normalization(observation, observation_stats) + return self._wrapped_actor.select_action(observation) + + def observe_first(self, timestep: dm_env.TimeStep): + return self._wrapped_actor.observe_first(timestep) + + def observe( + self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + ): + return self._wrapped_actor.observe(action, next_timestep) + + def update(self, wait: bool = False): + return self._wrapped_actor.update(wait) + + +@dataclasses.dataclass +class NormalizationLearnerWrapperState: + wrapped_learner_state: Any + observation_running_statistics: running_statistics.RunningStatisticsState + + +class NormalizationLearnerWrapper(core.Learner, core.Saveable): + """A learner wrapper that normalizes observations using running statistics.""" + + def __init__(self, learner_factory: Callable[[Iterator[reverb.ReplaySample]], + acme.Learner], + iterator: Iterator[reverb.ReplaySample], + environment_spec: specs.EnvironmentSpec, is_sequence_based: bool, + batch_dims: Optional[Tuple[int, ...]], + max_abs_observation: Optional[float]): + + def normalize_sample( + observation_statistics: running_statistics.RunningStatisticsState, + sample: reverb.ReplaySample + ) -> Tuple[running_statistics.RunningStatisticsState, reverb.ReplaySample]: + observation = sample.data.observation + observation_statistics = running_statistics.update( + observation_statistics, observation) + observation = running_statistics.normalize( + observation, + observation_statistics, + max_abs_value=max_abs_observation) + if is_sequence_based: + assert not hasattr(sample.data, 'next_observation') + sample = reverb.ReplaySample( + sample.info, sample.data._replace(observation=observation)) + else: + next_observation = running_statistics.normalize( + sample.data.next_observation, + observation_statistics, + max_abs_value=max_abs_observation) + sample = reverb.ReplaySample( + sample.info, + sample.data._replace( + observation=observation, next_observation=next_observation)) + + return observation_statistics, sample + + self._observation_running_statistics = running_statistics.init_state( + environment_spec.observations) + self._normalize_sample = jax.jit(normalize_sample) + + normalizing_iterator = ( + self._normalize_sample_and_update(sample) for sample in iterator) + self._wrapped_learner = learner_factory(normalizing_iterator) + + def _normalize_sample_and_update( + self, sample: reverb.ReplaySample) -> reverb.ReplaySample: + self._observation_running_statistics, sample = self._normalize_sample( + self._observation_running_statistics, sample) + return sample + + def step(self): + self._wrapped_learner.step() + + def get_variables(self, names: List[str]) -> List[types.NestedArray]: + stats = self._observation_running_statistics + # Make sure to only pass mean and std to minimize trafic. + mean_std = running_statistics.NestedMeanStd(mean=stats.mean, std=stats.std) + normalization_variables = {_NORMALIZATION_VARIABLES: mean_std} + + learner_names = [ + name for name in names if name not in normalization_variables + ] + learner_variables = dict( + zip(learner_names, self._wrapped_learner.get_variables( + learner_names))) if learner_names else {} + + return [ + normalization_variables.get(name, learner_variables.get(name, None)) + for name in names + ] + + def save(self) -> NormalizationLearnerWrapperState: + return NormalizationLearnerWrapperState( + wrapped_learner_state=self._wrapped_learner.save(), + observation_running_statistics=self._observation_running_statistics) + + def restore(self, state: NormalizationLearnerWrapperState): + self._wrapped_learner.restore(state.wrapped_learner_state) + self._observation_running_statistics = state.observation_running_statistics + + +@dataclasses.dataclass +class NormalizationBuilder(Generic[Networks, PolicyNetwork], + builders.ActorLearnerBuilder[Networks, PolicyNetwork, + reverb.ReplaySample]): + """Builder wrapper that normalizes observations using running mean/std.""" + builder: builders.ActorLearnerBuilder[Networks, PolicyNetwork, + reverb.ReplaySample] + is_sequence_based: bool + batch_dims: Optional[Tuple[int, ...]] + max_abs_observation: Optional[float] = 10.0 + statistics_update_period: int = 100 + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: PolicyNetwork, + ) -> List[reverb.Table]: + return self.builder.make_replay_tables(environment_spec, policy) + + def make_dataset_iterator( + self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]: + return self.builder.make_dataset_iterator(replay_client) + + def make_adder(self, replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[PolicyNetwork]) -> Optional[adders.Adder]: + return self.builder.make_adder(replay_client, environment_spec, policy) + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: Networks, + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + + learner_factory = functools.partial( + self.builder.make_learner, + random_key, + networks, + logger_fn=logger_fn, + environment_spec=environment_spec, + replay_client=replay_client, + counter=counter) + + return NormalizationLearnerWrapper( + learner_factory=learner_factory, + iterator=dataset, + environment_spec=environment_spec, + is_sequence_based=self.is_sequence_based, + batch_dims=self.batch_dims, + max_abs_observation=self.max_abs_observation) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: PolicyNetwork, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> core.Actor: + actor = self.builder.make_actor(random_key, policy, environment_spec, + variable_source, adder) + return NormalizationActorWrapper( + actor, + variable_source, + max_abs_observation=self.max_abs_observation, + update_period=self.statistics_update_period, + backend='cpu') + + def make_policy(self, + networks: Networks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False) -> PolicyNetwork: + return self.builder.make_policy( + networks=networks, + environment_spec=environment_spec, + evaluation=evaluation) diff --git a/acme/acme/agents/jax/ppo/README.md b/acme/acme/agents/jax/ppo/README.md new file mode 100644 index 00000000..4e8cf4bd --- /dev/null +++ b/acme/acme/agents/jax/ppo/README.md @@ -0,0 +1,13 @@ +# Proximal Policy Optimization (PPO) + +This folder contains an implementation of the PPO algorithm +([Schulman et al., 2017]) with clipped surrogate objective. + +Implementation notes: + - PPO is not a strictly on-policy algorithm. In each call to the learner's + step function, a batch of transitions are taken from the Reverb replay + buffer, and N epochs of updates are performed on the data in the batch. + Using larger values for num_epochs and num_minibatches makes the algorithm + "more off-policy". + +[Schulman et al., 2017]: https://arxiv.org/abs/1707.06347 diff --git a/acme/acme/agents/jax/ppo/__init__.py b/acme/acme/agents/jax/ppo/__init__.py new file mode 100644 index 00000000..d6a617be --- /dev/null +++ b/acme/acme/agents/jax/ppo/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PPO agent.""" + +from acme.agents.jax.ppo.builder import PPOBuilder +from acme.agents.jax.ppo.config import PPOConfig +from acme.agents.jax.ppo.learning import PPOLearner +from acme.agents.jax.ppo.networks import make_categorical_ppo_networks +from acme.agents.jax.ppo.networks import make_continuous_networks +from acme.agents.jax.ppo.networks import make_discrete_networks +from acme.agents.jax.ppo.networks import make_inference_fn +from acme.agents.jax.ppo.networks import make_mvn_diag_ppo_networks +from acme.agents.jax.ppo.networks import make_networks +from acme.agents.jax.ppo.networks import make_ppo_networks +from acme.agents.jax.ppo.networks import PPONetworks diff --git a/acme/acme/agents/jax/ppo/builder.py b/acme/acme/agents/jax/ppo/builder.py new file mode 100644 index 00000000..e8cd8205 --- /dev/null +++ b/acme/acme/agents/jax/ppo/builder.py @@ -0,0 +1,211 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PPO Builder.""" +from typing import Iterator, List, Optional + +from acme import adders +from acme import core +from acme import specs +from acme.adders import reverb as adders_reverb +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import actors +from acme.agents.jax import builders +from acme.agents.jax.ppo import config as ppo_config +from acme.agents.jax.ppo import learning +from acme.agents.jax.ppo import networks as ppo_networks +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.jax import variable_utils +from acme.utils import counting +from acme.utils import loggers +import jax +import numpy as np +import optax +import reverb + + +class PPOBuilder( + builders.ActorLearnerBuilder[ppo_networks.PPONetworks, + actor_core_lib.FeedForwardPolicyWithExtra, + reverb.ReplaySample]): + """PPO Builder.""" + + def __init__( + self, + config: ppo_config.PPOConfig, + ): + """Creates PPO builder.""" + self._config = config + + # An extra step is used for bootstrapping when computing advantages. + self._sequence_length = config.unroll_length + 1 + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: actor_core_lib.FeedForwardPolicyWithExtra, + ) -> List[reverb.Table]: + """Creates reverb tables for the algorithm.""" + del policy + extra_spec = { + 'log_prob': np.ones(shape=(), dtype=np.float32), + } + signature = adders_reverb.SequenceAdder.signature( + environment_spec, extra_spec, sequence_length=self._sequence_length) + return [ + reverb.Table.queue( + name=self._config.replay_table_name, + max_size=self._config.batch_size, + signature=signature) + ] + + def make_dataset_iterator( + self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]: + """Creates a dataset. + + The iterator batch size is computed as follows: + + Let: + B := learner batch size (config.batch_size) + H := number of hosts (jax.process_count()) + D := number of local devices per host + + The Reverb iterator will load batches of size B // (H * D). After wrapping + the iterator with utils.multi_device_put, this will result in an iterable + that provides B // H samples per item, with B // (H * D) samples placed on + each local device. In a multi-host setup, each host has its own learner + node and builds its own instance of the iterator. This will result + in a total batch size of H * (B // H) == B being consumed per learner + step (since the learner is pmapped across all devices). Note that + jax.device_count() returns the total number of devices across hosts, + i.e. H * D. + + Args: + replay_client: the reverb replay client + + Returns: + A replay buffer iterator to be used by the local devices. + """ + iterator_batch_size, ragged = divmod(self._config.batch_size, + jax.device_count()) + if ragged: + raise ValueError( + 'Learner batch size must be divisible by total number of devices!') + + # We don't use datasets.make_reverb_dataset() here to avoid interleaving + # and prefetching, that doesn't work well with can_sample() check on update. + # NOTE: Value for max_in_flight_samples_per_worker comes from a + # recommendation here: https://git.io/JYzXB + dataset = reverb.TrajectoryDataset.from_table_signature( + server_address=replay_client.server_address, + table=self._config.replay_table_name, + max_in_flight_samples_per_worker=(2 * self._config.batch_size / + jax.process_count())) + dataset = dataset.batch(iterator_batch_size, drop_remainder=True) + dataset = dataset.as_numpy_iterator() + return utils.multi_device_put(iterable=dataset, devices=jax.local_devices()) + + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[actor_core_lib.FeedForwardPolicyWithExtra], + ) -> Optional[adders.Adder]: + """Creates an adder which handles observations.""" + del environment_spec, policy + # Note that the last transition in the sequence is used for bootstrapping + # only and is ignored otherwise. So we need to make sure that sequences + # overlap on one transition, thus "-1" in the period length computation. + return adders_reverb.SequenceAdder( + client=replay_client, + priority_fns={self._config.replay_table_name: None}, + period=self._sequence_length - 1, + sequence_length=self._sequence_length, + ) + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: ppo_networks.PPONetworks, + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec, replay_client + + if callable(self._config.learning_rate): + optimizer = optax.chain( + optax.clip_by_global_norm(self._config.max_gradient_norm), + optax.scale_by_adam(eps=self._config.adam_epsilon), + optax.scale_by_schedule(self._config.learning_rate), optax.scale(-1)) + else: + optimizer = optax.chain( + optax.clip_by_global_norm(self._config.max_gradient_norm), + optax.scale_by_adam(eps=self._config.adam_epsilon), + optax.scale(-self._config.learning_rate)) + + return learning.PPOLearner( + ppo_networks=networks, + iterator=dataset, + discount=self._config.discount, + entropy_cost=self._config.entropy_cost, + value_cost=self._config.value_cost, + ppo_clipping_epsilon=self._config.ppo_clipping_epsilon, + normalize_advantage=self._config.normalize_advantage, + normalize_value=self._config.normalize_value, + normalization_ema_tau=self._config.normalization_ema_tau, + clip_value=self._config.clip_value, + value_clipping_epsilon=self._config.value_clipping_epsilon, + max_abs_reward=self._config.max_abs_reward, + gae_lambda=self._config.gae_lambda, + counter=counter, + random_key=random_key, + optimizer=optimizer, + num_epochs=self._config.num_epochs, + num_minibatches=self._config.num_minibatches, + logger=logger_fn('learner'), + log_global_norm_metrics=self._config.log_global_norm_metrics, + metrics_logging_period=self._config.metrics_logging_period, + ) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: actor_core_lib.FeedForwardPolicyWithExtra, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> core.Actor: + del environment_spec + assert variable_source is not None + actor = actor_core_lib.batched_feed_forward_with_extras_to_actor_core( + policy) + variable_client = variable_utils.VariableClient( + variable_source, + 'network', + device='cpu', + update_period=self._config.variable_update_period) + return actors.GenericActor( + actor, random_key, variable_client, adder, backend='cpu') + + def make_policy( + self, + networks: ppo_networks.PPONetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False) -> actor_core_lib.FeedForwardPolicyWithExtra: + del environment_spec + return ppo_networks.make_inference_fn(networks, evaluation) diff --git a/acme/acme/agents/jax/ppo/config.py b/acme/acme/agents/jax/ppo/config.py new file mode 100644 index 00000000..270a1e4c --- /dev/null +++ b/acme/acme/agents/jax/ppo/config.py @@ -0,0 +1,79 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PPO config.""" +import dataclasses +from typing import Callable, Union, Optional + +from acme.adders import reverb as adders_reverb + + +@dataclasses.dataclass +class PPOConfig: + """Configuration options for PPO. + + Attributes: + unroll_length: Length of sequences added to the replay buffer. + num_minibatches: The number of minibatches to split an epoch into. + i.e. minibatch size = batch_size * unroll_length / num_minibatches. + num_epochs: How many times to loop over the set of minibatches. + batch_size: Number of trajectory segments of length unroll_length to gather + for use in a call to the learner's step function. + replay_table_name: Replay table name. + ppo_clipping_epsilon: PPO clipping epsilon. + normalize_advantage: Whether to normalize the advantages in the batch. + normalize_value: Whether the critic should predict normalized values. + normalization_ema_tau: Float tau for the exponential moving average used to + maintain statistics for normalizing advantages and values. + clip_value: Whether to clip the values as described in "What Matters in + On-Policy Reinforcement Learning?". + value_clipping_epsilon: Epsilon for value clipping. + max_abs_reward: If provided clips the rewards in the trajectory to have + absolute value less than or equal to max_abs_reward. + gae_lambda: Lambda parameter in Generalized Advantage Estimation. + discount: Discount factor. + learning_rate: Learning rate for updating the policy and critic networks. + adam_epsilon: Adam epsilon parameter. + entropy_cost: Weight of the entropy regularizer term in policy optimization. + value_cost: Weight of the value loss term in optimization. + max_gradient_norm: Threshold for clipping the gradient norm. + variable_update_period: Determines how frequently actors pull the parameters + from the learner. + log_global_norm_metrics: Whether to log global norm of gradients and + updates. + metrics_logging_period: How often metrics should be aggregated to host and + logged. + """ + unroll_length: int = 8 + num_minibatches: int = 8 + num_epochs: int = 2 + batch_size: int = 256 + replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE + ppo_clipping_epsilon: float = 0.2 + normalize_advantage: bool = False + normalize_value: bool = False + normalization_ema_tau: float = 0.995 + clip_value: bool = False + value_clipping_epsilon: float = 0.2 + max_abs_reward: Optional[float] = None + gae_lambda: float = 0.95 + discount: float = 0.99 + learning_rate: Union[float, Callable[[int], float]] = 3e-4 + adam_epsilon: float = 1e-7 + entropy_cost: float = 3e-4 + value_cost: float = 1. + max_gradient_norm: float = 0.5 + variable_update_period: int = 1 + log_global_norm_metrics: bool = False + metrics_logging_period: int = 100 diff --git a/acme/acme/agents/jax/ppo/learning.py b/acme/acme/agents/jax/ppo/learning.py new file mode 100644 index 00000000..cecfbf0e --- /dev/null +++ b/acme/acme/agents/jax/ppo/learning.py @@ -0,0 +1,460 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Learner for the PPO agent.""" + +from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple + +import acme +from acme import types +from acme.agents.jax.ppo import networks +from acme.jax import networks as networks_lib +from acme.jax.utils import get_from_first_device +from acme.utils import counting +from acme.utils import loggers +import jax +import jax.numpy as jnp +import optax +import reverb +import rlax + + +class Batch(NamedTuple): + """A batch of data; all shapes are expected to be [B, ...].""" + observations: types.NestedArray + actions: jnp.ndarray + advantages: jnp.ndarray + + # Target value estimate used to bootstrap the value function. + target_values: jnp.ndarray + + # Value estimate and action log-prob at behavior time. + behavior_values: jnp.ndarray + behavior_log_probs: jnp.ndarray + + +class TrainingState(NamedTuple): + """Training state for the PPO learner.""" + params: networks_lib.Params + opt_state: optax.OptState + random_key: networks_lib.PRNGKey + + # Optional counter used for exponential moving average zero debiasing + ema_counter: Optional[jnp.int32] = None + + # Optional parameter for maintaining a running estimate of the scale of + # advantage estimates + biased_advantage_scale: Optional[networks_lib.Params] = None + advantage_scale: Optional[networks_lib.Params] = None + + # Optional parameter for maintaining a running estimate of the mean and + # standard deviation of value estimates + biased_value_first_moment: Optional[networks_lib.Params] = None + biased_value_second_moment: Optional[networks_lib.Params] = None + value_mean: Optional[networks_lib.Params] = None + value_std: Optional[networks_lib.Params] = None + + +class PPOLearner(acme.Learner): + """Learner for PPO.""" + + def __init__( + self, + ppo_networks: networks.PPONetworks, + iterator: Iterator[reverb.ReplaySample], + optimizer: optax.GradientTransformation, + random_key: networks_lib.PRNGKey, + ppo_clipping_epsilon: float = 0.2, + normalize_advantage: bool = True, + normalize_value: bool = False, + normalization_ema_tau: float = 0.995, + clip_value: bool = False, + value_clipping_epsilon: float = 0.2, + max_abs_reward: Optional[float] = None, + gae_lambda: float = 0.95, + discount: float = 0.99, + entropy_cost: float = 0., + value_cost: float = 1., + num_epochs: int = 4, + num_minibatches: int = 1, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + log_global_norm_metrics: bool = False, + metrics_logging_period: int = 100, + ): + self.local_learner_devices = jax.local_devices() + self.num_local_learner_devices = jax.local_device_count() + self.learner_devices = jax.devices() + self.num_epochs = num_epochs + self.num_minibatches = num_minibatches + self.metrics_logging_period = metrics_logging_period + self._num_full_update_steps = 0 + self._iterator = iterator + + # Set up logging/counting. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger('learner') + + def ppo_loss( + params: networks_lib.Params, + observations: networks_lib.Observation, + actions: networks_lib.Action, + advantages: jnp.ndarray, + target_values: networks_lib.Value, + behavior_values: networks_lib.Value, + behavior_log_probs: networks_lib.LogProb, + value_mean: jnp.ndarray, + value_std: jnp.ndarray, + key: networks_lib.PRNGKey, + ) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]: + """PPO loss for the policy and the critic.""" + distribution_params, values = ppo_networks.network.apply( + params, observations) + if normalize_value: + # values = values * jnp.fmax(value_std, 1e-6) + value_mean + target_values = (target_values - value_mean) / jnp.fmax(value_std, 1e-6) + policy_log_probs = ppo_networks.log_prob(distribution_params, actions) + key, sub_key = jax.random.split(key) # pylint: disable=unused-variable + policy_entropies = ppo_networks.entropy(distribution_params) + + # Compute the policy losses + rhos = jnp.exp(policy_log_probs - behavior_log_probs) + clipped_ppo_policy_loss = rlax.clipped_surrogate_pg_loss( + rhos, advantages, ppo_clipping_epsilon) + policy_entropy_loss = -jnp.mean(policy_entropies) + total_policy_loss = ( + clipped_ppo_policy_loss + entropy_cost * policy_entropy_loss) + + # Compute the critic losses + unclipped_value_loss = (values - target_values)**2 + + if clip_value: + # Clip values to reduce variablility during critic training. + clipped_values = behavior_values + jnp.clip(values - behavior_values, + -value_clipping_epsilon, + value_clipping_epsilon) + clipped_value_error = target_values - clipped_values + clipped_value_loss = clipped_value_error ** 2 + value_loss = jnp.mean(jnp.fmax(unclipped_value_loss, + clipped_value_loss)) + else: + # For Mujoco envs clipping hurts a lot. Evidenced by Figure 43 in + # https://arxiv.org/pdf/2006.05990.pdf + value_loss = jnp.mean(unclipped_value_loss) + + total_ppo_loss = total_policy_loss + value_cost * value_loss + return total_ppo_loss, { + 'loss_total': total_ppo_loss, + 'loss_policy_total': total_policy_loss, + 'loss_policy_pg': clipped_ppo_policy_loss, + 'loss_policy_entropy': policy_entropy_loss, + 'loss_critic': value_loss, + } + + ppo_loss_grad = jax.grad(ppo_loss, has_aux=True) + + def sgd_step(state: TrainingState, minibatch: Batch): + observations = minibatch.observations + actions = minibatch.actions + advantages = minibatch.advantages + target_values = minibatch.target_values + behavior_values = minibatch.behavior_values + behavior_log_probs = minibatch.behavior_log_probs + key, sub_key = jax.random.split(state.random_key) + + loss_grad, metrics = ppo_loss_grad( + state.params, + observations, + actions, + advantages, + target_values, + behavior_values, + behavior_log_probs, + state.value_mean, + state.value_std, + sub_key, + ) + + # Apply updates + loss_grad = jax.lax.pmean(loss_grad, axis_name='devices') + updates, opt_state = optimizer.update(loss_grad, state.opt_state) + params = optax.apply_updates(state.params, updates) + + if log_global_norm_metrics: + metrics['norm_grad'] = optax.global_norm(loss_grad) + metrics['norm_updates'] = optax.global_norm(updates) + + new_state = state._replace( + params=params, opt_state=opt_state, random_key=key) + + return new_state, metrics + + def epoch_update( + carry: Tuple[TrainingState, Batch], + unused_t: Tuple[()], + ): + state, carry_batch = carry + + # Shuffling into minibatches + batch_size = carry_batch.advantages.shape[0] + key, sub_key = jax.random.split(state.random_key) + # TODO(kamyar) For effiency could use same permutation for all epochs + permuted_batch = jax.tree_util.tree_map( + lambda x: jax.random.permutation( # pylint: disable=g-long-lambda + sub_key, + x, + axis=0, + independent=False), + carry_batch) + state = state._replace(random_key=key) + minibatches = jax.tree_util.tree_map( + lambda x: jnp.reshape( # pylint: disable=g-long-lambda + x, + [ # pylint: disable=g-long-lambda + num_minibatches, batch_size // num_minibatches + ] + list(x.shape[1:])), + permuted_batch) + + # Scan over the minibatches + state, metrics = jax.lax.scan( + sgd_step, state, minibatches, length=num_minibatches) + metrics = jax.tree_util.tree_map(jnp.mean, metrics) + + return (state, carry_batch), metrics + + vmapped_network_apply = jax.vmap( + ppo_networks.network.apply, in_axes=(None, 0), out_axes=0) + + def single_device_update( + state: TrainingState, + trajectories: types.NestedArray, + ): + # Update the EMA counter and obtain the zero debiasing multiplier + if normalize_advantage or normalize_value: + ema_counter = state.ema_counter + 1 + state = state._replace(ema_counter=ema_counter) + zero_debias = 1. / (1. - jnp.power(normalization_ema_tau, ema_counter)) + + # Extract the data. + data = trajectories.data + observations, actions, rewards, termination, extra = (data.observation, + data.action, + data.reward, + data.discount, + data.extras) + if max_abs_reward is not None: + # Apply reward clipping. + rewards = jnp.clip(rewards, -1. * max_abs_reward, max_abs_reward) + discounts = termination * discount + behavior_log_probs = extra['log_prob'] + _, behavior_values = vmapped_network_apply(state.params, observations) + + if normalize_value: + batch_value_first_moment = jnp.mean(behavior_values) + batch_value_second_moment = jnp.mean(behavior_values**2) + batch_value_first_moment, batch_value_second_moment = jax.lax.pmean( + (batch_value_first_moment, batch_value_second_moment), + axis_name='devices') + + biased_value_first_moment = ( + normalization_ema_tau * state.biased_value_first_moment + + (1. - normalization_ema_tau) * batch_value_first_moment) + biased_value_second_moment = ( + normalization_ema_tau * state.biased_value_second_moment + + (1. - normalization_ema_tau) * batch_value_second_moment) + + value_mean = biased_value_first_moment * zero_debias + value_second_moment = biased_value_second_moment * zero_debias + value_std = jnp.sqrt(jax.nn.relu(value_second_moment - value_mean**2)) + + state = state._replace( + biased_value_first_moment=biased_value_first_moment, + biased_value_second_moment=biased_value_second_moment, + value_mean=value_mean, + value_std=value_std, + ) + + behavior_values = behavior_values * jnp.fmax(state.value_std, + 1e-6) + state.value_mean + + behavior_values = jax.lax.stop_gradient(behavior_values) + + # Compute GAE using rlax + vmapped_rlax_truncated_generalized_advantage_estimation = jax.vmap( + rlax.truncated_generalized_advantage_estimation, + in_axes=(0, 0, None, 0)) + advantages = vmapped_rlax_truncated_generalized_advantage_estimation( + rewards[:, :-1], discounts[:, :-1], gae_lambda, behavior_values) + advantages = jax.lax.stop_gradient(advantages) + target_values = behavior_values[:, :-1] + advantages + target_values = jax.lax.stop_gradient(target_values) + + # Exclude the last step - it was only used for bootstrapping. + # The shape is [num_sequences, num_steps, ..] + observations, actions, behavior_log_probs, behavior_values = jax.tree_util.tree_map( + lambda x: x[:, :-1], + (observations, actions, behavior_log_probs, behavior_values)) + + # Shuffle the data and break into minibatches + batch_size = advantages.shape[0] * advantages.shape[1] + batch = Batch( + observations=observations, + actions=actions, + advantages=advantages, + target_values=target_values, + behavior_values=behavior_values, + behavior_log_probs=behavior_log_probs) + batch = jax.tree_util.tree_map( + lambda x: jnp.reshape(x, [batch_size] + list(x.shape[2:])), batch) + + if normalize_advantage: + batch_advantage_scale = jnp.mean(jnp.abs(batch.advantages)) + batch_advantage_scale = jax.lax.pmean(batch_advantage_scale, 'devices') + + # update the running statistics + biased_advantage_scale = ( + normalization_ema_tau * state.biased_advantage_scale + + (1. - normalization_ema_tau) * batch_advantage_scale) + advantage_scale = biased_advantage_scale * zero_debias + state = state._replace( + biased_advantage_scale=biased_advantage_scale, + advantage_scale=advantage_scale) + + # scale the advantages + scaled_advantages = batch.advantages / jnp.fmax(state.advantage_scale, + 1e-6) + batch = batch._replace(advantages=scaled_advantages) + + # Scan desired number of epoch updates + (state, _), metrics = jax.lax.scan( + epoch_update, (state, batch), (), length=num_epochs) + metrics = jax.tree_util.tree_map(jnp.mean, metrics) + + if normalize_advantage: + metrics['advantage_scale'] = state.advantage_scale + + if normalize_value: + metrics['value_mean'] = value_mean + metrics['value_std'] = value_std + + return state, metrics + + pmapped_update_step = jax.pmap( + single_device_update, axis_name='devices', devices=self.learner_devices) + + def full_update_step( + state: TrainingState, + trajectories: types.NestedArray, + ): + state, metrics = pmapped_update_step(state, trajectories) + return state, metrics + + self._full_update_step = full_update_step + + def make_initial_state(key: networks_lib.PRNGKey) -> TrainingState: + """Initialises the training state (parameters and optimiser state).""" + all_keys = jax.random.split(key, num=self.num_local_learner_devices + 1) + key_init, key_state = all_keys[0], all_keys[1:] + key_state = [key_state[i] for i in range(self.num_local_learner_devices)] + key_state = jax.device_put_sharded(key_state, self.local_learner_devices) + + initial_params = ppo_networks.network.init(key_init) + initial_opt_state = optimizer.init(initial_params) + + initial_params = jax.device_put_replicated(initial_params, + self.local_learner_devices) + initial_opt_state = jax.device_put_replicated(initial_opt_state, + self.local_learner_devices) + + ema_counter = jnp.int32(0) + ema_counter = jax.device_put_replicated(ema_counter, + self.local_learner_devices) + + init_state = TrainingState( + params=initial_params, + opt_state=initial_opt_state, + random_key=key_state, + ema_counter=ema_counter, + ) + + if normalize_advantage: + biased_advantage_scale = jax.device_put_replicated( + jnp.zeros([]), self.local_learner_devices) + advantage_scale = jax.device_put_replicated( + jnp.zeros([]), self.local_learner_devices) + + init_state = init_state._replace( + biased_advantage_scale=biased_advantage_scale, + advantage_scale=advantage_scale) + + if normalize_value: + biased_value_first_moment = jax.device_put_replicated( + jnp.zeros([]), self.local_learner_devices) + value_mean = biased_value_first_moment + + biased_value_second_moment = jax.device_put_replicated( + jnp.zeros([]), self.local_learner_devices) + value_second_moment = biased_value_second_moment + value_std = jnp.sqrt(jax.nn.relu(value_second_moment - value_mean**2)) + + init_state = init_state._replace( + biased_value_first_moment=biased_value_first_moment, + biased_value_second_moment=biased_value_second_moment, + value_mean=value_mean, + value_std=value_std) + + return init_state + + # Initialise training state (parameters and optimizer state). + self._state = make_initial_state(random_key) + + def step(self): + """Does a learner step and logs the results. + + One learner step consists of (possibly multiple) epochs of PPO updates on + a batch of NxT steps collected by the actors. + """ + sample = next(self._iterator) + self._state, results = self._full_update_step(self._state, sample) + + # Update our counts and record it. + counts = self._counter.increment(steps=self.num_epochs * + self.num_minibatches) + + # Snapshot and attempt to write logs. + if self._num_full_update_steps % self.metrics_logging_period == 0: + results = jax.tree_util.tree_map(jnp.mean, results) + self._logger.write({**results, **counts}) + + self._num_full_update_steps += 1 + + def get_variables(self, names: List[str]) -> List[networks_lib.Params]: + params = get_from_first_device(self._state.params, as_numpy=False) + return [params] + + def save(self) -> TrainingState: + return get_from_first_device(self._state, as_numpy=False) + + def restore(self, state: TrainingState): + # TODO(kamyar) Should the random_key come from self._state instead? + random_key = state.random_key + random_key = jax.random.split( + random_key, num=self.num_local_learner_devices) + random_key = jax.device_put_sharded( + [random_key[i] for i in range(self.num_local_learner_devices)], + self.local_learner_devices) + + state = jax.device_put_replicated(state, self.local_learner_devices) + state = state._replace(random_key=random_key) + self._state = state diff --git a/acme/acme/agents/jax/ppo/networks.py b/acme/acme/agents/jax/ppo/networks.py new file mode 100644 index 00000000..d901e58d --- /dev/null +++ b/acme/acme/agents/jax/ppo/networks.py @@ -0,0 +1,358 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PPO network definitions.""" + +import dataclasses +from typing import Any, Callable, Optional, Sequence, NamedTuple + +from acme import specs +from acme.agents.jax import actor_core as actor_core_lib +from acme.jax import networks as networks_lib +from acme.jax import utils + +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np + +import tensorflow_probability + +tfp = tensorflow_probability.substrates.jax +tfd = tfp.distributions + +EntropyFn = Callable[[Any], jnp.ndarray] + + +class MVNDiagParams(NamedTuple): + """Parameters for a diagonal multi-variate normal distribution.""" + loc: jnp.ndarray + scale_diag: jnp.ndarray + + +class TanhNormalParams(NamedTuple): + """Parameters for a tanh squashed diagonal MVN distribution.""" + loc: jnp.ndarray + scale: jnp.ndarray + + +class CategoricalParams(NamedTuple): + """Parameters for a categorical distribution.""" + logits: jnp.ndarray + + +@dataclasses.dataclass +class PPONetworks: + """Network and pure functions for the PPO agent. + + If 'network' returns tfd.Distribution, you can use make_ppo_networks() to + create this object properly. + If one is building this object manually, one has a freedom to make 'network' + object return anything that is later being passed as input to + log_prob/entropy/sample functions to perform the corresponding computations. + An example scenario where you would want to do this due to + tfd.Distribution not playing nice with jax.vmap. Please refer to the + make_continuous_networks() for an example where the network does not return a + tfd.Distribution object. + """ + network: networks_lib.FeedForwardNetwork + log_prob: networks_lib.LogProbFn + entropy: EntropyFn + sample: networks_lib.SampleFn + sample_eval: Optional[networks_lib.SampleFn] = None + + +def make_inference_fn( + ppo_networks: PPONetworks, + evaluation: bool = False) -> actor_core_lib.FeedForwardPolicyWithExtra: + """Returns a function to be used for inference by a PPO actor.""" + + def inference(params: networks_lib.Params, key: networks_lib.PRNGKey, + observations: networks_lib.Observation): + dist_params, _ = ppo_networks.network.apply(params, observations) + if evaluation and ppo_networks.sample_eval: + actions = ppo_networks.sample_eval(dist_params, key) + else: + actions = ppo_networks.sample(dist_params, key) + if evaluation: + return actions, {} + log_prob = ppo_networks.log_prob(dist_params, actions) + return actions, {'log_prob': log_prob} + + return inference + + +def make_networks( + spec: specs.EnvironmentSpec, hidden_layer_sizes: Sequence[int] = (256, 256) +) -> PPONetworks: + if isinstance(spec.actions, specs.DiscreteArray): + return make_discrete_networks(spec, hidden_layer_sizes) + else: + return make_continuous_networks( + spec, + policy_layer_sizes=hidden_layer_sizes, + value_layer_sizes=hidden_layer_sizes) + + +def make_ppo_networks(network: networks_lib.FeedForwardNetwork) -> PPONetworks: + """Constructs a PPONetworks instance from the given FeedForwardNetwork. + + This method assumes that the network returns a tfd.Distribution. Sometimes it + may be preferable to have networks that do not return tfd.Distribution + objects, for example, due to tfd.Distribution not playing nice with jax.vmap. + Please refer to the make_continuous_networks() for an example where the + network does not return a tfd.Distribution object. + + Args: + network: a transformed Haiku network that takes in observations and returns + the action distribution and value. + + Returns: + A PPONetworks instance with pure functions wrapping the input network. + """ + return PPONetworks( + network=network, + log_prob=lambda distribution, action: distribution.log_prob(action), + entropy=lambda distribution, key=None: distribution.entropy(), + sample=lambda distribution, key: distribution.sample(seed=key), + sample_eval=lambda distribution, key: distribution.mode()) + + +def make_mvn_diag_ppo_networks( + network: networks_lib.FeedForwardNetwork) -> PPONetworks: + """Constructs a PPONetworks for MVN Diag policy from the FeedForwardNetwork. + + Args: + network: a transformed Haiku network (or equivalent in other libraries) that + takes in observations and returns the action distribution and value. + + Returns: + A PPONetworks instance with pure functions wrapping the input network. + """ + + def log_prob(params: MVNDiagParams, action): + return tfd.MultivariateNormalDiag( + loc=params.loc, scale_diag=params.scale_diag).log_prob(action) + + def entropy(params: MVNDiagParams): + return tfd.MultivariateNormalDiag( + loc=params.loc, scale_diag=params.scale_diag).entropy() + + def sample(params: MVNDiagParams, key: networks_lib.PRNGKey): + return tfd.MultivariateNormalDiag( + loc=params.loc, scale_diag=params.scale_diag).sample(seed=key) + + def sample_eval(params: MVNDiagParams, key: networks_lib.PRNGKey): + del key + return tfd.MultivariateNormalDiag( + loc=params.loc, scale_diag=params.scale_diag).mode() + + return PPONetworks( + network=network, + log_prob=log_prob, + entropy=entropy, + sample=sample, + sample_eval=sample_eval) + + +def make_tanh_normal_ppo_networks( + network: networks_lib.FeedForwardNetwork) -> PPONetworks: + """Constructs a PPONetworks for Tanh MVN Diag policy from the FeedForwardNetwork. + + Args: + network: a transformed Haiku network (or equivalent in other libraries) that + takes in observations and returns the action distribution and value. + + Returns: + A PPONetworks instance with pure functions wrapping the input network. + """ + + def build_distribution(params: TanhNormalParams): + distribution = tfd.Normal(loc=params.loc, scale=params.scale) + distribution = tfd.Independent( + networks_lib.TanhTransformedDistribution(distribution), + reinterpreted_batch_ndims=1) + return distribution + + def log_prob(params: TanhNormalParams, action): + distribution = build_distribution(params) + return distribution.log_prob(action) + + def entropy(params: TanhNormalParams): + distribution = build_distribution(params) + return distribution.entropy(seed=jax.random.PRNGKey(42)) + + def sample(params: TanhNormalParams, key: networks_lib.PRNGKey): + distribution = build_distribution(params) + return distribution.sample(seed=key) + + def sample_eval(params: TanhNormalParams, key: networks_lib.PRNGKey): + del key + distribution = build_distribution(params) + return distribution.mode() + + return PPONetworks( + network=network, + log_prob=log_prob, + entropy=entropy, + sample=sample, + sample_eval=sample_eval) + + +def make_discrete_networks( + environment_spec: specs.EnvironmentSpec, + hidden_layer_sizes: Sequence[int] = (512,), + use_conv: bool = False, +) -> PPONetworks: + """Creates networks used by the agent for discrete action environments. + + Args: + environment_spec: Environment spec used to define number of actions. + hidden_layer_sizes: Network definition. + use_conv: Whether to use a conv or MLP feature extractor. + Returns: + PPONetworks + """ + + num_actions = environment_spec.actions.num_values + + def forward_fn(inputs): + layers = [] + if use_conv: + layers.extend([networks_lib.AtariTorso()]) + layers.extend([hk.nets.MLP(hidden_layer_sizes, activate_final=True)]) + trunk = hk.Sequential(layers) + h = trunk(inputs) + logits = hk.Linear(num_actions)(h) + values = hk.Linear(1)(h) + values = jnp.squeeze(values, axis=-1) + return (CategoricalParams(logits=logits), values) + + forward_fn = hk.without_apply_rng(hk.transform(forward_fn)) + dummy_obs = utils.zeros_like(environment_spec.observations) + dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. + network = networks_lib.FeedForwardNetwork( + lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply) + # Create PPONetworks to add functionality required by the agent. + return make_categorical_ppo_networks(network) # pylint:disable=undefined-variable + + +def make_categorical_ppo_networks( + network: networks_lib.FeedForwardNetwork) -> PPONetworks: + """Constructs a PPONetworks for Categorical Policy from FeedForwardNetwork. + + Args: + network: a transformed Haiku network (or equivalent in other libraries) that + takes in observations and returns the action distribution and value. + + Returns: + A PPONetworks instance with pure functions wrapping the input network. + """ + + def log_prob(params: CategoricalParams, action): + return tfd.Categorical(logits=params.logits).log_prob(action) + + def entropy(params: CategoricalParams): + return tfd.Categorical(logits=params.logits).entropy() + + def sample(params: CategoricalParams, key: networks_lib.PRNGKey): + return tfd.Categorical(logits=params.logits).sample(seed=key) + + def sample_eval(params: CategoricalParams, key: networks_lib.PRNGKey): + del key + return tfd.Categorical(logits=params.logits).mode() + + return PPONetworks( + network=network, + log_prob=log_prob, + entropy=entropy, + sample=sample, + sample_eval=sample_eval) + + +def make_continuous_networks( + environment_spec: specs.EnvironmentSpec, + policy_layer_sizes: Sequence[int] = (64, 64), + value_layer_sizes: Sequence[int] = (64, 64), + use_tanh_gaussian_policy: bool = True, +) -> PPONetworks: + """Creates PPONetworks to be used for continuous action environments.""" + + # Get total number of action dimensions from action spec. + num_dimensions = np.prod(environment_spec.actions.shape, dtype=int) + + def forward_fn(inputs: networks_lib.Observation): + + def _policy_network(obs: networks_lib.Observation): + h = utils.batch_concat(obs) + h = hk.nets.MLP(policy_layer_sizes, activate_final=True)(h) + + # tfd distributions have a weird bug in jax when vmapping is used, so the + # safer implementation in general is for the policy network to output the + # distribution parameters, and for the distribution to be constructed + # in a method such as make_ppo_networks above + if not use_tanh_gaussian_policy: + # Following networks_lib.MultivariateNormalDiagHead + init_scale = 0.3 + min_scale = 1e-6 + w_init = hk.initializers.VarianceScaling(1e-4) + b_init = hk.initializers.Constant(0.) + loc_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) + scale_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) + + loc = loc_layer(h) + scale = jax.nn.softplus(scale_layer(h)) + scale *= init_scale / jax.nn.softplus(0.) + scale += min_scale + + return MVNDiagParams(loc=loc, scale_diag=scale) + + # Following networks_lib.NormalTanhDistribution + min_scale = 1e-3 + w_init = hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform') + b_init = hk.initializers.Constant(0.) + loc_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) + scale_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) + + loc = loc_layer(h) + scale = scale_layer(h) + scale = jax.nn.softplus(scale) + min_scale + + return TanhNormalParams(loc=loc, scale=scale) + + value_network = hk.Sequential([ + utils.batch_concat, + hk.nets.MLP(value_layer_sizes, activate_final=True), + hk.Linear(1), + lambda x: jnp.squeeze(x, axis=-1) + ]) + + policy_output = _policy_network(inputs) + value = value_network(inputs) + return (policy_output, value) + + # Transform into pure functions. + forward_fn = hk.without_apply_rng(hk.transform(forward_fn)) + + dummy_obs = utils.zeros_like(environment_spec.observations) + dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. + network = networks_lib.FeedForwardNetwork( + lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply) + + # Create PPONetworks to add functionality required by the agent. + + if not use_tanh_gaussian_policy: + return make_mvn_diag_ppo_networks(network) + + return make_tanh_normal_ppo_networks(network) diff --git a/acme/acme/agents/jax/pwil/README.md b/acme/acme/agents/jax/pwil/README.md new file mode 100644 index 00000000..05f7c291 --- /dev/null +++ b/acme/acme/agents/jax/pwil/README.md @@ -0,0 +1,14 @@ +# PWIL + +This folder contains an implementation of the PWIL algorithm +([R.Dadashi et al., 2020]). + +The description of PWIL in ([R.Dadashi et al., 2020]) leaves the behavior +unspecified when the episode lengths are not fixed in advance. Here, we assign +zero reward when a trajectory exceeds the desired length, and keep the partial +return unaffected when a trajectory is shorter than the desired length. + +We prefill the replay buffer in a concurrent thread of the learner, to avoid +potential Reverb deadlocks. + +[R.Dadashi et al., 2020]: https://arxiv.org/abs/2006.04678 diff --git a/acme/acme/agents/jax/pwil/__init__.py b/acme/acme/agents/jax/pwil/__init__.py new file mode 100644 index 00000000..95b492f3 --- /dev/null +++ b/acme/acme/agents/jax/pwil/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PWIL agent.""" + +from acme.agents.jax.pwil.builder import PWILBuilder +from acme.agents.jax.pwil.config import PWILConfig +from acme.agents.jax.pwil.config import PWILDemonstrations diff --git a/acme/acme/agents/jax/pwil/adder.py b/acme/acme/agents/jax/pwil/adder.py new file mode 100644 index 00000000..3aba5ec1 --- /dev/null +++ b/acme/acme/agents/jax/pwil/adder.py @@ -0,0 +1,49 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reward-substituting adder wrapper.""" + +from acme import adders +from acme import types +from acme.agents.jax.pwil import rewarder +import dm_env + + +class PWILAdder(adders.Adder): + """Adder wrapper substituting PWIL rewards.""" + + def __init__(self, direct_rl_adder: adders.Adder, + pwil_rewarder: rewarder.WassersteinDistanceRewarder): + self._adder = direct_rl_adder + self._rewarder = pwil_rewarder + self._latest_observation = None + + def add_first(self, timestep: dm_env.TimeStep): + self._rewarder.reset() + self._latest_observation = timestep.observation + self._adder.add_first(timestep) + + def add(self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + extras: types.NestedArray = ()): + updated_timestep = next_timestep._replace( + reward=self._rewarder.append_and_compute_reward( + observation=self._latest_observation, action=action)) + self._latest_observation = next_timestep.observation + self._adder.add(action, updated_timestep, extras) + + def reset(self): + self._latest_observation = None + self._adder.reset() diff --git a/acme/acme/agents/jax/pwil/builder.py b/acme/acme/agents/jax/pwil/builder.py new file mode 100644 index 00000000..e77feeb6 --- /dev/null +++ b/acme/acme/agents/jax/pwil/builder.py @@ -0,0 +1,203 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PWIL agent implementation, using JAX.""" + +import threading +from typing import Callable, Generic, Iterator, List, Optional, Sequence + +from acme import adders +from acme import core +from acme import specs +from acme import types +from acme.agents.jax import builders +from acme.agents.jax.pwil import adder as pwil_adder +from acme.agents.jax.pwil import config as pwil_config +from acme.agents.jax.pwil import rewarder +from acme.jax import networks as networks_lib +from acme.jax.imitation_learning_types import DirectPolicyNetwork, DirectRLNetworks # pylint: disable=g-multiple-import +from acme.jax.types import PRNGKey +from acme.utils import counting +from acme.utils import loggers +import dm_env +import numpy as np +import reverb + + +def _prefill_with_demonstrations(adder: adders.Adder, + demonstrations: Sequence[types.Transition], + reward: Optional[float], + min_num_transitions: int = 0) -> None: + """Fill the adder's replay buffer with expert transitions. + + Assumes that the demonstrations dataset stores transitions in order. + + Args: + adder: the agent which adds the demonstrations. + demonstrations: the expert demonstrations to iterate over. + reward: if non-None, populates the environment reward entry of transitions. + min_num_transitions: the lower bound on transitions processed, the dataset + will be iterated over multiple times if needed. Once at least + min_num_transitions are added, the processing is interrupted at the + nearest episode end. + """ + if not demonstrations: + return + + reward = np.float32(reward) if reward is not None else reward + remaining_transitions = min_num_transitions + step_type = None + action = None + ts = dm_env.TimeStep(None, None, None, None) # Unused. + while remaining_transitions > 0: + # In case we share the adder or demonstrations don't end with + # end-of-episode, reset the adder prior to add_first. + adder.reset() + for transition_num, transition in enumerate(demonstrations): + remaining_transitions -= 1 + discount = np.float32(1.0) + ts_reward = reward if reward is not None else transition.reward + if step_type == dm_env.StepType.LAST or transition_num == 0: + ts = dm_env.TimeStep(dm_env.StepType.FIRST, ts_reward, discount, + transition.observation) + adder.add_first(ts) + + observation = transition.next_observation + action = transition.action + if transition.discount == 0. or transition_num == len(demonstrations) - 1: + step_type = dm_env.StepType.LAST + discount = np.float32(0.0) + else: + step_type = dm_env.StepType.MID + ts = dm_env.TimeStep(step_type, ts_reward, discount, observation) + adder.add(action, ts) + if remaining_transitions <= 0: + # Note: we could check `step_type == dm_env.StepType.LAST` to stop at an + # episode end if possible. + break + + # Explicitly finalize the Reverb client writes. + adder.reset() + + +class PWILBuilder(builders.ActorLearnerBuilder[DirectRLNetworks, + DirectPolicyNetwork, + reverb.ReplaySample], + Generic[DirectRLNetworks, DirectPolicyNetwork]): + """PWIL Agent builder.""" + + def __init__(self, + rl_agent: builders.ActorLearnerBuilder[DirectRLNetworks, + DirectPolicyNetwork, + reverb.ReplaySample], + config: pwil_config.PWILConfig, + demonstrations_fn: Callable[[], pwil_config.PWILDemonstrations]): + """Initialize the agent. + + Args: + rl_agent: the standard RL algorithm. + config: PWIL-specific configuration. + demonstrations_fn: A function that returns an iterator over contiguous + demonstration transitions, and the average demonstration episode length. + """ + self._rl_agent = rl_agent + self._config = config + self._demonstrations_fn = demonstrations_fn + super().__init__() + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: DirectRLNetworks, + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + return self._rl_agent.make_learner( + random_key=random_key, + networks=networks, + dataset=dataset, + logger_fn=logger_fn, + environment_spec=environment_spec, + replay_client=replay_client, + counter=counter) + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: DirectPolicyNetwork, + ) -> List[reverb.Table]: + return self._rl_agent.make_replay_tables(environment_spec, policy) + + def make_dataset_iterator( + self, + replay_client: reverb.Client) -> Optional[Iterator[reverb.ReplaySample]]: + # make_dataset_iterator is only called once (per learner), to pass the + # iterator to make_learner. By using adders we ensure the transition types + # (e.g. n-step transitions) that the direct RL agent expects. + if self._config.num_transitions_rb > 0: + + def prefill_thread(): + # Populating the replay buffer with the direct RL agent guarantees that + # a constant reward will be used, not the imitation reward. + prefill_reward = ( + self._config.alpha + if self._config.prefill_constant_reward else None) + _prefill_with_demonstrations( + adder=self._rl_agent.make_adder(replay_client, None, None), + demonstrations=list(self._demonstrations_fn().demonstrations), + min_num_transitions=self._config.num_transitions_rb, + reward=prefill_reward) + # Populate the replay buffer in a separate thread, so that the learner + # can sample from the buffer, to avoid blocking on the buffer being full. + threading.Thread(target=prefill_thread, daemon=True).start() + + return self._rl_agent.make_dataset_iterator(replay_client) + + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[DirectPolicyNetwork], + ) -> Optional[adders.Adder]: + """Creates the adder substituting imitation reward.""" + pwil_demonstrations = self._demonstrations_fn() + return pwil_adder.PWILAdder( + direct_rl_adder=self._rl_agent.make_adder(replay_client, + environment_spec, policy), + pwil_rewarder=rewarder.WassersteinDistanceRewarder( + demonstrations_it=pwil_demonstrations.demonstrations, + episode_length=pwil_demonstrations.episode_length, + use_actions_for_distance=self._config.use_actions_for_distance, + alpha=self._config.alpha, + beta=self._config.beta)) + + def make_actor( + self, + random_key: PRNGKey, + policy: DirectPolicyNetwork, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> core.Actor: + return self._rl_agent.make_actor(random_key, policy, environment_spec, + variable_source, adder) + + def make_policy(self, + networks: DirectRLNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False) -> DirectPolicyNetwork: + return self._rl_agent.make_policy(networks, environment_spec, evaluation) diff --git a/acme/acme/agents/jax/pwil/config.py b/acme/acme/agents/jax/pwil/config.py new file mode 100644 index 00000000..83cdd120 --- /dev/null +++ b/acme/acme/agents/jax/pwil/config.py @@ -0,0 +1,55 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PWIL config.""" +import dataclasses +from typing import Iterator + +from acme import types + + +@dataclasses.dataclass +class PWILConfig: + """Configuration options for PWIL. + + The default values correspond to the experiment setup from the PWIL + publication http://arxiv.org/abs/2006.04678. + """ + + # Number of transitions to fill the replay buffer with for pretraining. + num_transitions_rb: int = 50000 + + # If False, uses only observations for computing the distance; if True, also + # uses the actions. + use_actions_for_distance: bool = True + + # Scaling for the reward function, see equation (6) in + # http://arxiv.org/abs/2006.04678. + alpha: float = 5. + + # Controls the kernel size of the reward function, see equation (6) + # in http://arxiv.org/abs/2006.04678. + beta: float = 5. + + # When False, uses the reward signal from the dataset during prefilling. + prefill_constant_reward: bool = True + + num_sgd_steps_per_step: int = 1 + + +@dataclasses.dataclass +class PWILDemonstrations: + """Unbatched, unshuffled transitions with approximate episode length.""" + demonstrations: Iterator[types.Transition] + episode_length: int diff --git a/acme/acme/agents/jax/pwil/rewarder.py b/acme/acme/agents/jax/pwil/rewarder.py new file mode 100644 index 00000000..b94a41bd --- /dev/null +++ b/acme/acme/agents/jax/pwil/rewarder.py @@ -0,0 +1,157 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Rewarder class implementation.""" + +from typing import Iterator + +from acme import types +import jax +import jax.numpy as jnp +import numpy as np + + +class WassersteinDistanceRewarder: + """Computes PWIL rewards along a trajectory. + + The rewards measure similarity to the demonstration transitions and are based + on a greedy approximation to the Wasserstein distance between trajectories. + """ + + def __init__(self, + demonstrations_it: Iterator[types.Transition], + episode_length: int, + use_actions_for_distance: bool = False, + alpha: float = 5., + beta: float = 5.): + """Initializes the rewarder. + + Args: + demonstrations_it: An iterator over acme.types.Transition. + episode_length: a target episode length (policies will be encouraged by + the imitation reward to have that length). + use_actions_for_distance: whether to use action to compute reward. + alpha: float scaling the reward function. + beta: float controling the kernel size of the reward function. + """ + self._episode_length = episode_length + + self._use_actions_for_distance = use_actions_for_distance + self._vectorized_demonstrations = self._vectorize(demonstrations_it) + + # Observations and actions are flat. + atom_dims = self._vectorized_demonstrations.shape[1] + self._reward_sigma = beta * self._episode_length / np.sqrt(atom_dims) + self._reward_scale = alpha + + self._std = np.std(self._vectorized_demonstrations, axis=0, dtype='float64') + # The std is set to 1 if the observation values are below a threshold. + # This prevents normalizing observation values that are constant (which can + # be problematic with e.g. demonstrations coming from a different version + # of the environment and where the constant values are slightly different). + self._std = (self._std < 1e-6) + self._std + + self.expert_atoms = self._vectorized_demonstrations / self._std + self._compute_norm = jax.jit(lambda a, b: jnp.linalg.norm(a - b, axis=1), + device=jax.devices('cpu')[0]) + + def _vectorize(self, + demonstrations_it: Iterator[types.Transition]) -> np.ndarray: + """Converts filtered expert demonstrations to numpy array. + + Args: + demonstrations_it: list of expert demonstrations + + Returns: + numpy array with dimension: + [num_expert_transitions, dim_observation] if not use_actions_for_distance + [num_expert_transitions, (dim_observation + dim_action)] otherwise + """ + if self._use_actions_for_distance: + demonstrations = [ + np.concatenate([t.observation, t.action]) for t in demonstrations_it + ] + else: + demonstrations = [t.observation for t in demonstrations_it] + return np.array(demonstrations) + + def reset(self) -> None: + """Makes all expert transitions available and initialize weights.""" + num_expert_atoms = len(self.expert_atoms) + self._all_expert_weights_zero = False + self.expert_weights = np.ones(num_expert_atoms) / num_expert_atoms + + def append_and_compute_reward(self, observation: jnp.ndarray, + action: jnp.ndarray) -> np.float32: + """Computes reward and updates state, advancing it along a trajectory. + + Subsequent calls to append_and_compute_reward assume inputs are subsequent + trajectory points. + + Args: + observation: observation on a trajectory, to compare with the expert + demonstration(s). + action: the action following the observation on the trajectory. + + Returns: + the reward value: the return contribution from the trajectory point. + + """ + # If we run out of demonstrations, penalize further action. + if self._all_expert_weights_zero: + return np.float32(0.) + + # Scale observation and action. + if self._use_actions_for_distance: + agent_atom = np.concatenate([observation, action]) + else: + agent_atom = observation + agent_atom /= self._std + + cost = 0. + # A special marker for records with zero expert weight. Has to be large so + # that argmin will not return it. + DELETED = 1e10 # pylint: disable=invalid-name + # As we match the expert's weights with the agent's weights, we might + # raise an error due to float precision, we substract a small epsilon from + # the agent's weights to prevent that. + weight = 1. / self._episode_length - 1e-6 + norms = np.array(self._compute_norm(self.expert_atoms, agent_atom)) + # We need to mask out states with zero weight, so that 'argmin' would not + # return them. + adjusted_norms = (1 - np.sign(self.expert_weights)) * DELETED + norms + while weight > 0: + # Get closest expert state action to agent's state action. + argmin = adjusted_norms.argmin() + effective_weight = min(weight, self.expert_weights[argmin]) + + if adjusted_norms[argmin] >= DELETED: + self._all_expert_weights_zero = True + break + + # Update cost and weights. + weight -= effective_weight + self.expert_weights[argmin] -= effective_weight + cost += effective_weight * norms[argmin] + adjusted_norms[argmin] = DELETED + + if weight > 0: + # We have a 'partial' cost if we ran out of demonstrations in the reward + # computation loop. We assign a high cost (infinite) in this case which + # makes the reward equal to 0. + reward = np.array(0.) + else: + reward = self._reward_scale * np.exp(-self._reward_sigma * cost) + + return reward.astype('float32') diff --git a/acme/acme/agents/jax/r2d2/__init__.py b/acme/acme/agents/jax/r2d2/__init__.py new file mode 100644 index 00000000..3a285bab --- /dev/null +++ b/acme/acme/agents/jax/r2d2/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementation of an R2D2 agent.""" + +from acme.agents.jax.r2d2.actor import EpsilonRecurrentPolicy +from acme.agents.jax.r2d2.actor import make_behavior_policy +from acme.agents.jax.r2d2.builder import R2D2Builder +from acme.agents.jax.r2d2.config import R2D2Config +from acme.agents.jax.r2d2.learning import R2D2Learner +from acme.agents.jax.r2d2.learning import R2D2ReplaySample +from acme.agents.jax.r2d2.networks import make_atari_networks +from acme.agents.jax.r2d2.networks import make_networks +from acme.agents.jax.r2d2.networks import R2D2Networks diff --git a/acme/acme/agents/jax/r2d2/actor.py b/acme/acme/agents/jax/r2d2/actor.py new file mode 100644 index 00000000..e3b4137d --- /dev/null +++ b/acme/acme/agents/jax/r2d2/actor.py @@ -0,0 +1,109 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""R2D2 actor.""" + +from typing import Callable, Generic, Mapping, Optional, Tuple + +from acme import types +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax.r2d2 import config as r2d2_config +from acme.agents.jax.r2d2 import networks as r2d2_networks +from acme.jax import networks as networks_lib +import chex +import jax +import jax.numpy as jnp +import numpy as np +import rlax + +Epsilon = float +R2D2Extras = Mapping[str, jnp.ndarray] +EpsilonRecurrentPolicy = Callable[[ + networks_lib.Params, networks_lib.PRNGKey, networks_lib + .Observation, actor_core_lib.RecurrentState, Epsilon +], Tuple[networks_lib.Action, actor_core_lib.RecurrentState]] + + +@chex.dataclass(frozen=True, mappable_dataclass=False) +class R2D2ActorState(Generic[actor_core_lib.RecurrentState]): + rng: networks_lib.PRNGKey + epsilon: jnp.ndarray + recurrent_state: actor_core_lib.RecurrentState + + +R2D2Policy = actor_core_lib.ActorCore[ + R2D2ActorState[actor_core_lib.RecurrentState], R2D2Extras] + + +def get_actor_core( + networks: r2d2_networks.R2D2Networks, + num_epsilons: Optional[int], + evaluation_epsilon: Optional[float] = None, +) -> R2D2Policy: + """Returns ActorCore for R2D2.""" + + if (not num_epsilons and evaluation_epsilon is None) or (num_epsilons and + evaluation_epsilon): + raise ValueError( + 'Exactly one of `num_epsilons` or `evaluation_epsilon` must be ' + f'specified. Received num_epsilon={num_epsilons} and ' + f'evaluation_epsilon={evaluation_epsilon}.') + + def select_action(params: networks_lib.Params, + observation: networks_lib.Observation, + state: R2D2ActorState[actor_core_lib.RecurrentState]): + rng, policy_rng = jax.random.split(state.rng) + + q_values, recurrent_state = networks.forward.apply(params, policy_rng, + observation, + state.recurrent_state) + action = rlax.epsilon_greedy(state.epsilon).sample(policy_rng, q_values) + + return action, R2D2ActorState(rng, state.epsilon, recurrent_state) + + def init( + rng: networks_lib.PRNGKey + ) -> R2D2ActorState[actor_core_lib.RecurrentState]: + rng, epsilon_rng, state_rng = jax.random.split(rng, 3) + if num_epsilons: + epsilon = jax.random.choice(epsilon_rng, + np.logspace(1, 8, num_epsilons, base=0.4)) + else: + epsilon = evaluation_epsilon + initial_core_state = networks.initial_state.apply(None, state_rng, None) + return R2D2ActorState(rng, epsilon, initial_core_state) + + def get_extras( + state: R2D2ActorState[actor_core_lib.RecurrentState]) -> R2D2Extras: + return {'core_state': state.recurrent_state} + + return actor_core_lib.ActorCore(init=init, select_action=select_action, + get_extras=get_extras) + + +# TODO(bshahr): Deprecate this in favour of R2D2Builder.make_policy. +def make_behavior_policy(networks: r2d2_networks.R2D2Networks, + config: r2d2_config.R2D2Config, + evaluation: bool = False) -> EpsilonRecurrentPolicy: + """Selects action according to the policy.""" + + def behavior_policy(params: networks_lib.Params, key: networks_lib.PRNGKey, + observation: types.NestedArray, + core_state: types.NestedArray, epsilon: float): + q_values, core_state = networks.forward.apply(params, key, observation, + core_state) + epsilon = config.evaluation_epsilon if evaluation else epsilon + return rlax.epsilon_greedy(epsilon).sample(key, q_values), core_state + + return behavior_policy diff --git a/acme/acme/agents/jax/r2d2/builder.py b/acme/acme/agents/jax/r2d2/builder.py new file mode 100644 index 00000000..8ce024a0 --- /dev/null +++ b/acme/acme/agents/jax/r2d2/builder.py @@ -0,0 +1,270 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""R2D2 Builder.""" +from typing import Generic, Iterator, List, Optional + +import acme +from acme import adders +from acme import core +from acme import specs +from acme.adders import reverb as adders_reverb +from acme.adders.reverb import base as reverb_base +from acme.adders.reverb import structured +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import actors +from acme.agents.jax import builders +from acme.agents.jax.r2d2 import actor as r2d2_actor +from acme.agents.jax.r2d2 import config as r2d2_config +from acme.agents.jax.r2d2 import learning as r2d2_learning +from acme.agents.jax.r2d2 import networks as r2d2_networks +from acme.datasets import reverb as datasets +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.jax import variable_utils +from acme.utils import counting +from acme.utils import loggers +import jax +import optax +import reverb +from reverb import structured_writer as sw +import tensorflow as tf +import tree + +# TODO(b/450949030): extrac the private functions to a library once other agents +# reuse them. + +# TODO(b/450949030): add support to add all the final subsequences of +# length < sequence_lenght at the end of the episode and pad them with zeros. +# We have to check if this requires moving _zero_pad to the adder. + + +def _build_sequence(length: int, + step_spec: reverb_base.Step) -> reverb_base.Trajectory: + """Constructs the sequence using only the first value of core_state.""" + step_dict = step_spec._asdict() + extras_dict = step_dict.pop('extras') + return reverb_base.Trajectory( + **tree.map_structure(lambda x: x[-length:], step_dict), + extras=tree.map_structure(lambda x: x[-length], extras_dict)) + + +def _zero_pad(sequence_length: int) -> datasets.Transform: + """Adds zero padding to the right so all samples have the same length.""" + + def _zero_pad_transform(sample: reverb.ReplaySample) -> reverb.ReplaySample: + trajectory: reverb_base.Trajectory = sample.data + + # Split steps and extras data (the extras won't be padded as they only + # contain one element) + trajectory_steps = trajectory._asdict() + trajectory_extras = trajectory_steps.pop('extras') + + unpadded_length = len(tree.flatten(trajectory_steps)[0]) + + # Do nothing if the sequence is already full. + if unpadded_length != sequence_length: + to_pad = sequence_length - unpadded_length + pad = lambda x: tf.pad(x, [[0, to_pad]] + [[0, 0]] * (len(x.shape) - 1)) + + trajectory_steps = tree.map_structure(pad, trajectory_steps) + + # Set the shape to be statically known, and checks it at runtime. + def _ensure_shape(x): + shape = tf.TensorShape([sequence_length]).concatenate(x.shape[1:]) + return tf.ensure_shape(x, shape) + + trajectory_steps = tree.map_structure(_ensure_shape, trajectory_steps) + return reverb.ReplaySample( + info=sample.info, + data=reverb_base.Trajectory( + **trajectory_steps, extras=trajectory_extras)) + + return _zero_pad_transform + + +def _make_adder_config(step_spec: reverb_base.Step, seq_len: int, + seq_period: int) -> list[sw.Config]: + return structured.create_sequence_config( + step_spec=step_spec, + sequence_length=seq_len, + period=seq_period, + end_of_episode_behavior=adders_reverb.EndBehavior.TRUNCATE, + sequence_pattern=_build_sequence) + + +class R2D2Builder(Generic[actor_core_lib.RecurrentState], + builders.ActorLearnerBuilder[r2d2_networks.R2D2Networks, + r2d2_actor.R2D2Policy, + r2d2_learning.R2D2ReplaySample]): + """R2D2 Builder. + + This is constructs all of the components for Recurrent Experience Replay in + Distributed Reinforcement Learning (Kapturowski et al.) + https://openreview.net/pdf?id=r1lyTjAqYX. + """ + + def __init__(self, config: r2d2_config.R2D2Config): + """Creates a R2D2 learner, a behavior policy and an eval actor.""" + self._config = config + self._sequence_length = ( + self._config.burn_in_length + self._config.trace_length + 1) + + @property + def _batch_size_per_device(self) -> int: + """Splits batch size across all learner devices evenly.""" + # TODO(bshahr): Using jax.device_count will not be valid when colocating + # learning and inference. + return self._config.batch_size // jax.device_count() + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: r2d2_networks.R2D2Networks, + dataset: Iterator[r2d2_learning.R2D2ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec + + # The learner updates the parameters (and initializes them). + return r2d2_learning.R2D2Learner( + unroll=networks.unroll, + initial_state=networks.initial_state, + batch_size=self._batch_size_per_device, + random_key=random_key, + burn_in_length=self._config.burn_in_length, + discount=self._config.discount, + importance_sampling_exponent=( + self._config.importance_sampling_exponent), + max_priority_weight=self._config.max_priority_weight, + target_update_period=self._config.target_update_period, + iterator=dataset, + optimizer=optax.adam(self._config.learning_rate), + bootstrap_n=self._config.bootstrap_n, + tx_pair=self._config.tx_pair, + clip_rewards=self._config.clip_rewards, + replay_client=replay_client, + counter=counter, + logger=logger_fn('learner')) + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: r2d2_actor.R2D2Policy, + ) -> List[reverb.Table]: + """Create tables to insert data into.""" + dummy_actor_state = policy.init(jax.random.PRNGKey(0)) + extras_spec = policy.get_extras(dummy_actor_state) + step_spec = structured.create_step_spec( + environment_spec=environment_spec, extras_spec=extras_spec) + if self._config.samples_per_insert: + samples_per_insert_tolerance = ( + self._config.samples_per_insert_tolerance_rate * + self._config.samples_per_insert) + error_buffer = self._config.min_replay_size * samples_per_insert_tolerance + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._config.min_replay_size, + samples_per_insert=self._config.samples_per_insert, + error_buffer=error_buffer) + else: + limiter = reverb.rate_limiters.MinSize(self._config.min_replay_size) + return [ + reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Prioritized( + self._config.priority_exponent), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_replay_size, + rate_limiter=limiter, + signature=sw.infer_signature( + configs=_make_adder_config(step_spec, self._sequence_length, + self._config.sequence_period), + step_spec=step_spec)) + ] + + def make_dataset_iterator( + self, + replay_client: reverb.Client) -> Iterator[r2d2_learning.R2D2ReplaySample]: + """Create a dataset iterator to use for learning/updating the agent.""" + batch_size_per_learner = self._config.batch_size // jax.process_count() + dataset = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=replay_client.server_address, + batch_size=self._batch_size_per_device, + num_parallel_calls=None, + max_in_flight_samples_per_worker=2 * batch_size_per_learner, + postprocess=_zero_pad(self._sequence_length), + ) + + # We split samples in two outputs, the keys which need to be kept on-host + # since int64 arrays are not supported in TPUs, and the entire sample + # separately so it can be sent to the sgd_step method. + def split_sample(sample: reverb.ReplaySample) -> utils.PrefetchingSplit: + return utils.PrefetchingSplit(host=sample.info.key, device=sample) + + return utils.multi_device_put( + dataset.as_numpy_iterator(), + devices=jax.local_devices(), + split_fn=split_sample) + + def make_adder( + self, replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[r2d2_actor.R2D2Policy]) -> Optional[adders.Adder]: + """Create an adder which records data generated by the actor/environment.""" + if environment_spec is None or policy is None: + raise ValueError('`environment_spec` and `policy` cannot be None.') + dummy_actor_state = policy.init(jax.random.PRNGKey(0)) + extras_spec = policy.get_extras(dummy_actor_state) + step_spec = structured.create_step_spec( + environment_spec=environment_spec, extras_spec=extras_spec) + return structured.StructuredAdder( + client=replay_client, + max_in_flight_items=5, + configs=_make_adder_config(step_spec, self._sequence_length, + self._config.sequence_period), + step_spec=step_spec) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: r2d2_actor.R2D2Policy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> acme.Actor: + del environment_spec + # Create variable client. + variable_client = variable_utils.VariableClient( + variable_source, + key='actor_variables', + update_period=self._config.variable_update_period) + + return actors.GenericActor( + policy, random_key, variable_client, adder, backend='cpu') + + def make_policy(self, + networks: r2d2_networks.R2D2Networks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False) -> r2d2_actor.R2D2Policy: + if evaluation: + return r2d2_actor.get_actor_core( + networks, + num_epsilons=None, + evaluation_epsilon=self._config.evaluation_epsilon) + else: + return r2d2_actor.get_actor_core(networks, self._config.num_epsilons) diff --git a/acme/acme/agents/jax/r2d2/config.py b/acme/acme/agents/jax/r2d2/config.py new file mode 100644 index 00000000..2fc52d90 --- /dev/null +++ b/acme/acme/agents/jax/r2d2/config.py @@ -0,0 +1,53 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PPO config.""" +import dataclasses + +from acme.adders import reverb as adders_reverb +import rlax + + +@dataclasses.dataclass +class R2D2Config: + """Configuration options for R2D2 agent.""" + discount: float = 0.997 + target_update_period: int = 2500 + evaluation_epsilon: float = 0. + num_epsilons: int = 256 + variable_update_period: int = 400 + + # Learner options + burn_in_length: int = 40 + trace_length: int = 80 + sequence_period: int = 40 + learning_rate: float = 1e-3 + bootstrap_n: int = 5 + clip_rewards: bool = False + tx_pair: rlax.TxPair = rlax.SIGNED_HYPERBOLIC_PAIR + + # Replay options + samples_per_insert_tolerance_rate: float = 0.1 + samples_per_insert: float = 4.0 + min_replay_size: int = 50_000 + max_replay_size: int = 100_000 + batch_size: int = 64 + prefetch_size: int = 2 + num_parallel_calls: int = 16 + replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE + + # Priority options + importance_sampling_exponent: float = 0.6 + priority_exponent: float = 0.9 + max_priority_weight: float = 0.9 diff --git a/acme/acme/agents/jax/r2d2/learning.py b/acme/acme/agents/jax/r2d2/learning.py new file mode 100644 index 00000000..f8181e3e --- /dev/null +++ b/acme/acme/agents/jax/r2d2/learning.py @@ -0,0 +1,279 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""R2D2 learner implementation.""" + +import functools +import time +from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple + +from absl import logging +import acme +from acme.adders import reverb as adders +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.utils import async_utils +from acme.utils import counting +from acme.utils import loggers +import jax +import jax.numpy as jnp +import optax +import reverb +import rlax +import tree + +_PMAP_AXIS_NAME = 'data' +# This type allows splitting a sample between the host and device, which avoids +# putting item keys (uint64) on device for the purposes of priority updating. +R2D2ReplaySample = utils.PrefetchingSplit + + +class TrainingState(NamedTuple): + """Holds the agent's training state.""" + params: networks_lib.Params + target_params: networks_lib.Params + opt_state: optax.OptState + steps: int + random_key: networks_lib.PRNGKey + + +class R2D2Learner(acme.Learner): + """R2D2 learner.""" + + def __init__(self, + unroll: networks_lib.FeedForwardNetwork, + initial_state: networks_lib.FeedForwardNetwork, + batch_size: int, + random_key: networks_lib.PRNGKey, + burn_in_length: int, + discount: float, + importance_sampling_exponent: float, + max_priority_weight: float, + target_update_period: int, + iterator: Iterator[R2D2ReplaySample], + optimizer: optax.GradientTransformation, + bootstrap_n: int = 5, + tx_pair: rlax.TxPair = rlax.SIGNED_HYPERBOLIC_PAIR, + clip_rewards: bool = False, + max_abs_reward: float = 1., + use_core_state: bool = True, + prefetch_size: int = 2, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None): + """Initializes the learner.""" + + random_key, key_initial_1, key_initial_2 = jax.random.split(random_key, 3) + initial_state_params = initial_state.init(key_initial_1, batch_size) + initial_state = initial_state.apply(initial_state_params, key_initial_2, + batch_size) + + def loss( + params: networks_lib.Params, + target_params: networks_lib.Params, + key_grad: networks_lib.PRNGKey, + sample: reverb.ReplaySample + ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: + """Computes mean transformed N-step loss for a batch of sequences.""" + + # Get core state & warm it up on observations for a burn-in period. + if use_core_state: + # Replay core state. + # NOTE: We may need to recover the type of the hk.LSTMState if the user + # specifies a dynamically unrolled RNN as it will strictly enforce the + # match between input/output state types. + online_state = utils.maybe_recover_lstm_type( + sample.data.extras.get('core_state')) + else: + online_state = initial_state + target_state = online_state + + # Convert sample data to sequence-major format [T, B, ...]. + data = utils.batch_to_sequence(sample.data) + + # Maybe burn the core state in. + if burn_in_length: + burn_obs = jax.tree_map(lambda x: x[:burn_in_length], data.observation) + key_grad, key1, key2 = jax.random.split(key_grad, 3) + _, online_state = unroll.apply(params, key1, burn_obs, online_state) + _, target_state = unroll.apply(target_params, key2, burn_obs, + target_state) + + # Only get data to learn on from after the end of the burn in period. + data = jax.tree_map(lambda seq: seq[burn_in_length:], data) + + # Unroll on sequences to get online and target Q-Values. + key1, key2 = jax.random.split(key_grad) + online_q, _ = unroll.apply(params, key1, data.observation, online_state) + target_q, _ = unroll.apply(target_params, key2, data.observation, + target_state) + + # Get value-selector actions from online Q-values for double Q-learning. + selector_actions = jnp.argmax(online_q, axis=-1) + # Preprocess discounts & rewards. + discounts = (data.discount * discount).astype(online_q.dtype) + rewards = data.reward + if clip_rewards: + rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward) + rewards = rewards.astype(online_q.dtype) + + # Get N-step transformed TD error and loss. + batch_td_error_fn = jax.vmap( + functools.partial( + rlax.transformed_n_step_q_learning, + n=bootstrap_n, + tx_pair=tx_pair), + in_axes=1, + out_axes=1) + # TODO(b/183945808): when this bug is fixed, truncations of actions, + # rewards, and discounts will no longer be necessary. + batch_td_error = batch_td_error_fn( + online_q[:-1], + data.action[:-1], + target_q[1:], + selector_actions[1:], + rewards[:-1], + discounts[:-1]) + batch_loss = 0.5 * jnp.square(batch_td_error).sum(axis=0) + + # Importance weighting. + probs = sample.info.probability + importance_weights = (1. / (probs + 1e-6)).astype(online_q.dtype) + importance_weights **= importance_sampling_exponent + importance_weights /= jnp.max(importance_weights) + mean_loss = jnp.mean(importance_weights * batch_loss) + + # Calculate priorities as a mixture of max and mean sequence errors. + abs_td_error = jnp.abs(batch_td_error).astype(online_q.dtype) + max_priority = max_priority_weight * jnp.max(abs_td_error, axis=0) + mean_priority = (1 - max_priority_weight) * jnp.mean(abs_td_error, axis=0) + priorities = (max_priority + mean_priority) + + return mean_loss, priorities + + def sgd_step( + state: TrainingState, + samples: reverb.ReplaySample + ) -> Tuple[TrainingState, jnp.ndarray, Dict[str, jnp.ndarray]]: + """Performs an update step, averaging over pmap replicas.""" + + # Compute loss and gradients. + grad_fn = jax.value_and_grad(loss, has_aux=True) + key, key_grad = jax.random.split(state.random_key) + (loss_value, priorities), gradients = grad_fn(state.params, + state.target_params, + key_grad, + samples) + + # Average gradients over pmap replicas before optimizer update. + gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME) + + # Apply optimizer updates. + updates, new_opt_state = optimizer.update(gradients, state.opt_state) + new_params = optax.apply_updates(state.params, updates) + + # Periodically update target networks. + steps = state.steps + 1 + target_params = optax.periodic_update(new_params, state.target_params, + steps, self._target_update_period) + + new_state = TrainingState( + params=new_params, + target_params=target_params, + opt_state=new_opt_state, + steps=steps, + random_key=key) + return new_state, priorities, {'loss': loss_value} + + def update_priorities( + keys_and_priorities: Tuple[jnp.ndarray, jnp.ndarray]): + keys, priorities = keys_and_priorities + keys, priorities = tree.map_structure( + # Fetch array and combine device and batch dimensions. + lambda x: utils.fetch_devicearray(x).reshape((-1,) + x.shape[2:]), + (keys, priorities)) + replay_client.mutate_priorities( # pytype: disable=attribute-error + table=adders.DEFAULT_PRIORITY_TABLE, + updates=dict(zip(keys, priorities))) + + # Internalise components, hyperparameters, logger, counter, and methods. + self._iterator = iterator + self._replay_client = replay_client + self._target_update_period = target_update_period + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + 'learner', + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + time_delta=1., + steps_key=self._counter.get_steps_key()) + + self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME) + self._async_priority_updater = async_utils.AsyncExecutor(update_priorities) + + # Initialise and internalise training state (parameters/optimiser state). + random_key, key_init = jax.random.split(random_key) + initial_params = unroll.init(key_init, initial_state) + opt_state = optimizer.init(initial_params) + + # Log how many parameters the network has. + sizes = tree.map_structure(jnp.size, initial_params) + logging.info('Total number of params: %d', + sum(tree.flatten(sizes.values()))) + + state = TrainingState( + params=initial_params, + target_params=initial_params, + opt_state=opt_state, + steps=jnp.array(0), + random_key=random_key) + # Replicate parameters. + self._state = utils.replicate_in_all_devices(state) + + def step(self): + prefetching_split = next(self._iterator) + # The split_sample method passed to utils.sharded_prefetch specifies what + # parts of the objects returned by the original iterator are kept in the + # host and what parts are prefetched on-device. + # In this case the host property of the prefetching split contains only the + # replay keys and the device property is the prefetched full original + # sample. + keys = prefetching_split.host + samples: reverb.ReplaySample = prefetching_split.device + + # Do a batch of SGD. + start = time.time() + self._state, priorities, metrics = self._sgd_step(self._state, samples) + # Take metrics from first replica. + metrics = utils.get_from_first_device(metrics) + # Update our counts and record it. + counts = self._counter.increment(steps=1, time_elapsed=time.time() - start) + + # Update priorities in replay. + if self._replay_client: + self._async_priority_updater.put((keys, priorities)) + + # Attempt to write logs. + self._logger.write({**metrics, **counts}) + + def get_variables(self, names: List[str]) -> List[networks_lib.Params]: + # Return first replica of parameters. + return [utils.get_from_first_device(self._state.params)] + + def save(self) -> TrainingState: + # Serialize only the first replica of parameters and optimizer state. + return jax.tree_map(utils.get_from_first_device, self._state) + + def restore(self, state: TrainingState): + self._state = utils.replicate_in_all_devices(state) diff --git a/acme/acme/agents/jax/r2d2/networks.py b/acme/acme/agents/jax/r2d2/networks.py new file mode 100644 index 00000000..ea2f40a9 --- /dev/null +++ b/acme/acme/agents/jax/r2d2/networks.py @@ -0,0 +1,91 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""R2D2 Networks.""" + +import dataclasses +from typing import Any, Optional + +from acme import specs +from acme.jax import networks as networks_lib +from acme.jax import types as jax_types +from acme.jax import utils +import haiku as hk +import jax + + +@dataclasses.dataclass +class R2D2Networks: + """Network and pure functions for the R2D2 agent..""" + forward: networks_lib.FeedForwardNetwork + unroll: networks_lib.FeedForwardNetwork + initial_state: networks_lib.FeedForwardNetwork + + +def make_networks(env_spec: specs.EnvironmentSpec, forward_fn: Any, + initial_state_fn: Any, unroll_fn: Any, + batch_size: int) -> R2D2Networks: + """Builds functional r2d2 network from recurrent model definitions.""" + del batch_size + + # Make networks purely functional. + forward_hk = hk.transform(forward_fn) + initial_state_hk = hk.transform(initial_state_fn) + unroll_hk = hk.transform(unroll_fn) + + # Define networks init functions. + def unroll_init_fn(rng: jax_types.PRNGKey, + initial_state: hk.LSTMState) -> hk.Params: + del initial_state + init_state_params_rng, init_state_rng, unroll_rng = jax.random.split(rng, 3) + init_state_params = initial_state_hk.init(init_state_params_rng) + dummy_initial_state = initial_state_hk.apply(init_state_params, + init_state_rng, 1) + dummy_obs = utils.zeros_like(env_spec.observations) + for _ in ('batch', 'time'): # Add time and batch dimensions. + dummy_obs = utils.add_batch_dim(dummy_obs) + return unroll_hk.init(unroll_rng, dummy_obs, dummy_initial_state) + + # Make FeedForwardNetworks. + forward = networks_lib.FeedForwardNetwork( + init=forward_hk.init, apply=forward_hk.apply) + unroll = networks_lib.FeedForwardNetwork( + init=unroll_init_fn, apply=unroll_hk.apply) + initial_state = networks_lib.FeedForwardNetwork(*initial_state_hk) + return R2D2Networks( + forward=forward, unroll=unroll, initial_state=initial_state) + + +def make_atari_networks(batch_size: int, + env_spec: specs.EnvironmentSpec) -> R2D2Networks: + """Builds default R2D2 networks for Atari games.""" + + def make_model() -> networks_lib.R2D2AtariNetwork: + return networks_lib.R2D2AtariNetwork(env_spec.actions.num_values) + + def forward_fn(x, s): + return make_model()(x, s) + + def initial_state_fn(batch_size: Optional[int] = None): + return make_model().initial_state(batch_size) + + def unroll_fn(inputs, state): + return make_model().unroll(inputs, state) + + return make_networks( + env_spec=env_spec, + forward_fn=forward_fn, + initial_state_fn=initial_state_fn, + unroll_fn=unroll_fn, + batch_size=batch_size) diff --git a/acme/acme/agents/jax/rnd/README.md b/acme/acme/agents/jax/rnd/README.md new file mode 100644 index 00000000..8b73121d --- /dev/null +++ b/acme/acme/agents/jax/rnd/README.md @@ -0,0 +1,12 @@ +# Random Network Distillation (RND) + +This folder contains an implementation of the RND algorithm +([Burda et al., 2018]) + +RND requires a RL algorithm to work, passed in as an `ActorLearnerBuilder`. + +By default this implementation ignores the original reward: the agent is trained +only on the intrinsic exploration reward. To also use extrinsic reward, +intrinsic and extrinsic reward weights can be passed into make_networks. + +[Burda et al., 2018]: https://arxiv.org/abs/1810.12894 diff --git a/acme/acme/agents/jax/rnd/__init__.py b/acme/acme/agents/jax/rnd/__init__.py new file mode 100644 index 00000000..cb09a0ec --- /dev/null +++ b/acme/acme/agents/jax/rnd/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RND agent.""" + +from acme.agents.jax.rnd.builder import RNDBuilder +from acme.agents.jax.rnd.config import RNDConfig +from acme.agents.jax.rnd.learning import rnd_loss +from acme.agents.jax.rnd.learning import rnd_update_step +from acme.agents.jax.rnd.learning import RNDLearner +from acme.agents.jax.rnd.learning import RNDTrainingState +from acme.agents.jax.rnd.networks import compute_rnd_reward +from acme.agents.jax.rnd.networks import make_networks +from acme.agents.jax.rnd.networks import rnd_reward_fn +from acme.agents.jax.rnd.networks import RNDNetworks diff --git a/acme/acme/agents/jax/rnd/builder.py b/acme/acme/agents/jax/rnd/builder.py new file mode 100644 index 00000000..f7d5c261 --- /dev/null +++ b/acme/acme/agents/jax/rnd/builder.py @@ -0,0 +1,135 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RND Builder.""" + +from typing import Callable, Generic, Iterator, List, Optional + +from acme import adders +from acme import core +from acme import specs +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import builders +from acme.agents.jax.rnd import config as rnd_config +from acme.agents.jax.rnd import learning as rnd_learning +from acme.agents.jax.rnd import networks as rnd_networks +from acme.jax import networks as networks_lib +from acme.jax.types import PolicyNetwork +from acme.utils import counting +from acme.utils import loggers +import jax +import optax +import reverb + + +class RNDBuilder(Generic[rnd_networks.DirectRLNetworks, PolicyNetwork], + builders.ActorLearnerBuilder[rnd_networks.RNDNetworks, + PolicyNetwork, + reverb.ReplaySample]): + """RND Builder.""" + + def __init__( + self, + rl_agent: builders.ActorLearnerBuilder[rnd_networks.DirectRLNetworks, + PolicyNetwork, + reverb.ReplaySample], + config: rnd_config.RNDConfig, + logger_fn: Callable[[], loggers.Logger] = lambda: None, + ): + """Implements a builder for RND using rl_agent as forward RL algorithm. + + Args: + rl_agent: The standard RL agent used by RND to optimize the generator. + config: A config with RND HPs. + logger_fn: a logger factory for the rl_agent's learner. + """ + self._rl_agent = rl_agent + self._config = config + self._logger_fn = logger_fn + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: rnd_networks.RNDNetworks[rnd_networks.DirectRLNetworks], + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + direct_rl_learner_key, rnd_learner_key = jax.random.split(random_key) + + counter = counter or counting.Counter() + direct_rl_counter = counting.Counter(counter, 'direct_rl') + + def direct_rl_learner_factory( + networks: rnd_networks.DirectRLNetworks, + dataset: Iterator[reverb.ReplaySample]) -> core.Learner: + return self._rl_agent.make_learner( + direct_rl_learner_key, + networks, + dataset, + logger_fn=lambda name: self._logger_fn(), + environment_spec=environment_spec, + replay_client=replay_client, + counter=direct_rl_counter) + + optimizer = optax.adam(learning_rate=self._config.predictor_learning_rate) + + return rnd_learning.RNDLearner( + direct_rl_learner_factory=direct_rl_learner_factory, + iterator=dataset, + optimizer=optimizer, + rnd_network=networks, + rng_key=rnd_learner_key, + is_sequence_based=self._config.is_sequence_based, + grad_updates_per_batch=self._config.num_sgd_steps_per_step, + counter=counter, + logger=logger_fn('learner')) + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: PolicyNetwork, + ) -> List[reverb.Table]: + return self._rl_agent.make_replay_tables(environment_spec, policy) + + def make_dataset_iterator( + self, + replay_client: reverb.Client) -> Optional[Iterator[reverb.ReplaySample]]: + return self._rl_agent.make_dataset_iterator(replay_client) + + def make_adder(self, replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[PolicyNetwork]) -> Optional[adders.Adder]: + return self._rl_agent.make_adder(replay_client, environment_spec, policy) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: PolicyNetwork, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> core.Actor: + return self._rl_agent.make_actor(random_key, policy, environment_spec, + variable_source, adder) + + def make_policy(self, + networks: rnd_networks.RNDNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False) -> actor_core_lib.FeedForwardPolicy: + """Construct the policy.""" + return self._rl_agent.make_policy(networks.direct_rl_networks, + environment_spec, evaluation) diff --git a/acme/acme/agents/jax/rnd/config.py b/acme/acme/agents/jax/rnd/config.py new file mode 100644 index 00000000..db50c843 --- /dev/null +++ b/acme/acme/agents/jax/rnd/config.py @@ -0,0 +1,30 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RND config.""" +import dataclasses + + +@dataclasses.dataclass +class RNDConfig: + """Configuration options for RND.""" + + # Learning rate for the predictor. + predictor_learning_rate: float = 1e-4 + + # If True, the direct rl algorithm is using the SequenceAdder data format. + is_sequence_based: bool = False + + # How many gradient updates to perform per step. + num_sgd_steps_per_step: int = 1 diff --git a/acme/acme/agents/jax/rnd/learning.py b/acme/acme/agents/jax/rnd/learning.py new file mode 100644 index 00000000..168c2132 --- /dev/null +++ b/acme/acme/agents/jax/rnd/learning.py @@ -0,0 +1,229 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RND learner implementation.""" + +import functools +import time +from typing import Any, Callable, Dict, Iterator, List, NamedTuple, Optional, Tuple + +import acme +from acme import types +from acme.agents.jax.rnd import networks as rnd_networks +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.utils import counting +from acme.utils import loggers +from acme.utils import reverb_utils +import jax +import jax.numpy as jnp +import optax +import reverb + + +class RNDTrainingState(NamedTuple): + """Contains training state for the learner.""" + optimizer_state: optax.OptState + params: networks_lib.Params + target_params: networks_lib.Params + steps: int + + +class GlobalTrainingState(NamedTuple): + """Contains training state of the RND learner.""" + rewarder_state: RNDTrainingState + learner_state: Any + + +RNDLoss = Callable[[networks_lib.Params, networks_lib.Params, types.Transition], + float] + + +def rnd_update_step( + state: RNDTrainingState, transitions: types.Transition, + loss_fn: RNDLoss, optimizer: optax.GradientTransformation +) -> Tuple[RNDTrainingState, Dict[str, jnp.ndarray]]: + """Run an update steps on the given transitions. + + Args: + state: The learner state. + transitions: Transitions to update on. + loss_fn: The loss function. + optimizer: The optimizer of the predictor network. + + Returns: + A new state and metrics. + """ + loss, grads = jax.value_and_grad(loss_fn)( + state.params, + state.target_params, + transitions=transitions) + + update, optimizer_state = optimizer.update(grads, state.optimizer_state) + params = optax.apply_updates(state.params, update) + + new_state = RNDTrainingState( + optimizer_state=optimizer_state, + params=params, + target_params=state.target_params, + steps=state.steps + 1, + ) + return new_state, {'rnd_loss': loss} + + +def rnd_loss( + predictor_params: networks_lib.Params, + target_params: networks_lib.Params, + transitions: types.Transition, + networks: rnd_networks.RNDNetworks, +) -> float: + """The Random Network Distillation loss. + + See https://arxiv.org/pdf/1810.12894.pdf A.2 + + Args: + predictor_params: Parameters of the predictor + target_params: Parameters of the target + transitions: Transitions to compute the loss on. + networks: RND networks + + Returns: + The MSE loss as a float. + """ + target_output = networks.target.apply(target_params, + transitions.observation, + transitions.action) + predictor_output = networks.predictor.apply(predictor_params, + transitions.observation, + transitions.action) + return jnp.mean(jnp.square(target_output - predictor_output)) + + +class RNDLearner(acme.Learner): + """RND learner.""" + + def __init__( + self, + direct_rl_learner_factory: Callable[[Any, Iterator[reverb.ReplaySample]], + acme.Learner], + iterator: Iterator[reverb.ReplaySample], + optimizer: optax.GradientTransformation, + rnd_network: rnd_networks.RNDNetworks, + rng_key: jnp.ndarray, + grad_updates_per_batch: int, + is_sequence_based: bool, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None): + self._is_sequence_based = is_sequence_based + + target_key, predictor_key = jax.random.split(rng_key) + target_params = rnd_network.target.init(target_key) + predictor_params = rnd_network.predictor.init(predictor_key) + optimizer_state = optimizer.init(predictor_params) + + self._state = RNDTrainingState( + optimizer_state=optimizer_state, + params=predictor_params, + target_params=target_params, + steps=0) + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + 'learner', + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + steps_key=self._counter.get_steps_key()) + + loss = functools.partial(rnd_loss, networks=rnd_network) + self._update = functools.partial(rnd_update_step, + loss_fn=loss, + optimizer=optimizer) + self._update = utils.process_multiple_batches(self._update, + grad_updates_per_batch) + self._update = jax.jit(self._update) + + self._get_reward = jax.jit( + functools.partial( + rnd_networks.compute_rnd_reward, networks=rnd_network)) + + # Generator expression that works the same as an iterator. + # https://pymbook.readthedocs.io/en/latest/igd.html#generator-expressions + updated_iterator = (self._process_sample(sample) for sample in iterator) + + self._direct_rl_learner = direct_rl_learner_factory( + rnd_network.direct_rl_networks, updated_iterator) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + def _process_sample(self, sample: reverb.ReplaySample) -> reverb.ReplaySample: + """Uses the replay sample to train and update its reward. + + Args: + sample: Replay sample to train on. + + Returns: + The sample replay sample with an updated reward. + """ + transitions = reverb_utils.replay_sample_to_sars_transition( + sample, is_sequence=self._is_sequence_based) + self._state, metrics = self._update(self._state, transitions) + rewards = self._get_reward(self._state.params, self._state.target_params, + transitions) + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Increment counts and record the current time + counts = self._counter.increment(steps=1, walltime=elapsed_time) + + # Attempts to write the logs. + self._logger.write({**metrics, **counts}) + + return sample._replace(data=sample.data._replace(reward=rewards)) + + def step(self): + self._direct_rl_learner.step() + + def get_variables(self, names: List[str]) -> List[Any]: + rnd_variables = { + 'target_params': self._state.target_params, + 'predictor_params': self._state.params + } + + learner_names = [name for name in names if name not in rnd_variables] + learner_dict = {} + if learner_names: + learner_dict = dict( + zip(learner_names, + self._direct_rl_learner.get_variables(learner_names))) + + variables = [ + rnd_variables.get(name, learner_dict.get(name, None)) for name in names + ] + return variables + + def save(self) -> GlobalTrainingState: + return GlobalTrainingState( + rewarder_state=self._state, + learner_state=self._direct_rl_learner.save()) + + def restore(self, state: GlobalTrainingState): + self._state = state.rewarder_state + self._direct_rl_learner.restore(state.learner_state) diff --git a/acme/acme/agents/jax/rnd/networks.py b/acme/acme/agents/jax/rnd/networks.py new file mode 100644 index 00000000..c81ebc1e --- /dev/null +++ b/acme/acme/agents/jax/rnd/networks.py @@ -0,0 +1,124 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Networks definitions for the BC agent.""" + +import dataclasses +import functools +from typing import Callable, Generic, Tuple, TypeVar + +from acme import specs +from acme import types +from acme.jax import networks as networks_lib +from acme.jax import utils +import haiku as hk +import jax.numpy as jnp + + +DirectRLNetworks = TypeVar('DirectRLNetworks') + + +@dataclasses.dataclass +class RNDNetworks(Generic[DirectRLNetworks]): + """Container of RND networks factories.""" + target: networks_lib.FeedForwardNetwork + predictor: networks_lib.FeedForwardNetwork + # Function from predictor output, target output, and original reward to reward + get_reward: Callable[ + [networks_lib.NetworkOutput, networks_lib.NetworkOutput, jnp.ndarray], + jnp.ndarray] + direct_rl_networks: DirectRLNetworks = None + + +# See Appendix A.2 of https://arxiv.org/pdf/1810.12894.pdf +def rnd_reward_fn( + predictor_output: networks_lib.NetworkOutput, + target_output: networks_lib.NetworkOutput, + original_reward: jnp.ndarray, + intrinsic_reward_coefficient: float = 1.0, + extrinsic_reward_coefficient: float = 0.0, +) -> jnp.ndarray: + intrinsic_reward = jnp.mean( + jnp.square(predictor_output - target_output), axis=-1) + return (intrinsic_reward_coefficient * intrinsic_reward + + extrinsic_reward_coefficient * original_reward) + + +def make_networks( + spec: specs.EnvironmentSpec, + direct_rl_networks: DirectRLNetworks, + layer_sizes: Tuple[int, ...] = (256, 256), + intrinsic_reward_coefficient: float = 1.0, + extrinsic_reward_coefficient: float = 0.0, +) -> RNDNetworks[DirectRLNetworks]: + """Creates networks used by the agent and returns RNDNetworks. + + Args: + spec: Environment spec. + direct_rl_networks: Networks used by a direct rl algorithm. + layer_sizes: Layer sizes. + intrinsic_reward_coefficient: Multiplier on intrinsic reward. + extrinsic_reward_coefficient: Multiplier on extrinsic reward. + + Returns: + The RND networks. + """ + + def _rnd_fn(obs, act): + # RND does not use the action but other variants like RED do. + del act + network = networks_lib.LayerNormMLP(list(layer_sizes)) + return network(obs) + + target = hk.without_apply_rng(hk.transform(_rnd_fn)) + predictor = hk.without_apply_rng(hk.transform(_rnd_fn)) + + # Create dummy observations and actions to create network parameters. + dummy_obs = utils.zeros_like(spec.observations) + dummy_obs = utils.add_batch_dim(dummy_obs) + + return RNDNetworks( + target=networks_lib.FeedForwardNetwork( + lambda key: target.init(key, dummy_obs, ()), target.apply), + predictor=networks_lib.FeedForwardNetwork( + lambda key: predictor.init(key, dummy_obs, ()), predictor.apply), + direct_rl_networks=direct_rl_networks, + get_reward=functools.partial( + rnd_reward_fn, + intrinsic_reward_coefficient=intrinsic_reward_coefficient, + extrinsic_reward_coefficient=extrinsic_reward_coefficient)) + + +def compute_rnd_reward(predictor_params: networks_lib.Params, + target_params: networks_lib.Params, + transitions: types.Transition, + networks: RNDNetworks) -> jnp.ndarray: + """Computes the intrinsic RND reward for a given transition. + + Args: + predictor_params: Parameters of the predictor network. + target_params: Parameters of the target network. + transitions: The sample to compute rewards for. + networks: RND networks + + Returns: + The rewards as an ndarray. + """ + target_output = networks.target.apply(target_params, transitions.observation, + transitions.action) + predictor_output = networks.predictor.apply(predictor_params, + transitions.observation, + transitions.action) + return networks.get_reward(predictor_output, target_output, + transitions.reward) diff --git a/acme/acme/agents/jax/sac/README.md b/acme/acme/agents/jax/sac/README.md new file mode 100644 index 00000000..a09985c3 --- /dev/null +++ b/acme/acme/agents/jax/sac/README.md @@ -0,0 +1,19 @@ +# Soft Actor-Critic (SAC) + +This folder contains an implementation of the SAC algorithm +([Haarnoja et al., 2018]) with automatic tuning of the temperature +([Haarnoja et al., 2019]). + +This is an actor-critic method with: + + - a stochastic policy optimization (as opposed to, e.g., DPG) with a maximum entropy regularization; and + - two critics to mitigate the over-estimation bias in policy evaluation ([Fujimoto et al., 2018]). + +For the maximum entropy regularization, we provide a commonly used heuristic for specifying entropy target (`target_entropy_from_env_spec`). +The heuristic returns `-num_actions` by default or `num_actions * target_entropy_per_dimension` +if `target_entropy_per_dimension` is specified. + + +[Haarnoja et al., 2018]: https://arxiv.org/abs/1801.01290 +[Haarnoja et al., 2019]: https://arxiv.org/abs/1812.05905 +[Fujimoto et al., 2018]: https://arxiv.org/abs/1802.09477 diff --git a/acme/acme/agents/jax/sac/__init__.py b/acme/acme/agents/jax/sac/__init__.py new file mode 100644 index 00000000..38f38d47 --- /dev/null +++ b/acme/acme/agents/jax/sac/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SAC agent.""" + +from acme.agents.jax.sac.builder import SACBuilder +from acme.agents.jax.sac.config import SACConfig +from acme.agents.jax.sac.config import target_entropy_from_env_spec +from acme.agents.jax.sac.learning import SACLearner +from acme.agents.jax.sac.networks import apply_policy_and_sample +from acme.agents.jax.sac.networks import default_models_to_snapshot +from acme.agents.jax.sac.networks import make_networks +from acme.agents.jax.sac.networks import SACNetworks diff --git a/acme/acme/agents/jax/sac/builder.py b/acme/acme/agents/jax/sac/builder.py new file mode 100644 index 00000000..072c04cb --- /dev/null +++ b/acme/acme/agents/jax/sac/builder.py @@ -0,0 +1,160 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SAC Builder.""" +from typing import Iterator, List, Optional + +import acme +from acme import adders +from acme import core +from acme import specs +from acme.adders import reverb as adders_reverb +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import actors +from acme.agents.jax import builders +from acme.agents.jax.sac import config as sac_config +from acme.agents.jax.sac import learning +from acme.agents.jax.sac import networks as sac_networks +from acme.datasets import reverb as datasets +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.jax import variable_utils +from acme.utils import counting +from acme.utils import loggers +import jax +import optax +import reverb +from reverb import rate_limiters + + +class SACBuilder(builders.ActorLearnerBuilder[sac_networks.SACNetworks, + actor_core_lib.FeedForwardPolicy, + reverb.ReplaySample]): + """SAC Builder.""" + + def __init__( + self, + config: sac_config.SACConfig, + ): + """Creates a SAC learner, a behavior policy and an eval actor. + + Args: + config: a config with SAC hps + """ + self._config = config + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: sac_networks.SACNetworks, + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec, replay_client + + # Create optimizers + policy_optimizer = optax.adam(learning_rate=self._config.learning_rate) + q_optimizer = optax.adam(learning_rate=self._config.learning_rate) + + return learning.SACLearner( + networks=networks, + tau=self._config.tau, + discount=self._config.discount, + entropy_coefficient=self._config.entropy_coefficient, + target_entropy=self._config.target_entropy, + rng=random_key, + reward_scale=self._config.reward_scale, + num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, + policy_optimizer=policy_optimizer, + q_optimizer=q_optimizer, + iterator=dataset, + logger=logger_fn('learner'), + counter=counter) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: actor_core_lib.FeedForwardPolicy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> acme.Actor: + del environment_spec + assert variable_source is not None + actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) + variable_client = variable_utils.VariableClient( + variable_source, 'policy', device='cpu') + return actors.GenericActor( + actor_core, random_key, variable_client, adder, backend='cpu') + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: actor_core_lib.FeedForwardPolicy, + ) -> List[reverb.Table]: + """Create tables to insert data into.""" + del policy + samples_per_insert_tolerance = ( + self._config.samples_per_insert_tolerance_rate * + self._config.samples_per_insert) + error_buffer = self._config.min_replay_size * samples_per_insert_tolerance + limiter = rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._config.min_replay_size, + samples_per_insert=self._config.samples_per_insert, + error_buffer=error_buffer) + return [ + reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_replay_size, + rate_limiter=limiter, + signature=adders_reverb.NStepTransitionAdder.signature( + environment_spec)) + ] + + def make_dataset_iterator( + self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]: + """Create a dataset iterator to use for learning/updating the agent.""" + dataset = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=replay_client.server_address, + batch_size=(self._config.batch_size * + self._config.num_sgd_steps_per_step), + prefetch_size=self._config.prefetch_size) + return utils.device_put(dataset.as_numpy_iterator(), jax.devices()[0]) + + def make_adder( + self, replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[actor_core_lib.FeedForwardPolicy] + ) -> Optional[adders.Adder]: + """Create an adder which records data generated by the actor/environment.""" + del environment_spec, policy + return adders_reverb.NStepTransitionAdder( + priority_fns={self._config.replay_table_name: None}, + client=replay_client, + n_step=self._config.n_step, + discount=self._config.discount) + + def make_policy(self, + networks: sac_networks.SACNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False) -> actor_core_lib.FeedForwardPolicy: + """Construct the policy.""" + del environment_spec + return sac_networks.apply_policy_and_sample(networks, eval_mode=evaluation) diff --git a/acme/acme/agents/jax/sac/config.py b/acme/acme/agents/jax/sac/config.py new file mode 100644 index 00000000..8682b719 --- /dev/null +++ b/acme/acme/agents/jax/sac/config.py @@ -0,0 +1,96 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SAC config.""" +import dataclasses +from typing import Any, Optional + +from acme import specs +from acme.adders import reverb as adders_reverb +import numpy as onp + + +@dataclasses.dataclass +class SACConfig: + """Configuration options for SAC.""" + # Loss options + batch_size: int = 256 + learning_rate: float = 3e-4 + reward_scale: float = 1 + discount: float = 0.99 + n_step: int = 1 + # Coefficient applied to the entropy bonus. If None, an adaptative + # coefficient will be used. + entropy_coefficient: Optional[float] = None + target_entropy: float = 0.0 + # Target smoothing coefficient. + tau: float = 0.005 + + # Replay options + min_replay_size: int = 10000 + max_replay_size: int = 1000000 + replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE + prefetch_size: int = 4 + samples_per_insert: float = 256 + # Rate to be used for the SampleToInsertRatio rate limitter tolerance. + # See a formula in make_replay_tables for more details. + samples_per_insert_tolerance_rate: float = 0.1 + + # How many gradient updates to perform per step. + num_sgd_steps_per_step: int = 1 + + +def target_entropy_from_env_spec( + spec: specs.EnvironmentSpec, + target_entropy_per_dimension: Optional[float] = None, +) -> float: + """A heuristic to determine a target entropy. + + If target_entropy_per_dimension is not specified, the target entropy is + computed as "-num_actions", otherwise it is + "target_entropy_per_dimension * num_actions". + + Args: + spec: environment spec + target_entropy_per_dimension: None or target entropy per action dimension + + Returns: + target entropy + """ + + def get_num_actions(action_spec: Any) -> float: + """Returns a number of actions in the spec.""" + if isinstance(action_spec, specs.BoundedArray): + return onp.prod(action_spec.shape, dtype=int) + elif isinstance(action_spec, tuple): + return sum(get_num_actions(subspace) for subspace in action_spec) + else: + raise ValueError('Unknown action space type.') + + num_actions = get_num_actions(spec.actions) + if target_entropy_per_dimension is None: + if not isinstance(spec.actions, specs.BoundedArray) or isinstance( + spec.actions, specs.DiscreteArray): + raise ValueError('Only accept BoundedArrays for automatic ' + f'target_entropy, got: {spec.actions}') + if not onp.all(spec.actions.minimum == -1.): + raise ValueError( + f'Minimum expected to be -1, got: {spec.actions.minimum}') + if not onp.all(spec.actions.maximum == 1.): + raise ValueError( + f'Maximum expected to be 1, got: {spec.actions.maximum}') + + return -num_actions + else: + return target_entropy_per_dimension * num_actions diff --git a/acme/acme/agents/jax/sac/learning.py b/acme/acme/agents/jax/sac/learning.py new file mode 100644 index 00000000..c10b11e1 --- /dev/null +++ b/acme/acme/agents/jax/sac/learning.py @@ -0,0 +1,289 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SAC learner implementation.""" + +import time +from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple + +import acme +from acme import types +from acme.agents.jax.sac import networks as sac_networks +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.utils import counting +from acme.utils import loggers +import jax +import jax.numpy as jnp +import optax +import reverb + + +class TrainingState(NamedTuple): + """Contains training state for the learner.""" + policy_optimizer_state: optax.OptState + q_optimizer_state: optax.OptState + policy_params: networks_lib.Params + q_params: networks_lib.Params + target_q_params: networks_lib.Params + key: networks_lib.PRNGKey + alpha_optimizer_state: Optional[optax.OptState] = None + alpha_params: Optional[networks_lib.Params] = None + + +class SACLearner(acme.Learner): + """SAC learner.""" + + _state: TrainingState + + def __init__( + self, + networks: sac_networks.SACNetworks, + rng: jnp.ndarray, + iterator: Iterator[reverb.ReplaySample], + policy_optimizer: optax.GradientTransformation, + q_optimizer: optax.GradientTransformation, + tau: float = 0.005, + reward_scale: float = 1.0, + discount: float = 0.99, + entropy_coefficient: Optional[float] = None, + target_entropy: float = 0, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + num_sgd_steps_per_step: int = 1): + """Initialize the SAC learner. + + Args: + networks: SAC networks + rng: a key for random number generation. + iterator: an iterator over training data. + policy_optimizer: the policy optimizer. + q_optimizer: the Q-function optimizer. + tau: target smoothing coefficient. + reward_scale: reward scale. + discount: discount to use for TD updates. + entropy_coefficient: coefficient applied to the entropy bonus. If None, an + adaptative coefficient will be used. + target_entropy: Used to normalize entropy. Only used when + entropy_coefficient is None. + counter: counter object used to keep track of steps. + logger: logger object to be used by learner. + num_sgd_steps_per_step: number of sgd steps to perform per learner 'step'. + """ + adaptive_entropy_coefficient = entropy_coefficient is None + if adaptive_entropy_coefficient: + # alpha is the temperature parameter that determines the relative + # importance of the entropy term versus the reward. + log_alpha = jnp.asarray(0., dtype=jnp.float32) + alpha_optimizer = optax.adam(learning_rate=3e-4) + alpha_optimizer_state = alpha_optimizer.init(log_alpha) + else: + if target_entropy: + raise ValueError('target_entropy should not be set when ' + 'entropy_coefficient is provided') + + def alpha_loss(log_alpha: jnp.ndarray, + policy_params: networks_lib.Params, + transitions: types.Transition, + key: networks_lib.PRNGKey) -> jnp.ndarray: + """Eq 18 from https://arxiv.org/pdf/1812.05905.pdf.""" + dist_params = networks.policy_network.apply( + policy_params, transitions.observation) + action = networks.sample(dist_params, key) + log_prob = networks.log_prob(dist_params, action) + alpha = jnp.exp(log_alpha) + alpha_loss = alpha * jax.lax.stop_gradient(-log_prob - target_entropy) + return jnp.mean(alpha_loss) + + def critic_loss(q_params: networks_lib.Params, + policy_params: networks_lib.Params, + target_q_params: networks_lib.Params, + alpha: jnp.ndarray, + transitions: types.Transition, + key: networks_lib.PRNGKey) -> jnp.ndarray: + q_old_action = networks.q_network.apply( + q_params, transitions.observation, transitions.action) + next_dist_params = networks.policy_network.apply( + policy_params, transitions.next_observation) + next_action = networks.sample(next_dist_params, key) + next_log_prob = networks.log_prob(next_dist_params, next_action) + next_q = networks.q_network.apply( + target_q_params, transitions.next_observation, next_action) + next_v = jnp.min(next_q, axis=-1) - alpha * next_log_prob + target_q = jax.lax.stop_gradient(transitions.reward * reward_scale + + transitions.discount * discount * next_v) + q_error = q_old_action - jnp.expand_dims(target_q, -1) + q_loss = 0.5 * jnp.mean(jnp.square(q_error)) + return q_loss + + def actor_loss(policy_params: networks_lib.Params, + q_params: networks_lib.Params, + alpha: jnp.ndarray, + transitions: types.Transition, + key: networks_lib.PRNGKey) -> jnp.ndarray: + dist_params = networks.policy_network.apply( + policy_params, transitions.observation) + action = networks.sample(dist_params, key) + log_prob = networks.log_prob(dist_params, action) + q_action = networks.q_network.apply( + q_params, transitions.observation, action) + min_q = jnp.min(q_action, axis=-1) + actor_loss = alpha * log_prob - min_q + return jnp.mean(actor_loss) + + alpha_grad = jax.value_and_grad(alpha_loss) + critic_grad = jax.value_and_grad(critic_loss) + actor_grad = jax.value_and_grad(actor_loss) + + def update_step( + state: TrainingState, + transitions: types.Transition, + ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: + + key, key_alpha, key_critic, key_actor = jax.random.split(state.key, 4) + if adaptive_entropy_coefficient: + alpha_loss, alpha_grads = alpha_grad(state.alpha_params, + state.policy_params, transitions, + key_alpha) + alpha = jnp.exp(state.alpha_params) + else: + alpha = entropy_coefficient + critic_loss, critic_grads = critic_grad(state.q_params, + state.policy_params, + state.target_q_params, alpha, + transitions, key_critic) + actor_loss, actor_grads = actor_grad(state.policy_params, state.q_params, + alpha, transitions, key_actor) + + # Apply policy gradients + actor_update, policy_optimizer_state = policy_optimizer.update( + actor_grads, state.policy_optimizer_state) + policy_params = optax.apply_updates(state.policy_params, actor_update) + + # Apply critic gradients + critic_update, q_optimizer_state = q_optimizer.update( + critic_grads, state.q_optimizer_state) + q_params = optax.apply_updates(state.q_params, critic_update) + + new_target_q_params = jax.tree_map(lambda x, y: x * (1 - tau) + y * tau, + state.target_q_params, q_params) + + metrics = { + 'critic_loss': critic_loss, + 'actor_loss': actor_loss, + } + + new_state = TrainingState( + policy_optimizer_state=policy_optimizer_state, + q_optimizer_state=q_optimizer_state, + policy_params=policy_params, + q_params=q_params, + target_q_params=new_target_q_params, + key=key, + ) + if adaptive_entropy_coefficient: + # Apply alpha gradients + alpha_update, alpha_optimizer_state = alpha_optimizer.update( + alpha_grads, state.alpha_optimizer_state) + alpha_params = optax.apply_updates(state.alpha_params, alpha_update) + metrics.update({ + 'alpha_loss': alpha_loss, + 'alpha': jnp.exp(alpha_params), + }) + new_state = new_state._replace( + alpha_optimizer_state=alpha_optimizer_state, + alpha_params=alpha_params) + + metrics['rewards_mean'] = jnp.mean( + jnp.abs(jnp.mean(transitions.reward, axis=0))) + metrics['rewards_std'] = jnp.std(transitions.reward, axis=0) + + return new_state, metrics + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + 'learner', + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + steps_key=self._counter.get_steps_key()) + + # Iterator on demonstration transitions. + self._iterator = iterator + + update_step = utils.process_multiple_batches(update_step, + num_sgd_steps_per_step) + # Use the JIT compiler. + self._update_step = jax.jit(update_step) + + def make_initial_state(key: networks_lib.PRNGKey) -> TrainingState: + """Initialises the training state (parameters and optimiser state).""" + key_policy, key_q, key = jax.random.split(key, 3) + + policy_params = networks.policy_network.init(key_policy) + policy_optimizer_state = policy_optimizer.init(policy_params) + + q_params = networks.q_network.init(key_q) + q_optimizer_state = q_optimizer.init(q_params) + + state = TrainingState( + policy_optimizer_state=policy_optimizer_state, + q_optimizer_state=q_optimizer_state, + policy_params=policy_params, + q_params=q_params, + target_q_params=q_params, + key=key) + + if adaptive_entropy_coefficient: + state = state._replace(alpha_optimizer_state=alpha_optimizer_state, + alpha_params=log_alpha) + return state + + # Create initial state. + self._state = make_initial_state(rng) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + def step(self): + sample = next(self._iterator) + transitions = types.Transition(*sample.data) + + self._state, metrics = self._update_step(self._state, transitions) + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Increment counts and record the current time + counts = self._counter.increment(steps=1, walltime=elapsed_time) + + # Attempts to write the logs. + self._logger.write({**metrics, **counts}) + + def get_variables(self, names: List[str]) -> List[Any]: + variables = { + 'policy': self._state.policy_params, + 'critic': self._state.q_params, + } + return [variables[name] for name in names] + + def save(self) -> TrainingState: + return self._state + + def restore(self, state: TrainingState): + self._state = state diff --git a/acme/acme/agents/jax/sac/networks.py b/acme/acme/agents/jax/sac/networks.py new file mode 100644 index 00000000..10ebfbb2 --- /dev/null +++ b/acme/acme/agents/jax/sac/networks.py @@ -0,0 +1,143 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SAC networks definition.""" + +import dataclasses +from typing import Optional, Tuple + +from acme import core +from acme import specs +from acme.agents.jax import actor_core as actor_core_lib +from acme.jax import networks as networks_lib +from acme.jax import types +from acme.jax import utils +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np + + +@dataclasses.dataclass +class SACNetworks: + """Network and pure functions for the SAC agent..""" + policy_network: networks_lib.FeedForwardNetwork + q_network: networks_lib.FeedForwardNetwork + log_prob: networks_lib.LogProbFn + sample: networks_lib.SampleFn + sample_eval: Optional[networks_lib.SampleFn] = None + + +def default_models_to_snapshot( + networks: SACNetworks, + spec: specs.EnvironmentSpec): + """Defines default models to be snapshotted.""" + dummy_obs = utils.zeros_like(spec.observations) + dummy_action = utils.zeros_like(spec.actions) + dummy_key = jax.random.PRNGKey(0) + + def q_network( + source: core.VariableSource) -> types.ModelToSnapshot: + params = source.get_variables(['critic'])[0] + return types.ModelToSnapshot( + networks.q_network.apply, params, + {'obs': dummy_obs, 'action': dummy_action}) + + def default_training_actor( + source: core.VariableSource) -> types.ModelToSnapshot: + params = source.get_variables(['policy'])[0] + return types.ModelToSnapshot(apply_policy_and_sample(networks, False), + params, + {'key': dummy_key, 'obs': dummy_obs}) + + def default_eval_actor( + source: core.VariableSource) -> types.ModelToSnapshot: + params = source.get_variables(['policy'])[0] + return types.ModelToSnapshot( + apply_policy_and_sample(networks, True), params, + {'key': dummy_key, 'obs': dummy_obs}) + + return { + 'q_network': q_network, + 'default_training_actor': default_training_actor, + 'default_eval_actor': default_eval_actor, + } + + +def apply_policy_and_sample( + networks: SACNetworks, + eval_mode: bool = False) -> actor_core_lib.FeedForwardPolicy: + """Returns a function that computes actions.""" + sample_fn = networks.sample if not eval_mode else networks.sample_eval + if not sample_fn: + raise ValueError('sample function is not provided') + + def apply_and_sample(params, key, obs): + return sample_fn(networks.policy_network.apply(params, obs), key) + return apply_and_sample + + +def make_networks( + spec: specs.EnvironmentSpec, + hidden_layer_sizes: Tuple[int, ...] = (256, 256)) -> SACNetworks: + """Creates networks used by the agent.""" + + num_dimensions = np.prod(spec.actions.shape, dtype=int) + + def _actor_fn(obs): + network = hk.Sequential([ + hk.nets.MLP( + list(hidden_layer_sizes), + w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'), + activation=jax.nn.relu, + activate_final=True), + networks_lib.NormalTanhDistribution(num_dimensions), + ]) + return network(obs) + + def _critic_fn(obs, action): + network1 = hk.Sequential([ + hk.nets.MLP( + list(hidden_layer_sizes) + [1], + w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'), + activation=jax.nn.relu), + ]) + network2 = hk.Sequential([ + hk.nets.MLP( + list(hidden_layer_sizes) + [1], + w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'), + activation=jax.nn.relu), + ]) + input_ = jnp.concatenate([obs, action], axis=-1) + value1 = network1(input_) + value2 = network2(input_) + return jnp.concatenate([value1, value2], axis=-1) + + policy = hk.without_apply_rng(hk.transform(_actor_fn)) + critic = hk.without_apply_rng(hk.transform(_critic_fn)) + + # Create dummy observations and actions to create network parameters. + dummy_action = utils.zeros_like(spec.actions) + dummy_obs = utils.zeros_like(spec.observations) + dummy_action = utils.add_batch_dim(dummy_action) + dummy_obs = utils.add_batch_dim(dummy_obs) + + return SACNetworks( + policy_network=networks_lib.FeedForwardNetwork( + lambda key: policy.init(key, dummy_obs), policy.apply), + q_network=networks_lib.FeedForwardNetwork( + lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply), + log_prob=lambda params, actions: params.log_prob(actions), + sample=lambda params, key: params.sample(seed=key), + sample_eval=lambda params, key: params.mode()) diff --git a/acme/acme/agents/jax/sqil/README.md b/acme/acme/agents/jax/sqil/README.md new file mode 100644 index 00000000..add8d056 --- /dev/null +++ b/acme/acme/agents/jax/sqil/README.md @@ -0,0 +1,9 @@ +# Soft Q imitation learning (SQIL) + +This folder contains an implementation of the SQIL algorithm +([Reddy et al., 2019]) + +SQIL requires an off-policy RL algorithm to work, passed in as an +`ActorLearnerBuilder`. + +[Reddy et al., 2019]: https://arxiv.org/abs/1905.11108 diff --git a/acme/acme/agents/jax/sqil/__init__.py b/acme/acme/agents/jax/sqil/__init__.py new file mode 100644 index 00000000..440575b9 --- /dev/null +++ b/acme/acme/agents/jax/sqil/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SQIL agent.""" + +from acme.agents.jax.sqil.builder import SQILBuilder diff --git a/acme/acme/agents/jax/sqil/builder.py b/acme/acme/agents/jax/sqil/builder.py new file mode 100644 index 00000000..11458458 --- /dev/null +++ b/acme/acme/agents/jax/sqil/builder.py @@ -0,0 +1,170 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SQIL Builder (https://arxiv.org/pdf/1905.11108.pdf).""" + +from typing import Callable, Generic, Iterator, List, Optional + +from acme import adders +from acme import core +from acme import specs +from acme import types +from acme.agents.jax import builders +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.jax.imitation_learning_types import DirectPolicyNetwork, DirectRLNetworks # pylint: disable=g-multiple-import +from acme.utils import counting +from acme.utils import loggers +import jax +import numpy as np +import reverb +import tree + + +def _generate_sqil_samples( + demonstration_iterator: Iterator[types.Transition], + replay_iterator: Iterator[reverb.ReplaySample] +) -> Iterator[reverb.ReplaySample]: + """Generator which creates the sample iterator for SQIL. + + Args: + demonstration_iterator: Iterator of demonstrations. + replay_iterator: Replay buffer sample iterator. + + Yields: + Samples having a mix of demonstrations with reward 1 and replay samples with + reward 0. + """ + for demonstrations, replay_sample in zip(demonstration_iterator, + replay_iterator): + demonstrations = demonstrations._replace( + reward=np.ones_like(demonstrations.reward)) + + replay_transitions = replay_sample.data + replay_transitions = replay_transitions._replace( + reward=np.zeros_like(replay_transitions.reward)) + + double_batch = tree.map_structure(lambda x, y: np.concatenate([x, y]), + demonstrations, replay_transitions) + + # Split the double batch in an interleaving fashion. + # e.g [1, 2, 3, 4 ,5 ,6] -> [1, 3, 5] and [2, 4, 6] + yield reverb.ReplaySample( + info=replay_sample.info, + data=tree.map_structure(lambda x: x[0::2], double_batch)) + yield reverb.ReplaySample( + info=replay_sample.info, + data=tree.map_structure(lambda x: x[1::2], double_batch)) + + +class SQILBuilder(Generic[DirectRLNetworks, DirectPolicyNetwork], + builders.ActorLearnerBuilder[DirectRLNetworks, + DirectPolicyNetwork, + reverb.ReplaySample]): + """SQIL Builder (https://openreview.net/pdf?id=S1xKd24twB).""" + + def __init__(self, + rl_agent: builders.ActorLearnerBuilder[DirectRLNetworks, + DirectPolicyNetwork, + reverb.ReplaySample], + rl_agent_batch_size: int, + make_demonstrations: Callable[[int], + Iterator[types.Transition]]): + """Builds a SQIL agent. + + Args: + rl_agent: An off policy direct RL agent.. + rl_agent_batch_size: The batch size of the above algorithm. + make_demonstrations: A function that returns an infinite iterator with + demonstrations. + """ + self._rl_agent = rl_agent + self._rl_agent_batch_size = rl_agent_batch_size + self._make_demonstrations = make_demonstrations + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: DirectRLNetworks, + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: Optional[specs.EnvironmentSpec] = None, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + """Creates the learner.""" + counter = counter or counting.Counter() + direct_rl_counter = counting.Counter(counter, 'direct_rl') + return self._rl_agent.make_learner( + random_key, + networks, + dataset=dataset, + logger_fn=logger_fn, + environment_spec=environment_spec, + replay_client=replay_client, + counter=direct_rl_counter) + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: DirectPolicyNetwork, + ) -> List[reverb.Table]: + return self._rl_agent.make_replay_tables(environment_spec, policy) + + def make_dataset_iterator( + self, + replay_client: reverb.Client) -> Optional[Iterator[reverb.ReplaySample]]: + """The returned iterator returns batches with both expert and policy data. + + Batch items will alternate between expert data and policy data. + + Args: + replay_client: Reverb client. + + Returns: + The Replay sample iterator. + """ + # TODO(eorsini): Make sure we have the exact same format as the rl_agent's + # adder writes in. + demonstration_iterator = self._make_demonstrations( + self._rl_agent_batch_size) + + rb_iterator = self._rl_agent.make_dataset_iterator(replay_client) + + return utils.device_put( + _generate_sqil_samples(demonstration_iterator, rb_iterator), + jax.devices()[0]) + + def make_adder( + self, replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[DirectPolicyNetwork]) -> Optional[adders.Adder]: + return self._rl_agent.make_adder(replay_client, environment_spec, policy) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: DirectPolicyNetwork, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> core.Actor: + return self._rl_agent.make_actor(random_key, policy, environment_spec, + variable_source, adder) + + def make_policy(self, + networks: DirectRLNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False) -> DirectPolicyNetwork: + return self._rl_agent.make_policy(networks, environment_spec, evaluation) diff --git a/acme/acme/agents/jax/sqil/builder_test.py b/acme/acme/agents/jax/sqil/builder_test.py new file mode 100644 index 00000000..4a2608f0 --- /dev/null +++ b/acme/acme/agents/jax/sqil/builder_test.py @@ -0,0 +1,44 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the SQIL iterator.""" + +from acme import types +from acme.agents.jax.sqil import builder +import numpy as np +import reverb + +from absl.testing import absltest + + +class BuilderTest(absltest.TestCase): + + def test_sqil_iterator(self): + demonstrations = [ + types.Transition(np.array([[1], [2], [3]]), (), (), (), ()) + ] + replay = [ + reverb.ReplaySample( + info=(), + data=types.Transition(np.array([[4], [5], [6]]), (), (), (), ())) + ] + sqil_it = builder._generate_sqil_samples(iter(demonstrations), iter(replay)) + np.testing.assert_array_equal( + next(sqil_it).data.observation, np.array([[1], [3], [5]])) + np.testing.assert_array_equal( + next(sqil_it).data.observation, np.array([[2], [4], [6]])) + self.assertRaises(StopIteration, lambda: next(sqil_it)) + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/jax/td3/README.md b/acme/acme/agents/jax/td3/README.md new file mode 100644 index 00000000..b8045b31 --- /dev/null +++ b/acme/acme/agents/jax/td3/README.md @@ -0,0 +1,13 @@ +# Twin Delayed Deep Deterministic policy gradient algorithm (TD3) + +This folder contains an implementation of the TD3 algorithm, +[Fujimoto, 2018]. + + +Note the following differences with the original author's implementation: + +* the default network architecture is a LayerNorm MLP, +* there is no initial exploration phase with a random policy, +* the target critic and twin critic updates are not delayed. + +[Fujimoto, 2018]: https://arxiv.org/pdf/1802.09477.pdf diff --git a/acme/acme/agents/jax/td3/__init__.py b/acme/acme/agents/jax/td3/__init__.py new file mode 100644 index 00000000..1e3f9387 --- /dev/null +++ b/acme/acme/agents/jax/td3/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TD3 agent.""" + +from acme.agents.jax.td3.builder import TD3Builder +from acme.agents.jax.td3.config import TD3Config +from acme.agents.jax.td3.learning import TD3Learner +from acme.agents.jax.td3.networks import get_default_behavior_policy +from acme.agents.jax.td3.networks import make_networks +from acme.agents.jax.td3.networks import TD3Networks diff --git a/acme/acme/agents/jax/td3/builder.py b/acme/acme/agents/jax/td3/builder.py new file mode 100644 index 00000000..1149bd9d --- /dev/null +++ b/acme/acme/agents/jax/td3/builder.py @@ -0,0 +1,165 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TD3 Builder.""" +from typing import Iterator, List, Optional + +from acme import adders +from acme import core +from acme import specs +from acme.adders import reverb as adders_reverb +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import actors +from acme.agents.jax import builders +from acme.agents.jax.td3 import config as td3_config +from acme.agents.jax.td3 import learning +from acme.agents.jax.td3 import networks as td3_networks +from acme.datasets import reverb as datasets +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.jax import variable_utils +from acme.utils import counting +from acme.utils import loggers +import jax +import optax +import reverb +from reverb import rate_limiters + + +class TD3Builder(builders.ActorLearnerBuilder[td3_networks.TD3Networks, + actor_core_lib.FeedForwardPolicy, + reverb.ReplaySample]): + """TD3 Builder.""" + + def __init__( + self, + config: td3_config.TD3Config, + ): + """Creates a TD3 learner, a behavior policy and an eval actor. + + Args: + config: a config with TD3 hps + """ + self._config = config + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: td3_networks.TD3Networks, + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec, replay_client + + critic_optimizer = optax.adam(self._config.critic_learning_rate) + twin_critic_optimizer = optax.adam(self._config.critic_learning_rate) + policy_optimizer = optax.adam(self._config.policy_learning_rate) + + if self._config.policy_gradient_clipping is not None: + policy_optimizer = optax.chain( + optax.clip_by_global_norm(self._config.policy_gradient_clipping), + policy_optimizer) + + return learning.TD3Learner( + networks=networks, + random_key=random_key, + discount=self._config.discount, + target_sigma=self._config.target_sigma, + noise_clip=self._config.noise_clip, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + twin_critic_optimizer=twin_critic_optimizer, + num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, + bc_alpha=self._config.bc_alpha, + iterator=dataset, + logger=logger_fn('learner'), + counter=counter) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: actor_core_lib.FeedForwardPolicy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> core.Actor: + del environment_spec + assert variable_source is not None + actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) + # Inference happens on CPU, so it's better to move variables there too. + variable_client = variable_utils.VariableClient(variable_source, 'policy', + device='cpu') + return actors.GenericActor( + actor_core, random_key, variable_client, adder, backend='cpu') + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: actor_core_lib.FeedForwardPolicy, + ) -> List[reverb.Table]: + """Creates reverb tables for the algorithm.""" + del policy + samples_per_insert_tolerance = ( + self._config.samples_per_insert_tolerance_rate * + self._config.samples_per_insert) + error_buffer = self._config.min_replay_size * samples_per_insert_tolerance + limiter = rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._config.min_replay_size, + samples_per_insert=self._config.samples_per_insert, + error_buffer=error_buffer) + return [reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_replay_size, + rate_limiter=limiter, + signature=adders_reverb.NStepTransitionAdder.signature( + environment_spec))] + + def make_dataset_iterator( + self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]: + """Creates a dataset iterator to use for learning.""" + dataset = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=replay_client.server_address, + batch_size=( + self._config.batch_size * self._config.num_sgd_steps_per_step), + prefetch_size=self._config.prefetch_size, + transition_adder=True) + return utils.device_put(dataset.as_numpy_iterator(), jax.devices()[0]) + + def make_adder( + self, replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[actor_core_lib.FeedForwardPolicy] + ) -> Optional[adders.Adder]: + """Creates an adder which handles observations.""" + del environment_spec, policy + return adders_reverb.NStepTransitionAdder( + priority_fns={self._config.replay_table_name: None}, + client=replay_client, + n_step=self._config.n_step, + discount=self._config.discount) + + def make_policy(self, + networks: td3_networks.TD3Networks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False) -> actor_core_lib.FeedForwardPolicy: + """Creates a policy.""" + sigma = 0 if evaluation else self._config.sigma + return td3_networks.get_default_behavior_policy( + networks=networks, action_specs=environment_spec.actions, sigma=sigma) diff --git a/acme/acme/agents/jax/td3/config.py b/acme/acme/agents/jax/td3/config.py new file mode 100644 index 00000000..cb4e51e0 --- /dev/null +++ b/acme/acme/agents/jax/td3/config.py @@ -0,0 +1,60 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TD3 config.""" +import dataclasses +from typing import Optional, Union + +from acme.adders import reverb as adders_reverb +import optax + + +@dataclasses.dataclass +class TD3Config: + """Configuration options for TD3.""" + + # Loss options + batch_size: int = 256 + policy_learning_rate: Union[optax.Schedule, float] = 3e-4 + critic_learning_rate: Union[optax.Schedule, float] = 3e-4 + # Policy gradient clipping is not part of the original TD3 implementation, + # used e.g. in DAC https://arxiv.org/pdf/1809.02925.pdf + policy_gradient_clipping: Optional[float] = None + discount: float = 0.99 + n_step: int = 1 + + # TD3 specific options (https://arxiv.org/pdf/1802.09477.pdf) + sigma: float = 0.1 + delay: int = 2 + target_sigma: float = 0.2 + noise_clip: float = 0.5 + tau: float = 0.005 + + # Replay options + min_replay_size: int = 1000 + max_replay_size: int = 1000000 + replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE + prefetch_size: int = 4 + samples_per_insert: float = 256 + # Rate to be used for the SampleToInsertRatio rate limiter tolerance. + # See a formula in make_replay_tables for more details. + samples_per_insert_tolerance_rate: float = 0.1 + + # How many gradient updates to perform per step. + num_sgd_steps_per_step: int = 1 + + # Offline RL options + # if bc_alpha: if given, will add a bc regularization term to the policy loss, + # (https://arxiv.org/pdf/2106.06860.pdf), useful for offline training. + bc_alpha: Optional[float] = None diff --git a/acme/acme/agents/jax/td3/learning.py b/acme/acme/agents/jax/td3/learning.py new file mode 100644 index 00000000..4459c9d7 --- /dev/null +++ b/acme/acme/agents/jax/td3/learning.py @@ -0,0 +1,333 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TD3 learner implementation.""" + +import time +from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple + +import acme +from acme import types +from acme.agents.jax.td3 import networks as td3_networks +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.utils import counting +from acme.utils import loggers +import jax +import jax.numpy as jnp +import optax +import reverb +import rlax + + +class TrainingState(NamedTuple): + """Contains training state for the learner.""" + policy_params: networks_lib.Params + target_policy_params: networks_lib.Params + critic_params: networks_lib.Params + target_critic_params: networks_lib.Params + twin_critic_params: networks_lib.Params + target_twin_critic_params: networks_lib.Params + policy_opt_state: optax.OptState + critic_opt_state: optax.OptState + twin_critic_opt_state: optax.OptState + steps: int + random_key: networks_lib.PRNGKey + + +class TD3Learner(acme.Learner): + """TD3 learner.""" + + _state: TrainingState + + def __init__(self, + networks: td3_networks.TD3Networks, + random_key: networks_lib.PRNGKey, + discount: float, + iterator: Iterator[reverb.ReplaySample], + policy_optimizer: optax.GradientTransformation, + critic_optimizer: optax.GradientTransformation, + twin_critic_optimizer: optax.GradientTransformation, + delay: int = 2, + target_sigma: float = 0.2, + noise_clip: float = 0.5, + tau: float = 0.005, + use_sarsa_target: bool = False, + bc_alpha: Optional[float] = None, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + num_sgd_steps_per_step: int = 1): + """Initializes the TD3 learner. + + Args: + networks: TD3 networks. + random_key: a key for random number generation. + discount: discount to use for TD updates + iterator: an iterator over training data. + policy_optimizer: the policy optimizer. + critic_optimizer: the Q-function optimizer. + twin_critic_optimizer: the twin Q-function optimizer. + delay: ratio of policy updates for critic updates (see TD3), + delay=2 means 2 updates of the critic for 1 policy update. + target_sigma: std of zero mean Gaussian added to the action of + the next_state, for critic evaluation (reducing overestimation bias). + noise_clip: hard constraint on target noise. + tau: target parameters smoothing coefficient. + use_sarsa_target: compute on-policy target using iterator's actions rather + than sampled actions. + Useful for 1-step offline RL (https://arxiv.org/pdf/2106.08909.pdf). + When set to `True`, `target_policy_params` are unused. + This is only working when the learner is used as an offline algorithm. + I.e. TD3Builder does not support adding the SARSA target to the replay + buffer. + bc_alpha: bc_alpha: Implements TD3+BC. + See comments in TD3Config.bc_alpha for details. + counter: counter object used to keep track of steps. + logger: logger object to be used by learner. + num_sgd_steps_per_step: number of sgd steps to perform per learner 'step'. + """ + + def policy_loss( + policy_params: networks_lib.Params, + critic_params: networks_lib.Params, + transition: types.NestedArray, + ) -> jnp.ndarray: + # Computes the discrete policy gradient loss. + action = networks.policy_network.apply( + policy_params, transition.observation) + grad_critic = jax.vmap( + jax.grad(networks.critic_network.apply, argnums=2), + in_axes=(None, 0, 0)) + dq_da = grad_critic(critic_params, transition.observation, action) + batch_dpg_learning = jax.vmap(rlax.dpg_loss, in_axes=(0, 0)) + loss = jnp.mean(batch_dpg_learning(action, dq_da)) + if bc_alpha is not None: + # BC regularization for offline RL + q_sa = networks.critic_network.apply(critic_params, + transition.observation, action) + bc_factor = jax.lax.stop_gradient(bc_alpha / jnp.mean(jnp.abs(q_sa))) + loss += jnp.mean(jnp.square(action - transition.action)) / bc_factor + return loss + + def critic_loss( + critic_params: networks_lib.Params, + state: TrainingState, + transition: types.Transition, + random_key: jnp.ndarray, + ): + # Computes the critic loss. + q_tm1 = networks.critic_network.apply( + critic_params, transition.observation, transition.action) + + if use_sarsa_target: + # TODO(b/222674779): use N-steps Trajectories to get the next actions. + assert 'next_action' in transition.extras, ( + 'next actions should be given as extras for one step RL.') + action = transition.extras['next_action'] + else: + action = networks.policy_network.apply(state.target_policy_params, + transition.next_observation) + action = networks.add_policy_noise(action, random_key, + target_sigma, noise_clip) + + q_t = networks.critic_network.apply( + state.target_critic_params, + transition.next_observation, + action) + twin_q_t = networks.twin_critic_network.apply( + state.target_twin_critic_params, + transition.next_observation, + action) + + q_t = jnp.minimum(q_t, twin_q_t) + + target_q_tm1 = transition.reward + discount * transition.discount * q_t + td_error = jax.lax.stop_gradient(target_q_tm1) - q_tm1 + + return jnp.mean(jnp.square(td_error)) + + def update_step( + state: TrainingState, + transitions: types.Transition, + ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: + + random_key, key_critic, key_twin = jax.random.split(state.random_key, 3) + + # Updates on the critic: compute the gradients, and update using + # Polyak averaging. + critic_loss_and_grad = jax.value_and_grad(critic_loss) + critic_loss_value, critic_gradients = critic_loss_and_grad( + state.critic_params, state, transitions, key_critic) + critic_updates, critic_opt_state = critic_optimizer.update( + critic_gradients, state.critic_opt_state) + critic_params = optax.apply_updates(state.critic_params, critic_updates) + # In the original authors' implementation the critic target update is + # delayed similarly to the policy update which we found empirically to + # perform slightly worse. + target_critic_params = optax.incremental_update( + new_tensors=critic_params, + old_tensors=state.target_critic_params, + step_size=tau) + + # Updates on the twin critic: compute the gradients, and update using + # Polyak averaging. + twin_critic_loss_value, twin_critic_gradients = critic_loss_and_grad( + state.twin_critic_params, state, transitions, key_twin) + twin_critic_updates, twin_critic_opt_state = twin_critic_optimizer.update( + twin_critic_gradients, state.twin_critic_opt_state) + twin_critic_params = optax.apply_updates(state.twin_critic_params, + twin_critic_updates) + # In the original authors' implementation the twin critic target update is + # delayed similarly to the policy update which we found empirically to + # perform slightly worse. + target_twin_critic_params = optax.incremental_update( + new_tensors=twin_critic_params, + old_tensors=state.target_twin_critic_params, + step_size=tau) + + # Updates on the policy: compute the gradients, and update using + # Polyak averaging (if delay enabled, the update might not be applied). + policy_loss_and_grad = jax.value_and_grad(policy_loss) + policy_loss_value, policy_gradients = policy_loss_and_grad( + state.policy_params, state.critic_params, transitions) + def update_policy_step(): + policy_updates, policy_opt_state = policy_optimizer.update( + policy_gradients, state.policy_opt_state) + policy_params = optax.apply_updates(state.policy_params, policy_updates) + target_policy_params = optax.incremental_update( + new_tensors=policy_params, + old_tensors=state.target_policy_params, + step_size=tau) + return policy_params, target_policy_params, policy_opt_state + + # The update on the policy is applied every `delay` steps. + current_policy_state = (state.policy_params, state.target_policy_params, + state.policy_opt_state) + policy_params, target_policy_params, policy_opt_state = jax.lax.cond( + state.steps % delay == 0, + lambda _: update_policy_step(), + lambda _: current_policy_state, + operand=None) + + steps = state.steps + 1 + + new_state = TrainingState( + policy_params=policy_params, + critic_params=critic_params, + twin_critic_params=twin_critic_params, + target_policy_params=target_policy_params, + target_critic_params=target_critic_params, + target_twin_critic_params=target_twin_critic_params, + policy_opt_state=policy_opt_state, + critic_opt_state=critic_opt_state, + twin_critic_opt_state=twin_critic_opt_state, + steps=steps, + random_key=random_key, + ) + + metrics = { + 'policy_loss': policy_loss_value, + 'critic_loss': critic_loss_value, + 'twin_critic_loss': twin_critic_loss_value, + } + + return new_state, metrics + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + 'learner', + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + steps_key=self._counter.get_steps_key()) + + # Create prefetching dataset iterator. + self._iterator = iterator + + # Faster sgd step + update_step = utils.process_multiple_batches(update_step, + num_sgd_steps_per_step) + # Use the JIT compiler. + self._update_step = jax.jit(update_step) + + (key_init_policy, key_init_twin, key_init_target, key_state + ) = jax.random.split(random_key, 4) + # Create the network parameters and copy into the target network parameters. + initial_policy_params = networks.policy_network.init(key_init_policy) + initial_critic_params = networks.critic_network.init(key_init_twin) + initial_twin_critic_params = networks.twin_critic_network.init( + key_init_target) + + initial_target_policy_params = initial_policy_params + initial_target_critic_params = initial_critic_params + initial_target_twin_critic_params = initial_twin_critic_params + + # Initialize optimizers. + initial_policy_opt_state = policy_optimizer.init(initial_policy_params) + initial_critic_opt_state = critic_optimizer.init(initial_critic_params) + initial_twin_critic_opt_state = twin_critic_optimizer.init( + initial_twin_critic_params) + + # Create initial state. + self._state = TrainingState( + policy_params=initial_policy_params, + target_policy_params=initial_target_policy_params, + critic_params=initial_critic_params, + twin_critic_params=initial_twin_critic_params, + target_critic_params=initial_target_critic_params, + target_twin_critic_params=initial_target_twin_critic_params, + policy_opt_state=initial_policy_opt_state, + critic_opt_state=initial_critic_opt_state, + twin_critic_opt_state=initial_twin_critic_opt_state, + steps=0, + random_key=key_state + ) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + def step(self): + # Get data from replay (dropping extras if any). Note there is no + # extra data here because we do not insert any into Reverb. + sample = next(self._iterator) + transitions = types.Transition(*sample.data) + + self._state, metrics = self._update_step(self._state, transitions) + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Increment counts and record the current time + counts = self._counter.increment(steps=1, walltime=elapsed_time) + + # Attempts to write the logs. + self._logger.write({**metrics, **counts}) + + def get_variables(self, names: List[str]) -> List[networks_lib.Params]: + variables = { + 'policy': self._state.policy_params, + 'critic': self._state.critic_params, + 'twin_critic': self._state.twin_critic_params, + } + return [variables[name] for name in names] + + def save(self) -> TrainingState: + return self._state + + def restore(self, state: TrainingState): + self._state = state diff --git a/acme/acme/agents/jax/td3/networks.py b/acme/acme/agents/jax/td3/networks.py new file mode 100644 index 00000000..aa478b03 --- /dev/null +++ b/acme/acme/agents/jax/td3/networks.py @@ -0,0 +1,118 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TD3 networks definition.""" +import dataclasses +from typing import Callable, Sequence + +from acme import specs +from acme import types +from acme.agents.jax import actor_core as actor_core_lib +from acme.jax import networks as networks_lib +from acme.jax import utils +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np + + +@dataclasses.dataclass +class TD3Networks: + """Network and pure functions for the TD3 agent.""" + policy_network: networks_lib.FeedForwardNetwork + critic_network: networks_lib.FeedForwardNetwork + twin_critic_network: networks_lib.FeedForwardNetwork + add_policy_noise: Callable[[types.NestedArray, networks_lib.PRNGKey, + float, float], types.NestedArray] + + +def get_default_behavior_policy( + networks: TD3Networks, action_specs: specs.BoundedArray, + sigma: float) -> actor_core_lib.FeedForwardPolicy: + """Selects action according to the policy.""" + def behavior_policy(params: networks_lib.Params, key: networks_lib.PRNGKey, + observation: types.NestedArray): + action = networks.policy_network.apply(params, observation) + noise = jax.random.normal(key, shape=action.shape) * sigma + noisy_action = jnp.clip(action + noise, + action_specs.minimum, action_specs.maximum) + return noisy_action + return behavior_policy + + +def make_networks( + spec: specs.EnvironmentSpec, + hidden_layer_sizes: Sequence[int] = (256, 256)) -> TD3Networks: + """Creates networks used by the agent. + + The networks used are based on LayerNormMLP, which is different than the + MLP with relu activation described in TD3 (which empirically performs worse). + + Args: + spec: Environment specs + hidden_layer_sizes: list of sizes of hidden layers in actor/critic networks + + Returns: + network: TD3Networks + """ + + action_specs = spec.actions + num_dimensions = np.prod(action_specs.shape, dtype=int) + + def add_policy_noise(action: types.NestedArray, + key: networks_lib.PRNGKey, + target_sigma: float, + noise_clip: float) -> types.NestedArray: + """Adds action noise to bootstrapped Q-value estimate in critic loss.""" + noise = jax.random.normal(key=key, shape=action_specs.shape) * target_sigma + noise = jnp.clip(noise, -noise_clip, noise_clip) + return jnp.clip(action + noise, action_specs.minimum, action_specs.maximum) + + def _actor_fn(obs: types.NestedArray) -> types.NestedArray: + network = hk.Sequential([ + networks_lib.LayerNormMLP(hidden_layer_sizes, + activate_final=True), + networks_lib.NearZeroInitializedLinear(num_dimensions), + networks_lib.TanhToSpec(spec.actions), + ]) + return network(obs) + + def _critic_fn(obs: types.NestedArray, + action: types.NestedArray) -> types.NestedArray: + network1 = hk.Sequential([ + networks_lib.LayerNormMLP(list(hidden_layer_sizes) + [1]), + ]) + input_ = jnp.concatenate([obs, action], axis=-1) + value = network1(input_) + return jnp.squeeze(value) + + policy = hk.without_apply_rng(hk.transform(_actor_fn)) + critic = hk.without_apply_rng(hk.transform(_critic_fn)) + + # Create dummy observations and actions to create network parameters. + dummy_action = utils.zeros_like(spec.actions) + dummy_obs = utils.zeros_like(spec.observations) + dummy_action = utils.add_batch_dim(dummy_action) + dummy_obs = utils.add_batch_dim(dummy_obs) + + network = TD3Networks( + policy_network=networks_lib.FeedForwardNetwork( + lambda key: policy.init(key, dummy_obs), policy.apply), + critic_network=networks_lib.FeedForwardNetwork( + lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply), + twin_critic_network=networks_lib.FeedForwardNetwork( + lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply), + add_policy_noise=add_policy_noise) + + return network diff --git a/acme/acme/agents/jax/value_dice/README.md b/acme/acme/agents/jax/value_dice/README.md new file mode 100644 index 00000000..976d0885 --- /dev/null +++ b/acme/acme/agents/jax/value_dice/README.md @@ -0,0 +1,12 @@ +# Value Dice + +This folder contains an implementation of the ValueDice algorithm +([Kostrikov et al., 2019]). + +The implementation supports both: + - offline training (demonstrations only) + - mixed mode + +Offline training is achieved by setting 'nu_reg_scale' and 'alpha' to 0. + +[Kostrikov et al., 2019]: https://arxiv.org/abs/1912.05032 diff --git a/acme/acme/agents/jax/value_dice/__init__.py b/acme/acme/agents/jax/value_dice/__init__.py new file mode 100644 index 00000000..d86640aa --- /dev/null +++ b/acme/acme/agents/jax/value_dice/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ValueDice agent.""" + +from acme.agents.jax.value_dice.builder import ValueDiceBuilder +from acme.agents.jax.value_dice.config import ValueDiceConfig +from acme.agents.jax.value_dice.learning import ValueDiceLearner +from acme.agents.jax.value_dice.networks import apply_policy_and_sample +from acme.agents.jax.value_dice.networks import make_networks +from acme.agents.jax.value_dice.networks import ValueDiceNetworks diff --git a/acme/acme/agents/jax/value_dice/builder.py b/acme/acme/agents/jax/value_dice/builder.py new file mode 100644 index 00000000..c85f9f3b --- /dev/null +++ b/acme/acme/agents/jax/value_dice/builder.py @@ -0,0 +1,151 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ValueDice agent implementation, using JAX.""" + +from typing import Callable, Iterator, List, Optional + +from acme import adders +from acme import core +from acme import specs +from acme import types +from acme.adders import reverb as adders_reverb +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import actors +from acme.agents.jax import builders +from acme.agents.jax.value_dice import config as value_dice_config +from acme.agents.jax.value_dice import learning +from acme.agents.jax.value_dice import networks as value_dice_networks +from acme.datasets import reverb as datasets +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.jax import variable_utils +from acme.utils import counting +from acme.utils import loggers +import jax +import optax +import reverb +from reverb import rate_limiters + + +class ValueDiceBuilder( + builders.ActorLearnerBuilder[value_dice_networks.ValueDiceNetworks, + actor_core_lib.FeedForwardPolicy, + reverb.ReplaySample]): + """ValueDice Builder. + + This builder is an entry point for online version of ValueDice. + For offline please use the ValueDiceLearner directly. + """ + + def __init__(self, config: value_dice_config.ValueDiceConfig, + make_demonstrations: Callable[[int], + Iterator[types.Transition]]): + self._make_demonstrations = make_demonstrations + self._config = config + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: value_dice_networks.ValueDiceNetworks, + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec, replay_client + iterator_demonstration = self._make_demonstrations( + self._config.batch_size * self._config.num_sgd_steps_per_step) + policy_optimizer = optax.adam( + learning_rate=self._config.policy_learning_rate) + nu_optimizer = optax.adam(learning_rate=self._config.nu_learning_rate) + return learning.ValueDiceLearner( + networks=networks, + policy_optimizer=policy_optimizer, + nu_optimizer=nu_optimizer, + discount=self._config.discount, + rng=random_key, + alpha=self._config.alpha, + policy_reg_scale=self._config.policy_reg_scale, + nu_reg_scale=self._config.nu_reg_scale, + num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, + iterator_replay=dataset, + iterator_demonstrations=iterator_demonstration, + logger=logger_fn('learner'), + counter=counter, + ) + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: actor_core_lib.FeedForwardPolicy, + ) -> List[reverb.Table]: + del policy + samples_per_insert_tolerance = ( + self._config.samples_per_insert_tolerance_rate * + self._config.samples_per_insert) + error_buffer = self._config.min_replay_size * samples_per_insert_tolerance + limiter = rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._config.min_replay_size, + samples_per_insert=self._config.samples_per_insert, + error_buffer=error_buffer) + return [reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_replay_size, + rate_limiter=limiter, + signature=adders_reverb.NStepTransitionAdder.signature( + environment_spec))] + + def make_dataset_iterator( + self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]: + """Creates a dataset iterator to use for learning.""" + dataset = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=replay_client.server_address, + batch_size=( + self._config.batch_size * self._config.num_sgd_steps_per_step), + prefetch_size=self._config.prefetch_size) + return utils.device_put(dataset.as_numpy_iterator(), jax.devices()[0]) + + def make_adder( + self, replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[actor_core_lib.FeedForwardPolicy] + ) -> Optional[adders.Adder]: + del environment_spec, policy + return adders_reverb.NStepTransitionAdder( + priority_fns={self._config.replay_table_name: None}, + client=replay_client, + n_step=1, + discount=self._config.discount) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: actor_core_lib.FeedForwardPolicy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> core.Actor: + del environment_spec + assert variable_source is not None + actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) + # Inference happens on CPU, so it's better to move variables there too. + variable_client = variable_utils.VariableClient(variable_source, 'policy', + device='cpu') + return actors.GenericActor( + actor_core, random_key, variable_client, adder, backend='cpu') diff --git a/acme/acme/agents/jax/value_dice/config.py b/acme/acme/agents/jax/value_dice/config.py new file mode 100644 index 00000000..7f8a28da --- /dev/null +++ b/acme/acme/agents/jax/value_dice/config.py @@ -0,0 +1,45 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ValueDice config.""" + +import dataclasses + +from acme.adders import reverb as adders_reverb + + +@dataclasses.dataclass +class ValueDiceConfig: + """Configuration options for ValueDice.""" + + policy_learning_rate: float = 1e-5 + nu_learning_rate: float = 1e-3 + discount: float = .99 + batch_size: int = 256 + alpha: float = 0.05 + policy_reg_scale: float = 1e-4 + nu_reg_scale: float = 10.0 + + # Replay options + replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE + samples_per_insert: float = 256 * 4 + # Rate to be used for the SampleToInsertRatio rate limitter tolerance. + # See a formula in make_replay_tables for more details. + samples_per_insert_tolerance_rate: float = 0.1 + min_replay_size: int = 1000 + max_replay_size: int = 1000000 + prefetch_size: int = 4 + + # How many gradient updates to perform per step. + num_sgd_steps_per_step: int = 1 diff --git a/acme/acme/agents/jax/value_dice/learning.py b/acme/acme/agents/jax/value_dice/learning.py new file mode 100644 index 00000000..64ca881f --- /dev/null +++ b/acme/acme/agents/jax/value_dice/learning.py @@ -0,0 +1,329 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ValueDice learner implementation.""" + +import functools +import time +from typing import Any, Dict, Iterator, List, Mapping, NamedTuple, Optional, Tuple + +import acme +from acme import types +from acme.agents.jax.value_dice import networks as value_dice_networks +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.utils import counting +from acme.utils import loggers +import jax +import jax.numpy as jnp +import optax +import reverb + + +class TrainingState(NamedTuple): + """Contains training state for the learner.""" + policy_optimizer_state: optax.OptState + policy_params: networks_lib.Params + nu_optimizer_state: optax.OptState + nu_params: networks_lib.Params + key: jnp.ndarray + steps: int + + +def _orthogonal_regularization_loss(params: networks_lib.Params): + """Orthogonal regularization. + + See equation (3) in https://arxiv.org/abs/1809.11096. + + Args: + params: Dictionary of parameters to apply regualization for. + + Returns: + A regularization loss term. + """ + reg_loss = 0 + for key in params: + if isinstance(params[key], Mapping): + reg_loss += _orthogonal_regularization_loss(params[key]) + continue + variable = params[key] + assert len(variable.shape) in [1, 2, 4] + if len(variable.shape) == 1: + # This is a bias so do not apply regularization. + continue + if len(variable.shape) == 4: + # CNN + variable = jnp.reshape(variable, (-1, variable.shape[-1])) + prod = jnp.matmul(jnp.transpose(variable), variable) + reg_loss += jnp.sum(jnp.square(prod * (1 - jnp.eye(prod.shape[0])))) + return reg_loss + + +class ValueDiceLearner(acme.Learner): + """ValueDice learner.""" + + _state: TrainingState + + def __init__(self, + networks: value_dice_networks.ValueDiceNetworks, + policy_optimizer: optax.GradientTransformation, + nu_optimizer: optax.GradientTransformation, + discount: float, + rng: jnp.ndarray, + iterator_replay: Iterator[reverb.ReplaySample], + iterator_demonstrations: Iterator[types.Transition], + alpha: float = 0.05, + policy_reg_scale: float = 1e-4, + nu_reg_scale: float = 10.0, + num_sgd_steps_per_step: int = 1, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None): + + rng, policy_key, nu_key = jax.random.split(rng, 3) + policy_init_params = networks.policy_network.init(policy_key) + policy_optimizer_state = policy_optimizer.init(policy_init_params) + + nu_init_params = networks.nu_network.init(nu_key) + nu_optimizer_state = nu_optimizer.init(nu_init_params) + + def compute_losses( + policy_params: networks_lib.Params, + nu_params: networks_lib.Params, + key: jnp.ndarray, + replay_o_tm1: types.NestedArray, + replay_a_tm1: types.NestedArray, + replay_o_t: types.NestedArray, + demo_o_tm1: types.NestedArray, + demo_a_tm1: types.NestedArray, + demo_o_t: types.NestedArray, + ) -> jnp.ndarray: + # TODO(damienv, hussenot): what to do with the discounts ? + + def policy(obs, key): + dist_params = networks.policy_network.apply(policy_params, obs) + return networks.sample(dist_params, key) + + key1, key2, key3, key4 = jax.random.split(key, 4) + + # Predicted actions. + demo_o_t0 = demo_o_tm1 + policy_demo_a_t0 = policy(demo_o_t0, key1) + policy_demo_a_t = policy(demo_o_t, key2) + policy_replay_a_t = policy(replay_o_t, key3) + + replay_a_tm1 = networks.encode_action(replay_a_tm1) + demo_a_tm1 = networks.encode_action(demo_a_tm1) + policy_demo_a_t0 = networks.encode_action(policy_demo_a_t0) + policy_demo_a_t = networks.encode_action(policy_demo_a_t) + policy_replay_a_t = networks.encode_action(policy_replay_a_t) + + # "Value function" nu over the expert states. + nu_demo_t0 = networks.nu_network.apply(nu_params, demo_o_t0, + policy_demo_a_t0) + nu_demo_tm1 = networks.nu_network.apply(nu_params, demo_o_tm1, demo_a_tm1) + nu_demo_t = networks.nu_network.apply(nu_params, demo_o_t, + policy_demo_a_t) + nu_demo_diff = nu_demo_tm1 - discount * nu_demo_t + + # "Value function" nu over the replay buffer states. + nu_replay_tm1 = networks.nu_network.apply(nu_params, replay_o_tm1, + replay_a_tm1) + nu_replay_t = networks.nu_network.apply(nu_params, replay_o_t, + policy_replay_a_t) + nu_replay_diff = nu_replay_tm1 - discount * nu_replay_t + + # Linear part of the loss. + linear_loss_demo = jnp.mean(nu_demo_t0 * (1.0 - discount)) + linear_loss_rb = jnp.mean(nu_replay_diff) + linear_loss = (linear_loss_demo * (1 - alpha) + linear_loss_rb * alpha) + + # Non linear part of the loss. + nu_replay_demo_diff = jnp.concatenate([nu_demo_diff, nu_replay_diff], + axis=0) + replay_demo_weights = jnp.concatenate([ + jnp.ones_like(nu_demo_diff) * (1 - alpha), + jnp.ones_like(nu_replay_diff) * alpha + ], + axis=0) + replay_demo_weights /= jnp.mean(replay_demo_weights) + non_linear_loss = jnp.sum( + jax.lax.stop_gradient( + utils.weighted_softmax(nu_replay_demo_diff, replay_demo_weights, + axis=0)) * + nu_replay_demo_diff) + + # Final loss. + loss = (non_linear_loss - linear_loss) + + # Regularized policy loss. + if policy_reg_scale > 0.: + policy_reg = _orthogonal_regularization_loss(policy_params) + else: + policy_reg = 0. + + # Gradient penality on nu + if nu_reg_scale > 0.0: + batch_size = demo_o_tm1.shape[0] + c = jax.random.uniform(key4, shape=(batch_size,)) + shape_o = [ + dim if i == 0 else 1 for i, dim in enumerate(replay_o_tm1.shape) + ] + shape_a = [ + dim if i == 0 else 1 for i, dim in enumerate(replay_a_tm1.shape) + ] + c_o = jnp.reshape(c, shape_o) + c_a = jnp.reshape(c, shape_a) + mixed_o_tm1 = c_o * demo_o_tm1 + (1 - c_o) * replay_o_tm1 + mixed_a_tm1 = c_a * demo_a_tm1 + (1 - c_a) * replay_a_tm1 + mixed_o_t = c_o * demo_o_t + (1 - c_o) * replay_o_t + mixed_policy_a_t = c_a * policy_demo_a_t + (1 - c_a) * policy_replay_a_t + mixed_o = jnp.concatenate([mixed_o_tm1, mixed_o_t], axis=0) + mixed_a = jnp.concatenate([mixed_a_tm1, mixed_policy_a_t], axis=0) + + def sum_nu(o, a): + return jnp.sum(networks.nu_network.apply(nu_params, o, a)) + + nu_grad_o_fn = jax.grad(sum_nu, argnums=0) + nu_grad_a_fn = jax.grad(sum_nu, argnums=1) + nu_grad_o = nu_grad_o_fn(mixed_o, mixed_a) + nu_grad_a = nu_grad_a_fn(mixed_o, mixed_a) + nu_grad = jnp.concatenate([ + jnp.reshape(nu_grad_o, [batch_size, -1]), + jnp.reshape(nu_grad_a, [batch_size, -1])], axis=-1) + # TODO(damienv, hussenot): check for the need of eps + # (like in the original value dice code). + nu_grad_penalty = jnp.mean( + jnp.square( + jnp.linalg.norm(nu_grad + 1e-8, axis=-1, keepdims=True) - 1)) + else: + nu_grad_penalty = 0.0 + + policy_loss = -loss + policy_reg_scale * policy_reg + nu_loss = loss + nu_reg_scale * nu_grad_penalty + + return policy_loss, nu_loss + + def sgd_step( + state: TrainingState, + data: Tuple[types.Transition, types.Transition] + ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: + replay_transitions, demo_transitions = data + key, key_loss = jax.random.split(state.key) + compute_losses_with_input = functools.partial( + compute_losses, + replay_o_tm1=replay_transitions.observation, + replay_a_tm1=replay_transitions.action, + replay_o_t=replay_transitions.next_observation, + demo_o_tm1=demo_transitions.observation, + demo_a_tm1=demo_transitions.action, + demo_o_t=demo_transitions.next_observation, + key=key_loss) + (policy_loss_value, nu_loss_value), vjpfun = jax.vjp( + compute_losses_with_input, + state.policy_params, state.nu_params) + policy_gradients, _ = vjpfun((1.0, 0.0)) + _, nu_gradients = vjpfun((0.0, 1.0)) + + # Update optimizers. + policy_update, policy_optimizer_state = policy_optimizer.update( + policy_gradients, state.policy_optimizer_state) + policy_params = optax.apply_updates(state.policy_params, policy_update) + + nu_update, nu_optimizer_state = nu_optimizer.update( + nu_gradients, state.nu_optimizer_state) + nu_params = optax.apply_updates(state.nu_params, nu_update) + + new_state = TrainingState( + policy_optimizer_state=policy_optimizer_state, + policy_params=policy_params, + nu_optimizer_state=nu_optimizer_state, + nu_params=nu_params, + key=key, + steps=state.steps + 1, + ) + + metrics = { + 'policy_loss': policy_loss_value, + 'nu_loss': nu_loss_value, + } + + return new_state, metrics + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + 'learner', + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + steps_key=self._counter.get_steps_key()) + + # Iterator on demonstration transitions. + self._iterator_demonstrations = iterator_demonstrations + self._iterator_replay = iterator_replay + + self._sgd_step = jax.jit(utils.process_multiple_batches( + sgd_step, num_sgd_steps_per_step)) + + # Create initial state. + self._state = TrainingState( + policy_optimizer_state=policy_optimizer_state, + policy_params=policy_init_params, + nu_optimizer_state=nu_optimizer_state, + nu_params=nu_init_params, + key=rng, + steps=0, + ) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + def step(self): + # Get data from replay (dropping extras if any). Note there is no + # extra data here because we do not insert any into Reverb. + # TODO(raveman): Add a support for offline training, where we do not consume + # data from the replay buffer. + sample = next(self._iterator_replay) + replay_transitions = types.Transition(*sample.data) + + # Get a batch of Transitions from the demonstration. + demonstration_transitions = next(self._iterator_demonstrations) + + self._state, metrics = self._sgd_step( + self._state, (replay_transitions, demonstration_transitions)) + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Increment counts and record the current time + counts = self._counter.increment(steps=1, walltime=elapsed_time) + + # Attempts to write the logs. + self._logger.write({**metrics, **counts}) + + def get_variables(self, names: List[str]) -> List[Any]: + variables = { + 'policy': self._state.policy_params, + 'nu': self._state.nu_params, + } + return [variables[name] for name in names] + + def save(self) -> TrainingState: + return self._state + + def restore(self, state: TrainingState): + self._state = state diff --git a/acme/acme/agents/jax/value_dice/networks.py b/acme/acme/agents/jax/value_dice/networks.py new file mode 100644 index 00000000..479a0f85 --- /dev/null +++ b/acme/acme/agents/jax/value_dice/networks.py @@ -0,0 +1,100 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ValueDice networks definition.""" + + +import dataclasses +from typing import Callable, Optional, Tuple + +from acme import specs +from acme.agents.jax import actor_core as actor_core_lib +from acme.jax import networks as networks_lib +from acme.jax import utils +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np + + +@dataclasses.dataclass +class ValueDiceNetworks: + """ValueDice networks.""" + policy_network: networks_lib.FeedForwardNetwork + nu_network: networks_lib.FeedForwardNetwork + # Functions for actors and evaluators, resp., to sample actions. + sample: networks_lib.SampleFn + sample_eval: Optional[networks_lib.SampleFn] = None + # Function that transforms an action before a mixture is applied, typically + # the identity for continuous actions and one-hot encoding for discrete + # actions. + encode_action: Callable[[networks_lib.Action], jnp.ndarray] = lambda x: x + + +def apply_policy_and_sample( + networks: ValueDiceNetworks, + eval_mode: bool = False) -> actor_core_lib.FeedForwardPolicy: + """Returns a function that computes actions.""" + sample_fn = networks.sample if not eval_mode else networks.sample_eval + if not sample_fn: + raise ValueError('sample function is not provided') + + def apply_and_sample(params, key, obs): + return sample_fn(networks.policy_network.apply(params, obs), key) + return apply_and_sample + + +def make_networks( + spec: specs.EnvironmentSpec, + hidden_layer_sizes: Tuple[int, ...] = (256, 256)) -> ValueDiceNetworks: + """Creates networks used by the agent.""" + + num_dimensions = np.prod(spec.actions.shape, dtype=int) + + def _actor_fn(obs): + network = hk.Sequential([ + hk.nets.MLP( + list(hidden_layer_sizes), + w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'), + activation=jax.nn.relu, + activate_final=True), + networks_lib.NormalTanhDistribution(num_dimensions), + ]) + return network(obs) + + def _nu_fn(obs, action): + network = hk.Sequential([ + hk.nets.MLP( + list(hidden_layer_sizes) + [1], + w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'), + activation=jax.nn.relu), + ]) + return network(jnp.concatenate([obs, action], axis=-1)) + + policy = hk.without_apply_rng(hk.transform(_actor_fn)) + nu = hk.without_apply_rng(hk.transform(_nu_fn)) + + # Create dummy observations and actions to create network parameters. + dummy_action = utils.zeros_like(spec.actions) + dummy_obs = utils.zeros_like(spec.observations) + dummy_action = utils.add_batch_dim(dummy_action) + dummy_obs = utils.add_batch_dim(dummy_obs) + + return ValueDiceNetworks( + policy_network=networks_lib.FeedForwardNetwork( + lambda key: policy.init(key, dummy_obs), policy.apply), + nu_network=networks_lib.FeedForwardNetwork( + lambda key: nu.init(key, dummy_obs, dummy_action), nu.apply), + sample=lambda params, key: params.sample(seed=key), + sample_eval=lambda params, key: params.mode()) diff --git a/acme/acme/agents/replay.py b/acme/acme/agents/replay.py new file mode 100644 index 00000000..7f275766 --- /dev/null +++ b/acme/acme/agents/replay.py @@ -0,0 +1,168 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common tools for reverb replay.""" + +import dataclasses +from typing import Any, Callable, Dict, Iterator, Optional + +from acme import adders as adders_lib +from acme import datasets +from acme import specs +from acme import types +from acme.adders import reverb as adders +import reverb + + +@dataclasses.dataclass +class ReverbReplay: + server: reverb.Server + adder: adders_lib.Adder + data_iterator: Iterator[reverb.ReplaySample] + client: Optional[reverb.Client] = None + can_sample: Callable[[], bool] = lambda: True + + +def make_reverb_prioritized_nstep_replay( + environment_spec: specs.EnvironmentSpec, + extra_spec: types.NestedSpec = (), + n_step: int = 1, + batch_size: int = 32, + max_replay_size: int = 100_000, + min_replay_size: int = 1, + discount: float = 1., + prefetch_size: int = 4, # TODO(iosband): rationalize prefetch size. + replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, + priority_exponent: Optional[float] = None, # If None, default to uniform. +) -> ReverbReplay: + """Creates a single-process replay infrastructure from an environment spec.""" + # Parsing priority exponent to determine uniform vs prioritized replay + if priority_exponent is None: + sampler = reverb.selectors.Uniform() + priority_fns = {replay_table_name: lambda x: 1.} + else: + sampler = reverb.selectors.Prioritized(priority_exponent) + priority_fns = None + + # Create a replay server to add data to. This uses no limiter behavior in + # order to allow the Agent interface to handle it. + replay_table = reverb.Table( + name=replay_table_name, + sampler=sampler, + remover=reverb.selectors.Fifo(), + max_size=max_replay_size, + rate_limiter=reverb.rate_limiters.MinSize(min_replay_size), + signature=adders.NStepTransitionAdder.signature(environment_spec, + extra_spec), + ) + server = reverb.Server([replay_table], port=None) + + # The adder is used to insert observations into replay. + address = f'localhost:{server.port}' + client = reverb.Client(address) + adder = adders.NStepTransitionAdder( + client, n_step, discount, priority_fns=priority_fns) + + # The dataset provides an interface to sample from replay. + data_iterator = datasets.make_reverb_dataset( + table=replay_table_name, + server_address=address, + batch_size=batch_size, + prefetch_size=prefetch_size, + ).as_numpy_iterator() + return ReverbReplay(server, adder, data_iterator, client=client) + + +def make_reverb_online_queue( + environment_spec: specs.EnvironmentSpec, + extra_spec: Dict[str, Any], + max_queue_size: int, + sequence_length: int, + sequence_period: int, + batch_size: int, + replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, +) -> ReverbReplay: + """Creates a single process queue from an environment spec and extra_spec.""" + signature = adders.SequenceAdder.signature(environment_spec, extra_spec) + queue = reverb.Table.queue( + name=replay_table_name, max_size=max_queue_size, signature=signature) + server = reverb.Server([queue], port=None) + can_sample = lambda: queue.can_sample(batch_size) + + # Component to add things into replay. + address = f'localhost:{server.port}' + adder = adders.SequenceAdder( + client=reverb.Client(address), + period=sequence_period, + sequence_length=sequence_length, + ) + + # The dataset object to learn from. + # We don't use datasets.make_reverb_dataset() here to avoid interleaving + # and prefetching, that doesn't work well with can_sample() check on update. + dataset = reverb.TrajectoryDataset.from_table_signature( + server_address=address, + table=replay_table_name, + max_in_flight_samples_per_worker=1, + ) + dataset = dataset.batch(batch_size, drop_remainder=True) + data_iterator = dataset.as_numpy_iterator() + return ReverbReplay(server, adder, data_iterator, can_sample=can_sample) + + +def make_reverb_prioritized_sequence_replay( + environment_spec: specs.EnvironmentSpec, + extra_spec: types.NestedSpec = (), + batch_size: int = 32, + max_replay_size: int = 100_000, + min_replay_size: int = 1, + priority_exponent: float = 0., + burn_in_length: int = 40, + sequence_length: int = 80, + sequence_period: int = 40, + replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, + prefetch_size: int = 4, +) -> ReverbReplay: + """Single-process replay for sequence data from an environment spec.""" + # Create a replay server to add data to. This uses no limiter behavior in + # order to allow the Agent interface to handle it. + replay_table = reverb.Table( + name=replay_table_name, + sampler=reverb.selectors.Prioritized(priority_exponent), + remover=reverb.selectors.Fifo(), + max_size=max_replay_size, + rate_limiter=reverb.rate_limiters.MinSize(min_replay_size), + signature=adders.SequenceAdder.signature(environment_spec, extra_spec), + ) + server = reverb.Server([replay_table], port=None) + + # The adder is used to insert observations into replay. + address = f'localhost:{server.port}' + client = reverb.Client(address) + sequence_length = burn_in_length + sequence_length + 1 + adder = adders.SequenceAdder( + client=client, + period=sequence_period, + sequence_length=sequence_length, + delta_encoded=True, + ) + + # The dataset provides an interface to sample from replay. + data_iterator = datasets.make_reverb_dataset( + table=replay_table_name, + server_address=address, + batch_size=batch_size, + prefetch_size=prefetch_size, + ).as_numpy_iterator() + return ReverbReplay(server, adder, data_iterator, client) diff --git a/acme/acme/agents/tf/__init__.py b/acme/acme/agents/tf/__init__.py new file mode 100644 index 00000000..240cb715 --- /dev/null +++ b/acme/acme/agents/tf/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/acme/acme/agents/tf/actors.py b/acme/acme/agents/tf/actors.py new file mode 100644 index 00000000..1f934be4 --- /dev/null +++ b/acme/acme/agents/tf/actors.py @@ -0,0 +1,186 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic actor implementation, using TensorFlow and Sonnet.""" + +from typing import Optional, Tuple + +from acme import adders +from acme import core +from acme import types +# Internal imports. +from acme.tf import utils as tf2_utils +from acme.tf import variable_utils as tf2_variable_utils + +import dm_env +import sonnet as snt +import tensorflow as tf +import tensorflow_probability as tfp + +tfd = tfp.distributions + + +class FeedForwardActor(core.Actor): + """A feed-forward actor. + + An actor based on a feed-forward policy which takes non-batched observations + and outputs non-batched actions. It also allows adding experiences to replay + and updating the weights from the policy on the learner. + """ + + def __init__( + self, + policy_network: snt.Module, + adder: Optional[adders.Adder] = None, + variable_client: Optional[tf2_variable_utils.VariableClient] = None, + ): + """Initializes the actor. + + Args: + policy_network: the policy to run. + adder: the adder object to which allows to add experiences to a + dataset/replay buffer. + variable_client: object which allows to copy weights from the learner copy + of the policy to the actor copy (in case they are separate). + """ + + # Store these for later use. + self._adder = adder + self._variable_client = variable_client + self._policy_network = policy_network + + @tf.function + def _policy(self, observation: types.NestedTensor) -> types.NestedTensor: + # Add a dummy batch dimension and as a side effect convert numpy to TF. + batched_observation = tf2_utils.add_batch_dim(observation) + + # Compute the policy, conditioned on the observation. + policy = self._policy_network(batched_observation) + + # Sample from the policy if it is stochastic. + action = policy.sample() if isinstance(policy, tfd.Distribution) else policy + + return action + + def select_action(self, observation: types.NestedArray) -> types.NestedArray: + # Pass the observation through the policy network. + action = self._policy(observation) + + # Return a numpy array with squeezed out batch dimension. + return tf2_utils.to_numpy_squeeze(action) + + def observe_first(self, timestep: dm_env.TimeStep): + if self._adder: + self._adder.add_first(timestep) + + def observe(self, action: types.NestedArray, next_timestep: dm_env.TimeStep): + if self._adder: + self._adder.add(action, next_timestep) + + def update(self, wait: bool = False): + if self._variable_client: + self._variable_client.update(wait) + + +class RecurrentActor(core.Actor): + """A recurrent actor. + + An actor based on a recurrent policy which takes non-batched observations and + outputs non-batched actions, and keeps track of the recurrent state inside. It + also allows adding experiences to replay and updating the weights from the + policy on the learner. + """ + + def __init__( + self, + policy_network: snt.RNNCore, + adder: Optional[adders.Adder] = None, + variable_client: Optional[tf2_variable_utils.VariableClient] = None, + store_recurrent_state: bool = True, + ): + """Initializes the actor. + + Args: + policy_network: the (recurrent) policy to run. + adder: the adder object to which allows to add experiences to a + dataset/replay buffer. + variable_client: object which allows to copy weights from the learner copy + of the policy to the actor copy (in case they are separate). + store_recurrent_state: Whether to pass the recurrent state to the adder. + """ + # Store these for later use. + self._adder = adder + self._variable_client = variable_client + self._network = policy_network + self._state = None + self._prev_state = None + self._store_recurrent_state = store_recurrent_state + + @tf.function + def _policy( + self, + observation: types.NestedTensor, + state: types.NestedTensor, + ) -> Tuple[types.NestedTensor, types.NestedTensor]: + + # Add a dummy batch dimension and as a side effect convert numpy to TF. + batched_observation = tf2_utils.add_batch_dim(observation) + + # Compute the policy, conditioned on the observation. + policy, new_state = self._network(batched_observation, state) + + # Sample from the policy if it is stochastic. + action = policy.sample() if isinstance(policy, tfd.Distribution) else policy + + return action, new_state + + def select_action(self, observation: types.NestedArray) -> types.NestedArray: + # Initialize the RNN state if necessary. + if self._state is None: + self._state = self._network.initial_state(1) + + # Step the recurrent policy forward given the current observation and state. + policy_output, new_state = self._policy(observation, self._state) + + # Bookkeeping of recurrent states for the observe method. + self._prev_state = self._state + self._state = new_state + + # Return a numpy array with squeezed out batch dimension. + return tf2_utils.to_numpy_squeeze(policy_output) + + def observe_first(self, timestep: dm_env.TimeStep): + if self._adder: + self._adder.add_first(timestep) + + # Set the state to None so that we re-initialize at the next policy call. + self._state = None + + def observe(self, action: types.NestedArray, next_timestep: dm_env.TimeStep): + if not self._adder: + return + + if not self._store_recurrent_state: + self._adder.add(action, next_timestep) + return + + numpy_state = tf2_utils.to_numpy_squeeze(self._prev_state) + self._adder.add(action, next_timestep, extras=(numpy_state,)) + + def update(self, wait: bool = False): + if self._variable_client: + self._variable_client.update(wait) + +# Internal class 1. +# Internal class 2. diff --git a/acme/acme/agents/tf/actors_test.py b/acme/acme/agents/tf/actors_test.py new file mode 100644 index 00000000..a8d4e35f --- /dev/null +++ b/acme/acme/agents/tf/actors_test.py @@ -0,0 +1,72 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for actors_tf2.""" + +from acme import environment_loop +from acme import specs +from acme.agents.tf import actors +from acme.testing import fakes +import dm_env +import numpy as np +import sonnet as snt +import tensorflow as tf + +from absl.testing import absltest + + +def _make_fake_env() -> dm_env.Environment: + env_spec = specs.EnvironmentSpec( + observations=specs.Array(shape=(10, 5), dtype=np.float32), + actions=specs.DiscreteArray(num_values=3), + rewards=specs.Array(shape=(), dtype=np.float32), + discounts=specs.BoundedArray( + shape=(), dtype=np.float32, minimum=0., maximum=1.), + ) + return fakes.Environment(env_spec, episode_length=10) + + +class ActorTest(absltest.TestCase): + + def test_feedforward(self): + environment = _make_fake_env() + env_spec = specs.make_environment_spec(environment) + + network = snt.Sequential([ + snt.Flatten(), + snt.Linear(env_spec.actions.num_values), + lambda x: tf.argmax(x, axis=-1, output_type=env_spec.actions.dtype), + ]) + + actor = actors.FeedForwardActor(network) + loop = environment_loop.EnvironmentLoop(environment, actor) + loop.run(20) + + def test_recurrent(self): + environment = _make_fake_env() + env_spec = specs.make_environment_spec(environment) + + network = snt.DeepRNN([ + snt.Flatten(), + snt.Linear(env_spec.actions.num_values), + lambda x: tf.argmax(x, axis=-1, output_type=env_spec.actions.dtype), + ]) + + actor = actors.RecurrentActor(network) + loop = environment_loop.EnvironmentLoop(environment, actor) + loop.run(20) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/bc/README.md b/acme/acme/agents/tf/bc/README.md new file mode 100644 index 00000000..eca9361e --- /dev/null +++ b/acme/acme/agents/tf/bc/README.md @@ -0,0 +1,8 @@ +# Behavioral Cloning (BC) + +This folder contains an implementation for supervised learning of a policy from +a dataset of observations and target actions. This is an approach known as +Behavioral Cloning, introduced by [Pomerleau, 1989]. There is an example which +generates data for bsuite environment `Deep Sea` using an optimal policy. + +[Pomerleau, 1989]: https://papers.nips.cc/paper/95-alvinn-an-autonomous-land-vehicle-in-a-neural-network.pdf diff --git a/acme/acme/agents/tf/bc/__init__.py b/acme/acme/agents/tf/bc/__init__.py new file mode 100644 index 00000000..d415c73b --- /dev/null +++ b/acme/acme/agents/tf/bc/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementation of a behavior cloning (BC) agent.""" + +from acme.agents.tf.bc.learning import BCLearner diff --git a/acme/acme/agents/tf/bc/learning.py b/acme/acme/agents/tf/bc/learning.py new file mode 100644 index 00000000..05c692eb --- /dev/null +++ b/acme/acme/agents/tf/bc/learning.py @@ -0,0 +1,123 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BC Learner implementation.""" + +from typing import Dict, List, Optional + +import acme +from acme import types +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import numpy as np +import sonnet as snt +import tensorflow as tf + + +class BCLearner(acme.Learner, tf2_savers.TFSaveable): + """BC learner. + + This is the learning component of a BC agent. IE it takes a dataset as input + and implements update functionality to learn from this dataset. + """ + + def __init__(self, + network: snt.Module, + learning_rate: float, + dataset: tf.data.Dataset, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True): + """Initializes the learner. + + Args: + network: the BC network (the one being optimized) + learning_rate: learning rate for the cross-entropy update. + dataset: dataset to learn from. + counter: Counter object for (potentially distributed) counting. + logger: Logger object for writing logs to. + checkpoint: boolean indicating whether to checkpoint the learner. + """ + + self._counter = counter or counting.Counter() + self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) + + # Get an iterator over the dataset. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + # TODO(b/155086959): Fix type stubs and remove. + + self._network = network + self._optimizer = snt.optimizers.Adam(learning_rate) + + self._variables: List[List[tf.Tensor]] = [network.trainable_variables] + self._num_steps = tf.Variable(0, dtype=tf.int32) + + # Create a snapshotter object. + if checkpoint: + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={'network': network}, time_delta_minutes=60.) + else: + self._snapshotter = None + + @tf.function + def _step(self) -> Dict[str, tf.Tensor]: + """Do a step of SGD and update the priorities.""" + + # Pull out the data needed for updates/priorities. + inputs = next(self._iterator) + transitions: types.Transition = inputs.data + + with tf.GradientTape() as tape: + # Evaluate our networks. + logits = self._network(transitions.observation) + cce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + loss = cce(transitions.action, logits) + + gradients = tape.gradient(loss, self._network.trainable_variables) + self._optimizer.apply(gradients, self._network.trainable_variables) + + self._num_steps.assign_add(1) + + # Compute the global norm of the gradients for logging. + global_gradient_norm = tf.linalg.global_norm(gradients) + fetches = {'loss': loss, 'gradient_norm': global_gradient_norm} + + return fetches + + def step(self): + # Do a batch of SGD. + result = self._step() + + # Update our counts and record it. + counts = self._counter.increment(steps=1) + result.update(counts) + + # Snapshot and attempt to write logs. + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(result) + + def get_variables(self, names: List[str]) -> List[np.ndarray]: + return tf2_utils.to_numpy(self._variables) + + @property + def state(self): + """Returns the stateful parts of the learner for checkpointing.""" + return { + 'network': self._network, + 'optimizer': self._optimizer, + 'num_steps': self._num_steps + } diff --git a/acme/acme/agents/tf/bcq/README.md b/acme/acme/agents/tf/bcq/README.md new file mode 100644 index 00000000..18908141 --- /dev/null +++ b/acme/acme/agents/tf/bcq/README.md @@ -0,0 +1,8 @@ +# Discrete Batch-Constrained Deep Q-learning (BCQ) + +This folder contains an implementation of the discrete BCQ algorithm introduced +in ([Fujimoto et al., 2019]), which is a variant of the BCQ algorithm +([Fujimoto et al., 2018]). + +[Fujimoto et al., 2018]: https://arxiv.org/pdf/1812.02900.pdf +[Fujimoto et al., 2019]: https://arxiv.org/pdf/1910.01708.pdf diff --git a/acme/acme/agents/tf/bcq/__init__.py b/acme/acme/agents/tf/bcq/__init__.py new file mode 100644 index 00000000..59bdcaf5 --- /dev/null +++ b/acme/acme/agents/tf/bcq/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Batch-Constrained Deep Q-learning (BCQ).""" + +from acme.agents.tf.bcq.discrete_learning import DiscreteBCQLearner diff --git a/acme/acme/agents/tf/bcq/discrete_learning.py b/acme/acme/agents/tf/bcq/discrete_learning.py new file mode 100644 index 00000000..4b4e2340 --- /dev/null +++ b/acme/acme/agents/tf/bcq/discrete_learning.py @@ -0,0 +1,260 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Discrete BCQ learner implementation. + +As described in https://arxiv.org/pdf/1910.01708.pdf. +""" + +import copy +from typing import Dict, List, Optional + +from acme import core +from acme import types +from acme.adders import reverb as adders +from acme.agents.tf import bc +from acme.tf import losses +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.tf.networks import discrete as discrete_networks +from acme.utils import counting +from acme.utils import loggers +import numpy as np +import reverb +import sonnet as snt +import tensorflow as tf +import trfl + + +class _InternalBCQLearner(core.Learner, tf2_savers.TFSaveable): + """Internal BCQ learner. + + This implements the Q-learning component in the discrete BCQ algorithm. + """ + + def __init__( + self, + network: discrete_networks.DiscreteFilteredQNetwork, + discount: float, + importance_sampling_exponent: float, + learning_rate: float, + target_update_period: int, + dataset: tf.data.Dataset, + huber_loss_parameter: float = 1., + replay_client: Optional[reverb.TFClient] = None, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = False, + ): + """Initializes the learner. + + Args: + network: BCQ network + discount: discount to use for TD updates. + importance_sampling_exponent: power to which importance weights are raised + before normalizing. + learning_rate: learning rate for the q-network update. + target_update_period: number of learner steps to perform before updating + the target networks. + dataset: dataset to learn from, whether fixed or from a replay buffer (see + `acme.datasets.reverb.make_reverb_dataset` documentation). + huber_loss_parameter: Quadratic-linear boundary for Huber loss. + replay_client: client to replay to allow for updating priorities. + counter: Counter object for (potentially distributed) counting. + logger: Logger object for writing logs to. + checkpoint: boolean indicating whether to checkpoint the learner. + """ + + # Internalise agent components (replay buffer, networks, optimizer). + # TODO(b/155086959): Fix type stubs and remove. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + self._network = network + self._q_network = network.q_network + self._target_q_network = copy.deepcopy(network.q_network) + self._optimizer = snt.optimizers.Adam(learning_rate) + self._replay_client = replay_client + + # Internalise the hyperparameters. + self._discount = discount + self._target_update_period = target_update_period + self._importance_sampling_exponent = importance_sampling_exponent + self._huber_loss_parameter = huber_loss_parameter + + # Learner state. + self._variables = [self._network.trainable_variables] + self._num_steps = tf.Variable(0, dtype=tf.int32) + + # Internalise logging/counting objects. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger('learner', + save_data=False) + + # Create a snapshotter object. + if checkpoint: + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={'network': network}, time_delta_minutes=60.) + else: + self._snapshotter = None + + @tf.function + def _step(self) -> Dict[str, tf.Tensor]: + """Do a step of SGD and update the priorities.""" + + # Pull out the data needed for updates/priorities. + inputs = next(self._iterator) + transitions: types.Transition = inputs.data + keys, probs = inputs.info[:2] + + with tf.GradientTape() as tape: + # Evaluate our networks. + q_tm1 = self._q_network(transitions.observation) + q_t_value = self._target_q_network(transitions.next_observation) + q_t_selector = self._network(transitions.next_observation) + + # The rewards and discounts have to have the same type as network values. + r_t = tf.cast(transitions.reward, q_tm1.dtype) + r_t = tf.clip_by_value(r_t, -1., 1.) + d_t = tf.cast(transitions.discount, q_tm1.dtype) * tf.cast( + self._discount, q_tm1.dtype) + + # Compute the loss. + _, extra = trfl.double_qlearning(q_tm1, transitions.action, r_t, d_t, + q_t_value, q_t_selector) + loss = losses.huber(extra.td_error, self._huber_loss_parameter) + + # Get the importance weights. + importance_weights = 1. / probs # [B] + importance_weights **= self._importance_sampling_exponent + importance_weights /= tf.reduce_max(importance_weights) + + # Reweight. + loss *= tf.cast(importance_weights, loss.dtype) # [B] + loss = tf.reduce_mean(loss, axis=[0]) # [] + + # Do a step of SGD. + gradients = tape.gradient(loss, self._network.trainable_variables) + self._optimizer.apply(gradients, self._network.trainable_variables) + + # Update the priorities in the replay buffer. + if self._replay_client: + priorities = tf.cast(tf.abs(extra.td_error), tf.float64) + self._replay_client.update_priorities( + table=adders.DEFAULT_PRIORITY_TABLE, keys=keys, priorities=priorities) + + # Periodically update the target network. + if tf.math.mod(self._num_steps, self._target_update_period) == 0: + for src, dest in zip(self._q_network.variables, + self._target_q_network.variables): + dest.assign(src) + self._num_steps.assign_add(1) + + # Compute the global norm of the gradients for logging. + global_gradient_norm = tf.linalg.global_norm(gradients) + + # Compute statistics of the Q-values for logging. + max_q = tf.reduce_max(q_t_value) + min_q = tf.reduce_min(q_t_value) + mean_q, var_q = tf.nn.moments(q_t_value, [0, 1]) + + # Report loss & statistics for logging. + fetches = { + 'gradient_norm': global_gradient_norm, + 'loss': loss, + 'max_q': max_q, + 'mean_q': mean_q, + 'min_q': min_q, + 'var_q': var_q, + } + + return fetches + + def step(self): + # Do a batch of SGD. + result = self._step() + + # Update our counts and record it. + counts = self._counter.increment(steps=1) + result.update(counts) + + # Snapshot and attempt to write logs. + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(result) + + def get_variables(self, names: List[str]) -> List[np.ndarray]: + return tf2_utils.to_numpy(self._variables) + + @property + def state(self): + """Returns the stateful parts of the learner for checkpointing.""" + return { + 'network': self._network, + 'target_q_network': self._target_q_network, + 'optimizer': self._optimizer, + 'num_steps': self._num_steps + } + + +class DiscreteBCQLearner(core.Learner, tf2_savers.TFSaveable): + """Discrete BCQ learner. + + This learner combines supervised BC learning and Q learning to implement the + discrete BCQ algorithm as described in https://arxiv.org/pdf/1910.01708.pdf. + """ + + def __init__(self, + network: discrete_networks.DiscreteFilteredQNetwork, + dataset: tf.data.Dataset, + learning_rate: float, + counter: Optional[counting.Counter] = None, + bc_logger: Optional[loggers.Logger] = None, + bcq_logger: Optional[loggers.Logger] = None, + **bcq_learner_kwargs): + counter = counter or counting.Counter() + self._bc_logger = bc_logger or loggers.TerminalLogger('bc_learner', + time_delta=1.) + self._bcq_logger = bcq_logger or loggers.TerminalLogger('bcq_learner', + time_delta=1.) + + self._bc_learner = bc.BCLearner( + network=network.g_network, + learning_rate=learning_rate, + dataset=dataset, + counter=counting.Counter(counter, 'bc'), + logger=self._bc_logger, + checkpoint=False) + self._bcq_learner = _InternalBCQLearner( + network=network, + learning_rate=learning_rate, + dataset=dataset, + counter=counting.Counter(counter, 'bcq'), + logger=self._bcq_logger, + **bcq_learner_kwargs) + + def get_variables(self, names): + return self._bcq_learner.get_variables(names) + + @property + def state(self): + bc_state = self._bc_learner.state + bc_state.pop('network') # No need to checkpoint the BC network. + bcq_state = self._bcq_learner.state + state = dict() + state.update({f'bc_{k}': v for k, v in bc_state.items()}) + state.update({f'bcq_{k}': v for k, v in bcq_state.items()}) + return state + + def step(self): + self._bc_learner.step() + self._bcq_learner.step() diff --git a/acme/acme/agents/tf/bcq/discrete_learning_test.py b/acme/acme/agents/tf/bcq/discrete_learning_test.py new file mode 100644 index 00000000..8169f10c --- /dev/null +++ b/acme/acme/agents/tf/bcq/discrete_learning_test.py @@ -0,0 +1,81 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for discrete BCQ learner.""" + +from acme import specs +from acme.agents.tf import bcq +from acme.testing import fakes +from acme.tf import utils as tf2_utils +from acme.tf.networks import discrete as discrete_networks +from acme.utils import counting +import numpy as np +import sonnet as snt + +from absl.testing import absltest + + +def _make_network(action_spec: specs.DiscreteArray) -> snt.Module: + return snt.Sequential([ + snt.Flatten(), + snt.nets.MLP([50, 50, action_spec.num_values]), + ]) + + +class DiscreteBCQLearnerTest(absltest.TestCase): + + def test_full_learner(self): + # Create dataset. + environment = fakes.DiscreteEnvironment( + num_actions=5, + num_observations=10, + obs_dtype=np.float32, + episode_length=10) + spec = specs.make_environment_spec(environment) + dataset = fakes.transition_dataset(environment).batch(2) + + # Build network. + g_network = _make_network(spec.actions) + q_network = _make_network(spec.actions) + network = discrete_networks.DiscreteFilteredQNetwork(g_network=g_network, + q_network=q_network, + threshold=0.5) + tf2_utils.create_variables(network, [spec.observations]) + + # Build learner. + counter = counting.Counter() + learner = bcq.DiscreteBCQLearner( + network=network, + dataset=dataset, + learning_rate=1e-4, + discount=0.99, + importance_sampling_exponent=0.2, + target_update_period=100, + counter=counter) + + # Run a learner step. + learner.step() + + # Check counts from BC and BCQ learners. + counts = counter.get_counts() + self.assertEqual(1, counts['bc_steps']) + self.assertEqual(1, counts['bcq_steps']) + + # Check learner state. + self.assertEqual(1, learner.state['bc_num_steps'].numpy()) + self.assertEqual(1, learner.state['bcq_num_steps'].numpy()) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/crr/__init__.py b/acme/acme/agents/tf/crr/__init__.py new file mode 100644 index 00000000..c6416dc8 --- /dev/null +++ b/acme/acme/agents/tf/crr/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementations of a CRR agent.""" + +from acme.agents.tf.crr.recurrent_learning import RCRRLearner diff --git a/acme/acme/agents/tf/crr/recurrent_learning.py b/acme/acme/agents/tf/crr/recurrent_learning.py new file mode 100644 index 00000000..9369b695 --- /dev/null +++ b/acme/acme/agents/tf/crr/recurrent_learning.py @@ -0,0 +1,407 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Recurrent CRR learner implementation.""" + +import operator +import time +from typing import Dict, List, Optional + +from acme import core +from acme.tf import losses +from acme.tf import networks +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import numpy as np +import reverb +import sonnet as snt +import tensorflow as tf +import tree + + +class RCRRLearner(core.Learner): + """Recurrent CRR learner. + + This is the learning component of a RCRR agent. It takes a dataset as + input and implements update functionality to learn from this dataset. + """ + + def __init__(self, + policy_network: snt.RNNCore, + critic_network: networks.CriticDeepRNN, + target_policy_network: snt.RNNCore, + target_critic_network: networks.CriticDeepRNN, + dataset: tf.data.Dataset, + accelerator_strategy: Optional[tf.distribute.Strategy] = None, + behavior_network: Optional[snt.Module] = None, + cwp_network: Optional[snt.Module] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + discount: float = 0.99, + target_update_period: int = 100, + num_action_samples_td_learning: int = 1, + num_action_samples_policy_weight: int = 4, + baseline_reduce_function: str = 'mean', + clipping: bool = True, + policy_improvement_modes: str = 'exp', + ratio_upper_bound: float = 20., + beta: float = 1.0, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = False): + """Initializes the learner. + + Args: + policy_network: the online (optimized) policy. + critic_network: the online critic. + target_policy_network: the target policy (which lags behind the online + policy). + target_critic_network: the target critic. + dataset: dataset to learn from, whether fixed or from a replay buffer + (see `acme.datasets.reverb.make_reverb_dataset` documentation). + accelerator_strategy: the strategy used to distribute computation, + whether on a single, or multiple, GPU or TPU; as supported by + tf.distribute. + behavior_network: The network to snapshot under `policy` name. If None, + snapshots `policy_network` instead. + cwp_network: CWP network to snapshot: samples actions + from the policy and weighs them with the critic, then returns the action + by sampling from the softmax distribution using critic values as logits. + Used only for snapshotting, not training. + policy_optimizer: the optimizer to be applied to the policy loss. + critic_optimizer: the optimizer to be applied to the distributional + Bellman loss. + discount: discount to use for TD updates. + target_update_period: number of learner steps to perform before updating + the target networks. + num_action_samples_td_learning: number of action samples to use to + estimate expected value of the critic loss w.r.t. stochastic policy. + num_action_samples_policy_weight: number of action samples to use to + estimate the advantage function for the CRR weighting of the policy + loss. + baseline_reduce_function: one of 'mean', 'max', 'min'. Way of aggregating + values from `num_action_samples` estimates of the value function. + clipping: whether to clip gradients by global norm. + policy_improvement_modes: one of 'exp', 'binary', 'all'. CRR mode which + determines how the advantage function is processed before being + multiplied by the policy loss. + ratio_upper_bound: if policy_improvement_modes is 'exp', determines + the upper bound of the weight (i.e. the weight is + min(exp(advantage / beta), upper_bound) + ). + beta: if policy_improvement_modes is 'exp', determines the beta (see + above). + counter: counter object used to keep track of steps. + logger: logger object to be used by learner. + checkpoint: boolean indicating whether to checkpoint the learner. + """ + + if accelerator_strategy is None: + accelerator_strategy = snt.distribute.Replicator() + self._accelerator_strategy = accelerator_strategy + self._policy_improvement_modes = policy_improvement_modes + self._ratio_upper_bound = ratio_upper_bound + self._num_action_samples_td_learning = num_action_samples_td_learning + self._num_action_samples_policy_weight = num_action_samples_policy_weight + self._baseline_reduce_function = baseline_reduce_function + self._beta = beta + + # When running on TPUs we have to know the amount of memory required (and + # thus the sequence length) at the graph compilation stage. At the moment, + # the only way to get it is to sample from the dataset, since the dataset + # does not have any metadata, see b/160672927 to track this upcoming + # feature. + sample = next(dataset.as_numpy_iterator()) + self._sequence_length = sample.action.shape[1] + + self._counter = counter or counting.Counter() + self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) + self._discount = discount + self._clipping = clipping + + self._target_update_period = target_update_period + + with self._accelerator_strategy.scope(): + # Necessary to track when to update target networks. + self._num_steps = tf.Variable(0, dtype=tf.int32) + + # (Maybe) distributing the dataset across multiple accelerators. + distributed_dataset = self._accelerator_strategy.experimental_distribute_dataset( + dataset) + self._iterator = iter(distributed_dataset) + + # Create the optimizers. + self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + + # Store online and target networks. + self._policy_network = policy_network + self._critic_network = critic_network + self._target_policy_network = target_policy_network + self._target_critic_network = target_critic_network + + # Expose the variables. + self._variables = { + 'critic': self._target_critic_network.variables, + 'policy': self._target_policy_network.variables, + } + + # Create a checkpointer object. + self._checkpointer = None + self._snapshotter = None + if checkpoint: + self._checkpointer = tf2_savers.Checkpointer( + objects_to_save={ + 'counter': self._counter, + 'policy': self._policy_network, + 'critic': self._critic_network, + 'target_policy': self._target_policy_network, + 'target_critic': self._target_critic_network, + 'policy_optimizer': self._policy_optimizer, + 'critic_optimizer': self._critic_optimizer, + 'num_steps': self._num_steps, + }, + time_delta_minutes=30.) + + raw_policy = snt.DeepRNN( + [policy_network, networks.StochasticSamplingHead()]) + critic_mean = networks.CriticDeepRNN( + [critic_network, networks.StochasticMeanHead()]) + objects_to_save = { + 'raw_policy': raw_policy, + 'critic': critic_mean, + } + if behavior_network is not None: + objects_to_save['policy'] = behavior_network + if cwp_network is not None: + objects_to_save['cwp_policy'] = cwp_network + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save=objects_to_save, time_delta_minutes=30) + # Timestamp to keep track of the wall time. + self._walltime_timestamp = time.time() + + def _step(self, sample: reverb.ReplaySample) -> Dict[str, tf.Tensor]: + # Transpose batch and sequence axes, i.e. [B, T, ...] to [T, B, ...]. + sample = tf2_utils.batch_to_sequence(sample) + observations = sample.observation + actions = sample.action + rewards = sample.reward + discounts = sample.discount + + dtype = rewards.dtype + + # Cast the additional discount to match the environment discount dtype. + discount = tf.cast(self._discount, dtype=discounts.dtype) + + # Loss cumulants across time. These cannot be python mutable objects. + critic_loss = 0. + policy_loss = 0. + + # Each transition induces a policy loss, which we then weight using + # the `policy_loss_coef_t`; shape [B], see https://arxiv.org/abs/2006.15134. + # `policy_loss_coef` is a scalar average of these coefficients across + # the batch and sequence length dimensions. + policy_loss_coef = 0. + + per_device_batch_size = actions.shape[1] + + # Initialize recurrent states. + critic_state = self._critic_network.initial_state(per_device_batch_size) + target_critic_state = critic_state + policy_state = self._policy_network.initial_state(per_device_batch_size) + target_policy_state = policy_state + + with tf.GradientTape(persistent=True) as tape: + for t in range(1, self._sequence_length): + o_tm1 = tree.map_structure(operator.itemgetter(t - 1), observations) + a_tm1 = tree.map_structure(operator.itemgetter(t - 1), actions) + r_t = tree.map_structure(operator.itemgetter(t - 1), rewards) + d_t = tree.map_structure(operator.itemgetter(t - 1), discounts) + o_t = tree.map_structure(operator.itemgetter(t), observations) + + if t != 1: + # By only updating the target critic state here we are forcing + # the target critic to ignore observations[0]. Otherwise, the + # target_critic will be unrolled for one more timestep than critic. + # The smaller the sequence length, the more problematic this is: if + # you use RNN on sequences of length 2, you would expect the code to + # never use recurrent connections. But if you don't skip updating the + # target_critic_state on observation[0] here, it won't be the case. + _, target_critic_state = self._target_critic_network( + o_tm1, a_tm1, target_critic_state) + + # ========================= Critic learning ============================ + q_tm1, next_critic_state = self._critic_network(o_tm1, a_tm1, + critic_state) + target_action_distribution, target_policy_state = self._target_policy_network( + o_t, target_policy_state) + + sampled_actions_t = target_action_distribution.sample( + self._num_action_samples_td_learning) + # [N, B, ...] + tiled_o_t = tf2_utils.tile_nested( + o_t, self._num_action_samples_td_learning) + tiled_target_critic_state = tf2_utils.tile_nested( + target_critic_state, self._num_action_samples_td_learning) + + # Compute the target critic's Q-value of the sampled actions. + sampled_q_t, _ = snt.BatchApply(self._target_critic_network)( + tiled_o_t, sampled_actions_t, tiled_target_critic_state) + + # Compute average logits by first reshaping them to [N, B, A] and then + # normalizing them across atoms. + new_shape = [self._num_action_samples_td_learning, r_t.shape[0], -1] + sampled_logits = tf.reshape(sampled_q_t.logits, new_shape) + sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1) + averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0) + + # Construct the expected distributional value for bootstrapping. + q_t = networks.DiscreteValuedDistribution( + values=sampled_q_t.values, logits=averaged_logits) + critic_loss_t = losses.categorical(q_tm1, r_t, discount * d_t, q_t) + critic_loss_t = tf.reduce_mean(critic_loss_t) + + # ========================= Actor learning ============================= + action_distribution_tm1, policy_state = self._policy_network( + o_tm1, policy_state) + q_tm1_mean = q_tm1.mean() + + # Compute the estimate of the value function based on + # self._num_action_samples_policy_weight samples from the policy. + tiled_o_tm1 = tf2_utils.tile_nested( + o_tm1, self._num_action_samples_policy_weight) + tiled_critic_state = tf2_utils.tile_nested( + critic_state, self._num_action_samples_policy_weight) + action_tm1 = action_distribution_tm1.sample( + self._num_action_samples_policy_weight) + tiled_z_tm1, _ = snt.BatchApply(self._critic_network)( + tiled_o_tm1, action_tm1, tiled_critic_state) + tiled_v_tm1 = tf.reshape(tiled_z_tm1.mean(), + [self._num_action_samples_policy_weight, -1]) + + # Use mean, min, or max to aggregate Q(s, a_i), a_i ~ pi(s) into the + # final estimate of the value function. + if self._baseline_reduce_function == 'mean': + v_tm1_estimate = tf.reduce_mean(tiled_v_tm1, axis=0) + elif self._baseline_reduce_function == 'max': + v_tm1_estimate = tf.reduce_max(tiled_v_tm1, axis=0) + elif self._baseline_reduce_function == 'min': + v_tm1_estimate = tf.reduce_min(tiled_v_tm1, axis=0) + + # Assert that action_distribution_tm1 is a batch of multivariate + # distributions (in contrast to e.g. a [batch, action_size] collection + # of 1d distributions). + assert len(action_distribution_tm1.batch_shape) == 1 + policy_loss_batch = -action_distribution_tm1.log_prob(a_tm1) + + advantage = q_tm1_mean - v_tm1_estimate + if self._policy_improvement_modes == 'exp': + policy_loss_coef_t = tf.math.minimum( + tf.math.exp(advantage / self._beta), self._ratio_upper_bound) + elif self._policy_improvement_modes == 'binary': + policy_loss_coef_t = tf.cast(advantage > 0, dtype=dtype) + elif self._policy_improvement_modes == 'all': + # Regress against all actions (effectively pure BC). + policy_loss_coef_t = 1. + policy_loss_coef_t = tf.stop_gradient(policy_loss_coef_t) + + policy_loss_batch *= policy_loss_coef_t + policy_loss_t = tf.reduce_mean(policy_loss_batch) + + critic_state = next_critic_state + + critic_loss += critic_loss_t + policy_loss += policy_loss_t + policy_loss_coef += tf.reduce_mean(policy_loss_coef_t) # For logging. + + # Divide by sequence length to get mean losses. + critic_loss /= tf.cast(self._sequence_length, dtype=dtype) + policy_loss /= tf.cast(self._sequence_length, dtype=dtype) + policy_loss_coef /= tf.cast(self._sequence_length, dtype=dtype) + + # Compute gradients. + critic_gradients = tape.gradient(critic_loss, + self._critic_network.trainable_variables) + policy_gradients = tape.gradient(policy_loss, + self._policy_network.trainable_variables) + + # Delete the tape manually because of the persistent=True flag. + del tape + + # Sync gradients across GPUs or TPUs. + ctx = tf.distribute.get_replica_context() + critic_gradients = ctx.all_reduce('mean', critic_gradients) + policy_gradients = ctx.all_reduce('mean', policy_gradients) + + # Maybe clip gradients. + if self._clipping: + policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.)[0] + critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0] + + # Apply gradients. + self._critic_optimizer.apply(critic_gradients, + self._critic_network.trainable_variables) + self._policy_optimizer.apply(policy_gradients, + self._policy_network.trainable_variables) + + source_variables = ( + self._critic_network.variables + self._policy_network.variables) + target_variables = ( + self._target_critic_network.variables + + self._target_policy_network.variables) + + # Make online -> target network update ops. + if tf.math.mod(self._num_steps, self._target_update_period) == 0: + for src, dest in zip(source_variables, target_variables): + dest.assign(src) + self._num_steps.assign_add(1) + + return { + 'critic_loss': critic_loss, + 'policy_loss': policy_loss, + 'policy_loss_coef': policy_loss_coef, + } + + @tf.function + def _replicated_step(self) -> Dict[str, tf.Tensor]: + sample = next(self._iterator) + fetches = self._accelerator_strategy.run(self._step, args=(sample,)) + mean = tf.distribute.ReduceOp.MEAN + return { + k: self._accelerator_strategy.reduce(mean, fetches[k], axis=None) + for k in fetches + } + + def step(self): + # Run the learning step. + with self._accelerator_strategy.scope(): + fetches = self._replicated_step() + + # Update our counts and record it. + new_timestamp = time.time() + time_passed = new_timestamp - self._walltime_timestamp + self._walltime_timestamp = new_timestamp + counts = self._counter.increment(steps=1, wall_time=time_passed) + fetches.update(counts) + + # Checkpoint and attempt to write the logs. + if self._checkpointer is not None: + self._checkpointer.save() + self._snapshotter.save() + self._logger.write(fetches) + + def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: + return [tf2_utils.to_numpy(self._variables[name]) for name in names] diff --git a/acme/acme/agents/tf/d4pg/README.md b/acme/acme/agents/tf/d4pg/README.md new file mode 100644 index 00000000..48c3d75c --- /dev/null +++ b/acme/acme/agents/tf/d4pg/README.md @@ -0,0 +1,24 @@ +# Distributed Distributional Deep Deterministic Policy Gradient (D4PG) + +This folder contains an implementation of the D4PG agent introduced in +([Barth-Maron et al., 2018]), which extends previous Deterministic Policy +Gradient (DPG) algorithms ([Silver et al., 2014]; [Lillicrap et al., 2015]) by +using a distributional Q-network similar to C51 ([Bellemare et al., 2017]). + +Note that since the synchronous agent is not distributed (i.e. not using +multiple asynchronous actors), it is not precisely speaking D4PG; a more +accurate name would be Distributional DDPG. In this algorithm, the critic +outputs a distribution over state-action values; in this particular case this +discrete distribution is parametrized as in C51. + +Detailed notes: + +- The `vmin|vmax` hyperparameters of the distributional critic may need tuning + depending on your environment's rewards. A good rule of thumb is to set + `vmax` to the discounted sum of the maximum instantaneous rewards for the + maximum episode length; then set `vmin` to `-vmax`. + +[Barth-Maron et al., 2018]: https://arxiv.org/abs/1804.08617 +[Bellemare et al., 2017]: https://arxiv.org/abs/1707.06887 +[Lillicrap et al., 2015]: https://arxiv.org/abs/1509.02971 +[Silver et al., 2014]: http://proceedings.mlr.press/v32/silver14 diff --git a/acme/acme/agents/tf/d4pg/__init__.py b/acme/acme/agents/tf/d4pg/__init__.py new file mode 100644 index 00000000..f2a39ee5 --- /dev/null +++ b/acme/acme/agents/tf/d4pg/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementations of a D4PG agent.""" + +from acme.agents.tf.d4pg.agent import D4PG +from acme.agents.tf.d4pg.agent_distributed import DistributedD4PG +from acme.agents.tf.d4pg.learning import D4PGLearner +from acme.agents.tf.d4pg.networks import make_default_networks diff --git a/acme/acme/agents/tf/d4pg/agent.py b/acme/acme/agents/tf/d4pg/agent.py new file mode 100644 index 00000000..96daa1ae --- /dev/null +++ b/acme/acme/agents/tf/d4pg/agent.py @@ -0,0 +1,463 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""D4PG agent implementation.""" + +import copy +import dataclasses +import functools +from typing import Iterator, List, Optional, Tuple, Union, Sequence + +from acme import adders +from acme import core +from acme import datasets +from acme import specs +from acme import types +from acme.adders import reverb as reverb_adders +from acme.agents import agent +from acme.agents.tf import actors +from acme.agents.tf.d4pg import learning +from acme.tf import networks as network_utils +from acme.tf import utils +from acme.tf import variable_utils +from acme.utils import counting +from acme.utils import loggers +import reverb +import sonnet as snt +import tensorflow as tf + +Replicator = Union[snt.distribute.Replicator, snt.distribute.TpuReplicator] + + +@dataclasses.dataclass +class D4PGConfig: + """Configuration options for the D4PG agent.""" + + accelerator: Optional[str] = None + discount: float = 0.99 + batch_size: int = 256 + prefetch_size: int = 4 + target_update_period: int = 100 + variable_update_period: int = 1000 + policy_optimizer: Optional[snt.Optimizer] = None + critic_optimizer: Optional[snt.Optimizer] = None + min_replay_size: int = 1000 + max_replay_size: int = 1000000 + samples_per_insert: Optional[float] = 32.0 + n_step: int = 5 + sigma: float = 0.3 + clipping: bool = True + replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE + + +@dataclasses.dataclass +class D4PGNetworks: + """Structure containing the networks for D4PG.""" + + policy_network: snt.Module + critic_network: snt.Module + observation_network: snt.Module + + def __init__( + self, + policy_network: snt.Module, + critic_network: snt.Module, + observation_network: types.TensorTransformation, + ): + # This method is implemented (rather than added by the dataclass decorator) + # in order to allow observation network to be passed as an arbitrary tensor + # transformation rather than as a snt Module. + # TODO(mwhoffman): use Protocol rather than Module/TensorTransformation. + self.policy_network = policy_network + self.critic_network = critic_network + self.observation_network = utils.to_sonnet_module(observation_network) + + def init(self, environment_spec: specs.EnvironmentSpec): + """Initialize the networks given an environment spec.""" + # Get observation and action specs. + act_spec = environment_spec.actions + obs_spec = environment_spec.observations + + # Create variables for the observation net and, as a side-effect, get a + # spec describing the embedding space. + emb_spec = utils.create_variables(self.observation_network, [obs_spec]) + + # Create variables for the policy and critic nets. + _ = utils.create_variables(self.policy_network, [emb_spec]) + _ = utils.create_variables(self.critic_network, [emb_spec, act_spec]) + + def make_policy( + self, + environment_spec: specs.EnvironmentSpec, + sigma: float = 0.0, + ) -> snt.Module: + """Create a single network which evaluates the policy.""" + # Stack the observation and policy networks. + stack = [ + self.observation_network, + self.policy_network, + ] + + # If a stochastic/non-greedy policy is requested, add Gaussian noise on + # top to enable a simple form of exploration. + # TODO(mwhoffman): Refactor this to remove it from the class. + if sigma > 0.0: + stack += [ + network_utils.ClippedGaussian(sigma), + network_utils.ClipToSpec(environment_spec.actions), + ] + + # Return a network which sequentially evaluates everything in the stack. + return snt.Sequential(stack) + + +class D4PGBuilder: + """Builder for D4PG which constructs individual components of the agent.""" + + def __init__(self, config: D4PGConfig): + self._config = config + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + ) -> List[reverb.Table]: + """Create tables to insert data into.""" + if self._config.samples_per_insert is None: + # We will take a samples_per_insert ratio of None to mean that there is + # no limit, i.e. this only implies a min size limit. + limiter = reverb.rate_limiters.MinSize(self._config.min_replay_size) + + else: + # Create enough of an error buffer to give a 10% tolerance in rate. + samples_per_insert_tolerance = 0.1 * self._config.samples_per_insert + error_buffer = self._config.min_replay_size * samples_per_insert_tolerance + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._config.min_replay_size, + samples_per_insert=self._config.samples_per_insert, + error_buffer=error_buffer) + + replay_table = reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_replay_size, + rate_limiter=limiter, + signature=reverb_adders.NStepTransitionAdder.signature( + environment_spec)) + + return [replay_table] + + def make_dataset_iterator( + self, + reverb_client: reverb.Client, + ) -> Iterator[reverb.ReplaySample]: + """Create a dataset iterator to use for learning/updating the agent.""" + # The dataset provides an interface to sample from replay. + dataset = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=reverb_client.server_address, + batch_size=self._config.batch_size, + prefetch_size=self._config.prefetch_size) + + replicator = get_replicator(self._config.accelerator) + dataset = replicator.experimental_distribute_dataset(dataset) + + # TODO(b/155086959): Fix type stubs and remove. + return iter(dataset) # pytype: disable=wrong-arg-types + + def make_adder( + self, + replay_client: reverb.Client, + ) -> adders.Adder: + """Create an adder which records data generated by the actor/environment.""" + return reverb_adders.NStepTransitionAdder( + priority_fns={self._config.replay_table_name: lambda x: 1.}, + client=replay_client, + n_step=self._config.n_step, + discount=self._config.discount) + + def make_actor( + self, + policy_network: snt.Module, + adder: Optional[adders.Adder] = None, + variable_source: Optional[core.VariableSource] = None, + ): + """Create an actor instance.""" + if variable_source: + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = variable_utils.VariableClient( + client=variable_source, + variables={'policy': policy_network.variables}, + update_period=self._config.variable_update_period, + ) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + else: + variable_client = None + + # Create the actor which defines how we take actions. + return actors.FeedForwardActor( + policy_network=policy_network, + adder=adder, + variable_client=variable_client, + ) + + def make_learner( + self, + networks: Tuple[D4PGNetworks, D4PGNetworks], + dataset: Iterator[reverb.ReplaySample], + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = False, + ): + """Creates an instance of the learner.""" + online_networks, target_networks = networks + + # The learner updates the parameters (and initializes them). + return learning.D4PGLearner( + policy_network=online_networks.policy_network, + critic_network=online_networks.critic_network, + observation_network=online_networks.observation_network, + target_policy_network=target_networks.policy_network, + target_critic_network=target_networks.critic_network, + target_observation_network=target_networks.observation_network, + policy_optimizer=self._config.policy_optimizer, + critic_optimizer=self._config.critic_optimizer, + clipping=self._config.clipping, + discount=self._config.discount, + target_update_period=self._config.target_update_period, + dataset_iterator=dataset, + replicator=get_replicator(self._config.accelerator), + counter=counter, + logger=logger, + checkpoint=checkpoint, + ) + + +class D4PG(agent.Agent): + """D4PG Agent. + + This implements a single-process D4PG agent. This is an actor-critic algorithm + that generates data via a behavior policy, inserts N-step transitions into + a replay buffer, and periodically updates the policy (and as a result the + behavior) by sampling uniformly from this buffer. + """ + + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + policy_network: snt.Module, + critic_network: snt.Module, + observation_network: types.TensorTransformation = tf.identity, + accelerator: Optional[str] = None, + discount: float = 0.99, + batch_size: int = 256, + prefetch_size: int = 4, + target_update_period: int = 100, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: float = 32.0, + n_step: int = 5, + sigma: float = 0.3, + clipping: bool = True, + replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + ): + """Initialize the agent. + + Args: + environment_spec: description of the actions, observations, etc. + policy_network: the online (optimized) policy. + critic_network: the online critic. + observation_network: optional network to transform the observations before + they are fed into any network. + accelerator: 'TPU', 'GPU', or 'CPU'. If omitted, the first available + accelerator type from ['TPU', 'GPU', 'CPU'] will be selected. + discount: discount to use for TD updates. + batch_size: batch size for updates. + prefetch_size: size to prefetch from replay. + target_update_period: number of learner steps to perform before updating + the target networks. + policy_optimizer: optimizer for the policy network updates. + critic_optimizer: optimizer for the critic network updates. + min_replay_size: minimum replay size before updating. + max_replay_size: maximum replay size. + samples_per_insert: number of samples to take from replay for every insert + that is made. + n_step: number of steps to squash into a single transition. + sigma: standard deviation of zero-mean, Gaussian exploration noise. + clipping: whether to clip gradients by global norm. + replay_table_name: string indicating what name to give the replay table. + counter: counter object used to keep track of steps. + logger: logger object to be used by learner. + checkpoint: boolean indicating whether to checkpoint the learner. + """ + if not accelerator: + accelerator = _get_first_available_accelerator_type(['TPU', 'GPU', 'CPU']) + + # Create the Builder object which will internally create agent components. + builder = D4PGBuilder( + # TODO(mwhoffman): pass the config dataclass in directly. + # TODO(mwhoffman): use the limiter rather than the workaround below. + # Right now this modifies min_replay_size and samples_per_insert so that + # they are not controlled by a limiter and are instead handled by the + # Agent base class (the above TODO directly references this behavior). + D4PGConfig( + accelerator=accelerator, + discount=discount, + batch_size=batch_size, + prefetch_size=prefetch_size, + target_update_period=target_update_period, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + min_replay_size=1, # Let the Agent class handle this. + max_replay_size=max_replay_size, + samples_per_insert=None, # Let the Agent class handle this. + n_step=n_step, + sigma=sigma, + clipping=clipping, + replay_table_name=replay_table_name, + )) + + replicator = get_replicator(accelerator) + + with replicator.scope(): + # TODO(mwhoffman): pass the network dataclass in directly. + online_networks = D4PGNetworks(policy_network=policy_network, + critic_network=critic_network, + observation_network=observation_network) + + # Target networks are just a copy of the online networks. + target_networks = copy.deepcopy(online_networks) + + # Initialize the networks. + online_networks.init(environment_spec) + target_networks.init(environment_spec) + + # TODO(mwhoffman): either make this Dataclass or pass only one struct. + # The network struct passed to make_learner is just a tuple for the + # time-being (for backwards compatibility). + networks = (online_networks, target_networks) + + # Create the behavior policy. + policy_network = online_networks.make_policy(environment_spec, sigma) + + # Create the replay server and grab its address. + replay_tables = builder.make_replay_tables(environment_spec) + replay_server = reverb.Server(replay_tables, port=None) + replay_client = reverb.Client(f'localhost:{replay_server.port}') + + # Create actor, dataset, and learner for generating, storing, and consuming + # data respectively. + adder = builder.make_adder(replay_client) + actor = builder.make_actor(policy_network, adder) + dataset = builder.make_dataset_iterator(replay_client) + learner = builder.make_learner(networks, dataset, counter, logger, + checkpoint) + + super().__init__( + actor=actor, + learner=learner, + min_observations=max(batch_size, min_replay_size), + observations_per_step=float(batch_size) / samples_per_insert) + + # Save the replay so we don't garbage collect it. + self._replay_server = replay_server + + +def _ensure_accelerator(accelerator: str) -> str: + """Checks for the existence of the expected accelerator type. + + Args: + accelerator: 'CPU', 'GPU' or 'TPU'. + + Returns: + The validated `accelerator` argument. + + Raises: + RuntimeError: Thrown if the expected accelerator isn't found. + """ + devices = tf.config.get_visible_devices(device_type=accelerator) + + if devices: + return accelerator + else: + error_messages = [f'Couldn\'t find any {accelerator} devices.', + 'tf.config.get_visible_devices() returned:'] + error_messages.extend([str(d) for d in devices]) + raise RuntimeError('\n'.join(error_messages)) + + +def _get_first_available_accelerator_type( + wishlist: Sequence[str] = ('TPU', 'GPU', 'CPU')) -> str: + """Returns the first available accelerator type listed in a wishlist. + + Args: + wishlist: A sequence of elements from {'CPU', 'GPU', 'TPU'}, listed in + order of descending preference. + + Returns: + The first available accelerator type from `wishlist`. + + Raises: + RuntimeError: Thrown if no accelerators from the `wishlist` are found. + """ + get_visible_devices = tf.config.get_visible_devices + + for wishlist_device in wishlist: + devices = get_visible_devices(device_type=wishlist_device) + if devices: + return wishlist_device + + available = ', '.join( + sorted(frozenset([d.type for d in get_visible_devices()]))) + raise RuntimeError( + 'Couldn\'t find any devices from {wishlist}.' + + f'Only the following types are available: {available}.') + + +# Only instantiate one replicator per (process, accelerator type), in case +# a replicator stores state that needs to be carried between its method calls. +@functools.lru_cache() +def get_replicator(accelerator: Optional[str]) -> Replicator: + """Returns a replicator instance appropriate for the given accelerator. + + This caches the instance using functools.cache, so that only one replicator + is instantiated per process and argument value. + + Args: + accelerator: None, 'TPU', 'GPU', or 'CPU'. If None, the first available + accelerator type will be chosen from ('TPU', 'GPU', 'CPU'). + + Returns: + A replicator, for replciating weights, datasets, and updates across + one or more accelerators. + """ + if accelerator: + accelerator = _ensure_accelerator(accelerator) + else: + accelerator = _get_first_available_accelerator_type() + + if accelerator == 'TPU': + tf.tpu.experimental.initialize_tpu_system() + return snt.distribute.TpuReplicator() + else: + return snt.distribute.Replicator() diff --git a/acme/acme/agents/tf/d4pg/agent_distributed.py b/acme/acme/agents/tf/d4pg/agent_distributed.py new file mode 100644 index 00000000..8ac8ff3b --- /dev/null +++ b/acme/acme/agents/tf/d4pg/agent_distributed.py @@ -0,0 +1,268 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the D4PG agent class.""" + +import copy +from typing import Callable, Dict, Optional + +import acme +from acme import specs +from acme.agents.tf.d4pg import agent +from acme.tf import savers as tf2_savers +from acme.utils import counting +from acme.utils import loggers +from acme.utils import lp_utils +import dm_env +import launchpad as lp +import reverb +import sonnet as snt +import tensorflow as tf + +# Valid values of the "accelerator" argument. +_ACCELERATORS = ('CPU', 'GPU', 'TPU') + + +class DistributedD4PG: + """Program definition for D4PG.""" + + def __init__( + self, + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: Callable[[specs.BoundedArray], Dict[str, snt.Module]], + accelerator: Optional[str] = None, + num_actors: int = 1, + num_caches: int = 0, + environment_spec: Optional[specs.EnvironmentSpec] = None, + batch_size: int = 256, + prefetch_size: int = 4, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: Optional[float] = 32.0, + n_step: int = 5, + sigma: float = 0.3, + clipping: bool = True, + discount: float = 0.99, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + target_update_period: int = 100, + variable_update_period: int = 1000, + max_actor_steps: Optional[int] = None, + log_every: float = 10.0, + ): + + if accelerator is not None and accelerator not in _ACCELERATORS: + raise ValueError(f'Accelerator must be one of {_ACCELERATORS}, ' + f'not "{accelerator}".') + + if not environment_spec: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + # TODO(mwhoffman): Make network_factory directly return the struct. + # TODO(mwhoffman): Make the factory take the entire spec. + def wrapped_network_factory(action_spec): + networks_dict = network_factory(action_spec) + networks = agent.D4PGNetworks( + policy_network=networks_dict.get('policy'), + critic_network=networks_dict.get('critic'), + observation_network=networks_dict.get('observation', tf.identity)) + return networks + + self._environment_factory = environment_factory + self._network_factory = wrapped_network_factory + self._environment_spec = environment_spec + self._sigma = sigma + self._num_actors = num_actors + self._num_caches = num_caches + self._max_actor_steps = max_actor_steps + self._log_every = log_every + self._accelerator = accelerator + self._variable_update_period = variable_update_period + + self._builder = agent.D4PGBuilder( + # TODO(mwhoffman): pass the config dataclass in directly. + # TODO(mwhoffman): use the limiter rather than the workaround below. + agent.D4PGConfig( + accelerator=accelerator, + discount=discount, + batch_size=batch_size, + prefetch_size=prefetch_size, + target_update_period=target_update_period, + variable_update_period=variable_update_period, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + min_replay_size=min_replay_size, + max_replay_size=max_replay_size, + samples_per_insert=samples_per_insert, + n_step=n_step, + sigma=sigma, + clipping=clipping, + )) + + def replay(self): + """The replay storage.""" + return self._builder.make_replay_tables(self._environment_spec) + + def counter(self): + return tf2_savers.CheckpointingRunner(counting.Counter(), + time_delta_minutes=1, + subdirectory='counter') + + def coordinator(self, counter: counting.Counter): + return lp_utils.StepsLimiter(counter, self._max_actor_steps) + + def learner( + self, + replay: reverb.Client, + counter: counting.Counter, + ): + """The Learning part of the agent.""" + + # If we are running on multiple accelerator devices, this replicates + # weights and updates across devices. + replicator = agent.get_replicator(self._accelerator) + + with replicator.scope(): + # Create the networks to optimize (online) and target networks. + online_networks = self._network_factory(self._environment_spec.actions) + target_networks = copy.deepcopy(online_networks) + + # Initialize the networks. + online_networks.init(self._environment_spec) + target_networks.init(self._environment_spec) + + dataset = self._builder.make_dataset_iterator(replay) + + counter = counting.Counter(counter, 'learner') + logger = loggers.make_default_logger( + 'learner', time_delta=self._log_every, steps_key='learner_steps') + + return self._builder.make_learner( + networks=(online_networks, target_networks), + dataset=dataset, + counter=counter, + logger=logger, + checkpoint=True, + ) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + ) -> acme.EnvironmentLoop: + """The actor process.""" + + # Create the behavior policy. + networks = self._network_factory(self._environment_spec.actions) + networks.init(self._environment_spec) + policy_network = networks.make_policy( + environment_spec=self._environment_spec, + sigma=self._sigma, + ) + + # Create the agent. + actor = self._builder.make_actor( + policy_network=policy_network, + adder=self._builder.make_adder(replay), + variable_source=variable_source, + ) + + # Create the environment. + environment = self._environment_factory(False) + + # Create logger and counter; actors will not spam bigtable. + counter = counting.Counter(counter, 'actor') + logger = loggers.make_default_logger( + 'actor', + save_data=False, + time_delta=self._log_every, + steps_key='actor_steps') + + # Create the loop to connect environment and agent. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def evaluator( + self, + variable_source: acme.VariableSource, + counter: counting.Counter, + logger: Optional[loggers.Logger] = None, + ): + """The evaluation process.""" + + # Create the behavior policy. + networks = self._network_factory(self._environment_spec.actions) + networks.init(self._environment_spec) + policy_network = networks.make_policy(self._environment_spec) + + # Create the agent. + actor = self._builder.make_actor( + policy_network=policy_network, + variable_source=variable_source, + ) + + # Make the environment. + environment = self._environment_factory(True) + + # Create logger and counter. + counter = counting.Counter(counter, 'evaluator') + logger = logger or loggers.make_default_logger( + 'evaluator', + time_delta=self._log_every, + steps_key='evaluator_steps', + ) + + # Create the run loop and return it. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def build(self, name='d4pg'): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group('replay'): + replay = program.add_node(lp.ReverbNode(self.replay)) + + with program.group('counter'): + counter = program.add_node(lp.CourierNode(self.counter)) + + if self._max_actor_steps: + with program.group('coordinator'): + _ = program.add_node(lp.CourierNode(self.coordinator, counter)) + + with program.group('learner'): + learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) + + with program.group('evaluator'): + program.add_node(lp.CourierNode(self.evaluator, learner, counter)) + + if not self._num_caches: + # Use our learner as a single variable source. + sources = [learner] + else: + with program.group('cacher'): + # Create a set of learner caches. + sources = [] + for _ in range(self._num_caches): + cacher = program.add_node( + lp.CacherNode( + learner, refresh_interval_ms=2000, stale_after_ms=4000)) + sources.append(cacher) + + with program.group('actor'): + # Add actors which pull round-robin from our variable sources. + for actor_id in range(self._num_actors): + source = sources[actor_id % len(sources)] + program.add_node(lp.CourierNode(self.actor, replay, source, counter)) + + return program diff --git a/acme/acme/agents/tf/d4pg/agent_distributed_test.py b/acme/acme/agents/tf/d4pg/agent_distributed_test.py new file mode 100644 index 00000000..cc71b890 --- /dev/null +++ b/acme/acme/agents/tf/d4pg/agent_distributed_test.py @@ -0,0 +1,84 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test for the distributed agent.""" + +import acme +from acme import specs +from acme.agents.tf import d4pg +from acme.testing import fakes +from acme.tf import networks +from acme.tf import utils as tf2_utils +import launchpad as lp +import numpy as np +import sonnet as snt + +from absl.testing import absltest + + +def make_networks(action_spec: specs.BoundedArray): + """Simple networks for testing..""" + + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential([ + networks.LayerNormMLP([50], activate_final=True), + networks.NearZeroInitializedLinear(num_dimensions), + networks.TanhToSpec(action_spec) + ]) + # The multiplexer concatenates the (maybe transformed) observations/actions. + critic_network = snt.Sequential([ + networks.CriticMultiplexer( + critic_network=networks.LayerNormMLP( + [50], activate_final=True)), + networks.DiscreteValuedHead(-1., 1., 10) + ]) + + return { + 'policy': policy_network, + 'critic': critic_network, + 'observation': tf2_utils.batch_concat, + } + + +class DistributedAgentTest(absltest.TestCase): + """Simple integration/smoke test for the distributed agent.""" + + def test_control_suite(self): + """Tests that the agent can run on the control suite without crashing.""" + + agent = d4pg.DistributedD4PG( + environment_factory=lambda x: fakes.ContinuousEnvironment(bounded=True), + network_factory=make_networks, + accelerator='CPU', + num_actors=2, + batch_size=32, + min_replay_size=32, + max_replay_size=1000, + ) + program = agent.build() + + (learner_node,) = program.groups['learner'] + learner_node.disable_run() + + lp.launch(program, launch_type='test_mt') + + learner: acme.Learner = learner_node.create_handle().dereference() + + for _ in range(5): + learner.step() + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/d4pg/agent_test.py b/acme/acme/agents/tf/d4pg/agent_test.py new file mode 100644 index 00000000..10b89b9f --- /dev/null +++ b/acme/acme/agents/tf/d4pg/agent_test.py @@ -0,0 +1,91 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the D4PG agent.""" + +import sys +from typing import Dict, Sequence + +import acme +from acme import specs +from acme import types +from acme.agents.tf import d4pg +from acme.testing import fakes +from acme.tf import networks +import numpy as np +import sonnet as snt +import tensorflow as tf + +from absl.testing import absltest + + +def make_networks( + action_spec: types.NestedSpec, + policy_layer_sizes: Sequence[int] = (10, 10), + critic_layer_sizes: Sequence[int] = (10, 10), + vmin: float = -150., + vmax: float = 150., + num_atoms: int = 51, +) -> Dict[str, snt.Module]: + """Creates networks used by the agent.""" + + num_dimensions = np.prod(action_spec.shape, dtype=int) + policy_layer_sizes = list(policy_layer_sizes) + [num_dimensions] + + policy_network = snt.Sequential( + [networks.LayerNormMLP(policy_layer_sizes), tf.tanh]) + critic_network = snt.Sequential([ + networks.CriticMultiplexer( + critic_network=networks.LayerNormMLP( + critic_layer_sizes, activate_final=True)), + networks.DiscreteValuedHead(vmin, vmax, num_atoms) + ]) + + return { + 'policy': policy_network, + 'critic': critic_network, + } + + +class D4PGTest(absltest.TestCase): + + def test_d4pg(self): + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment(episode_length=10, bounded=True) + spec = specs.make_environment_spec(environment) + + # Create the networks. + agent_networks = make_networks(spec.actions) + + # Construct the agent. + agent = d4pg.D4PG( + environment_spec=spec, + accelerator='CPU', + policy_network=agent_networks['policy'], + critic_network=agent_networks['critic'], + batch_size=10, + samples_per_insert=2, + min_replay_size=10, + ) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=2) + + # Imports check + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/d4pg/learning.py b/acme/acme/agents/tf/d4pg/learning.py new file mode 100644 index 00000000..b167fb98 --- /dev/null +++ b/acme/acme/agents/tf/d4pg/learning.py @@ -0,0 +1,372 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""D4PG learner implementation.""" + +import time +from typing import Dict, Iterator, List, Optional, Union, Sequence + +import acme +from acme import types +from acme.tf import losses +from acme.tf import networks as acme_nets +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import numpy as np +import reverb +import sonnet as snt +import tensorflow as tf +import tree + +Replicator = Union[snt.distribute.Replicator, snt.distribute.TpuReplicator] + + +class D4PGLearner(acme.Learner): + """D4PG learner. + + This is the learning component of a D4PG agent. IE it takes a dataset as input + and implements update functionality to learn from this dataset. + """ + + def __init__( + self, + policy_network: snt.Module, + critic_network: snt.Module, + target_policy_network: snt.Module, + target_critic_network: snt.Module, + discount: float, + target_update_period: int, + dataset_iterator: Iterator[reverb.ReplaySample], + replicator: Optional[Replicator] = None, + observation_network: types.TensorTransformation = lambda x: x, + target_observation_network: types.TensorTransformation = lambda x: x, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + clipping: bool = True, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + ): + """Initializes the learner. + + Args: + policy_network: the online (optimized) policy. + critic_network: the online critic. + target_policy_network: the target policy (which lags behind the online + policy). + target_critic_network: the target critic. + discount: discount to use for TD updates. + target_update_period: number of learner steps to perform before updating + the target networks. + dataset_iterator: dataset to learn from, whether fixed or from a replay + buffer (see `acme.datasets.reverb.make_reverb_dataset` documentation). + replicator: Replicates variables and their update methods over multiple + accelerators, such as the multiple chips in a TPU. + observation_network: an optional online network to process observations + before the policy and the critic. + target_observation_network: the target observation network. + policy_optimizer: the optimizer to be applied to the DPG (policy) loss. + critic_optimizer: the optimizer to be applied to the distributional + Bellman loss. + clipping: whether to clip gradients by global norm. + counter: counter object used to keep track of steps. + logger: logger object to be used by learner. + checkpoint: boolean indicating whether to checkpoint the learner. + """ + + # Store online and target networks. + self._policy_network = policy_network + self._critic_network = critic_network + self._target_policy_network = target_policy_network + self._target_critic_network = target_critic_network + + # Make sure observation networks are snt.Module's so they have variables. + self._observation_network = tf2_utils.to_sonnet_module(observation_network) + self._target_observation_network = tf2_utils.to_sonnet_module( + target_observation_network) + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger('learner') + + # Other learner parameters. + self._discount = discount + self._clipping = clipping + + # Replicates Variables across multiple accelerators + if not replicator: + accelerator = _get_first_available_accelerator_type() + if accelerator == 'TPU': + replicator = snt.distribute.TpuReplicator() + else: + replicator = snt.distribute.Replicator() + + self._replicator = replicator + + with replicator.scope(): + # Necessary to track when to update target networks. + self._num_steps = tf.Variable(0, dtype=tf.int32) + self._target_update_period = target_update_period + + # Create optimizers if they aren't given. + self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + + # Batch dataset and create iterator. + self._iterator = dataset_iterator + + # Expose the variables. + policy_network_to_expose = snt.Sequential( + [self._target_observation_network, self._target_policy_network]) + self._variables = { + 'critic': self._target_critic_network.variables, + 'policy': policy_network_to_expose.variables, + } + + # Create a checkpointer and snapshotter objects. + self._checkpointer = None + self._snapshotter = None + + if checkpoint: + self._checkpointer = tf2_savers.Checkpointer( + subdirectory='d4pg_learner', + objects_to_save={ + 'counter': self._counter, + 'policy': self._policy_network, + 'critic': self._critic_network, + 'observation': self._observation_network, + 'target_policy': self._target_policy_network, + 'target_critic': self._target_critic_network, + 'target_observation': self._target_observation_network, + 'policy_optimizer': self._policy_optimizer, + 'critic_optimizer': self._critic_optimizer, + 'num_steps': self._num_steps, + }) + critic_mean = snt.Sequential( + [self._critic_network, acme_nets.StochasticMeanHead()]) + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={ + 'policy': self._policy_network, + 'critic': critic_mean, + }) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + @tf.function + def _step(self, sample) -> Dict[str, tf.Tensor]: + transitions: types.Transition = sample.data # Assuming ReverbSample. + + # Cast the additional discount to match the environment discount dtype. + discount = tf.cast(self._discount, dtype=transitions.discount.dtype) + + with tf.GradientTape(persistent=True) as tape: + # Maybe transform the observation before feeding into policy and critic. + # Transforming the observations this way at the start of the learning + # step effectively means that the policy and critic share observation + # network weights. + o_tm1 = self._observation_network(transitions.observation) + o_t = self._target_observation_network(transitions.next_observation) + # This stop_gradient prevents gradients to propagate into the target + # observation network. In addition, since the online policy network is + # evaluated at o_t, this also means the policy loss does not influence + # the observation network training. + o_t = tree.map_structure(tf.stop_gradient, o_t) + + # Critic learning. + q_tm1 = self._critic_network(o_tm1, transitions.action) + q_t = self._target_critic_network(o_t, self._target_policy_network(o_t)) + + # Critic loss. + critic_loss = losses.categorical(q_tm1, transitions.reward, + discount * transitions.discount, q_t) + critic_loss = tf.reduce_mean(critic_loss, axis=[0]) + + # Actor learning. + dpg_a_t = self._policy_network(o_t) + dpg_z_t = self._critic_network(o_t, dpg_a_t) + dpg_q_t = dpg_z_t.mean() + + # Actor loss. If clipping is true use dqda clipping and clip the norm. + dqda_clipping = 1.0 if self._clipping else None + policy_loss = losses.dpg( + dpg_q_t, + dpg_a_t, + tape=tape, + dqda_clipping=dqda_clipping, + clip_norm=self._clipping) + policy_loss = tf.reduce_mean(policy_loss, axis=[0]) + + # Get trainable variables. + policy_variables = self._policy_network.trainable_variables + critic_variables = ( + # In this agent, the critic loss trains the observation network. + self._observation_network.trainable_variables + + self._critic_network.trainable_variables) + + # Compute gradients. + replica_context = tf.distribute.get_replica_context() + policy_gradients = _average_gradients_across_replicas( + replica_context, + tape.gradient(policy_loss, policy_variables)) + critic_gradients = _average_gradients_across_replicas( + replica_context, + tape.gradient(critic_loss, critic_variables)) + + # Delete the tape manually because of the persistent=True flag. + del tape + + # Maybe clip gradients. + if self._clipping: + policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.)[0] + critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0] + + # Apply gradients. + self._policy_optimizer.apply(policy_gradients, policy_variables) + self._critic_optimizer.apply(critic_gradients, critic_variables) + + # Losses to track. + return { + 'critic_loss': critic_loss, + 'policy_loss': policy_loss, + } + + @tf.function + def _replicated_step(self): + # Update target network + online_variables = ( + *self._observation_network.variables, + *self._critic_network.variables, + *self._policy_network.variables, + ) + target_variables = ( + *self._target_observation_network.variables, + *self._target_critic_network.variables, + *self._target_policy_network.variables, + ) + + # Make online -> target network update ops. + if tf.math.mod(self._num_steps, self._target_update_period) == 0: + for src, dest in zip(online_variables, target_variables): + dest.assign(src) + self._num_steps.assign_add(1) + + # Get data from replay (dropping extras if any). Note there is no + # extra data here because we do not insert any into Reverb. + sample = next(self._iterator) + + # This mirrors the structure of the fetches returned by self._step(), + # but the Tensors are replaced with replicated Tensors, one per accelerator. + replicated_fetches = self._replicator.run(self._step, args=(sample,)) + + def reduce_mean_over_replicas(replicated_value): + """Averages a replicated_value across replicas.""" + # The "axis=None" arg means reduce across replicas, not internal axes. + return self._replicator.reduce( + reduce_op=tf.distribute.ReduceOp.MEAN, + value=replicated_value, + axis=None) + + fetches = tree.map_structure(reduce_mean_over_replicas, replicated_fetches) + + return fetches + + def step(self): + # Run the learning step. + fetches = self._replicated_step() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + fetches.update(counts) + + # Checkpoint and attempt to write the logs. + if self._checkpointer is not None: + self._checkpointer.save() + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(fetches) + + def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: + return [tf2_utils.to_numpy(self._variables[name]) for name in names] + + +def _get_first_available_accelerator_type( + wishlist: Sequence[str] = ('TPU', 'GPU', 'CPU')) -> str: + """Returns the first available accelerator type listed in a wishlist. + + Args: + wishlist: A sequence of elements from {'CPU', 'GPU', 'TPU'}, listed in + order of descending preference. + + Returns: + The first available accelerator type from `wishlist`. + + Raises: + RuntimeError: Thrown if no accelerators from the `wishlist` are found. + """ + get_visible_devices = tf.config.get_visible_devices + + for wishlist_device in wishlist: + devices = get_visible_devices(device_type=wishlist_device) + if devices: + return wishlist_device + + available = ', '.join( + sorted(frozenset([d.type for d in get_visible_devices()]))) + raise RuntimeError( + 'Couldn\'t find any devices from {wishlist}.' + + f'Only the following types are available: {available}.') + + +def _average_gradients_across_replicas(replica_context, gradients): + """Computes the average gradient across replicas. + + This computes the gradient locally on this device, then copies over the + gradients computed on the other replicas, and takes the average across + replicas. + + This is faster than copying the gradients from TPU to CPU, and averaging + them on the CPU (which is what we do for the losses/fetches). + + Args: + replica_context: the return value of `tf.distribute.get_replica_context()`. + gradients: The output of tape.gradients(loss, variables) + + Returns: + A list of (d_loss/d_varabiable)s. + """ + + # We must remove any Nones from gradients before passing them to all_reduce. + # Nones occur when you call tape.gradient(loss, variables) with some + # variables that don't affect the loss. + # See: https://github.com/tensorflow/tensorflow/issues/783 + gradients_without_nones = [g for g in gradients if g is not None] + original_indices = [i for i, g in enumerate(gradients) if g is not None] + + results_without_nones = replica_context.all_reduce('mean', + gradients_without_nones) + results = [None] * len(gradients) + for ii, result in zip(original_indices, results_without_nones): + results[ii] = result + + return results diff --git a/acme/acme/agents/tf/d4pg/networks.py b/acme/acme/agents/tf/d4pg/networks.py new file mode 100644 index 00000000..d0a225c1 --- /dev/null +++ b/acme/acme/agents/tf/d4pg/networks.py @@ -0,0 +1,63 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared helpers for different experiment flavours.""" + +from typing import Mapping, Sequence + +from acme import specs +from acme import types +from acme.tf import networks +from acme.tf import utils as tf2_utils + +import numpy as np +import sonnet as snt + + +def make_default_networks( + action_spec: specs.BoundedArray, + policy_layer_sizes: Sequence[int] = (256, 256, 256), + critic_layer_sizes: Sequence[int] = (512, 512, 256), + vmin: float = -150., + vmax: float = 150., + num_atoms: int = 51, +) -> Mapping[str, types.TensorTransformation]: + """Creates networks used by the agent.""" + + # Get total number of action dimensions from action spec. + num_dimensions = np.prod(action_spec.shape, dtype=int) + + # Create the shared observation network; here simply a state-less operation. + observation_network = tf2_utils.batch_concat + + # Create the policy network. + policy_network = snt.Sequential([ + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(num_dimensions), + networks.TanhToSpec(action_spec), + ]) + + # Create the critic network. + critic_network = snt.Sequential([ + # The multiplexer concatenates the observations/actions. + networks.CriticMultiplexer(), + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.DiscreteValuedHead(vmin, vmax, num_atoms), + ]) + + return { + 'policy': policy_network, + 'critic': critic_network, + 'observation': observation_network, + } diff --git a/acme/acme/agents/tf/ddpg/README.md b/acme/acme/agents/tf/ddpg/README.md new file mode 100644 index 00000000..ff5c5fe5 --- /dev/null +++ b/acme/acme/agents/tf/ddpg/README.md @@ -0,0 +1,16 @@ +# Deep Deterministic Policy Gradient (DDPG) + +This folder contains an implementation of the DDPG agent introduced in ( +[Lillicrap et al., 2015]), which extends the Deterministic Policy Gradient (DPG) +algorithm (introduced in [Silver et al., 2014]) to the realm of deep learning. + +DDPG is an off-policy [actor-critic algorithm]. In this algorithm, critic is a +network that takes an observation and an action and outputs a value estimate +based on the current policy. It is trained to minimize the square +temporal-difference (TD) error. The actor is the policy network that takes +observations as input and outputs actions. For each observation, it is trained +to maximize the critic's value estimate. + +[Lillicrap et al., 2015]: https://arxiv.org/abs/1509.02971 +[Silver et al., 2014]: http://proceedings.mlr.press/v32/silver14 +[actor-critic algorithm]: http://incompleteideas.net/book/RLbook2018.pdf#page=353 diff --git a/acme/acme/agents/tf/ddpg/__init__.py b/acme/acme/agents/tf/ddpg/__init__.py new file mode 100644 index 00000000..f5f93231 --- /dev/null +++ b/acme/acme/agents/tf/ddpg/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementations of a DDPG agent.""" + +from acme.agents.tf.ddpg.agent import DDPG +from acme.agents.tf.ddpg.agent_distributed import DistributedDDPG +from acme.agents.tf.ddpg.learning import DDPGLearner diff --git a/acme/acme/agents/tf/ddpg/agent.py b/acme/acme/agents/tf/ddpg/agent.py new file mode 100644 index 00000000..4e6ea791 --- /dev/null +++ b/acme/acme/agents/tf/ddpg/agent.py @@ -0,0 +1,173 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DDPG agent implementation.""" + +import copy +from typing import Optional + +from acme import datasets +from acme import specs +from acme import types +from acme.adders import reverb as adders +from acme.agents import agent +from acme.agents.tf import actors +from acme.agents.tf.ddpg import learning +from acme.tf import networks +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import reverb +import sonnet as snt +import tensorflow as tf + + +class DDPG(agent.Agent): + """DDPG Agent. + + This implements a single-process DDPG agent. This is an actor-critic algorithm + that generates data via a behavior policy, inserts N-step transitions into + a replay buffer, and periodically updates the policy (and as a result the + behavior) by sampling uniformly from this buffer. + """ + + def __init__(self, + environment_spec: specs.EnvironmentSpec, + policy_network: snt.Module, + critic_network: snt.Module, + observation_network: types.TensorTransformation = tf.identity, + discount: float = 0.99, + batch_size: int = 256, + prefetch_size: int = 4, + target_update_period: int = 100, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: float = 32.0, + n_step: int = 5, + sigma: float = 0.3, + clipping: bool = True, + logger: Optional[loggers.Logger] = None, + counter: Optional[counting.Counter] = None, + checkpoint: bool = True, + replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE): + """Initialize the agent. + + Args: + environment_spec: description of the actions, observations, etc. + policy_network: the online (optimized) policy. + critic_network: the online critic. + observation_network: optional network to transform the observations before + they are fed into any network. + discount: discount to use for TD updates. + batch_size: batch size for updates. + prefetch_size: size to prefetch from replay. + target_update_period: number of learner steps to perform before updating + the target networks. + min_replay_size: minimum replay size before updating. + max_replay_size: maximum replay size. + samples_per_insert: number of samples to take from replay for every insert + that is made. + n_step: number of steps to squash into a single transition. + sigma: standard deviation of zero-mean, Gaussian exploration noise. + clipping: whether to clip gradients by global norm. + logger: logger object to be used by learner. + counter: counter object used to keep track of steps. + checkpoint: boolean indicating whether to checkpoint the learner. + replay_table_name: string indicating what name to give the replay table. + """ + # Create a replay server to add data to. This uses no limiter behavior in + # order to allow the Agent interface to handle it. + replay_table = reverb.Table( + name=replay_table_name, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=max_replay_size, + rate_limiter=reverb.rate_limiters.MinSize(1), + signature=adders.NStepTransitionAdder.signature(environment_spec)) + self._server = reverb.Server([replay_table], port=None) + + # The adder is used to insert observations into replay. + address = f'localhost:{self._server.port}' + adder = adders.NStepTransitionAdder( + priority_fns={replay_table_name: lambda x: 1.}, + client=reverb.Client(address), + n_step=n_step, + discount=discount) + + # The dataset provides an interface to sample from replay. + dataset = datasets.make_reverb_dataset( + table=replay_table_name, + server_address=address, + batch_size=batch_size, + prefetch_size=prefetch_size) + + # Make sure observation network is a Sonnet Module. + observation_network = tf2_utils.to_sonnet_module(observation_network) + + # Get observation and action specs. + act_spec = environment_spec.actions + obs_spec = environment_spec.observations + emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) + + # Create target networks. + target_policy_network = copy.deepcopy(policy_network) + target_critic_network = copy.deepcopy(critic_network) + target_observation_network = copy.deepcopy(observation_network) + + # Create the behavior policy. + behavior_network = snt.Sequential([ + observation_network, + policy_network, + networks.ClippedGaussian(sigma), + networks.ClipToSpec(act_spec), + ]) + + # Create variables. + tf2_utils.create_variables(policy_network, [emb_spec]) + tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) + tf2_utils.create_variables(target_policy_network, [emb_spec]) + tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) + tf2_utils.create_variables(target_observation_network, [obs_spec]) + + # Create the actor which defines how we take actions. + actor = actors.FeedForwardActor(behavior_network, adder=adder) + + # Create optimizers. + policy_optimizer = snt.optimizers.Adam(learning_rate=1e-4) + critic_optimizer = snt.optimizers.Adam(learning_rate=1e-4) + + # The learner updates the parameters (and initializes them). + learner = learning.DDPGLearner( + policy_network=policy_network, + critic_network=critic_network, + observation_network=observation_network, + target_policy_network=target_policy_network, + target_critic_network=target_critic_network, + target_observation_network=target_observation_network, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + clipping=clipping, + discount=discount, + target_update_period=target_update_period, + dataset=dataset, + counter=counter, + logger=logger, + checkpoint=checkpoint, + ) + + super().__init__( + actor=actor, + learner=learner, + min_observations=max(batch_size, min_replay_size), + observations_per_step=float(batch_size) / samples_per_insert) diff --git a/acme/acme/agents/tf/ddpg/agent_distributed.py b/acme/acme/agents/tf/ddpg/agent_distributed.py new file mode 100644 index 00000000..f9f852d4 --- /dev/null +++ b/acme/acme/agents/tf/ddpg/agent_distributed.py @@ -0,0 +1,319 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the distribted DDPG (D3PG) agent class.""" + +from typing import Callable, Dict, Optional + +import acme +from acme import datasets +from acme import specs +from acme.adders import reverb as adders +from acme.agents.tf import actors +from acme.agents.tf.ddpg import learning +from acme.tf import networks +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.tf import variable_utils as tf2_variable_utils +from acme.utils import counting +from acme.utils import loggers +from acme.utils import lp_utils +import dm_env +import launchpad as lp +import reverb +import sonnet as snt +import tensorflow as tf + + +class DistributedDDPG: + """Program definition for distributed DDPG (D3PG).""" + + def __init__( + self, + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: Callable[[specs.BoundedArray], Dict[str, snt.Module]], + num_actors: int = 1, + num_caches: int = 0, + environment_spec: Optional[specs.EnvironmentSpec] = None, + batch_size: int = 256, + prefetch_size: int = 4, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: Optional[float] = 32.0, + n_step: int = 5, + sigma: float = 0.3, + clipping: bool = True, + discount: float = 0.99, + target_update_period: int = 100, + variable_update_period: int = 1000, + max_actor_steps: Optional[int] = None, + log_every: float = 10.0, + ): + + if not environment_spec: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + self._environment_factory = environment_factory + self._network_factory = network_factory + self._environment_spec = environment_spec + self._num_actors = num_actors + self._num_caches = num_caches + self._batch_size = batch_size + self._prefetch_size = prefetch_size + self._min_replay_size = min_replay_size + self._max_replay_size = max_replay_size + self._samples_per_insert = samples_per_insert + self._n_step = n_step + self._sigma = sigma + self._clipping = clipping + self._discount = discount + self._target_update_period = target_update_period + self._variable_update_period = variable_update_period + self._max_actor_steps = max_actor_steps + self._log_every = log_every + + def replay(self): + """The replay storage.""" + if self._samples_per_insert is not None: + # Create enough of an error buffer to give a 10% tolerance in rate. + samples_per_insert_tolerance = 0.1 * self._samples_per_insert + error_buffer = self._min_replay_size * samples_per_insert_tolerance + + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._min_replay_size, + samples_per_insert=self._samples_per_insert, + error_buffer=error_buffer) + else: + limiter = reverb.rate_limiters.MinSize(self._min_replay_size) + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._max_replay_size, + rate_limiter=limiter, + signature=adders.NStepTransitionAdder.signature( + self._environment_spec)) + return [replay_table] + + def counter(self): + return tf2_savers.CheckpointingRunner(counting.Counter(), + time_delta_minutes=1, + subdirectory='counter') + + def coordinator(self, counter: counting.Counter, max_actor_steps: int): + return lp_utils.StepsLimiter(counter, max_actor_steps) + + def learner( + self, + replay: reverb.Client, + counter: counting.Counter, + ): + """The Learning part of the agent.""" + + act_spec = self._environment_spec.actions + obs_spec = self._environment_spec.observations + + # Create the networks to optimize (online) and target networks. + online_networks = self._network_factory(act_spec) + target_networks = self._network_factory(act_spec) + + # Make sure observation network is a Sonnet Module. + observation_network = online_networks.get('observation', tf.identity) + target_observation_network = target_networks.get('observation', tf.identity) + observation_network = tf2_utils.to_sonnet_module(observation_network) + target_observation_network = tf2_utils.to_sonnet_module( + target_observation_network) + + # Get embedding spec and create observation network variables. + emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) + + # Create variables. + tf2_utils.create_variables(online_networks['policy'], [emb_spec]) + tf2_utils.create_variables(online_networks['critic'], [emb_spec, act_spec]) + tf2_utils.create_variables(target_networks['policy'], [emb_spec]) + tf2_utils.create_variables(target_networks['critic'], [emb_spec, act_spec]) + tf2_utils.create_variables(target_observation_network, [obs_spec]) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset( + server_address=replay.server_address, + batch_size=self._batch_size, + prefetch_size=self._prefetch_size) + + # Create optimizers. + policy_optimizer = snt.optimizers.Adam(learning_rate=1e-4) + critic_optimizer = snt.optimizers.Adam(learning_rate=1e-4) + + counter = counting.Counter(counter, 'learner') + logger = loggers.make_default_logger( + 'learner', time_delta=self._log_every, steps_key='learner_steps') + + # Return the learning agent. + return learning.DDPGLearner( + policy_network=online_networks['policy'], + critic_network=online_networks['critic'], + observation_network=observation_network, + target_policy_network=target_networks['policy'], + target_critic_network=target_networks['critic'], + target_observation_network=target_observation_network, + discount=self._discount, + target_update_period=self._target_update_period, + dataset=dataset, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + clipping=self._clipping, + counter=counter, + logger=logger, + ) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + ): + """The actor process.""" + + action_spec = self._environment_spec.actions + observation_spec = self._environment_spec.observations + + # Create environment and behavior networks + environment = self._environment_factory(False) + agent_networks = self._network_factory(action_spec) + + # Create behavior network by adding some random dithering. + behavior_network = snt.Sequential([ + agent_networks.get('observation', tf.identity), + agent_networks.get('policy'), + networks.ClippedGaussian(self._sigma), + ]) + + # Ensure network variables are created. + tf2_utils.create_variables(behavior_network, [observation_spec]) + variables = {'policy': behavior_network.variables} + + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = tf2_variable_utils.VariableClient( + variable_source, variables, update_period=self._variable_update_period) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + # Component to add things into replay. + adder = adders.NStepTransitionAdder( + client=replay, n_step=self._n_step, discount=self._discount) + + # Create the agent. + actor = actors.FeedForwardActor( + behavior_network, adder=adder, variable_client=variable_client) + + # Create logger and counter; actors will not spam bigtable. + counter = counting.Counter(counter, 'actor') + logger = loggers.make_default_logger( + 'actor', + save_data=False, + time_delta=self._log_every, + steps_key='actor_steps') + + # Create the loop to connect environment and agent. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def evaluator( + self, + variable_source: acme.VariableSource, + counter: counting.Counter, + ): + """The evaluation process.""" + + action_spec = self._environment_spec.actions + observation_spec = self._environment_spec.observations + + # Create environment and evaluator networks + environment = self._environment_factory(True) + agent_networks = self._network_factory(action_spec) + + # Create evaluator network. + evaluator_network = snt.Sequential([ + agent_networks.get('observation', tf.identity), + agent_networks.get('policy'), + ]) + + # Ensure network variables are created. + tf2_utils.create_variables(evaluator_network, [observation_spec]) + variables = {'policy': evaluator_network.variables} + + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = tf2_variable_utils.VariableClient( + variable_source, variables, update_period=self._variable_update_period) + + # Make sure not to evaluate a random actor by assigning variables before + # running the environment loop. + variable_client.update_and_wait() + + # Create the evaluator; note it will not add experience to replay. + evaluator = actors.FeedForwardActor( + evaluator_network, variable_client=variable_client) + + # Create logger and counter. + counter = counting.Counter(counter, 'evaluator') + logger = loggers.make_default_logger( + 'evaluator', time_delta=self._log_every, steps_key='evaluator_steps') + + # Create the run loop and return it. + return acme.EnvironmentLoop( + environment, evaluator, counter, logger) + + def build(self, name='ddpg'): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group('replay'): + replay = program.add_node(lp.ReverbNode(self.replay)) + + with program.group('counter'): + counter = program.add_node(lp.CourierNode(self.counter)) + + if self._max_actor_steps: + _ = program.add_node( + lp.CourierNode(self.coordinator, counter, self._max_actor_steps)) + + with program.group('learner'): + learner = program.add_node( + lp.CourierNode(self.learner, replay, counter)) + + with program.group('evaluator'): + program.add_node( + lp.CourierNode(self.evaluator, learner, counter)) + + if not self._num_caches: + # Use our learner as a single variable source. + sources = [learner] + else: + with program.group('cacher'): + # Create a set of learner caches. + sources = [] + for _ in range(self._num_caches): + cacher = program.add_node( + lp.CacherNode( + learner, refresh_interval_ms=2000, stale_after_ms=4000)) + sources.append(cacher) + + with program.group('actor'): + # Add actors which pull round-robin from our variable sources. + for actor_id in range(self._num_actors): + source = sources[actor_id % len(sources)] + program.add_node(lp.CourierNode(self.actor, replay, source, counter)) + + return program diff --git a/acme/acme/agents/tf/ddpg/agent_distributed_test.py b/acme/acme/agents/tf/ddpg/agent_distributed_test.py new file mode 100644 index 00000000..1b930a97 --- /dev/null +++ b/acme/acme/agents/tf/ddpg/agent_distributed_test.py @@ -0,0 +1,89 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test for the distributed agent.""" + +import acme +from acme import specs +from acme.agents.tf import ddpg +from acme.testing import fakes +from acme.tf import networks +from acme.tf import utils as tf2_utils +import launchpad as lp +import numpy as np +import sonnet as snt + +from absl.testing import absltest + + +def make_networks(action_spec: specs.BoundedArray): + """Creates simple networks for testing..""" + + num_dimensions = np.prod(action_spec.shape, dtype=int) + + # Create the observation network shared between the policy and critic. + observation_network = tf2_utils.batch_concat + + # Create the policy network (head) and the evaluation network. + policy_network = snt.Sequential([ + networks.LayerNormMLP([50], activate_final=True), + networks.NearZeroInitializedLinear(num_dimensions), + networks.TanhToSpec(action_spec) + ]) + evaluator_network = snt.Sequential([observation_network, policy_network]) + + # Create the critic network. + critic_network = snt.Sequential([ + # The multiplexer concatenates the observations/actions. + networks.CriticMultiplexer(), + networks.LayerNormMLP([50], activate_final=True), + networks.NearZeroInitializedLinear(1), + ]) + + return { + 'policy': policy_network, + 'critic': critic_network, + 'observation': observation_network, + 'evaluator': evaluator_network, + } + + +class DistributedAgentTest(absltest.TestCase): + """Simple integration/smoke test for the distributed agent.""" + + def test_agent(self): + + agent = ddpg.DistributedDDPG( + environment_factory=lambda x: fakes.ContinuousEnvironment(bounded=True), + network_factory=make_networks, + num_actors=2, + batch_size=32, + min_replay_size=32, + max_replay_size=1000, + ) + program = agent.build() + + (learner_node,) = program.groups['learner'] + learner_node.disable_run() + + lp.launch(program, launch_type='test_mt') + + learner: acme.Learner = learner_node.create_handle().dereference() + + for _ in range(5): + learner.step() + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/ddpg/agent_test.py b/acme/acme/agents/tf/ddpg/agent_test.py new file mode 100644 index 00000000..9287e8c2 --- /dev/null +++ b/acme/acme/agents/tf/ddpg/agent_test.py @@ -0,0 +1,82 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the DDPG agent.""" + +from typing import Dict, Sequence + +import acme +from acme import specs +from acme import types +from acme.agents.tf import ddpg +from acme.testing import fakes +from acme.tf import networks +import numpy as np +import sonnet as snt +import tensorflow as tf + +from absl.testing import absltest + + +def make_networks( + action_spec: types.NestedSpec, + policy_layer_sizes: Sequence[int] = (10, 10), + critic_layer_sizes: Sequence[int] = (10, 10), +) -> Dict[str, snt.Module]: + """Creates networks used by the agent.""" + + num_dimensions = np.prod(action_spec.shape, dtype=int) + policy_layer_sizes = list(policy_layer_sizes) + [num_dimensions] + critic_layer_sizes = list(critic_layer_sizes) + [1] + + policy_network = snt.Sequential( + [networks.LayerNormMLP(policy_layer_sizes), tf.tanh]) + # The multiplexer concatenates the (maybe transformed) observations/actions. + critic_network = networks.CriticMultiplexer( + critic_network=networks.LayerNormMLP(critic_layer_sizes)) + + return { + 'policy': policy_network, + 'critic': critic_network, + } + + +class DDPGTest(absltest.TestCase): + + def test_ddpg(self): + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment(episode_length=10, bounded=True) + spec = specs.make_environment_spec(environment) + + # Create the networks to optimize (online) and target networks. + agent_networks = make_networks(spec.actions) + + # Construct the agent. + agent = ddpg.DDPG( + environment_spec=spec, + policy_network=agent_networks['policy'], + critic_network=agent_networks['critic'], + batch_size=10, + samples_per_insert=2, + min_replay_size=10, + ) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=2) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/ddpg/learning.py b/acme/acme/agents/tf/ddpg/learning.py new file mode 100644 index 00000000..1c74200a --- /dev/null +++ b/acme/acme/agents/tf/ddpg/learning.py @@ -0,0 +1,257 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DDPG learner implementation.""" + +import time +from typing import List, Optional + +import acme +from acme import types +from acme.tf import losses +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import numpy as np +import sonnet as snt +import tensorflow as tf +import tree +import trfl + + +class DDPGLearner(acme.Learner): + """DDPG learner. + + This is the learning component of a DDPG agent. IE it takes a dataset as input + and implements update functionality to learn from this dataset. + """ + + def __init__( + self, + policy_network: snt.Module, + critic_network: snt.Module, + target_policy_network: snt.Module, + target_critic_network: snt.Module, + discount: float, + target_update_period: int, + dataset: tf.data.Dataset, + observation_network: types.TensorTransformation = lambda x: x, + target_observation_network: types.TensorTransformation = lambda x: x, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + clipping: bool = True, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + ): + """Initializes the learner. + + Args: + policy_network: the online (optimized) policy. + critic_network: the online critic. + target_policy_network: the target policy (which lags behind the online + policy). + target_critic_network: the target critic. + discount: discount to use for TD updates. + target_update_period: number of learner steps to perform before updating + the target networks. + dataset: dataset to learn from, whether fixed or from a replay buffer + (see `acme.datasets.reverb.make_reverb_dataset` documentation). + observation_network: an optional online network to process observations + before the policy and the critic. + target_observation_network: the target observation network. + policy_optimizer: the optimizer to be applied to the DPG (policy) loss. + critic_optimizer: the optimizer to be applied to the critic loss. + clipping: whether to clip gradients by global norm. + counter: counter object used to keep track of steps. + logger: logger object to be used by learner. + checkpoint: boolean indicating whether to checkpoint the learner. + """ + + # Store online and target networks. + self._policy_network = policy_network + self._critic_network = critic_network + self._target_policy_network = target_policy_network + self._target_critic_network = target_critic_network + + # Make sure observation networks are snt.Module's so they have variables. + self._observation_network = tf2_utils.to_sonnet_module(observation_network) + self._target_observation_network = tf2_utils.to_sonnet_module( + target_observation_network) + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger('learner') + + # Other learner parameters. + self._discount = discount + self._clipping = clipping + + # Necessary to track when to update target networks. + self._num_steps = tf.Variable(0, dtype=tf.int32) + self._target_update_period = target_update_period + + # Create an iterator to go through the dataset. + # TODO(b/155086959): Fix type stubs and remove. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + + # Create optimizers if they aren't given. + self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + + # Expose the variables. + policy_network_to_expose = snt.Sequential( + [self._target_observation_network, self._target_policy_network]) + self._variables = { + 'critic': target_critic_network.variables, + 'policy': policy_network_to_expose.variables, + } + + self._checkpointer = tf2_savers.Checkpointer( + time_delta_minutes=5, + objects_to_save={ + 'counter': self._counter, + 'policy': self._policy_network, + 'critic': self._critic_network, + 'target_policy': self._target_policy_network, + 'target_critic': self._target_critic_network, + 'policy_optimizer': self._policy_optimizer, + 'critic_optimizer': self._critic_optimizer, + 'num_steps': self._num_steps, + }, + enable_checkpointing=checkpoint, + ) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + @tf.function + def _step(self): + # Update target network. + online_variables = ( + *self._observation_network.variables, + *self._critic_network.variables, + *self._policy_network.variables, + ) + target_variables = ( + *self._target_observation_network.variables, + *self._target_critic_network.variables, + *self._target_policy_network.variables, + ) + + # Make online -> target network update ops. + if tf.math.mod(self._num_steps, self._target_update_period) == 0: + for src, dest in zip(online_variables, target_variables): + dest.assign(src) + self._num_steps.assign_add(1) + + # Get data from replay (dropping extras if any). Note there is no + # extra data here because we do not insert any into Reverb. + inputs = next(self._iterator) + transitions: types.Transition = inputs.data + + # Cast the additional discount to match the environment discount dtype. + discount = tf.cast(self._discount, dtype=transitions.discount.dtype) + + with tf.GradientTape(persistent=True) as tape: + # Maybe transform the observation before feeding into policy and critic. + # Transforming the observations this way at the start of the learning + # step effectively means that the policy and critic share observation + # network weights. + o_tm1 = self._observation_network(transitions.observation) + o_t = self._target_observation_network(transitions.next_observation) + # This stop_gradient prevents gradients to propagate into the target + # observation network. In addition, since the online policy network is + # evaluated at o_t, this also means the policy loss does not influence + # the observation network training. + o_t = tree.map_structure(tf.stop_gradient, o_t) + + # Critic learning. + q_tm1 = self._critic_network(o_tm1, transitions.action) + q_t = self._target_critic_network(o_t, self._target_policy_network(o_t)) + + # Squeeze into the shape expected by the td_learning implementation. + q_tm1 = tf.squeeze(q_tm1, axis=-1) # [B] + q_t = tf.squeeze(q_t, axis=-1) # [B] + + # Critic loss. + critic_loss = trfl.td_learning(q_tm1, transitions.reward, + discount * transitions.discount, q_t).loss + critic_loss = tf.reduce_mean(critic_loss, axis=0) + + # Actor learning. + dpg_a_t = self._policy_network(o_t) + dpg_q_t = self._critic_network(o_t, dpg_a_t) + + # Actor loss. If clipping is true use dqda clipping and clip the norm. + dqda_clipping = 1.0 if self._clipping else None + policy_loss = losses.dpg( + dpg_q_t, + dpg_a_t, + tape=tape, + dqda_clipping=dqda_clipping, + clip_norm=self._clipping) + policy_loss = tf.reduce_mean(policy_loss, axis=0) + + # Get trainable variables. + policy_variables = self._policy_network.trainable_variables + critic_variables = ( + # In this agent, the critic loss trains the observation network. + self._observation_network.trainable_variables + + self._critic_network.trainable_variables) + + # Compute gradients. + policy_gradients = tape.gradient(policy_loss, policy_variables) + critic_gradients = tape.gradient(critic_loss, critic_variables) + + # Delete the tape manually because of the persistent=True flag. + del tape + + # Maybe clip gradients. + if self._clipping: + policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.)[0] + critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0] + + # Apply gradients. + self._policy_optimizer.apply(policy_gradients, policy_variables) + self._critic_optimizer.apply(critic_gradients, critic_variables) + + # Losses to track. + return { + 'critic_loss': critic_loss, + 'policy_loss': policy_loss, + } + + def step(self): + # Run the learning step. + fetches = self._step() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + fetches.update(counts) + + # Checkpoint and attempt to write the logs. + self._checkpointer.save() + self._logger.write(fetches) + + def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: + return [tf2_utils.to_numpy(self._variables[name]) for name in names] diff --git a/acme/acme/agents/tf/dmpo/README.md b/acme/acme/agents/tf/dmpo/README.md new file mode 100644 index 00000000..e0ac5f51 --- /dev/null +++ b/acme/acme/agents/tf/dmpo/README.md @@ -0,0 +1,32 @@ +# Distributional Maximum a posteriori Policy Optimization (DMPO) + +This folder contains an implementation of a novel agent (DMPO) introduced in the +original Acme release. This work extends the MPO algorithm +([Abdolmaleki et al., 2018a], [2018b]) by using a distributional Q-network +similar to C51 ([Bellemare et al., 2017]). Therefore, as in the case of the D4PG +agent, this algorithm's critic outputs a distribution over state-action values. + +As in our MPO agent, this is a more general algorithm, the current +implementation targets the continuous control setting and is most readily +applied to the DeepMind control suite or similar control tasks. This +implementation also includes the options of: + +* per-dimension KL constraint satisfaction, and +* action penalization via the multi-objective MPO work of + [Abdolmaleki et al., 2020]. + +Detailed notes: + +* The `vmin|vmax` hyperparameters of the distributional critic may need tuning + depending on your environment's rewards. A good rule of thumb is to set + `vmax` to the discounted sum of the maximum instantaneous rewards for the + maximum episode length; then set `vmin` to `-vmax`. +* When using per-dimension KL constraint satisfaction, you may need to tune + the value of `epsilon_mean` (and `epsilon_stddev` if not fixed). A good rule + of thumb would be to divide it by the number of dimensions in the action + space. + +[Abdolmaleki et al., 2018a]: https://arxiv.org/pdf/1806.06920.pdf +[2018b]: https://arxiv.org/pdf/1812.02256.pdf +[Abdolmaleki et al., 2020]: https://arxiv.org/pdf/2005.07513.pdf +[Bellemare et al., 2017]: https://arxiv.org/abs/1707.06887 diff --git a/acme/acme/agents/tf/dmpo/__init__.py b/acme/acme/agents/tf/dmpo/__init__.py new file mode 100644 index 00000000..58b03546 --- /dev/null +++ b/acme/acme/agents/tf/dmpo/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementations of a distributional MPO agent.""" + +from acme.agents.tf.dmpo.agent import DistributionalMPO +from acme.agents.tf.dmpo.agent_distributed import DistributedDistributionalMPO +from acme.agents.tf.dmpo.learning import DistributionalMPOLearner diff --git a/acme/acme/agents/tf/dmpo/agent.py b/acme/acme/agents/tf/dmpo/agent.py new file mode 100644 index 00000000..8ca0c621 --- /dev/null +++ b/acme/acme/agents/tf/dmpo/agent.py @@ -0,0 +1,188 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Distributional MPO agent implementation.""" + +import copy +from typing import Optional + +from acme import datasets +from acme import specs +from acme import types +from acme.adders import reverb as adders +from acme.agents import agent +from acme.agents.tf import actors +from acme.agents.tf.dmpo import learning +from acme.tf import networks +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import reverb +import sonnet as snt +import tensorflow as tf + + +class DistributionalMPO(agent.Agent): + """Distributional MPO Agent. + + This implements a single-process distributional MPO agent. This is an + actor-critic algorithm that generates data via a behavior policy, inserts + N-step transitions into a replay buffer, and periodically updates the policy + (and as a result the behavior) by sampling uniformly from this buffer. + This agent distinguishes itself from the MPO agent by using a distributional + critic (state-action value approximator). + """ + + def __init__(self, + environment_spec: specs.EnvironmentSpec, + policy_network: snt.Module, + critic_network: snt.Module, + observation_network: types.TensorTransformation = tf.identity, + discount: float = 0.99, + batch_size: int = 256, + prefetch_size: int = 4, + target_policy_update_period: int = 100, + target_critic_update_period: int = 100, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: float = 32.0, + policy_loss_module: Optional[snt.Module] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + n_step: int = 5, + num_samples: int = 20, + clipping: bool = True, + logger: Optional[loggers.Logger] = None, + counter: Optional[counting.Counter] = None, + checkpoint: bool = True, + replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE): + """Initialize the agent. + + Args: + environment_spec: description of the actions, observations, etc. + policy_network: the online (optimized) policy. + critic_network: the online critic. + observation_network: optional network to transform the observations before + they are fed into any network. + discount: discount to use for TD updates. + batch_size: batch size for updates. + prefetch_size: size to prefetch from replay. + target_policy_update_period: number of updates to perform before updating + the target policy network. + target_critic_update_period: number of updates to perform before updating + the target critic network. + min_replay_size: minimum replay size before updating. + max_replay_size: maximum replay size. + samples_per_insert: number of samples to take from replay for every insert + that is made. + policy_loss_module: configured MPO loss function for the policy + optimization; defaults to sensible values on the control suite. + See `acme/tf/losses/mpo.py` for more details. + policy_optimizer: optimizer to be used on the policy. + critic_optimizer: optimizer to be used on the critic. + n_step: number of steps to squash into a single transition. + num_samples: number of actions to sample when doing a Monte Carlo + integration with respect to the policy. + clipping: whether to clip gradients by global norm. + logger: logging object used to write to logs. + counter: counter object used to keep track of steps. + checkpoint: boolean indicating whether to checkpoint the learner. + replay_table_name: string indicating what name to give the replay table. + """ + + # Create a replay server to add data to. + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=max_replay_size, + rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), + signature=adders.NStepTransitionAdder.signature(environment_spec)) + self._server = reverb.Server([replay_table], port=None) + + # The adder is used to insert observations into replay. + address = f'localhost:{self._server.port}' + adder = adders.NStepTransitionAdder( + client=reverb.Client(address), + n_step=n_step, + discount=discount) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset( + table=replay_table_name, + server_address=address, + batch_size=batch_size, + prefetch_size=prefetch_size) + + # Make sure observation network is a Sonnet Module. + observation_network = tf2_utils.to_sonnet_module(observation_network) + + # Create target networks before creating online/target network variables. + target_policy_network = copy.deepcopy(policy_network) + target_critic_network = copy.deepcopy(critic_network) + target_observation_network = copy.deepcopy(observation_network) + + # Get observation and action specs. + act_spec = environment_spec.actions + obs_spec = environment_spec.observations + emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) + + # Create the behavior policy. + behavior_network = snt.Sequential([ + observation_network, + policy_network, + networks.StochasticSamplingHead(), + ]) + + # Create variables. + tf2_utils.create_variables(policy_network, [emb_spec]) + tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) + tf2_utils.create_variables(target_policy_network, [emb_spec]) + tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) + tf2_utils.create_variables(target_observation_network, [obs_spec]) + + # Create the actor which defines how we take actions. + actor = actors.FeedForwardActor( + policy_network=behavior_network, adder=adder) + + # Create optimizers. + policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + + # The learner updates the parameters (and initializes them). + learner = learning.DistributionalMPOLearner( + policy_network=policy_network, + critic_network=critic_network, + observation_network=observation_network, + target_policy_network=target_policy_network, + target_critic_network=target_critic_network, + target_observation_network=target_observation_network, + policy_loss_module=policy_loss_module, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + clipping=clipping, + discount=discount, + num_samples=num_samples, + target_policy_update_period=target_policy_update_period, + target_critic_update_period=target_critic_update_period, + dataset=dataset, + logger=logger, + counter=counter, + checkpoint=checkpoint) + + super().__init__( + actor=actor, + learner=learner, + min_observations=max(batch_size, min_replay_size), + observations_per_step=float(batch_size) / samples_per_insert) diff --git a/acme/acme/agents/tf/dmpo/agent_distributed.py b/acme/acme/agents/tf/dmpo/agent_distributed.py new file mode 100644 index 00000000..4fff2a17 --- /dev/null +++ b/acme/acme/agents/tf/dmpo/agent_distributed.py @@ -0,0 +1,358 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the distributional MPO distributed agent class.""" + +from typing import Callable, Dict, Optional, Sequence + +import acme +from acme import datasets +from acme import specs +from acme import types +from acme.adders import reverb as adders +from acme.agents.tf import actors +from acme.agents.tf.dmpo import learning +from acme.datasets import image_augmentation +from acme.tf import networks +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.tf import variable_utils as tf2_variable_utils +from acme.utils import counting +from acme.utils import loggers +from acme.utils import lp_utils +from acme.utils import observers as observers_lib +import dm_env +import launchpad as lp +import reverb +import sonnet as snt +import tensorflow as tf + + +class DistributedDistributionalMPO: + """Program definition for distributional MPO.""" + + def __init__( + self, + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: Callable[[specs.BoundedArray], Dict[str, snt.Module]], + num_actors: int = 1, + num_caches: int = 0, + environment_spec: Optional[specs.EnvironmentSpec] = None, + batch_size: int = 256, + prefetch_size: int = 4, + observation_augmentation: Optional[types.TensorTransformation] = None, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: Optional[float] = 32.0, + n_step: int = 5, + num_samples: int = 20, + additional_discount: float = 0.99, + target_policy_update_period: int = 100, + target_critic_update_period: int = 100, + variable_update_period: int = 1000, + policy_loss_factory: Optional[Callable[[], snt.Module]] = None, + max_actor_steps: Optional[int] = None, + log_every: float = 10.0, + make_observers: Optional[Callable[ + [], Sequence[observers_lib.EnvLoopObserver]]] = None): + + if environment_spec is None: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + self._environment_factory = environment_factory + self._network_factory = network_factory + self._policy_loss_factory = policy_loss_factory + self._environment_spec = environment_spec + self._num_actors = num_actors + self._num_caches = num_caches + self._batch_size = batch_size + self._prefetch_size = prefetch_size + self._observation_augmentation = observation_augmentation + self._min_replay_size = min_replay_size + self._max_replay_size = max_replay_size + self._samples_per_insert = samples_per_insert + self._n_step = n_step + self._additional_discount = additional_discount + self._num_samples = num_samples + self._target_policy_update_period = target_policy_update_period + self._target_critic_update_period = target_critic_update_period + self._variable_update_period = variable_update_period + self._max_actor_steps = max_actor_steps + self._log_every = log_every + self._make_observers = make_observers + + def replay(self): + """The replay storage.""" + if self._samples_per_insert is not None: + # Create enough of an error buffer to give a 10% tolerance in rate. + samples_per_insert_tolerance = 0.1 * self._samples_per_insert + error_buffer = self._min_replay_size * samples_per_insert_tolerance + + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._min_replay_size, + samples_per_insert=self._samples_per_insert, + error_buffer=error_buffer) + else: + limiter = reverb.rate_limiters.MinSize(self._min_replay_size) + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._max_replay_size, + rate_limiter=limiter, + signature=adders.NStepTransitionAdder.signature( + self._environment_spec)) + return [replay_table] + + def counter(self): + return tf2_savers.CheckpointingRunner(counting.Counter(), + time_delta_minutes=1, + subdirectory='counter') + + def coordinator(self, counter: counting.Counter, max_actor_steps: int): + return lp_utils.StepsLimiter(counter, max_actor_steps) + + def learner( + self, + replay: reverb.Client, + counter: counting.Counter, + ): + """The Learning part of the agent.""" + + act_spec = self._environment_spec.actions + obs_spec = self._environment_spec.observations + + # Create online and target networks. + online_networks = self._network_factory(act_spec) + target_networks = self._network_factory(act_spec) + + # Make sure observation network is a Sonnet Module. + observation_network = online_networks.get('observation', tf.identity) + target_observation_network = target_networks.get('observation', tf.identity) + observation_network = tf2_utils.to_sonnet_module(observation_network) + target_observation_network = tf2_utils.to_sonnet_module( + target_observation_network) + + # Get embedding spec and create observation network variables. + emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) + + # Create variables. + tf2_utils.create_variables(online_networks['policy'], [emb_spec]) + tf2_utils.create_variables(online_networks['critic'], [emb_spec, act_spec]) + tf2_utils.create_variables(target_networks['policy'], [emb_spec]) + tf2_utils.create_variables(target_networks['critic'], [emb_spec, act_spec]) + tf2_utils.create_variables(target_observation_network, [obs_spec]) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset(server_address=replay.server_address) + dataset = dataset.batch(self._batch_size, drop_remainder=True) + if self._observation_augmentation: + transform = image_augmentation.make_transform( + observation_transform=self._observation_augmentation) + dataset = dataset.map( + transform, num_parallel_calls=16, deterministic=False) + dataset = dataset.prefetch(self._prefetch_size) + + counter = counting.Counter(counter, 'learner') + logger = loggers.make_default_logger( + 'learner', time_delta=self._log_every, steps_key='learner_steps') + + # Create policy loss module if a factory is passed. + if self._policy_loss_factory: + policy_loss_module = self._policy_loss_factory() + else: + policy_loss_module = None + + # Return the learning agent. + return learning.DistributionalMPOLearner( + policy_network=online_networks['policy'], + critic_network=online_networks['critic'], + observation_network=observation_network, + target_policy_network=target_networks['policy'], + target_critic_network=target_networks['critic'], + target_observation_network=target_observation_network, + discount=self._additional_discount, + num_samples=self._num_samples, + target_policy_update_period=self._target_policy_update_period, + target_critic_update_period=self._target_critic_update_period, + policy_loss_module=policy_loss_module, + dataset=dataset, + counter=counter, + logger=logger) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + actor_id: int, + ) -> acme.EnvironmentLoop: + """The actor process.""" + + action_spec = self._environment_spec.actions + observation_spec = self._environment_spec.observations + + # Create environment and target networks to act with. + environment = self._environment_factory(False) + agent_networks = self._network_factory(action_spec) + + # Make sure observation network is defined. + observation_network = agent_networks.get('observation', tf.identity) + + # Create a stochastic behavior policy. + behavior_network = snt.Sequential([ + observation_network, + agent_networks['policy'], + networks.StochasticSamplingHead(), + ]) + + # Ensure network variables are created. + tf2_utils.create_variables(behavior_network, [observation_spec]) + policy_variables = {'policy': behavior_network.variables} + + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = tf2_variable_utils.VariableClient( + variable_source, + policy_variables, + update_period=self._variable_update_period) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + # Component to add things into replay. + adder = adders.NStepTransitionAdder( + client=replay, + n_step=self._n_step, + discount=self._additional_discount) + + # Create the agent. + actor = actors.FeedForwardActor( + policy_network=behavior_network, + adder=adder, + variable_client=variable_client) + + # Create logger and counter; only the first actor stores logs to bigtable. + save_data = actor_id == 0 + counter = counting.Counter(counter, 'actor') + logger = loggers.make_default_logger( + 'actor', + save_data=save_data, + time_delta=self._log_every, + steps_key='actor_steps') + observers = self._make_observers() if self._make_observers else () + + # Create the run loop and return it. + return acme.EnvironmentLoop( + environment, actor, counter, logger, observers=observers) + + def evaluator( + self, + variable_source: acme.VariableSource, + counter: counting.Counter, + ): + """The evaluation process.""" + + action_spec = self._environment_spec.actions + observation_spec = self._environment_spec.observations + + # Create environment and target networks to act with. + environment = self._environment_factory(True) + agent_networks = self._network_factory(action_spec) + + # Make sure observation network is defined. + observation_network = agent_networks.get('observation', tf.identity) + + # Create a stochastic behavior policy. + evaluator_network = snt.Sequential([ + observation_network, + agent_networks['policy'], + networks.StochasticMeanHead(), + ]) + + # Ensure network variables are created. + tf2_utils.create_variables(evaluator_network, [observation_spec]) + policy_variables = {'policy': evaluator_network.variables} + + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = tf2_variable_utils.VariableClient( + variable_source, + policy_variables, + update_period=self._variable_update_period) + + # Make sure not to evaluate a random actor by assigning variables before + # running the environment loop. + variable_client.update_and_wait() + + # Create the agent. + evaluator = actors.FeedForwardActor( + policy_network=evaluator_network, variable_client=variable_client) + + # Create logger and counter. + counter = counting.Counter(counter, 'evaluator') + logger = loggers.make_default_logger( + 'evaluator', time_delta=self._log_every, steps_key='evaluator_steps') + observers = self._make_observers() if self._make_observers else () + + # Create the run loop and return it. + return acme.EnvironmentLoop( + environment, + evaluator, + counter, + logger, + observers=observers) + + def build(self, name='dmpo'): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group('replay'): + replay = program.add_node(lp.ReverbNode(self.replay)) + + with program.group('counter'): + counter = program.add_node(lp.CourierNode(self.counter)) + + if self._max_actor_steps: + _ = program.add_node( + lp.CourierNode(self.coordinator, counter, self._max_actor_steps)) + + with program.group('learner'): + learner = program.add_node( + lp.CourierNode(self.learner, replay, counter)) + + with program.group('evaluator'): + program.add_node( + lp.CourierNode(self.evaluator, learner, counter)) + + if not self._num_caches: + # Use our learner as a single variable source. + sources = [learner] + else: + with program.group('cacher'): + # Create a set of learner caches. + sources = [] + for _ in range(self._num_caches): + cacher = program.add_node( + lp.CacherNode( + learner, refresh_interval_ms=2000, stale_after_ms=4000)) + sources.append(cacher) + + with program.group('actor'): + # Add actors which pull round-robin from our variable sources. + for actor_id in range(self._num_actors): + source = sources[actor_id % len(sources)] + program.add_node( + lp.CourierNode(self.actor, replay, source, counter, actor_id)) + + return program diff --git a/acme/acme/agents/tf/dmpo/agent_distributed_test.py b/acme/acme/agents/tf/dmpo/agent_distributed_test.py new file mode 100644 index 00000000..80085a42 --- /dev/null +++ b/acme/acme/agents/tf/dmpo/agent_distributed_test.py @@ -0,0 +1,97 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test for the distributed agent.""" + +from typing import Sequence + +import acme +from acme import specs +from acme.agents.tf import dmpo +from acme.testing import fakes +from acme.tf import networks +from acme.tf import utils as tf2_utils +import launchpad as lp +import numpy as np +import sonnet as snt + +from absl.testing import absltest + + +def make_networks( + action_spec: specs.BoundedArray, + policy_layer_sizes: Sequence[int] = (50,), + critic_layer_sizes: Sequence[int] = (50,), + vmin: float = -150., + vmax: float = 150., + num_atoms: int = 51, +): + """Creates networks used by the agent.""" + + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential([ + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + init_scale=0.3, + fixed_scale=True, + use_tfd_independent=False) + ]) + + # The multiplexer concatenates the (maybe transformed) observations/actions. + critic_network = networks.CriticMultiplexer( + critic_network=networks.LayerNormMLP( + critic_layer_sizes, activate_final=True), + action_network=networks.ClipToSpec(action_spec)) + critic_network = snt.Sequential( + [critic_network, + networks.DiscreteValuedHead(vmin, vmax, num_atoms)]) + + return { + 'policy': policy_network, + 'critic': critic_network, + 'observation': tf2_utils.batch_concat, + } + + +class DistributedAgentTest(absltest.TestCase): + """Simple integration/smoke test for the distributed agent.""" + + def test_agent(self): + + agent = dmpo.DistributedDistributionalMPO( + environment_factory=lambda x: fakes.ContinuousEnvironment(bounded=True), + network_factory=make_networks, + num_actors=2, + batch_size=32, + min_replay_size=32, + max_replay_size=1000, + ) + program = agent.build() + + (learner_node,) = program.groups['learner'] + learner_node.disable_run() + + lp.launch(program, launch_type='test_mt') + + learner: acme.Learner = learner_node.create_handle().dereference() + + for _ in range(5): + learner.step() + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/dmpo/agent_test.py b/acme/acme/agents/tf/dmpo/agent_test.py new file mode 100644 index 00000000..366a35cf --- /dev/null +++ b/acme/acme/agents/tf/dmpo/agent_test.py @@ -0,0 +1,83 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the distributional MPO agent.""" + +from typing import Dict, Sequence + +import acme +from acme import specs +from acme.agents.tf import dmpo +from acme.testing import fakes +from acme.tf import networks +import numpy as np +import sonnet as snt + +from absl.testing import absltest + + +def make_networks( + action_spec: specs.Array, + policy_layer_sizes: Sequence[int] = (300, 200), + critic_layer_sizes: Sequence[int] = (400, 300), +) -> Dict[str, snt.Module]: + """Creates networks used by the agent.""" + + num_dimensions = np.prod(action_spec.shape, dtype=int) + critic_layer_sizes = list(critic_layer_sizes) + + policy_network = snt.Sequential([ + networks.LayerNormMLP(policy_layer_sizes), + networks.MultivariateNormalDiagHead(num_dimensions), + ]) + # The multiplexer concatenates the (maybe transformed) observations/actions. + critic_network = snt.Sequential([ + networks.CriticMultiplexer( + critic_network=networks.LayerNormMLP(critic_layer_sizes)), + networks.DiscreteValuedHead(0., 1., 10), + ]) + + return { + 'policy': policy_network, + 'critic': critic_network, + } + + +class DMPOTest(absltest.TestCase): + + def test_dmpo(self): + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment(episode_length=10) + spec = specs.make_environment_spec(environment) + + # Create networks. + agent_networks = make_networks(spec.actions) + + # Construct the agent. + agent = dmpo.DistributionalMPO( + spec, + policy_network=agent_networks['policy'], + critic_network=agent_networks['critic'], + batch_size=10, + samples_per_insert=2, + min_replay_size=10) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=2) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/dmpo/learning.py b/acme/acme/agents/tf/dmpo/learning.py new file mode 100644 index 00000000..1812297f --- /dev/null +++ b/acme/acme/agents/tf/dmpo/learning.py @@ -0,0 +1,300 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Distributional MPO learner implementation.""" + +import time +from typing import List, Optional + +import acme +from acme import types +from acme.tf import losses +from acme.tf import networks +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import numpy as np +import sonnet as snt +import tensorflow as tf + + +class DistributionalMPOLearner(acme.Learner): + """Distributional MPO learner.""" + + def __init__( + self, + policy_network: snt.Module, + critic_network: snt.Module, + target_policy_network: snt.Module, + target_critic_network: snt.Module, + discount: float, + num_samples: int, + target_policy_update_period: int, + target_critic_update_period: int, + dataset: tf.data.Dataset, + observation_network: types.TensorTransformation = tf.identity, + target_observation_network: types.TensorTransformation = tf.identity, + policy_loss_module: Optional[snt.Module] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + dual_optimizer: Optional[snt.Optimizer] = None, + clipping: bool = True, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + ): + + # Store online and target networks. + self._policy_network = policy_network + self._critic_network = critic_network + self._target_policy_network = target_policy_network + self._target_critic_network = target_critic_network + + # Make sure observation networks are snt.Module's so they have variables. + self._observation_network = tf2_utils.to_sonnet_module(observation_network) + self._target_observation_network = tf2_utils.to_sonnet_module( + target_observation_network) + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger('learner') + + # Other learner parameters. + self._discount = discount + self._num_samples = num_samples + self._clipping = clipping + + # Necessary to track when to update target networks. + self._num_steps = tf.Variable(0, dtype=tf.int32) + self._target_policy_update_period = target_policy_update_period + self._target_critic_update_period = target_critic_update_period + + # Batch dataset and create iterator. + # TODO(b/155086959): Fix type stubs and remove. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + + self._policy_loss_module = policy_loss_module or losses.MPO( + epsilon=1e-1, + epsilon_penalty=1e-3, + epsilon_mean=2.5e-3, + epsilon_stddev=1e-6, + init_log_temperature=10., + init_log_alpha_mean=10., + init_log_alpha_stddev=1000.) + + # Create the optimizers. + self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2) + + # Expose the variables. + policy_network_to_expose = snt.Sequential( + [self._target_observation_network, self._target_policy_network]) + self._variables = { + 'critic': self._target_critic_network.variables, + 'policy': policy_network_to_expose.variables, + } + + # Create a checkpointer and snapshotter object. + self._checkpointer = None + self._snapshotter = None + + if checkpoint: + self._checkpointer = tf2_savers.Checkpointer( + subdirectory='dmpo_learner', + objects_to_save={ + 'counter': self._counter, + 'policy': self._policy_network, + 'critic': self._critic_network, + 'observation': self._observation_network, + 'target_policy': self._target_policy_network, + 'target_critic': self._target_critic_network, + 'target_observation': self._target_observation_network, + 'policy_optimizer': self._policy_optimizer, + 'critic_optimizer': self._critic_optimizer, + 'dual_optimizer': self._dual_optimizer, + 'policy_loss_module': self._policy_loss_module, + 'num_steps': self._num_steps, + }) + + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={ + 'policy': + snt.Sequential([ + self._target_observation_network, + self._target_policy_network + ]), + }) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + @tf.function + def _step(self) -> types.NestedTensor: + # Update target network. + online_policy_variables = self._policy_network.variables + target_policy_variables = self._target_policy_network.variables + online_critic_variables = ( + *self._observation_network.variables, + *self._critic_network.variables, + ) + target_critic_variables = ( + *self._target_observation_network.variables, + *self._target_critic_network.variables, + ) + + # Make online policy -> target policy network update ops. + if tf.math.mod(self._num_steps, self._target_policy_update_period) == 0: + for src, dest in zip(online_policy_variables, target_policy_variables): + dest.assign(src) + # Make online critic -> target critic network update ops. + if tf.math.mod(self._num_steps, self._target_critic_update_period) == 0: + for src, dest in zip(online_critic_variables, target_critic_variables): + dest.assign(src) + + self._num_steps.assign_add(1) + + # Get data from replay (dropping extras if any). Note there is no + # extra data here because we do not insert any into Reverb. + inputs = next(self._iterator) + transitions: types.Transition = inputs.data + + # Get batch size and scalar dtype. + batch_size = transitions.reward.shape[0] + + # Cast the additional discount to match the environment discount dtype. + discount = tf.cast(self._discount, dtype=transitions.discount.dtype) + + with tf.GradientTape(persistent=True) as tape: + # Maybe transform the observation before feeding into policy and critic. + # Transforming the observations this way at the start of the learning + # step effectively means that the policy and critic share observation + # network weights. + o_tm1 = self._observation_network(transitions.observation) + # This stop_gradient prevents gradients to propagate into the target + # observation network. In addition, since the online policy network is + # evaluated at o_t, this also means the policy loss does not influence + # the observation network training. + o_t = tf.stop_gradient( + self._target_observation_network(transitions.next_observation)) + + # Get online and target action distributions from policy networks. + online_action_distribution = self._policy_network(o_t) + target_action_distribution = self._target_policy_network(o_t) + + # Sample actions to evaluate policy; of size [N, B, ...]. + sampled_actions = target_action_distribution.sample(self._num_samples) + + # Tile embedded observations to feed into the target critic network. + # Note: this is more efficient than tiling before the embedding layer. + tiled_o_t = tf2_utils.tile_tensor(o_t, self._num_samples) # [N, B, ...] + + # Compute target-estimated distributional value of sampled actions at o_t. + sampled_q_t_distributions = self._target_critic_network( + # Merge batch dimensions; to shape [N*B, ...]. + snt.merge_leading_dims(tiled_o_t, num_dims=2), + snt.merge_leading_dims(sampled_actions, num_dims=2)) + + # Compute average logits by first reshaping them and normalizing them + # across atoms. + new_shape = [self._num_samples, batch_size, -1] # [N, B, A] + sampled_logits = tf.reshape(sampled_q_t_distributions.logits, new_shape) + sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1) + averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0) + + # Construct the expected distributional value for bootstrapping. + q_t_distribution = networks.DiscreteValuedDistribution( + values=sampled_q_t_distributions.values, logits=averaged_logits) + + # Compute online critic value distribution of a_tm1 in state o_tm1. + q_tm1_distribution = self._critic_network(o_tm1, transitions.action) + + # Compute critic distributional loss. + critic_loss = losses.categorical(q_tm1_distribution, transitions.reward, + discount * transitions.discount, + q_t_distribution) + critic_loss = tf.reduce_mean(critic_loss) + + # Compute Q-values of sampled actions and reshape to [N, B]. + sampled_q_values = sampled_q_t_distributions.mean() + sampled_q_values = tf.reshape(sampled_q_values, (self._num_samples, -1)) + + # Compute MPO policy loss. + policy_loss, policy_stats = self._policy_loss_module( + online_action_distribution=online_action_distribution, + target_action_distribution=target_action_distribution, + actions=sampled_actions, + q_values=sampled_q_values) + + # For clarity, explicitly define which variables are trained by which loss. + critic_trainable_variables = ( + # In this agent, the critic loss trains the observation network. + self._observation_network.trainable_variables + + self._critic_network.trainable_variables) + policy_trainable_variables = self._policy_network.trainable_variables + # The following are the MPO dual variables, stored in the loss module. + dual_trainable_variables = self._policy_loss_module.trainable_variables + + # Compute gradients. + critic_gradients = tape.gradient(critic_loss, critic_trainable_variables) + policy_gradients, dual_gradients = tape.gradient( + policy_loss, (policy_trainable_variables, dual_trainable_variables)) + + # Delete the tape manually because of the persistent=True flag. + del tape + + # Maybe clip gradients. + if self._clipping: + policy_gradients = tuple(tf.clip_by_global_norm(policy_gradients, 40.)[0]) + critic_gradients = tuple(tf.clip_by_global_norm(critic_gradients, 40.)[0]) + + # Apply gradients. + self._critic_optimizer.apply(critic_gradients, critic_trainable_variables) + self._policy_optimizer.apply(policy_gradients, policy_trainable_variables) + self._dual_optimizer.apply(dual_gradients, dual_trainable_variables) + + # Losses to track. + fetches = { + 'critic_loss': critic_loss, + 'policy_loss': policy_loss, + } + fetches.update(policy_stats) # Log MPO stats. + + return fetches + + def step(self): + # Run the learning step. + fetches = self._step() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + fetches.update(counts) + + # Checkpoint and attempt to write the logs. + if self._checkpointer is not None: + self._checkpointer.save() + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(fetches) + + def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: + return [tf2_utils.to_numpy(self._variables[name]) for name in names] diff --git a/acme/acme/agents/tf/dqfd/README.md b/acme/acme/agents/tf/dqfd/README.md new file mode 100644 index 00000000..d53ebb5d --- /dev/null +++ b/acme/acme/agents/tf/dqfd/README.md @@ -0,0 +1,7 @@ +# Deep Q-learning from Demonstrations (DQfD) + +This folder contains an implementation of the DQfD algorithm +([Hester et al., 2017]). This agent extends DQN by mixing expert demonstrations +with the agent's experience in each mini-batch. + +[Hester et al., 2017]: https://arxiv.org/abs/1704.03732 diff --git a/acme/acme/agents/tf/dqfd/__init__.py b/acme/acme/agents/tf/dqfd/__init__.py new file mode 100644 index 00000000..afba45de --- /dev/null +++ b/acme/acme/agents/tf/dqfd/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for DQfD.""" + +from acme.agents.tf.dqfd.agent import DQfD +from acme.agents.tf.dqfd.bsuite_demonstrations import DemonstrationRecorder diff --git a/acme/acme/agents/tf/dqfd/agent.py b/acme/acme/agents/tf/dqfd/agent.py new file mode 100644 index 00000000..8b9d21aa --- /dev/null +++ b/acme/acme/agents/tf/dqfd/agent.py @@ -0,0 +1,211 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DQfD Agent implementation.""" + +import copy +import functools +import operator +from typing import Optional + +from acme import datasets +from acme import specs +from acme import types as acme_types +from acme.adders import reverb as adders +from acme.agents import agent +from acme.agents.tf import actors +from acme.agents.tf import dqn +from acme.tf import utils as tf2_utils +import reverb +import sonnet as snt +import tensorflow as tf +import tree +import trfl + + +class DQfD(agent.Agent): + """DQfD agent. + + This implements a single-process DQN agent that mixes demonstrations with + actor experience. + """ + + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + network: snt.Module, + demonstration_dataset: tf.data.Dataset, + demonstration_ratio: float, + batch_size: int = 256, + prefetch_size: int = 4, + target_update_period: int = 100, + samples_per_insert: float = 32.0, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + importance_sampling_exponent: float = 0.2, + n_step: int = 5, + epsilon: Optional[tf.Tensor] = None, + learning_rate: float = 1e-3, + discount: float = 0.99, + ): + """Initialize the agent. + + Args: + environment_spec: description of the actions, observations, etc. + network: the online Q network (the one being optimized) + demonstration_dataset: tf.data.Dataset producing (timestep, action) + tuples containing full episodes. + demonstration_ratio: Ratio of transitions coming from demonstrations. + batch_size: batch size for updates. + prefetch_size: size to prefetch from replay. + target_update_period: number of learner steps to perform before updating + the target networks. + samples_per_insert: number of samples to take from replay for every insert + that is made. + min_replay_size: minimum replay size before updating. This and all + following arguments are related to dataset construction and will be + ignored if a dataset argument is passed. + max_replay_size: maximum replay size. + importance_sampling_exponent: power to which importance weights are raised + before normalizing. + n_step: number of steps to squash into a single transition. + epsilon: probability of taking a random action; ignored if a policy + network is given. + learning_rate: learning rate for the q-network update. + discount: discount to use for TD updates. + """ + + # Create a replay server to add data to. This uses no limiter behavior in + # order to allow the Agent interface to handle it. + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=max_replay_size, + rate_limiter=reverb.rate_limiters.MinSize(1), + signature=adders.NStepTransitionAdder.signature(environment_spec)) + self._server = reverb.Server([replay_table], port=None) + + # The adder is used to insert observations into replay. + address = f'localhost:{self._server.port}' + adder = adders.NStepTransitionAdder( + client=reverb.Client(address), + n_step=n_step, + discount=discount) + + # The dataset provides an interface to sample from replay. + replay_client = reverb.TFClient(address) + dataset = datasets.make_reverb_dataset(server_address=address) + + # Combine with demonstration dataset. + transition = functools.partial(_n_step_transition_from_episode, + n_step=n_step, + discount=discount) + dataset_demos = demonstration_dataset.map(transition) + dataset = tf.data.experimental.sample_from_datasets( + [dataset, dataset_demos], + [1 - demonstration_ratio, demonstration_ratio]) + + # Batch and prefetch. + dataset = dataset.batch(batch_size, drop_remainder=True) + dataset = dataset.prefetch(prefetch_size) + + # Use constant 0.05 epsilon greedy policy by default. + if epsilon is None: + epsilon = tf.Variable(0.05, trainable=False) + policy_network = snt.Sequential([ + network, + lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), + ]) + + # Create a target network. + target_network = copy.deepcopy(network) + + # Ensure that we create the variables before proceeding (maybe not needed). + tf2_utils.create_variables(network, [environment_spec.observations]) + tf2_utils.create_variables(target_network, [environment_spec.observations]) + + # Create the actor which defines how we take actions. + actor = actors.FeedForwardActor(policy_network, adder) + + # The learner updates the parameters (and initializes them). + learner = dqn.DQNLearner( + network=network, + target_network=target_network, + discount=discount, + importance_sampling_exponent=importance_sampling_exponent, + learning_rate=learning_rate, + target_update_period=target_update_period, + dataset=dataset, + replay_client=replay_client) + + super().__init__( + actor=actor, + learner=learner, + min_observations=max(batch_size, min_replay_size), + observations_per_step=float(batch_size) / samples_per_insert) + + +def _n_step_transition_from_episode(observations: acme_types.NestedTensor, + actions: tf.Tensor, + rewards: tf.Tensor, + discounts: tf.Tensor, + n_step: int, + discount: float): + """Produce Reverb-like N-step transition from a full episode. + + Observations, actions, rewards and discounts have the same length. This + function will ignore the first reward and discount and the last action. + + Args: + observations: [L, ...] Tensor. + actions: [L, ...] Tensor. + rewards: [L] Tensor. + discounts: [L] Tensor. + n_step: number of steps to squash into a single transition. + discount: discount to use for TD updates. + + Returns: + (o_t, a_t, r_t, d_t, o_tp1) tuple. + """ + + max_index = tf.shape(rewards)[0] - 1 + first = tf.random.uniform(shape=(), minval=0, maxval=max_index - 1, + dtype=tf.int32) + last = tf.minimum(first + n_step, max_index) + + o_t = tree.map_structure(operator.itemgetter(first), observations) + a_t = tree.map_structure(operator.itemgetter(first), actions) + o_tp1 = tree.map_structure(operator.itemgetter(last), observations) + + # 0, 1, ..., n-1. + discount_range = tf.cast(tf.range(last - first), tf.float32) + # 1, g, ..., g^{n-1}. + additional_discounts = tf.pow(discount, discount_range) + # 1, d_t, d_t * d_{t+1}, ..., d_t * ... * d_{t+n-2}. + discounts = tf.concat([[1.], tf.math.cumprod(discounts[first:last-1])], 0) + # 1, g * d_t, ..., g^{n-1} * d_t * ... * d_{t+n-2}. + discounts *= additional_discounts + # r_t + g * d_t * r_{t+1} + ... + g^{n-1} * d_t * ... * d_{t+n-2} * r_{t+n-1} + # We have to shift rewards by one so last=max_index corresponds to transitions + # that include the last reward. + r_t = tf.reduce_sum(rewards[first+1:last+1] * discounts) + + # g^{n-1} * d_{t} * ... * d_{t+n-1}. + d_t = discounts[-1] + + info = tree.map_structure(lambda dtype: tf.ones([], dtype), + reverb.SampleInfo.tf_dtypes()) + return reverb.ReplaySample( + info=info, data=acme_types.Transition(o_t, a_t, r_t, d_t, o_tp1)) diff --git a/acme/acme/agents/tf/dqfd/agent_test.py b/acme/acme/agents/tf/dqfd/agent_test.py new file mode 100644 index 00000000..9b7d8c5c --- /dev/null +++ b/acme/acme/agents/tf/dqfd/agent_test.py @@ -0,0 +1,75 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for DQN agent.""" + +import acme +from acme import specs +from acme.agents.tf.dqfd import agent as dqfd +from acme.agents.tf.dqfd import bsuite_demonstrations +from acme.testing import fakes +import dm_env +import numpy as np +import sonnet as snt + +from absl.testing import absltest + + +def _make_network(action_spec: specs.DiscreteArray) -> snt.Module: + return snt.Sequential([ + snt.Flatten(), + snt.nets.MLP([50, 50, action_spec.num_values]), + ]) + + +class DQfDTest(absltest.TestCase): + + def test_dqfd(self): + # Create a fake environment to test with. + # TODO(b/152596848): Allow DQN to deal with integer observations. + environment = fakes.DiscreteEnvironment( + num_actions=5, + num_observations=10, + obs_dtype=np.float32, + episode_length=10) + spec = specs.make_environment_spec(environment) + + # Build demonstrations. + dummy_action = np.zeros((), dtype=np.int32) + recorder = bsuite_demonstrations.DemonstrationRecorder() + timestep = environment.reset() + while timestep.step_type is not dm_env.StepType.LAST: + recorder.step(timestep, dummy_action) + timestep = environment.step(dummy_action) + recorder.step(timestep, dummy_action) + recorder.record_episode() + + # Construct the agent. + agent = dqfd.DQfD( + environment_spec=spec, + network=_make_network(spec.actions), + demonstration_dataset=recorder.make_tf_dataset(), + demonstration_ratio=0.5, + batch_size=10, + samples_per_insert=2, + min_replay_size=10) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=10) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/dqfd/bsuite_demonstrations.py b/acme/acme/agents/tf/dqfd/bsuite_demonstrations.py new file mode 100644 index 00000000..67c9a8d5 --- /dev/null +++ b/acme/acme/agents/tf/dqfd/bsuite_demonstrations.py @@ -0,0 +1,135 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""bsuite demonstrations.""" + +from typing import Any, List + +from absl import flags +from bsuite.environments import deep_sea +import dm_env +import numpy as np +import tensorflow as tf +import tree + +FLAGS = flags.FLAGS + + +def _nested_stack(sequence: List[Any]): + """Stack nested elements in a sequence.""" + return tree.map_structure(lambda *x: np.stack(x), *sequence) + + +class DemonstrationRecorder: + """Records demonstrations. + + A demonstration is a (observation, action, reward, discount) tuple where + every element is a numpy array corresponding to a full episode. + """ + + def __init__(self): + self._demos = [] + self._reset_episode() + + def step(self, timestep: dm_env.TimeStep, action: np.ndarray): + reward = np.array(timestep.reward or 0, np.float32) + self._episode_reward += reward + self._episode.append((timestep.observation, action, reward, + np.array(timestep.discount or 0, np.float32))) + + def record_episode(self): + self._demos.append(_nested_stack(self._episode)) + self._reset_episode() + + def discard_episode(self): + self._reset_episode() + + def _reset_episode(self): + self._episode = [] + self._episode_reward = 0 + + @property + def episode_reward(self): + return self._episode_reward + + def make_tf_dataset(self): + types = tree.map_structure(lambda x: x.dtype, self._demos[0]) + shapes = tree.map_structure(lambda x: x.shape, self._demos[0]) + ds = tf.data.Dataset.from_generator(lambda: self._demos, types, shapes) + return ds.repeat().shuffle(len(self._demos)) + + +def _optimal_deep_sea_policy(environment: deep_sea.DeepSea, + timestep: dm_env.TimeStep): + action = environment._action_mapping[np.where(timestep.observation)] # pylint: disable=protected-access + return action[0].astype(np.int32) + + +def _run_optimal_deep_sea_episode(environment: deep_sea.DeepSea, + recorder: DemonstrationRecorder): + timestep = environment.reset() + while timestep.step_type is not dm_env.StepType.LAST: + action = _optimal_deep_sea_policy(environment, timestep) + recorder.step(timestep, action) + timestep = environment.step(action) + recorder.step(timestep, np.zeros_like(action)) + + +def _make_deep_sea_dataset(environment: deep_sea.DeepSea): + """Make DeepSea demonstration dataset.""" + + recorder = DemonstrationRecorder() + + _run_optimal_deep_sea_episode(environment, recorder) + assert recorder.episode_reward > 0 + recorder.record_episode() + return recorder.make_tf_dataset() + + +def _make_deep_sea_stochastic_dataset(environment: deep_sea.DeepSea): + """Make stochastic DeepSea demonstration dataset.""" + + recorder = DemonstrationRecorder() + + # Use 10*size demos, 80% success, 20% failure. + num_demos = environment._size * 10 # pylint: disable=protected-access + num_failures = num_demos // 5 + num_successes = num_demos - num_failures + + successes_saved = 0 + failures_saved = 0 + while (successes_saved < num_successes) or (failures_saved < num_failures): + _run_optimal_deep_sea_episode(environment, recorder) + + if recorder.episode_reward > 0 and successes_saved < num_successes: + recorder.record_episode() + successes_saved += 1 + elif recorder.episode_reward <= 0 and failures_saved < num_failures: + recorder.record_episode() + failures_saved += 1 + else: + recorder.discard_episode() + + return recorder.make_tf_dataset() + + +def make_dataset(environment: dm_env.Environment, stochastic: bool): + """Make bsuite demos for the current task.""" + + if not stochastic: + assert isinstance(environment, deep_sea.DeepSea) + return _make_deep_sea_dataset(environment) + else: + assert isinstance(environment, deep_sea.DeepSea) + return _make_deep_sea_stochastic_dataset(environment) diff --git a/acme/acme/agents/tf/dqn/README.md b/acme/acme/agents/tf/dqn/README.md new file mode 100644 index 00000000..4d32548b --- /dev/null +++ b/acme/acme/agents/tf/dqn/README.md @@ -0,0 +1,20 @@ +# Deep Q-Networks (DQN) + +This folder contains an implementation of the DQN algorithm +([Mnih et al., 2013], [Mnih et al., 2015]), with extras bells & whistles, +similar to Rainbow DQN ([Hessel et al., 2017]). + +* Q-learning with neural network function approximation. The loss is given by + the Huber loss applied to the temporal difference error. +* Target Q' network updated periodically ([Mnih et al., 2015]). +* N-step bootstrapping ([Sutton & Barto, 2018]). +* Double Q-learning ([van Hasselt et al., 2015]). +* Prioritized experience replay ([Schaul et al., 2015]). + +[Mnih et al., 2013]: https://arxiv.org/abs/1312.5602 +[Mnih et al., 2015]: https://www.nature.com/articles/nature14236 +[van Hasselt et al., 2015]: https://arxiv.org/abs/1509.06461 +[Schaul et al., 2015]: https://arxiv.org/abs/1511.05952 +[Hessel et al., 2017]: https://arxiv.org/abs/1710.02298 +[Horgan et al., 2018]: https://arxiv.org/abs/1803.00933 +[Sutton & Barto, 2018]: http://incompleteideas.net/book/the-book.html diff --git a/acme/acme/agents/tf/dqn/__init__.py b/acme/acme/agents/tf/dqn/__init__.py new file mode 100644 index 00000000..b1429a95 --- /dev/null +++ b/acme/acme/agents/tf/dqn/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementation of a deep Q-networks (DQN) agent.""" + +from acme.agents.tf.dqn.agent import DQN +from acme.agents.tf.dqn.agent_distributed import DistributedDQN +from acme.agents.tf.dqn.learning import DQNLearner diff --git a/acme/acme/agents/tf/dqn/agent.py b/acme/acme/agents/tf/dqn/agent.py new file mode 100644 index 00000000..77844c7f --- /dev/null +++ b/acme/acme/agents/tf/dqn/agent.py @@ -0,0 +1,177 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DQN agent implementation.""" + +import copy +from typing import Optional + +from acme import datasets +from acme import specs +from acme.adders import reverb as adders +from acme.agents import agent +from acme.agents.tf import actors +from acme.agents.tf.dqn import learning +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import loggers +import reverb +import sonnet as snt +import tensorflow as tf +import trfl + + +class DQN(agent.Agent): + """DQN agent. + + This implements a single-process DQN agent. This is a simple Q-learning + algorithm that inserts N-step transitions into a replay buffer, and + periodically updates its policy by sampling these transitions using + prioritization. + """ + + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + network: snt.Module, + batch_size: int = 256, + prefetch_size: int = 4, + target_update_period: int = 100, + samples_per_insert: float = 32.0, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + importance_sampling_exponent: float = 0.2, + priority_exponent: float = 0.6, + n_step: int = 5, + epsilon: Optional[tf.Variable] = None, + learning_rate: float = 1e-3, + discount: float = 0.99, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + checkpoint_subpath: str = '~/acme', + policy_network: Optional[snt.Module] = None, + max_gradient_norm: Optional[float] = None, + ): + """Initialize the agent. + + Args: + environment_spec: description of the actions, observations, etc. + network: the online Q network (the one being optimized) + batch_size: batch size for updates. + prefetch_size: size to prefetch from replay. + target_update_period: number of learner steps to perform before updating + the target networks. + samples_per_insert: number of samples to take from replay for every insert + that is made. + min_replay_size: minimum replay size before updating. This and all + following arguments are related to dataset construction and will be + ignored if a dataset argument is passed. + max_replay_size: maximum replay size. + importance_sampling_exponent: power to which importance weights are raised + before normalizing. + priority_exponent: exponent used in prioritized sampling. + n_step: number of steps to squash into a single transition. + epsilon: probability of taking a random action; ignored if a policy + network is given. + learning_rate: learning rate for the q-network update. + discount: discount to use for TD updates. + logger: logger object to be used by learner. + checkpoint: boolean indicating whether to checkpoint the learner. + checkpoint_subpath: string indicating where the agent should save + checkpoints and snapshots. + policy_network: if given, this will be used as the policy network. + Otherwise, an epsilon greedy policy using the online Q network will be + created. Policy network is used in the actor to sample actions. + max_gradient_norm: used for gradient clipping. + """ + + # Create a replay server to add data to. This uses no limiter behavior in + # order to allow the Agent interface to handle it. + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Prioritized(priority_exponent), + remover=reverb.selectors.Fifo(), + max_size=max_replay_size, + rate_limiter=reverb.rate_limiters.MinSize(1), + signature=adders.NStepTransitionAdder.signature(environment_spec)) + self._server = reverb.Server([replay_table], port=None) + + # The adder is used to insert observations into replay. + address = f'localhost:{self._server.port}' + adder = adders.NStepTransitionAdder( + client=reverb.Client(address), + n_step=n_step, + discount=discount) + + # The dataset provides an interface to sample from replay. + replay_client = reverb.Client(address) + dataset = datasets.make_reverb_dataset( + server_address=address, + batch_size=batch_size, + prefetch_size=prefetch_size) + + # Create epsilon greedy policy network by default. + if policy_network is None: + # Use constant 0.05 epsilon greedy policy by default. + if epsilon is None: + epsilon = tf.Variable(0.05, trainable=False) + policy_network = snt.Sequential([ + network, + lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), + ]) + + # Create a target network. + target_network = copy.deepcopy(network) + + # Ensure that we create the variables before proceeding (maybe not needed). + tf2_utils.create_variables(network, [environment_spec.observations]) + tf2_utils.create_variables(target_network, [environment_spec.observations]) + + # Create the actor which defines how we take actions. + actor = actors.FeedForwardActor(policy_network, adder) + + # The learner updates the parameters (and initializes them). + learner = learning.DQNLearner( + network=network, + target_network=target_network, + discount=discount, + importance_sampling_exponent=importance_sampling_exponent, + learning_rate=learning_rate, + target_update_period=target_update_period, + dataset=dataset, + replay_client=replay_client, + max_gradient_norm=max_gradient_norm, + logger=logger, + checkpoint=checkpoint, + save_directory=checkpoint_subpath) + + if checkpoint: + self._checkpointer = tf2_savers.Checkpointer( + directory=checkpoint_subpath, + objects_to_save=learner.state, + subdirectory='dqn_learner', + time_delta_minutes=60.) + else: + self._checkpointer = None + + super().__init__( + actor=actor, + learner=learner, + min_observations=max(batch_size, min_replay_size), + observations_per_step=float(batch_size) / samples_per_insert) + + def update(self): + super().update() + if self._checkpointer is not None: + self._checkpointer.save() diff --git a/acme/acme/agents/tf/dqn/agent_distributed.py b/acme/acme/agents/tf/dqn/agent_distributed.py new file mode 100644 index 00000000..0e22dd61 --- /dev/null +++ b/acme/acme/agents/tf/dqn/agent_distributed.py @@ -0,0 +1,272 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the DQN agent class.""" + +import copy +from typing import Callable, Optional + +import acme +from acme import datasets +from acme import specs +from acme.adders import reverb as adders +from acme.agents.tf import actors +from acme.agents.tf.dqn import learning +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.tf import variable_utils as tf2_variable_utils +from acme.utils import counting +from acme.utils import loggers +from acme.utils import lp_utils +import dm_env +import launchpad as lp +import numpy as np +import reverb +import sonnet as snt +import trfl + + +class DistributedDQN: + """Distributed DQN agent.""" + + def __init__( + self, + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: Callable[[specs.DiscreteArray], snt.Module], + num_actors: int, + num_caches: int = 1, + batch_size: int = 256, + prefetch_size: int = 4, + target_update_period: int = 100, + samples_per_insert: float = 32.0, + min_replay_size: int = 1000, + max_replay_size: int = 1_000_000, + importance_sampling_exponent: float = 0.2, + priority_exponent: float = 0.6, + n_step: int = 5, + learning_rate: float = 1e-3, + evaluator_epsilon: float = 0., + max_actor_steps: Optional[int] = None, + discount: float = 0.99, + environment_spec: Optional[specs.EnvironmentSpec] = None, + variable_update_period: int = 1000, + ): + + assert num_caches >= 1 + + if environment_spec is None: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + self._environment_factory = environment_factory + self._network_factory = network_factory + self._num_actors = num_actors + self._num_caches = num_caches + self._env_spec = environment_spec + self._batch_size = batch_size + self._prefetch_size = prefetch_size + self._target_update_period = target_update_period + self._samples_per_insert = samples_per_insert + self._min_replay_size = min_replay_size + self._max_replay_size = max_replay_size + self._importance_sampling_exponent = importance_sampling_exponent + self._priority_exponent = priority_exponent + self._n_step = n_step + self._learning_rate = learning_rate + self._evaluator_epsilon = evaluator_epsilon + self._max_actor_steps = max_actor_steps + self._discount = discount + self._variable_update_period = variable_update_period + + def replay(self): + """The replay storage.""" + if self._samples_per_insert: + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._min_replay_size, + samples_per_insert=self._samples_per_insert, + error_buffer=self._batch_size) + else: + limiter = reverb.rate_limiters.MinSize(self._min_replay_size) + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Prioritized(self._priority_exponent), + remover=reverb.selectors.Fifo(), + max_size=self._max_replay_size, + rate_limiter=limiter, + signature=adders.NStepTransitionAdder.signature(self._env_spec)) + return [replay_table] + + def counter(self): + """Creates the master counter process.""" + return tf2_savers.CheckpointingRunner( + counting.Counter(), time_delta_minutes=1, subdirectory='counter') + + def coordinator(self, counter: counting.Counter, max_actor_steps: int): + return lp_utils.StepsLimiter(counter, max_actor_steps) + + def learner(self, replay: reverb.Client, counter: counting.Counter): + """The Learning part of the agent.""" + + # Create the networks. + network = self._network_factory(self._env_spec.actions) + target_network = copy.deepcopy(network) + + tf2_utils.create_variables(network, [self._env_spec.observations]) + tf2_utils.create_variables(target_network, [self._env_spec.observations]) + + # The dataset object to learn from. + replay_client = reverb.Client(replay.server_address) + dataset = datasets.make_reverb_dataset( + server_address=replay.server_address, + batch_size=self._batch_size, + prefetch_size=self._prefetch_size) + + logger = loggers.make_default_logger('learner', steps_key='learner_steps') + + # Return the learning agent. + counter = counting.Counter(counter, 'learner') + + learner = learning.DQNLearner( + network=network, + target_network=target_network, + discount=self._discount, + importance_sampling_exponent=self._importance_sampling_exponent, + learning_rate=self._learning_rate, + target_update_period=self._target_update_period, + dataset=dataset, + replay_client=replay_client, + counter=counter, + logger=logger) + return tf2_savers.CheckpointingRunner( + learner, subdirectory='dqn_learner', time_delta_minutes=60) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + epsilon: float, + ) -> acme.EnvironmentLoop: + """The actor process.""" + environment = self._environment_factory(False) + network = self._network_factory(self._env_spec.actions) + + # Just inline the policy network here. + policy_network = snt.Sequential([ + network, + lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), + ]) + + tf2_utils.create_variables(policy_network, [self._env_spec.observations]) + variable_client = tf2_variable_utils.VariableClient( + client=variable_source, + variables={'policy': policy_network.trainable_variables}, + update_period=self._variable_update_period) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + # Component to add things into replay. + adder = adders.NStepTransitionAdder( + client=replay, + n_step=self._n_step, + discount=self._discount, + ) + + # Create the agent. + actor = actors.FeedForwardActor(policy_network, adder, variable_client) + + # Create the loop to connect environment and agent. + counter = counting.Counter(counter, 'actor') + logger = loggers.make_default_logger( + 'actor', save_data=False, steps_key='actor_steps') + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def evaluator( + self, + variable_source: acme.VariableSource, + counter: counting.Counter, + ): + """The evaluation process.""" + environment = self._environment_factory(True) + network = self._network_factory(self._env_spec.actions) + + # Just inline the policy network here. + policy_network = snt.Sequential([ + network, + lambda q: trfl.epsilon_greedy(q, self._evaluator_epsilon).sample(), + ]) + + tf2_utils.create_variables(policy_network, [self._env_spec.observations]) + + variable_client = tf2_variable_utils.VariableClient( + client=variable_source, + variables={'policy': policy_network.trainable_variables}, + update_period=self._variable_update_period) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + # Create the agent. + actor = actors.FeedForwardActor( + policy_network, variable_client=variable_client) + + # Create the run loop and return it. + logger = loggers.make_default_logger( + 'evaluator', steps_key='evaluator_steps') + counter = counting.Counter(counter, 'evaluator') + return acme.EnvironmentLoop( + environment, actor, counter=counter, logger=logger) + + def build(self, name='dqn'): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group('replay'): + replay = program.add_node(lp.ReverbNode(self.replay)) + + with program.group('counter'): + counter = program.add_node(lp.CourierNode(self.counter)) + + if self._max_actor_steps: + program.add_node( + lp.CourierNode(self.coordinator, counter, self._max_actor_steps)) + + with program.group('learner'): + learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) + + with program.group('evaluator'): + program.add_node(lp.CourierNode(self.evaluator, learner, counter)) + + # Generate an epsilon for each actor. + epsilons = np.flip(np.logspace(1, 8, self._num_actors, base=0.4), axis=0) + + with program.group('cacher'): + # Create a set of learner caches. + sources = [] + for _ in range(self._num_caches): + cacher = program.add_node( + lp.CacherNode( + learner, refresh_interval_ms=2000, stale_after_ms=4000)) + sources.append(cacher) + + with program.group('actor'): + # Add actors which pull round-robin from our variable sources. + for actor_id, epsilon in enumerate(epsilons): + source = sources[actor_id % len(sources)] + program.add_node( + lp.CourierNode(self.actor, replay, source, counter, epsilon)) + + return program diff --git a/acme/acme/agents/tf/dqn/agent_distributed_test.py b/acme/acme/agents/tf/dqn/agent_distributed_test.py new file mode 100644 index 00000000..d6b4788a --- /dev/null +++ b/acme/acme/agents/tf/dqn/agent_distributed_test.py @@ -0,0 +1,56 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test for the distributed agent.""" + +import acme +from acme.agents.tf import dqn +from acme.testing import fakes +from acme.tf import networks +import launchpad as lp + +from absl.testing import absltest + + +class DistributedAgentTest(absltest.TestCase): + """Simple integration/smoke test for the distributed agent.""" + + def test_atari(self): + """Tests that the agent can run for some steps without crashing.""" + env_factory = lambda x: fakes.fake_atari_wrapped() + net_factory = lambda spec: networks.DQNAtariNetwork(spec.num_values) + + agent = dqn.DistributedDQN( + environment_factory=env_factory, + network_factory=net_factory, + num_actors=2, + batch_size=32, + min_replay_size=32, + max_replay_size=1000, + ) + program = agent.build() + + (learner_node,) = program.groups['learner'] + learner_node.disable_run() + + lp.launch(program, launch_type='test_mt') + + learner: acme.Learner = learner_node.create_handle().dereference() + + for _ in range(5): + learner.step() + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/dqn/agent_test.py b/acme/acme/agents/tf/dqn/agent_test.py new file mode 100644 index 00000000..c4dbfe7c --- /dev/null +++ b/acme/acme/agents/tf/dqn/agent_test.py @@ -0,0 +1,60 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for DQN agent.""" + +import acme +from acme import specs +from acme.agents.tf import dqn +from acme.testing import fakes +import numpy as np +import sonnet as snt + +from absl.testing import absltest + + +def _make_network(action_spec: specs.DiscreteArray) -> snt.Module: + return snt.Sequential([ + snt.Flatten(), + snt.nets.MLP([50, 50, action_spec.num_values]), + ]) + + +class DQNTest(absltest.TestCase): + + def test_dqn(self): + # Create a fake environment to test with. + environment = fakes.DiscreteEnvironment( + num_actions=5, + num_observations=10, + obs_dtype=np.float32, + episode_length=10) + spec = specs.make_environment_spec(environment) + + # Construct the agent. + agent = dqn.DQN( + environment_spec=spec, + network=_make_network(spec.actions), + batch_size=10, + samples_per_insert=2, + min_replay_size=10) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=2) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/dqn/learning.py b/acme/acme/agents/tf/dqn/learning.py new file mode 100644 index 00000000..00ce17af --- /dev/null +++ b/acme/acme/agents/tf/dqn/learning.py @@ -0,0 +1,238 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DQN learner implementation.""" + +import time +from typing import Dict, List, Optional, Union + +import acme +from acme import types +from acme.adders import reverb as adders +from acme.tf import losses +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import numpy as np +import reverb +import sonnet as snt +import tensorflow as tf +import trfl + + +class DQNLearner(acme.Learner, tf2_savers.TFSaveable): + """DQN learner. + + This is the learning component of a DQN agent. It takes a dataset as input + and implements update functionality to learn from this dataset. Optionally + it takes a replay client as well to allow for updating of priorities. + """ + + def __init__( + self, + network: snt.Module, + target_network: snt.Module, + discount: float, + importance_sampling_exponent: float, + learning_rate: float, + target_update_period: int, + dataset: tf.data.Dataset, + max_abs_reward: Optional[float] = 1., + huber_loss_parameter: float = 1., + replay_client: Optional[Union[reverb.Client, reverb.TFClient]] = None, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + save_directory: str = '~/acme', + max_gradient_norm: Optional[float] = None, + ): + """Initializes the learner. + + Args: + network: the online Q network (the one being optimized) + target_network: the target Q critic (which lags behind the online net). + discount: discount to use for TD updates. + importance_sampling_exponent: power to which importance weights are raised + before normalizing. + learning_rate: learning rate for the q-network update. + target_update_period: number of learner steps to perform before updating + the target networks. + dataset: dataset to learn from, whether fixed or from a replay buffer (see + `acme.datasets.reverb.make_reverb_dataset` documentation). + max_abs_reward: Optional maximum absolute value for the reward. + huber_loss_parameter: Quadratic-linear boundary for Huber loss. + replay_client: client to replay to allow for updating priorities. + counter: Counter object for (potentially distributed) counting. + logger: Logger object for writing logs to. + checkpoint: boolean indicating whether to checkpoint the learner. + save_directory: string indicating where the learner should save + checkpoints and snapshots. + max_gradient_norm: used for gradient clipping. + """ + + # TODO(mwhoffman): stop allowing replay_client to be passed as a TFClient. + # This is just here for backwards compatability for agents which reuse this + # Learner and still pass a TFClient instance. + if isinstance(replay_client, reverb.TFClient): + # TODO(b/170419518): open source pytype does not understand this + # isinstance() check because it does not have a way of getting precise + # type information for pip-installed packages. + replay_client = reverb.Client(replay_client._server_address) # pytype: disable=attribute-error + + # Internalise agent components (replay buffer, networks, optimizer). + # TODO(b/155086959): Fix type stubs and remove. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + self._network = network + self._target_network = target_network + self._optimizer = snt.optimizers.Adam(learning_rate) + self._replay_client = replay_client + + # Make sure to initialize the optimizer so that its variables (e.g. the Adam + # moments) are included in the state returned by the learner (which can then + # be checkpointed and restored). + self._optimizer._initialize(network.trainable_variables) # pylint: disable= protected-access + + # Internalise the hyperparameters. + self._discount = discount + self._target_update_period = target_update_period + self._importance_sampling_exponent = importance_sampling_exponent + self._max_abs_reward = max_abs_reward + self._huber_loss_parameter = huber_loss_parameter + if max_gradient_norm is None: + max_gradient_norm = 1e10 # A very large number. Infinity results in NaNs. + self._max_gradient_norm = tf.convert_to_tensor(max_gradient_norm) + + # Learner state. + self._variables: List[List[tf.Tensor]] = [network.trainable_variables] + self._num_steps = tf.Variable(0, dtype=tf.int32) + + # Internalise logging/counting objects. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) + + # Create a snapshotter object. + if checkpoint: + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={'network': network}, + directory=save_directory, + time_delta_minutes=60.) + else: + self._snapshotter = None + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + @tf.function + def _step(self) -> Dict[str, tf.Tensor]: + """Do a step of SGD and update the priorities.""" + + # Pull out the data needed for updates/priorities. + inputs = next(self._iterator) + transitions: types.Transition = inputs.data + keys, probs = inputs.info[:2] + + with tf.GradientTape() as tape: + # Evaluate our networks. + q_tm1 = self._network(transitions.observation) + q_t_value = self._target_network(transitions.next_observation) + q_t_selector = self._network(transitions.next_observation) + + # The rewards and discounts have to have the same type as network values. + r_t = tf.cast(transitions.reward, q_tm1.dtype) + if self._max_abs_reward: + r_t = tf.clip_by_value(r_t, -self._max_abs_reward, self._max_abs_reward) + d_t = tf.cast(transitions.discount, q_tm1.dtype) * tf.cast( + self._discount, q_tm1.dtype) + + # Compute the loss. + _, extra = trfl.double_qlearning(q_tm1, transitions.action, r_t, d_t, + q_t_value, q_t_selector) + loss = losses.huber(extra.td_error, self._huber_loss_parameter) + + # Get the importance weights. + importance_weights = 1. / probs # [B] + importance_weights **= self._importance_sampling_exponent + importance_weights /= tf.reduce_max(importance_weights) + + # Reweight. + loss *= tf.cast(importance_weights, loss.dtype) # [B] + loss = tf.reduce_mean(loss, axis=[0]) # [] + + # Do a step of SGD. + gradients = tape.gradient(loss, self._network.trainable_variables) + gradients, _ = tf.clip_by_global_norm(gradients, self._max_gradient_norm) + self._optimizer.apply(gradients, self._network.trainable_variables) + + # Get the priorities that we'll use to update. + priorities = tf.abs(extra.td_error) + + # Periodically update the target network. + if tf.math.mod(self._num_steps, self._target_update_period) == 0: + for src, dest in zip(self._network.variables, + self._target_network.variables): + dest.assign(src) + self._num_steps.assign_add(1) + + # Report loss & statistics for logging. + fetches = { + 'loss': loss, + 'keys': keys, + 'priorities': priorities, + } + + return fetches + + def step(self): + # Do a batch of SGD. + result = self._step() + + # Get the keys and priorities. + keys = result.pop('keys') + priorities = result.pop('priorities') + + # Update the priorities in the replay buffer. + if self._replay_client: + self._replay_client.mutate_priorities( + table=adders.DEFAULT_PRIORITY_TABLE, + updates=dict(zip(keys.numpy(), priorities.numpy()))) + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + result.update(counts) + + # Snapshot and attempt to write logs. + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(result) + + def get_variables(self, names: List[str]) -> List[np.ndarray]: + return tf2_utils.to_numpy(self._variables) + + @property + def state(self): + """Returns the stateful parts of the learner for checkpointing.""" + return { + 'network': self._network, + 'target_network': self._target_network, + 'optimizer': self._optimizer, + 'num_steps': self._num_steps + } diff --git a/acme/acme/agents/tf/impala/README.md b/acme/acme/agents/tf/impala/README.md new file mode 100644 index 00000000..17fe33ec --- /dev/null +++ b/acme/acme/agents/tf/impala/README.md @@ -0,0 +1,7 @@ +# Importance-weighted actor-learner architecture (IMPALA) + +This agent is an implementation of the algorithm described in *IMPALA: Scalable +Distributed Deep-RL with Importance Weighted Actor-Learner Architectures* +([Espeholt et al., 2018]). + +[Espeholt et al., 2018]: https://arxiv.org/abs/1802.01561 diff --git a/acme/acme/agents/tf/impala/__init__.py b/acme/acme/agents/tf/impala/__init__.py new file mode 100644 index 00000000..eb87d173 --- /dev/null +++ b/acme/acme/agents/tf/impala/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Importance-weighted actor-learner architecture (IMPALA) agent.""" + +from acme.agents.tf.impala.acting import IMPALAActor +from acme.agents.tf.impala.agent import IMPALA +from acme.agents.tf.impala.agent_distributed import DistributedIMPALA +from acme.agents.tf.impala.learning import IMPALALearner diff --git a/acme/acme/agents/tf/impala/acting.py b/acme/acme/agents/tf/impala/acting.py new file mode 100644 index 00000000..748d9d50 --- /dev/null +++ b/acme/acme/agents/tf/impala/acting.py @@ -0,0 +1,96 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""IMPALA actor implementation.""" + +from typing import Optional + +from acme import adders +from acme import core +from acme import types +from acme.tf import utils as tf2_utils +from acme.tf import variable_utils as tf2_variable_utils + +import dm_env +import sonnet as snt +import tensorflow as tf +import tensorflow_probability as tfp + +tfd = tfp.distributions + + +class IMPALAActor(core.Actor): + """A recurrent actor.""" + + def __init__( + self, + network: snt.RNNCore, + adder: Optional[adders.Adder] = None, + variable_client: Optional[tf2_variable_utils.VariableClient] = None, + ): + + # Store these for later use. + self._adder = adder + self._variable_client = variable_client + self._network = network + + # TODO(b/152382420): Ideally we would call tf.function(network) instead but + # this results in an error when using acme RNN snapshots. + self._policy = tf.function(network.__call__) + + self._state = None + self._prev_state = None + self._prev_logits = None + + def select_action(self, observation: types.NestedArray) -> types.NestedArray: + # Add a dummy batch dimension and as a side effect convert numpy to TF. + batched_obs = tf2_utils.add_batch_dim(observation) + + if self._state is None: + self._state = self._network.initial_state(1) + + # Forward. + (logits, _), new_state = self._policy(batched_obs, self._state) + + self._prev_logits = logits + self._prev_state = self._state + self._state = new_state + + action = tfd.Categorical(logits).sample() + action = tf2_utils.to_numpy_squeeze(action) + + return action + + def observe_first(self, timestep: dm_env.TimeStep): + if self._adder: + self._adder.add_first(timestep) + + # Set the state to None so that we re-initialize at the next policy call. + self._state = None + + def observe( + self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + ): + if not self._adder: + return + + extras = {'logits': self._prev_logits, 'core_state': self._prev_state} + extras = tf2_utils.to_numpy_squeeze(extras) + self._adder.add(action, next_timestep, extras) + + def update(self, wait: bool = False): + if self._variable_client: + self._variable_client.update(wait) diff --git a/acme/acme/agents/tf/impala/agent.py b/acme/acme/agents/tf/impala/agent.py new file mode 100644 index 00000000..807c5823 --- /dev/null +++ b/acme/acme/agents/tf/impala/agent.py @@ -0,0 +1,123 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Importance weighted advantage actor-critic (IMPALA) agent implementation.""" + +from typing import Optional + +import acme +from acme import datasets +from acme import specs +from acme import types +from acme.adders import reverb as adders +from acme.agents.tf.impala import acting +from acme.agents.tf.impala import learning +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import dm_env +import numpy as np +import reverb +import sonnet as snt +import tensorflow as tf + + +class IMPALA(acme.Actor): + """IMPALA Agent.""" + + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + network: snt.RNNCore, + sequence_length: int, + sequence_period: int, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + discount: float = 0.99, + max_queue_size: int = 100000, + batch_size: int = 16, + learning_rate: float = 1e-3, + entropy_cost: float = 0.01, + baseline_cost: float = 0.5, + max_abs_reward: Optional[float] = None, + max_gradient_norm: Optional[float] = None, + ): + + num_actions = environment_spec.actions.num_values + self._logger = logger or loggers.TerminalLogger('agent') + + extra_spec = { + 'core_state': network.initial_state(1), + 'logits': tf.ones(shape=(1, num_actions), dtype=tf.float32) + } + # Remove batch dimensions. + extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) + + queue = reverb.Table.queue( + name=adders.DEFAULT_PRIORITY_TABLE, + max_size=max_queue_size, + signature=adders.SequenceAdder.signature( + environment_spec, + extras_spec=extra_spec, + sequence_length=sequence_length)) + self._server = reverb.Server([queue], port=None) + self._can_sample = lambda: queue.can_sample(batch_size) + address = f'localhost:{self._server.port}' + + # Component to add things into replay. + adder = adders.SequenceAdder( + client=reverb.Client(address), + period=sequence_period, + sequence_length=sequence_length, + ) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset( + server_address=address, + batch_size=batch_size) + + tf2_utils.create_variables(network, [environment_spec.observations]) + + self._actor = acting.IMPALAActor(network, adder) + self._learner = learning.IMPALALearner( + environment_spec=environment_spec, + network=network, + dataset=dataset, + counter=counter, + logger=logger, + discount=discount, + learning_rate=learning_rate, + entropy_cost=entropy_cost, + baseline_cost=baseline_cost, + max_gradient_norm=max_gradient_norm, + max_abs_reward=max_abs_reward, + ) + + def observe_first(self, timestep: dm_env.TimeStep): + self._actor.observe_first(timestep) + + def observe( + self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + ): + self._actor.observe(action, next_timestep) + + def update(self, wait: bool = False): + # Run a number of learner steps (usually gradient steps). + while self._can_sample(): + self._learner.step() + + def select_action(self, observation: np.ndarray) -> int: + return self._actor.select_action(observation) diff --git a/acme/acme/agents/tf/impala/agent_distributed.py b/acme/acme/agents/tf/impala/agent_distributed.py new file mode 100644 index 00000000..6002601b --- /dev/null +++ b/acme/acme/agents/tf/impala/agent_distributed.py @@ -0,0 +1,230 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the IMPALA Launchpad program.""" + +from typing import Callable, Optional + +import acme +from acme import datasets +from acme import specs +from acme.adders import reverb as adders +from acme.agents.tf.impala import acting +from acme.agents.tf.impala import learning +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.tf import variable_utils as tf2_variable_utils +from acme.utils import counting +from acme.utils import loggers +import dm_env +import launchpad as lp +import reverb +import sonnet as snt +import tensorflow as tf + + +class DistributedIMPALA: + """Program definition for IMPALA.""" + + def __init__(self, + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: Callable[[specs.DiscreteArray], snt.RNNCore], + num_actors: int, + sequence_length: int, + sequence_period: int, + environment_spec: Optional[specs.EnvironmentSpec] = None, + batch_size: int = 256, + prefetch_size: int = 4, + max_queue_size: int = 10_000, + learning_rate: float = 1e-3, + discount: float = 0.99, + entropy_cost: float = 0.01, + baseline_cost: float = 0.5, + max_abs_reward: Optional[float] = None, + max_gradient_norm: Optional[float] = None, + variable_update_period: int = 1000, + save_logs: bool = False): + + if environment_spec is None: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + self._environment_factory = environment_factory + self._network_factory = network_factory + self._environment_spec = environment_spec + self._num_actors = num_actors + self._batch_size = batch_size + self._prefetch_size = prefetch_size + self._sequence_length = sequence_length + self._max_queue_size = max_queue_size + self._sequence_period = sequence_period + self._discount = discount + self._learning_rate = learning_rate + self._entropy_cost = entropy_cost + self._baseline_cost = baseline_cost + self._max_abs_reward = max_abs_reward + self._max_gradient_norm = max_gradient_norm + self._variable_update_period = variable_update_period + self._save_logs = save_logs + + def queue(self): + """The queue.""" + num_actions = self._environment_spec.actions.num_values + network = self._network_factory(self._environment_spec.actions) + extra_spec = { + 'core_state': network.initial_state(1), + 'logits': tf.ones(shape=(1, num_actions), dtype=tf.float32) + } + # Remove batch dimensions. + extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) + signature = adders.SequenceAdder.signature( + self._environment_spec, + extra_spec, + sequence_length=self._sequence_length) + queue = reverb.Table.queue( + name=adders.DEFAULT_PRIORITY_TABLE, + max_size=self._max_queue_size, + signature=signature) + return [queue] + + def counter(self): + """Creates the master counter process.""" + return tf2_savers.CheckpointingRunner( + counting.Counter(), time_delta_minutes=1, subdirectory='counter') + + def learner(self, queue: reverb.Client, counter: counting.Counter): + """The Learning part of the agent.""" + # Use architect and create the environment. + # Create the networks. + network = self._network_factory(self._environment_spec.actions) + tf2_utils.create_variables(network, [self._environment_spec.observations]) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset( + server_address=queue.server_address, + batch_size=self._batch_size, + prefetch_size=self._prefetch_size) + + logger = loggers.make_default_logger('learner', steps_key='learner_steps') + counter = counting.Counter(counter, 'learner') + + # Return the learning agent. + learner = learning.IMPALALearner( + environment_spec=self._environment_spec, + network=network, + dataset=dataset, + discount=self._discount, + learning_rate=self._learning_rate, + entropy_cost=self._entropy_cost, + baseline_cost=self._baseline_cost, + max_abs_reward=self._max_abs_reward, + max_gradient_norm=self._max_gradient_norm, + counter=counter, + logger=logger, + ) + + return tf2_savers.CheckpointingRunner(learner, + time_delta_minutes=5, + subdirectory='impala_learner') + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + ) -> acme.EnvironmentLoop: + """The actor process.""" + environment = self._environment_factory(False) + network = self._network_factory(self._environment_spec.actions) + tf2_utils.create_variables(network, [self._environment_spec.observations]) + + # Component to add things into the queue. + adder = adders.SequenceAdder( + client=replay, + period=self._sequence_period, + sequence_length=self._sequence_length) + + variable_client = tf2_variable_utils.VariableClient( + client=variable_source, + variables={'policy': network.variables}, + update_period=self._variable_update_period) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + # Create the agent. + actor = acting.IMPALAActor( + network=network, + variable_client=variable_client, + adder=adder) + + counter = counting.Counter(counter, 'actor') + logger = loggers.make_default_logger( + 'actor', save_data=False, steps_key='actor_steps') + + # Create the loop to connect environment and agent. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def evaluator(self, variable_source: acme.VariableSource, + counter: counting.Counter): + """The evaluation process.""" + environment = self._environment_factory(True) + network = self._network_factory(self._environment_spec.actions) + tf2_utils.create_variables(network, [self._environment_spec.observations]) + + variable_client = tf2_variable_utils.VariableClient( + client=variable_source, + variables={'policy': network.variables}, + update_period=self._variable_update_period) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + # Create the agent. + actor = acting.IMPALAActor( + network=network, variable_client=variable_client) + + # Create the run loop and return it. + logger = loggers.make_default_logger( + 'evaluator', steps_key='evaluator_steps') + counter = counting.Counter(counter, 'evaluator') + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def build(self, name='impala'): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group('replay'): + queue = program.add_node(lp.ReverbNode(self.queue)) + + with program.group('counter'): + counter = program.add_node(lp.CourierNode(self.counter)) + + with program.group('learner'): + learner = program.add_node( + lp.CourierNode(self.learner, queue, counter)) + + with program.group('evaluator'): + program.add_node(lp.CourierNode(self.evaluator, learner, counter)) + + with program.group('cacher'): + cacher = program.add_node( + lp.CacherNode(learner, refresh_interval_ms=2000, stale_after_ms=4000)) + + with program.group('actor'): + for _ in range(self._num_actors): + program.add_node(lp.CourierNode(self.actor, queue, cacher, counter)) + + return program diff --git a/acme/acme/agents/tf/impala/agent_distributed_test.py b/acme/acme/agents/tf/impala/agent_distributed_test.py new file mode 100644 index 00000000..04e59d2f --- /dev/null +++ b/acme/acme/agents/tf/impala/agent_distributed_test.py @@ -0,0 +1,56 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test for the distributed agent.""" + +import acme +from acme.agents.tf import impala +from acme.testing import fakes +from acme.tf import networks +import launchpad as lp + +from absl.testing import absltest + + +class DistributedAgentTest(absltest.TestCase): + """Simple integration/smoke test for the distributed agent.""" + + def test_atari(self): + """Tests that the agent can run for some steps without crashing.""" + env_factory = lambda x: fakes.fake_atari_wrapped(oar_wrapper=True) + net_factory = lambda spec: networks.IMPALAAtariNetwork(spec.num_values) + + agent = impala.DistributedIMPALA( + environment_factory=env_factory, + network_factory=net_factory, + num_actors=2, + batch_size=32, + sequence_length=5, + sequence_period=1, + ) + program = agent.build() + + (learner_node,) = program.groups['learner'] + learner_node.disable_run() + + lp.launch(program, launch_type='test_mt') + + learner: acme.Learner = learner_node.create_handle().dereference() + + for _ in range(5): + learner.step() + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/impala/agent_test.py b/acme/acme/agents/tf/impala/agent_test.py new file mode 100644 index 00000000..71dd5740 --- /dev/null +++ b/acme/acme/agents/tf/impala/agent_test.py @@ -0,0 +1,66 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for IMPALA agent.""" + +import acme +from acme import specs +from acme.agents.tf import impala +from acme.testing import fakes +from acme.tf import networks +import numpy as np +import sonnet as snt + +from absl.testing import absltest + + +def _make_network(action_spec: specs.DiscreteArray) -> snt.RNNCore: + return snt.DeepRNN([ + snt.Flatten(), + snt.LSTM(20), + snt.nets.MLP([50, 50]), + networks.PolicyValueHead(action_spec.num_values), + ]) + + +class IMPALATest(absltest.TestCase): + + # TODO(b/200509080): This test case is timing out. + @absltest.SkipTest + def test_impala(self): + # Create a fake environment to test with. + environment = fakes.DiscreteEnvironment( + num_actions=5, + num_observations=10, + obs_dtype=np.float32, + episode_length=10) + spec = specs.make_environment_spec(environment) + + # Construct the agent. + agent = impala.IMPALA( + environment_spec=spec, + network=_make_network(spec.actions), + sequence_length=3, + sequence_period=3, + batch_size=6, + ) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=20) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/impala/learning.py b/acme/acme/agents/tf/impala/learning.py new file mode 100644 index 00000000..b2ea2614 --- /dev/null +++ b/acme/acme/agents/tf/impala/learning.py @@ -0,0 +1,190 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Learner for the IMPALA actor-critic agent.""" + +import time +from typing import Dict, List, Mapping, Optional + +import acme +from acme import specs +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import numpy as np +import reverb +import sonnet as snt +import tensorflow as tf +import tensorflow_probability as tfp +import tree +import trfl + +tfd = tfp.distributions + + +class IMPALALearner(acme.Learner, tf2_savers.TFSaveable): + """Learner for an importanced-weighted advantage actor-critic.""" + + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + network: snt.RNNCore, + dataset: tf.data.Dataset, + learning_rate: float, + discount: float = 0.99, + entropy_cost: float = 0., + baseline_cost: float = 1., + max_abs_reward: Optional[float] = None, + max_gradient_norm: Optional[float] = None, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + ): + + # Internalise, optimizer, and dataset. + self._env_spec = environment_spec + self._optimizer = snt.optimizers.Adam(learning_rate=learning_rate) + self._network = network + self._variables = network.variables + # TODO(b/155086959): Fix type stubs and remove. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + + # Hyperparameters. + self._discount = discount + self._entropy_cost = entropy_cost + self._baseline_cost = baseline_cost + + # Set up reward/gradient clipping. + if max_abs_reward is None: + max_abs_reward = np.inf + if max_gradient_norm is None: + max_gradient_norm = 1e10 # A very large number. Infinity results in NaNs. + self._max_abs_reward = tf.convert_to_tensor(max_abs_reward) + self._max_gradient_norm = tf.convert_to_tensor(max_gradient_norm) + + # Set up logging/counting. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) + + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={'network': network}, time_delta_minutes=60.) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + @property + def state(self) -> Mapping[str, tf2_savers.Checkpointable]: + """Returns the stateful objects for checkpointing.""" + return { + 'network': self._network, + 'optimizer': self._optimizer, + } + + @tf.function + def _step(self) -> Dict[str, tf.Tensor]: + """Does an SGD step on a batch of sequences.""" + + # Retrieve a batch of data from replay. + inputs: reverb.ReplaySample = next(self._iterator) + data = tf2_utils.batch_to_sequence(inputs.data) + observations, actions, rewards, discounts, extra = (data.observation, + data.action, + data.reward, + data.discount, + data.extras) + core_state = tree.map_structure(lambda s: s[0], extra['core_state']) + + # + actions = actions[:-1] # [T-1] + rewards = rewards[:-1] # [T-1] + discounts = discounts[:-1] # [T-1] + + with tf.GradientTape() as tape: + # Unroll current policy over observations. + (logits, values), _ = snt.static_unroll(self._network, observations, + core_state) + + # Compute importance sampling weights: current policy / behavior policy. + behaviour_logits = extra['logits'] + pi_behaviour = tfd.Categorical(logits=behaviour_logits[:-1]) + pi_target = tfd.Categorical(logits=logits[:-1]) + log_rhos = pi_target.log_prob(actions) - pi_behaviour.log_prob(actions) + + # Optionally clip rewards. + rewards = tf.clip_by_value(rewards, + tf.cast(-self._max_abs_reward, rewards.dtype), + tf.cast(self._max_abs_reward, rewards.dtype)) + + # Critic loss. + vtrace_returns = trfl.vtrace_from_importance_weights( + log_rhos=tf.cast(log_rhos, tf.float32), + discounts=tf.cast(self._discount * discounts, tf.float32), + rewards=tf.cast(rewards, tf.float32), + values=tf.cast(values[:-1], tf.float32), + bootstrap_value=values[-1], + ) + critic_loss = tf.square(vtrace_returns.vs - values[:-1]) + + # Policy-gradient loss. + policy_gradient_loss = trfl.policy_gradient( + policies=pi_target, + actions=actions, + action_values=vtrace_returns.pg_advantages, + ) + + # Entropy regulariser. + entropy_loss = trfl.policy_entropy_loss(pi_target).loss + + # Combine weighted sum of actor & critic losses. + loss = tf.reduce_mean(policy_gradient_loss + + self._baseline_cost * critic_loss + + self._entropy_cost * entropy_loss) + + # Compute gradients and optionally apply clipping. + gradients = tape.gradient(loss, self._network.trainable_variables) + gradients, _ = tf.clip_by_global_norm(gradients, self._max_gradient_norm) + self._optimizer.apply(gradients, self._network.trainable_variables) + + metrics = { + 'loss': loss, + 'critic_loss': tf.reduce_mean(critic_loss), + 'entropy_loss': tf.reduce_mean(entropy_loss), + 'policy_gradient_loss': tf.reduce_mean(policy_gradient_loss), + } + + return metrics + + def step(self): + """Does a step of SGD and logs the results.""" + + # Do a batch of SGD. + results = self._step() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + results.update(counts) + + # Snapshot and attempt to write logs. + self._snapshotter.save() + self._logger.write(results) + + def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: + return [tf2_utils.to_numpy(self._variables)] diff --git a/acme/acme/agents/tf/iqn/README.md b/acme/acme/agents/tf/iqn/README.md new file mode 100644 index 00000000..b2b4247e --- /dev/null +++ b/acme/acme/agents/tf/iqn/README.md @@ -0,0 +1,6 @@ +# Implicit Quantile Networks for Distributional RL (IQN) + +This folder contains an implementation of the IQN algorithm introduced in +([Dabney et al., 2018]). + +[Dabney et al., 2018]: https://arxiv.org/abs/1806.06923 diff --git a/acme/acme/agents/tf/iqn/__init__.py b/acme/acme/agents/tf/iqn/__init__.py new file mode 100644 index 00000000..ccac6a13 --- /dev/null +++ b/acme/acme/agents/tf/iqn/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementation of an IQN agent.""" + +from acme.agents.tf.iqn.learning import IQNLearner diff --git a/acme/acme/agents/tf/iqn/learning.py b/acme/acme/agents/tf/iqn/learning.py new file mode 100644 index 00000000..aba47c61 --- /dev/null +++ b/acme/acme/agents/tf/iqn/learning.py @@ -0,0 +1,266 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implicit Quantile Network (IQN) learner implementation.""" + +from typing import Dict, List, Optional, Tuple + +from acme import core +from acme import types +from acme.adders import reverb as adders +from acme.tf import losses +from acme.tf import networks +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import numpy as np +import reverb +import sonnet as snt +import tensorflow as tf + + +class IQNLearner(core.Learner, tf2_savers.TFSaveable): + """Distributional DQN learner.""" + + def __init__( + self, + network: networks.IQNNetwork, + target_network: snt.Module, + discount: float, + importance_sampling_exponent: float, + learning_rate: float, + target_update_period: int, + dataset: tf.data.Dataset, + huber_loss_parameter: float = 1., + replay_client: Optional[reverb.TFClient] = None, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + ): + """Initializes the learner. + + Args: + network: the online Q network (the one being optimized) that outputs + (q_values, q_logits, atoms). + target_network: the target Q critic (which lags behind the online net). + discount: discount to use for TD updates. + importance_sampling_exponent: power to which importance weights are raised + before normalizing. + learning_rate: learning rate for the q-network update. + target_update_period: number of learner steps to perform before updating + the target networks. + dataset: dataset to learn from, whether fixed or from a replay buffer + (see `acme.datasets.reverb.make_reverb_dataset` documentation). + huber_loss_parameter: Quadratic-linear boundary for Huber loss. + replay_client: client to replay to allow for updating priorities. + counter: Counter object for (potentially distributed) counting. + logger: Logger object for writing logs to. + checkpoint: boolean indicating whether to checkpoint the learner or not. + """ + + # Internalise agent components (replay buffer, networks, optimizer). + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + self._network = network + self._target_network = target_network + self._optimizer = snt.optimizers.Adam(learning_rate) + self._replay_client = replay_client + + # Internalise the hyperparameters. + self._discount = discount + self._target_update_period = target_update_period + self._importance_sampling_exponent = importance_sampling_exponent + self._huber_loss_parameter = huber_loss_parameter + + # Learner state. + self._variables: List[List[tf.Tensor]] = [network.trainable_variables] + self._num_steps = tf.Variable(0, dtype=tf.int32) + + # Internalise logging/counting objects. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) + + # Create a snapshotter object. + if checkpoint: + self._checkpointer = tf2_savers.Checkpointer( + time_delta_minutes=5, + objects_to_save={ + 'network': self._network, + 'target_network': self._target_network, + 'optimizer': self._optimizer, + 'num_steps': self._num_steps + }) + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={'network': network}, time_delta_minutes=60.) + else: + self._checkpointer = None + self._snapshotter = None + + @tf.function + def _step(self) -> Dict[str, tf.Tensor]: + """Do a step of SGD and update the priorities.""" + + # Pull out the data needed for updates/priorities. + inputs = next(self._iterator) + transitions: types.Transition = inputs.data + keys, probs, *_ = inputs.info + + with tf.GradientTape() as tape: + loss, fetches = self._loss_and_fetches(transitions.observation, + transitions.action, + transitions.reward, + transitions.discount, + transitions.next_observation) + + # Get the importance weights. + importance_weights = 1. / probs # [B] + importance_weights **= self._importance_sampling_exponent + importance_weights /= tf.reduce_max(importance_weights) + + # Reweight. + loss *= tf.cast(importance_weights, loss.dtype) # [B] + loss = tf.reduce_mean(loss, axis=[0]) # [] + + # Do a step of SGD. + gradients = tape.gradient(loss, self._network.trainable_variables) + self._optimizer.apply(gradients, self._network.trainable_variables) + + # Update the priorities in the replay buffer. + if self._replay_client: + priorities = tf.clip_by_value(tf.abs(loss), -100, 100) + priorities = tf.cast(priorities, tf.float64) + self._replay_client.update_priorities( + table=adders.DEFAULT_PRIORITY_TABLE, keys=keys, priorities=priorities) + + # Periodically update the target network. + if tf.math.mod(self._num_steps, self._target_update_period) == 0: + for src, dest in zip(self._network.variables, + self._target_network.variables): + dest.assign(src) + self._num_steps.assign_add(1) + + # Report gradient norms. + fetches.update( + loss=loss, + gradient_norm=tf.linalg.global_norm(gradients)) + return fetches + + def step(self): + # Do a batch of SGD. + result = self._step() + + # Update our counts and record it. + counts = self._counter.increment(steps=1) + result.update(counts) + + # Checkpoint and attempt to write logs. + if self._checkpointer is not None: + self._checkpointer.save() + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(result) + + def get_variables(self, names: List[str]) -> List[np.ndarray]: + return tf2_utils.to_numpy(self._variables) + + def _loss_and_fetches( + self, + o_tm1: tf.Tensor, + a_tm1: tf.Tensor, + r_t: tf.Tensor, + d_t: tf.Tensor, + o_t: tf.Tensor, + ) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]: + # Evaluate our networks. + _, dist_tm1, tau = self._network(o_tm1) + q_tm1 = _index_embs_with_actions(dist_tm1, a_tm1) + + q_selector, _, _ = self._target_network(o_t) + a_t = tf.argmax(q_selector, axis=1) + + _, dist_t, _ = self._target_network(o_t) + q_t = _index_embs_with_actions(dist_t, a_t) + + q_tm1 = losses.QuantileDistribution(values=q_tm1, + logits=tf.zeros_like(q_tm1)) + q_t = losses.QuantileDistribution(values=q_t, logits=tf.zeros_like(q_t)) + + # The rewards and discounts have to have the same type as network values. + r_t = tf.cast(r_t, tf.float32) + r_t = tf.clip_by_value(r_t, -1., 1.) + d_t = tf.cast(d_t, tf.float32) * tf.cast(self._discount, tf.float32) + + # Compute the loss. + loss_module = losses.NonUniformQuantileRegression( + self._huber_loss_parameter) + loss = loss_module(q_tm1, r_t, d_t, q_t, tau) + + # Compute statistics of the Q-values for logging. + max_q = tf.reduce_max(q_t.values) + min_q = tf.reduce_min(q_t.values) + mean_q, var_q = tf.nn.moments(q_t.values, [0, 1]) + fetches = { + 'max_q': max_q, + 'mean_q': mean_q, + 'min_q': min_q, + 'var_q': var_q, + } + + return loss, fetches + + @property + def state(self): + """Returns the stateful parts of the learner for checkpointing.""" + return { + 'network': self._network, + 'target_network': self._target_network, + 'optimizer': self._optimizer, + 'num_steps': self._num_steps + } + + +def _index_embs_with_actions( + embeddings: tf.Tensor, + actions: tf.Tensor, +) -> tf.Tensor: + """Slice an embedding Tensor with action indices. + + Take embeddings of the form [batch_size, num_actions, embed_dim] + and actions of the form [batch_size], and return the sliced embeddings + like embeddings[:, actions, :]. Doing this my way because the comments in + the official op are scary. + + Args: + embeddings: Tensor of embeddings to index. + actions: int Tensor to use as index into embeddings + + Returns: + Tensor of embeddings indexed by actions + """ + batch_size, num_actions, _ = embeddings.shape.as_list() + + # Values are the 'values' in a sparse tensor we will be setting + act_indx = tf.cast(actions, tf.int64)[:, None] + values = tf.ones([tf.size(actions)], dtype=tf.bool) + + # Create a range for each index into the batch + act_range = tf.range(0, batch_size, dtype=tf.int64)[:, None] + # Combine this into coordinates with the action indices + indices = tf.concat([act_range, act_indx], 1) + + actions_mask = tf.SparseTensor(indices, values, [batch_size, num_actions]) + actions_mask = tf.stop_gradient( + tf.sparse.to_dense(actions_mask, default_value=False)) + sliced_emb = tf.boolean_mask(embeddings, actions_mask) + return sliced_emb diff --git a/acme/acme/agents/tf/iqn/learning_test.py b/acme/acme/agents/tf/iqn/learning_test.py new file mode 100644 index 00000000..9e2bdae6 --- /dev/null +++ b/acme/acme/agents/tf/iqn/learning_test.py @@ -0,0 +1,89 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for IQN learner.""" + +import copy + +from acme import specs +from acme.agents.tf import iqn +from acme.testing import fakes +from acme.tf import networks +from acme.tf import utils as tf2_utils +from acme.utils import counting +import numpy as np +import sonnet as snt + +from absl.testing import absltest + + +def _make_torso_network(num_outputs: int) -> snt.Module: + """Create torso network (outputs intermediate representation).""" + return snt.Sequential([ + snt.Flatten(), + snt.nets.MLP([num_outputs]) + ]) + + +def _make_head_network(num_outputs: int) -> snt.Module: + """Create head network (outputs Q-values).""" + return snt.nets.MLP([num_outputs]) + + +class IQNLearnerTest(absltest.TestCase): + + def test_full_learner(self): + # Create dataset. + environment = fakes.DiscreteEnvironment( + num_actions=5, + num_observations=10, + obs_dtype=np.float32, + episode_length=10) + spec = specs.make_environment_spec(environment) + dataset = fakes.transition_dataset(environment).batch( + 2, drop_remainder=True) + + # Build network. + network = networks.IQNNetwork( + torso=_make_torso_network(num_outputs=2), + head=_make_head_network(num_outputs=spec.actions.num_values), + latent_dim=2, + num_quantile_samples=1) + tf2_utils.create_variables(network, [spec.observations]) + + # Build learner. + counter = counting.Counter() + learner = iqn.IQNLearner( + network=network, + target_network=copy.deepcopy(network), + dataset=dataset, + learning_rate=1e-4, + discount=0.99, + importance_sampling_exponent=0.2, + target_update_period=1, + counter=counter) + + # Run a learner step. + learner.step() + + # Check counts from IQN learner. + counts = counter.get_counts() + self.assertEqual(1, counts['steps']) + + # Check learner state. + self.assertEqual(1, learner.state['num_steps'].numpy()) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/mcts/README.md b/acme/acme/agents/tf/mcts/README.md new file mode 100644 index 00000000..e4ffb185 --- /dev/null +++ b/acme/acme/agents/tf/mcts/README.md @@ -0,0 +1,12 @@ +# Monte-Carlo Tree Search (MCTS) + +This agent implements planning with a simulator (learned or otherwise), with +search guided by policy and value networks. This can be thought of as a +scaled-down and simplified version of the AlphaZero algorithm +([Silver et al., 2018]). + +The algorithm is agnostic to the choice of environment model -- this can be an +exact simulator (as in AlphaZero), or a learned transition model; we provide +examples of both cases. + +[Silver et al., 2018]: https://deepmind.com/blog/article/alphazero-shedding-new-light-grand-games-chess-shogi-and-go diff --git a/acme/acme/agents/tf/mcts/__init__.py b/acme/acme/agents/tf/mcts/__init__.py new file mode 100644 index 00000000..5438e923 --- /dev/null +++ b/acme/acme/agents/tf/mcts/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Monte-Carlo tree search (MCTS) agent.""" + +from acme.agents.tf.mcts.agent import MCTS +from acme.agents.tf.mcts.agent_distributed import DistributedMCTS diff --git a/acme/acme/agents/tf/mcts/acting.py b/acme/acme/agents/tf/mcts/acting.py new file mode 100644 index 00000000..887d7c25 --- /dev/null +++ b/acme/acme/agents/tf/mcts/acting.py @@ -0,0 +1,120 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A MCTS actor.""" + +from typing import Optional, Tuple + +import acme +from acme import adders +from acme import specs +from acme.agents.tf.mcts import models +from acme.agents.tf.mcts import search +from acme.agents.tf.mcts import types +from acme.tf import variable_utils as tf2_variable_utils + +import dm_env +import numpy as np +from scipy import special +import sonnet as snt +import tensorflow as tf + + +class MCTSActor(acme.Actor): + """Executes a policy- and value-network guided MCTS search.""" + + _prev_timestep: dm_env.TimeStep + + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + model: models.Model, + network: snt.Module, + discount: float, + num_simulations: int, + adder: Optional[adders.Adder] = None, + variable_client: Optional[tf2_variable_utils.VariableClient] = None, + ): + + # Internalize components: model, network, data sink and variable source. + self._model = model + self._network = tf.function(network) + self._variable_client = variable_client + self._adder = adder + + # Internalize hyperparameters. + self._num_actions = environment_spec.actions.num_values + self._num_simulations = num_simulations + self._actions = list(range(self._num_actions)) + self._discount = discount + + # We need to save the policy so as to add it to replay on the next step. + self._probs = np.ones( + shape=(self._num_actions,), dtype=np.float32) / self._num_actions + + def _forward( + self, observation: types.Observation) -> Tuple[types.Probs, types.Value]: + """Performs a forward pass of the policy-value network.""" + logits, value = self._network(tf.expand_dims(observation, axis=0)) + + # Convert to numpy & take softmax. + logits = logits.numpy().squeeze(axis=0) + value = value.numpy().item() + probs = special.softmax(logits) + + return probs, value + + def select_action(self, observation: types.Observation) -> types.Action: + """Computes the agent's policy via MCTS.""" + if self._model.needs_reset: + self._model.reset(observation) + + # Compute a fresh MCTS plan. + root = search.mcts( + observation, + model=self._model, + search_policy=search.puct, + evaluation=self._forward, + num_simulations=self._num_simulations, + num_actions=self._num_actions, + discount=self._discount, + ) + + # The agent's policy is softmax w.r.t. the *visit counts* as in AlphaZero. + probs = search.visit_count_policy(root) + action = np.int32(np.random.choice(self._actions, p=probs)) + + # Save the policy probs so that we can add them to replay in `observe()`. + self._probs = probs.astype(np.float32) + + return action + + def update(self, wait: bool = False): + """Fetches the latest variables from the variable source, if needed.""" + if self._variable_client: + self._variable_client.update(wait) + + def observe_first(self, timestep: dm_env.TimeStep): + self._prev_timestep = timestep + if self._adder: + self._adder.add_first(timestep) + + def observe(self, action: types.Action, next_timestep: dm_env.TimeStep): + """Updates the agent's internal model and adds the transition to replay.""" + self._model.update(self._prev_timestep, action, next_timestep) + + self._prev_timestep = next_timestep + + if self._adder: + self._adder.add(action, next_timestep, extras={'pi': self._probs}) diff --git a/acme/acme/agents/tf/mcts/agent.py b/acme/acme/agents/tf/mcts/agent.py new file mode 100644 index 00000000..c7b58bc0 --- /dev/null +++ b/acme/acme/agents/tf/mcts/agent.py @@ -0,0 +1,99 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A single-process MCTS agent.""" + +from acme import datasets +from acme import specs +from acme.adders import reverb as adders +from acme.agents import agent +from acme.agents.tf.mcts import acting +from acme.agents.tf.mcts import learning +from acme.agents.tf.mcts import models +from acme.tf import utils as tf2_utils + +import numpy as np +import reverb +import sonnet as snt + + +class MCTS(agent.Agent): + """A single-process MCTS agent.""" + + def __init__( + self, + network: snt.Module, + model: models.Model, + optimizer: snt.Optimizer, + n_step: int, + discount: float, + replay_capacity: int, + num_simulations: int, + environment_spec: specs.EnvironmentSpec, + batch_size: int, + ): + + extra_spec = { + 'pi': + specs.Array( + shape=(environment_spec.actions.num_values,), dtype=np.float32) + } + # Create a replay server for storing transitions. + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=replay_capacity, + rate_limiter=reverb.rate_limiters.MinSize(1), + signature=adders.NStepTransitionAdder.signature( + environment_spec, extra_spec)) + self._server = reverb.Server([replay_table], port=None) + + # The adder is used to insert observations into replay. + address = f'localhost:{self._server.port}' + adder = adders.NStepTransitionAdder( + client=reverb.Client(address), + n_step=n_step, + discount=discount) + + # The dataset provides an interface to sample from replay. + dataset = datasets.make_reverb_dataset(server_address=address) + dataset = dataset.batch(batch_size, drop_remainder=True) + + tf2_utils.create_variables(network, [environment_spec.observations]) + + # Now create the agent components: actor & learner. + actor = acting.MCTSActor( + environment_spec=environment_spec, + model=model, + network=network, + discount=discount, + adder=adder, + num_simulations=num_simulations, + ) + + learner = learning.AZLearner( + network=network, + optimizer=optimizer, + dataset=dataset, + discount=discount, + ) + + # The parent class combines these together into one 'agent'. + super().__init__( + actor=actor, + learner=learner, + min_observations=10, + observations_per_step=1, + ) diff --git a/acme/acme/agents/tf/mcts/agent_distributed.py b/acme/acme/agents/tf/mcts/agent_distributed.py new file mode 100644 index 00000000..b2eae72a --- /dev/null +++ b/acme/acme/agents/tf/mcts/agent_distributed.py @@ -0,0 +1,233 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the distributed MCTS agent topology via Launchpad.""" + +from typing import Callable, Optional + +import acme +from acme import datasets +from acme import specs +from acme.adders import reverb as adders +from acme.agents.tf.mcts import acting +from acme.agents.tf.mcts import learning +from acme.agents.tf.mcts import models +from acme.tf import utils as tf2_utils +from acme.tf import variable_utils as tf2_variable_utils +from acme.utils import counting +from acme.utils import loggers +import dm_env +import launchpad as lp +import reverb +import sonnet as snt + + +class DistributedMCTS: + """Distributed MCTS agent.""" + + def __init__( + self, + environment_factory: Callable[[], dm_env.Environment], + network_factory: Callable[[specs.DiscreteArray], snt.Module], + model_factory: Callable[[specs.EnvironmentSpec], models.Model], + num_actors: int, + num_simulations: int = 50, + batch_size: int = 256, + prefetch_size: int = 4, + target_update_period: int = 100, + samples_per_insert: float = 32.0, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + importance_sampling_exponent: float = 0.2, + priority_exponent: float = 0.6, + n_step: int = 5, + learning_rate: float = 1e-3, + discount: float = 0.99, + environment_spec: Optional[specs.EnvironmentSpec] = None, + save_logs: bool = False, + variable_update_period: int = 1000, + ): + + if environment_spec is None: + environment_spec = specs.make_environment_spec(environment_factory()) + + # These 'factories' create the relevant components on the workers. + self._environment_factory = environment_factory + self._network_factory = network_factory + self._model_factory = model_factory + + # Internalize hyperparameters. + self._num_actors = num_actors + self._num_simulations = num_simulations + self._env_spec = environment_spec + self._batch_size = batch_size + self._prefetch_size = prefetch_size + self._target_update_period = target_update_period + self._samples_per_insert = samples_per_insert + self._min_replay_size = min_replay_size + self._max_replay_size = max_replay_size + self._importance_sampling_exponent = importance_sampling_exponent + self._priority_exponent = priority_exponent + self._n_step = n_step + self._learning_rate = learning_rate + self._discount = discount + self._save_logs = save_logs + self._variable_update_period = variable_update_period + + def replay(self): + """The replay storage worker.""" + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._min_replay_size, + samples_per_insert=self._samples_per_insert, + error_buffer=self._batch_size) + extra_spec = { + 'pi': + specs.Array( + shape=(self._env_spec.actions.num_values,), dtype='float32') + } + signature = adders.NStepTransitionAdder.signature(self._env_spec, + extra_spec) + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._max_replay_size, + rate_limiter=limiter, + signature=signature) + return [replay_table] + + def learner(self, replay: reverb.Client, counter: counting.Counter): + """The learning part of the agent.""" + # Create the networks. + network = self._network_factory(self._env_spec.actions) + + tf2_utils.create_variables(network, [self._env_spec.observations]) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset( + server_address=replay.server_address, + batch_size=self._batch_size, + prefetch_size=self._prefetch_size) + + # Create the optimizer. + optimizer = snt.optimizers.Adam(self._learning_rate) + + # Return the learning agent. + return learning.AZLearner( + network=network, + discount=self._discount, + dataset=dataset, + optimizer=optimizer, + counter=counter, + ) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + ) -> acme.EnvironmentLoop: + """The actor process.""" + + # Build environment, model, network. + environment = self._environment_factory() + network = self._network_factory(self._env_spec.actions) + model = self._model_factory(self._env_spec) + + # Create variable client for communicating with the learner. + tf2_utils.create_variables(network, [self._env_spec.observations]) + variable_client = tf2_variable_utils.VariableClient( + client=variable_source, + variables={'network': network.trainable_variables}, + update_period=self._variable_update_period) + + # Component to add things into replay. + adder = adders.NStepTransitionAdder( + client=replay, + n_step=self._n_step, + discount=self._discount, + ) + + # Create the agent. + actor = acting.MCTSActor( + environment_spec=self._env_spec, + model=model, + network=network, + discount=self._discount, + adder=adder, + variable_client=variable_client, + num_simulations=self._num_simulations, + ) + + # Create the loop to connect environment and agent. + return acme.EnvironmentLoop(environment, actor, counter) + + def evaluator( + self, + variable_source: acme.VariableSource, + counter: counting.Counter, + ): + """The evaluation process.""" + + # Build environment, model, network. + environment = self._environment_factory() + network = self._network_factory(self._env_spec.actions) + model = self._model_factory(self._env_spec) + + # Create variable client for communicating with the learner. + tf2_utils.create_variables(network, [self._env_spec.observations]) + variable_client = tf2_variable_utils.VariableClient( + client=variable_source, + variables={'policy': network.trainable_variables}, + update_period=self._variable_update_period) + + # Create the agent. + actor = acting.MCTSActor( + environment_spec=self._env_spec, + model=model, + network=network, + discount=self._discount, + variable_client=variable_client, + num_simulations=self._num_simulations, + ) + + # Create the run loop and return it. + logger = loggers.make_default_logger('evaluator') + return acme.EnvironmentLoop( + environment, actor, counter=counter, logger=logger) + + def build(self, name='MCTS'): + """Builds the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group('replay'): + replay = program.add_node(lp.ReverbNode(self.replay), label='replay') + + with program.group('counter'): + counter = program.add_node( + lp.CourierNode(counting.Counter), label='counter') + + with program.group('learner'): + learner = program.add_node( + lp.CourierNode(self.learner, replay, counter), label='learner') + + with program.group('evaluator'): + program.add_node( + lp.CourierNode(self.evaluator, learner, counter), label='evaluator') + + with program.group('actor'): + program.add_node( + lp.CourierNode(self.actor, replay, learner, counter), label='actor') + + return program diff --git a/acme/acme/agents/tf/mcts/agent_test.py b/acme/acme/agents/tf/mcts/agent_test.py new file mode 100644 index 00000000..de1a8029 --- /dev/null +++ b/acme/acme/agents/tf/mcts/agent_test.py @@ -0,0 +1,68 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the MCTS agent.""" + +import acme +from acme import specs +from acme.agents.tf import mcts +from acme.agents.tf.mcts.models import simulator +from acme.testing import fakes +from acme.tf import networks +import numpy as np +import sonnet as snt + +from absl.testing import absltest + + +class MCTSTest(absltest.TestCase): + + def test_mcts(self): + # Create a fake environment to test with. + num_actions = 5 + environment = fakes.DiscreteEnvironment( + num_actions=num_actions, + num_observations=10, + obs_dtype=np.float32, + episode_length=10) + spec = specs.make_environment_spec(environment) + + network = snt.Sequential([ + snt.Flatten(), + snt.nets.MLP([50, 50]), + networks.PolicyValueHead(spec.actions.num_values), + ]) + model = simulator.Simulator(environment) + optimizer = snt.optimizers.Adam(1e-3) + + # Construct the agent. + agent = mcts.MCTS( + environment_spec=spec, + network=network, + model=model, + optimizer=optimizer, + n_step=1, + discount=1., + replay_capacity=100, + num_simulations=10, + batch_size=10) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=2) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/mcts/learning.py b/acme/acme/agents/tf/mcts/learning.py new file mode 100644 index 00000000..6ec52c3d --- /dev/null +++ b/acme/acme/agents/tf/mcts/learning.py @@ -0,0 +1,89 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A MCTS "AlphaZero-style" learner.""" + +from typing import List, Optional + +import acme +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import numpy as np +import sonnet as snt +import tensorflow as tf + + +class AZLearner(acme.Learner): + """AlphaZero-style learning.""" + + def __init__( + self, + network: snt.Module, + optimizer: snt.Optimizer, + dataset: tf.data.Dataset, + discount: float, + logger: Optional[loggers.Logger] = None, + counter: Optional[counting.Counter] = None, + ): + + # Logger and counter for tracking statistics / writing out to terminal. + self._counter = counting.Counter(counter, 'learner') + self._logger = logger or loggers.TerminalLogger('learner', time_delta=30.) + + # Internalize components. + # TODO(b/155086959): Fix type stubs and remove. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + self._optimizer = optimizer + self._network = network + self._variables = network.trainable_variables + self._discount = np.float32(discount) + + @tf.function + def _step(self) -> tf.Tensor: + """Do a step of SGD on the loss.""" + + inputs = next(self._iterator) + o_t, _, r_t, d_t, o_tp1, extras = inputs.data + pi_t = extras['pi'] + + with tf.GradientTape() as tape: + # Forward the network on the two states in the transition. + logits, value = self._network(o_t) + _, target_value = self._network(o_tp1) + target_value = tf.stop_gradient(target_value) + + # Value loss is simply on-policy TD learning. + value_loss = tf.square(r_t + self._discount * d_t * target_value - value) + + # Policy loss distills MCTS policy into the policy network. + policy_loss = tf.nn.softmax_cross_entropy_with_logits( + logits=logits, labels=pi_t) + + # Compute gradients. + loss = tf.reduce_mean(value_loss + policy_loss) + gradients = tape.gradient(loss, self._network.trainable_variables) + + self._optimizer.apply(gradients, self._network.trainable_variables) + + return loss + + def step(self): + """Does a step of SGD and logs the results.""" + loss = self._step() + self._logger.write({'loss': loss}) + + def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: + """Exposes the variables for actors to update from.""" + return tf2_utils.to_numpy(self._variables) diff --git a/acme/acme/agents/tf/mcts/models/__init__.py b/acme/acme/agents/tf/mcts/models/__init__.py new file mode 100644 index 00000000..f30e5294 --- /dev/null +++ b/acme/acme/agents/tf/mcts/models/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Models for planning via MCTS.""" + +# pylint: disable=unused-import + +from acme.agents.tf.mcts.models.base import Model diff --git a/acme/acme/agents/tf/mcts/models/base.py b/acme/acme/agents/tf/mcts/models/base.py new file mode 100644 index 00000000..616b8f95 --- /dev/null +++ b/acme/acme/agents/tf/mcts/models/base.py @@ -0,0 +1,52 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base model class, specifying the interface..""" + +import abc +from typing import Optional + +from acme.agents.tf.mcts import types + +import dm_env + + +class Model(dm_env.Environment, abc.ABC): + """Base (abstract) class for models used for planning via MCTS.""" + + @abc.abstractmethod + def load_checkpoint(self): + """Loads a saved model state, if it exists.""" + + @abc.abstractmethod + def save_checkpoint(self): + """Saves the model state so that we can reset it after a rollout.""" + + @abc.abstractmethod + def update( + self, + timestep: dm_env.TimeStep, + action: types.Action, + next_timestep: dm_env.TimeStep, + ) -> dm_env.TimeStep: + """Updates the model given an observation, action, reward, and discount.""" + + @abc.abstractmethod + def reset(self, initial_state: Optional[types.Observation] = None): + """Resets the model, optionally to an initial state.""" + + @property + @abc.abstractmethod + def needs_reset(self) -> bool: + """Returns whether or not the model needs to be reset.""" diff --git a/acme/acme/agents/tf/mcts/models/mlp.py b/acme/acme/agents/tf/mcts/models/mlp.py new file mode 100644 index 00000000..6f4c8adb --- /dev/null +++ b/acme/acme/agents/tf/mcts/models/mlp.py @@ -0,0 +1,220 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A simple (deterministic) environment transition model from pixels.""" + +from typing import Optional, Tuple + +from acme import specs +from acme.agents.tf.mcts import types +from acme.agents.tf.mcts.models import base +from acme.tf import utils as tf2_utils + +from bsuite.baselines.utils import replay +import dm_env +import numpy as np +from scipy import special +import sonnet as snt +import tensorflow as tf + + +class MLPTransitionModel(snt.Module): + """This uses MLPs to model (s, a) -> (r, d, s').""" + + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + hidden_sizes: Tuple[int, ...], + ): + super(MLPTransitionModel, self).__init__(name='mlp_transition_model') + + # Get num actions/observation shape. + self._num_actions = environment_spec.actions.num_values + self._input_shape = environment_spec.observations.shape + self._flat_shape = int(np.prod(self._input_shape)) + + # Prediction networks. + self._state_network = snt.Sequential([ + snt.nets.MLP(hidden_sizes + (self._flat_shape,)), + snt.Reshape(self._input_shape) + ]) + self._reward_network = snt.Sequential([ + snt.nets.MLP(hidden_sizes + (1,)), + lambda r: tf.squeeze(r, axis=-1), + ]) + self._discount_network = snt.Sequential([ + snt.nets.MLP(hidden_sizes + (1,)), + lambda d: tf.squeeze(d, axis=-1), + ]) + + def __call__(self, state: tf.Tensor, + action: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + + embedded_state = snt.Flatten()(state) + embedded_action = tf.one_hot(action, depth=self._num_actions) + + embedding = tf.concat([embedded_state, embedded_action], axis=-1) + + # Predict the next state, reward, and termination. + next_state = self._state_network(embedding) + reward = self._reward_network(embedding) + discount_logits = self._discount_network(embedding) + + return next_state, reward, discount_logits + + +class MLPModel(base.Model): + """A simple environment model.""" + + _checkpoint: types.Observation + _state: types.Observation + + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + replay_capacity: int, + batch_size: int, + hidden_sizes: Tuple[int, ...], + learning_rate: float = 1e-3, + terminal_tol: float = 1e-3, + ): + self._obs_spec = environment_spec.observations + self._action_spec = environment_spec.actions + # Hyperparameters. + self._batch_size = batch_size + self._terminal_tol = terminal_tol + + # Modelling + self._replay = replay.Replay(replay_capacity) + self._transition_model = MLPTransitionModel(environment_spec, hidden_sizes) + self._optimizer = snt.optimizers.Adam(learning_rate) + self._forward = tf.function(self._transition_model) + tf2_utils.create_variables( + self._transition_model, [self._obs_spec, self._action_spec]) + self._variables = self._transition_model.trainable_variables + + # Model state. + self._needs_reset = True + + @tf.function + def _step( + self, + o_t: tf.Tensor, + a_t: tf.Tensor, + r_t: tf.Tensor, + d_t: tf.Tensor, + o_tp1: tf.Tensor, + ) -> tf.Tensor: + + with tf.GradientTape() as tape: + next_state, reward, discount = self._transition_model(o_t, a_t) + + state_loss = tf.square(next_state - o_tp1) + reward_loss = tf.square(reward - r_t) + discount_loss = tf.nn.sigmoid_cross_entropy_with_logits(d_t, discount) + + loss = sum([ + tf.reduce_mean(state_loss), + tf.reduce_mean(reward_loss), + tf.reduce_mean(discount_loss), + ]) + + gradients = tape.gradient(loss, self._variables) + self._optimizer.apply(gradients, self._variables) + + return loss + + def step(self, action: types.Action): + # Reset if required. + if self._needs_reset: + raise ValueError('Model must be reset with an initial timestep.') + + # Step the model. + state, action = tf2_utils.add_batch_dim([self._state, action]) + new_state, reward, discount_logits = [ + x.numpy().squeeze(axis=0) for x in self._forward(state, action) + ] + discount = special.softmax(discount_logits) + + # Save the resulting state for the next step. + self._state = new_state + + # We threshold discount on a given tolerance. + if discount < self._terminal_tol: + self._needs_reset = True + return dm_env.termination(reward=reward, observation=self._state.copy()) + return dm_env.transition(reward=reward, observation=self._state.copy()) + + def reset(self, initial_state: Optional[types.Observation] = None): + if initial_state is None: + raise ValueError('Model must be reset with an initial state.') + # We reset to an initial state that we are explicitly given. + # This allows us to handle environments with stochastic resets (e.g. Catch). + self._state = initial_state.copy() + self._needs_reset = False + return dm_env.restart(self._state) + + def update( + self, + timestep: dm_env.TimeStep, + action: types.Action, + next_timestep: dm_env.TimeStep, + ) -> dm_env.TimeStep: + # Add the true transition to replay. + transition = [ + timestep.observation, + action, + next_timestep.reward, + next_timestep.discount, + next_timestep.observation, + ] + self._replay.add(transition) + + # Step the model to generate a synthetic transition. + ts = self.step(action) + + # Copy the *true* state on update. + self._state = next_timestep.observation.copy() + + if ts.last() or next_timestep.last(): + # Model believes that a termination has happened. + # This will result in a crash during planning if the true environment + # didn't terminate here as well. So, we indicate that we need a reset. + self._needs_reset = True + + # Sample from replay and do SGD. + if self._replay.size >= self._batch_size: + batch = self._replay.sample(self._batch_size) + self._step(*batch) + + return ts + + def save_checkpoint(self): + if self._needs_reset: + raise ValueError('Cannot save checkpoint: model must be reset first.') + self._checkpoint = self._state.copy() + + def load_checkpoint(self): + self._needs_reset = False + self._state = self._checkpoint.copy() + + def action_spec(self): + return self._action_spec + + def observation_spec(self): + return self._obs_spec + + @property + def needs_reset(self) -> bool: + return self._needs_reset diff --git a/acme/acme/agents/tf/mcts/models/simulator.py b/acme/acme/agents/tf/mcts/models/simulator.py new file mode 100644 index 00000000..e0acb5d7 --- /dev/null +++ b/acme/acme/agents/tf/mcts/models/simulator.py @@ -0,0 +1,87 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A simulator model, which wraps a copy of the true environment.""" + +import copy +import dataclasses + +from acme.agents.tf.mcts import types +from acme.agents.tf.mcts.models import base +import dm_env + + +@dataclasses.dataclass +class Checkpoint: + """Holds the checkpoint state for the environment simulator.""" + needs_reset: bool + environment: dm_env.Environment + + +class Simulator(base.Model): + """A simulator model, which wraps a copy of the true environment. + + Assumptions: + - The environment (including RNG) is fully copyable via `deepcopy`. + - Environment dynamics (modulo episode resets) are deterministic. + """ + + _checkpoint: Checkpoint + _env: dm_env.Environment + + def __init__(self, env: dm_env.Environment): + # Make a 'checkpoint' copy env to save/load from when doing rollouts. + self._env = copy.deepcopy(env) + self._needs_reset = True + self.save_checkpoint() + + def update( + self, + timestep: dm_env.TimeStep, + action: types.Action, + next_timestep: dm_env.TimeStep, + ) -> dm_env.TimeStep: + # Call update() once per 'real' experience to keep this env in sync. + return self.step(action) + + def save_checkpoint(self): + self._checkpoint = Checkpoint( + needs_reset=self._needs_reset, + environment=copy.deepcopy(self._env), + ) + + def load_checkpoint(self): + self._env = copy.deepcopy(self._checkpoint.environment) + self._needs_reset = self._checkpoint.needs_reset + + def step(self, action: types.Action) -> dm_env.TimeStep: + if self._needs_reset: + raise ValueError('This model needs to be explicitly reset.') + timestep = self._env.step(action) + self._needs_reset = timestep.last() + return timestep + + def reset(self, *unused_args, **unused_kwargs): + self._needs_reset = False + return self._env.reset() + + def observation_spec(self): + return self._env.observation_spec() + + def action_spec(self): + return self._env.action_spec() + + @property + def needs_reset(self) -> bool: + return self._needs_reset diff --git a/acme/acme/agents/tf/mcts/models/simulator_test.py b/acme/acme/agents/tf/mcts/models/simulator_test.py new file mode 100644 index 00000000..90b68d59 --- /dev/null +++ b/acme/acme/agents/tf/mcts/models/simulator_test.py @@ -0,0 +1,90 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for simulator.py.""" + +from acme.agents.tf.mcts.models import simulator +from bsuite.environments import catch +import dm_env +import numpy as np + +from absl.testing import absltest + + +class SimulatorTest(absltest.TestCase): + + def _check_equal(self, a: dm_env.TimeStep, b: dm_env.TimeStep): + self.assertEqual(a.reward, b.reward) + self.assertEqual(a.discount, b.discount) + self.assertEqual(a.step_type, b.step_type) + np.testing.assert_array_equal(a.observation, b.observation) + + def test_simulator_fidelity(self): + """Tests whether the simulator match the ground truth.""" + + # Given an environment. + env = catch.Catch() + + # If we instantiate a simulator 'model' of this environment. + model = simulator.Simulator(env) + + # Then the model and environment should always agree as we step them. + num_actions = env.action_spec().num_values + for _ in range(10): + true_timestep = env.reset() + self.assertTrue(model.needs_reset) + model_timestep = model.reset() + self.assertFalse(model.needs_reset) + self._check_equal(true_timestep, model_timestep) + + while not true_timestep.last(): + action = np.random.randint(num_actions) + true_timestep = env.step(action) + model_timestep = model.step(action) + self._check_equal(true_timestep, model_timestep) + + def test_checkpointing(self): + """Tests whether checkpointing restores the state correctly.""" + # Given an environment, and a model based on this environment. + model = simulator.Simulator(catch.Catch()) + num_actions = model.action_spec().num_values + + model.reset() + + # Now, we save a checkpoint. + model.save_checkpoint() + + ts = model.step(1) + + # Step the model once and load the checkpoint. + timestep = model.step(np.random.randint(num_actions)) + model.load_checkpoint() + self._check_equal(ts, model.step(1)) + + while not timestep.last(): + timestep = model.step(np.random.randint(num_actions)) + + # The model should require a reset. + self.assertTrue(model.needs_reset) + + # Once we load checkpoint, the model should no longer require reset. + model.load_checkpoint() + self.assertFalse(model.needs_reset) + + # Further steps should agree with the original environment state. + self._check_equal(ts, model.step(1)) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/mcts/search.py b/acme/acme/agents/tf/mcts/search.py new file mode 100644 index 00000000..7bcc85f4 --- /dev/null +++ b/acme/acme/agents/tf/mcts/search.py @@ -0,0 +1,194 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Monte Carlo Tree Search implementation.""" + +import dataclasses +from typing import Callable, Dict + +from acme.agents.tf.mcts import models +from acme.agents.tf.mcts import types +import numpy as np + + +@dataclasses.dataclass +class Node: + """A MCTS node.""" + + reward: float = 0. + visit_count: int = 0 + terminal: bool = False + prior: float = 1. + total_value: float = 0. + children: Dict[types.Action, 'Node'] = dataclasses.field(default_factory=dict) + + def expand(self, prior: np.ndarray): + """Expands this node, adding child nodes.""" + assert prior.ndim == 1 # Prior should be a flat vector. + for a, p in enumerate(prior): + self.children[a] = Node(prior=p) + + @property + def value(self) -> types.Value: # Q(s, a) + """Returns the value from this node.""" + if self.visit_count: + return self.total_value / self.visit_count + return 0. + + @property + def children_visits(self) -> np.ndarray: + """Return array of visit counts of visited children.""" + return np.array([c.visit_count for c in self.children.values()]) + + @property + def children_values(self) -> np.ndarray: + """Return array of values of visited children.""" + return np.array([c.value for c in self.children.values()]) + + +SearchPolicy = Callable[[Node], types.Action] + + +def mcts( + observation: types.Observation, + model: models.Model, + search_policy: SearchPolicy, + evaluation: types.EvaluationFn, + num_simulations: int, + num_actions: int, + discount: float = 1., + dirichlet_alpha: float = 1, + exploration_fraction: float = 0., +) -> Node: + """Does Monte Carlo tree search (MCTS), AlphaZero style.""" + + # Evaluate the prior policy for this state. + prior, value = evaluation(observation) + assert prior.shape == (num_actions,) + + # Add exploration noise to the prior. + noise = np.random.dirichlet(alpha=[dirichlet_alpha] * num_actions) + prior = prior * (1 - exploration_fraction) + noise * exploration_fraction + + # Create a fresh tree search. + root = Node() + root.expand(prior) + + # Save the model state so that we can reset it for each simulation. + model.save_checkpoint() + for _ in range(num_simulations): + # Start a new simulation from the top. + trajectory = [root] + node = root + + # Generate a trajectory. + timestep = None + while node.children: + # Select an action according to the search policy. + action = search_policy(node) + + # Point the node at the corresponding child. + node = node.children[action] + + # Step the simulator and add this timestep to the node. + timestep = model.step(action) + node.reward = timestep.reward or 0. + node.terminal = timestep.last() + trajectory.append(node) + + if timestep is None: + raise ValueError('Generated an empty rollout; this should not happen.') + + # Calculate the bootstrap for leaf nodes. + if node.terminal: + # If terminal, there is no bootstrap value. + value = 0. + else: + # Otherwise, bootstrap from this node with our value function. + prior, value = evaluation(timestep.observation) + + # We also want to expand this node for next time. + node.expand(prior) + + # Load the saved model state. + model.load_checkpoint() + + # Monte Carlo back-up with bootstrap from value function. + ret = value + while trajectory: + # Pop off the latest node in the trajectory. + node = trajectory.pop() + + # Accumulate the discounted return + ret *= discount + ret += node.reward + + # Update the node. + node.total_value += ret + node.visit_count += 1 + + return root + + +def bfs(node: Node) -> types.Action: + """Breadth-first search policy.""" + visit_counts = np.array([c.visit_count for c in node.children.values()]) + return argmax(-visit_counts) + + +def puct(node: Node, ucb_scaling: float = 1.) -> types.Action: + """PUCT search policy, i.e. UCT with 'prior' policy.""" + # Action values Q(s,a). + value_scores = np.array([child.value for child in node.children.values()]) + check_numerics(value_scores) + + # Policy prior P(s,a). + priors = np.array([child.prior for child in node.children.values()]) + check_numerics(priors) + + # Visit ratios. + visit_ratios = np.array([ + np.sqrt(node.visit_count) / (child.visit_count + 1) + for child in node.children.values() + ]) + check_numerics(visit_ratios) + + # Combine. + puct_scores = value_scores + ucb_scaling * priors * visit_ratios + return argmax(puct_scores) + + +def visit_count_policy(root: Node, temperature: float = 1.) -> types.Probs: + """Probability weighted by visit^{1/temp} of children nodes.""" + visits = root.children_visits + if np.sum(visits) == 0: # uniform policy for zero visits + visits += 1 + rescaled_visits = visits**(1 / temperature) + probs = rescaled_visits / np.sum(rescaled_visits) + check_numerics(probs) + + return probs + + +def argmax(values: np.ndarray) -> types.Action: + """Argmax with random tie-breaking.""" + check_numerics(values) + max_value = np.max(values) + return np.int32(np.random.choice(np.flatnonzero(values == max_value))) + + +def check_numerics(values: np.ndarray): + """Raises a ValueError if any of the inputs are NaN or Inf.""" + if not np.isfinite(values).all(): + raise ValueError('check_numerics failed. Inputs: {}. '.format(values)) diff --git a/acme/acme/agents/tf/mcts/search_test.py b/acme/acme/agents/tf/mcts/search_test.py new file mode 100644 index 00000000..c5b4190c --- /dev/null +++ b/acme/acme/agents/tf/mcts/search_test.py @@ -0,0 +1,65 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for search.py.""" + +from typing import Text + +from acme.agents.tf.mcts import search +from acme.agents.tf.mcts.models import simulator +from bsuite.environments import catch +import numpy as np + +from absl.testing import absltest +from absl.testing import parameterized + + +class TestSearch(parameterized.TestCase): + + @parameterized.parameters([ + 'puct', + 'bfs', + ]) + def test_catch(self, policy_type: Text): + env = catch.Catch(rows=2, seed=1) + num_actions = env.action_spec().num_values + model = simulator.Simulator(env) + eval_fn = lambda _: (np.ones(num_actions) / num_actions, 0.) + + timestep = env.reset() + model.reset() + + search_policy = search.bfs if policy_type == 'bfs' else search.puct + + root = search.mcts( + observation=timestep.observation, + model=model, + search_policy=search_policy, + evaluation=eval_fn, + num_simulations=100, + num_actions=num_actions) + + values = np.array([c.value for c in root.children.values()]) + best_action = search.argmax(values) + + if env._paddle_x > env._ball_x: + self.assertEqual(best_action, 0) + if env._paddle_x == env._ball_x: + self.assertEqual(best_action, 1) + if env._paddle_x < env._ball_x: + self.assertEqual(best_action, 2) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/mcts/types.py b/acme/acme/agents/tf/mcts/types.py new file mode 100644 index 00000000..93d8d8cd --- /dev/null +++ b/acme/acme/agents/tf/mcts/types.py @@ -0,0 +1,39 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Type aliases and assumptions that are specific to the MCTS agent.""" + +from typing import Callable, Tuple, Union +import numpy as np + +# pylint: disable=invalid-name + +# Assumption: actions are scalar and discrete (integral). +Action = Union[int, np.int32, np.int64] + +# Assumption: observations are array-like. +Observation = np.ndarray + +# Assumption: rewards and discounts are scalar. +Reward = Union[float, np.float32, np.float64] +Discount = Union[float, np.float32, np.float64] + +# Notation: policy logits/probabilities are simply a vector of floats. +Probs = np.ndarray + +# Notation: the value function is scalar-valued. +Value = float + +# Notation: the 'evaluation function' maps observations -> (probs, value). +EvaluationFn = Callable[[Observation], Tuple[Probs, Value]] diff --git a/acme/acme/agents/tf/mog_mpo/README.md b/acme/acme/agents/tf/mog_mpo/README.md new file mode 100644 index 00000000..b36d5249 --- /dev/null +++ b/acme/acme/agents/tf/mog_mpo/README.md @@ -0,0 +1,49 @@ +# Mixture of Gaussian distributional MPO (MoG-DMPO) + +This folder contains an implementation of a novel agent (MoG-MPO) introduced in +[this technical report](https://arxiv.org/abs/2204.10256). +This work extends the MPO algorithm ([Abdolmaleki et al., 2018a], [2018b]) by +using a distributional Q-network parameterized as a mixture of Gaussians. +Therefore, as in the case of the D4PG and DMPO agent, this algorithm's critic +outputs a distribution over state-action values. + +As in our MPO agent, this is a more general algorithm, the current +implementation targets the continuous control setting and is most readily +applied to the DeepMind control suite or similar control tasks. This +implementation also includes the options of: + +* per-dimension KL constraint satisfaction, and +* action penalization via the multi-objective MPO work of + [Abdolmaleki et al., 2020]. + +Detailed notes: + +* When using per-dimension KL constraint satisfaction, you may need to tune + the value of `epsilon_mean` (and `epsilon_stddev` if not fixed). A good rule + of thumb would be to divide it by the number of dimensions in the action + space. + +[Abdolmaleki et al., 2018a]: https://arxiv.org/pdf/1806.06920.pdf +[2018b]: https://arxiv.org/pdf/1812.02256.pdf +[Abdolmaleki et al., 2020]: https://arxiv.org/pdf/2005.07513.pdf + +Citation: + +``` +@misc{mog_mpo, + title = {Revisiting Gaussian mixture critics in off-policy reinforcement + learning: a sample-based approach}, + url = {https://arxiv.org/abs/2204.10256}, + author = {Shahriari, Bobak and + Abdolmaleki, Abbas and + Byravan, Arunkumar and + Friesen, Abram and + Liu, Siqi and + Springenberg, Jost Tobias and + Heess, Nicolas and + Hoffman, Matt and + Riedmiller, Martin}, + publisher = {arXiv}, + year = {2022}, +} +``` diff --git a/acme/acme/agents/tf/mog_mpo/__init__.py b/acme/acme/agents/tf/mog_mpo/__init__.py new file mode 100644 index 00000000..bd2906f0 --- /dev/null +++ b/acme/acme/agents/tf/mog_mpo/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementations of a (MoG) distributional MPO agent.""" + +from acme.agents.tf.mog_mpo.agent_distributed import DistributedMoGMPO +from acme.agents.tf.mog_mpo.learning import MoGMPOLearner +from acme.agents.tf.mog_mpo.learning import PolicyEvaluationConfig +from acme.agents.tf.mog_mpo.networks import make_default_networks diff --git a/acme/acme/agents/tf/mog_mpo/agent_distributed.py b/acme/acme/agents/tf/mog_mpo/agent_distributed.py new file mode 100644 index 00000000..de9a4e60 --- /dev/null +++ b/acme/acme/agents/tf/mog_mpo/agent_distributed.py @@ -0,0 +1,292 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the (MoG) distributional MPO distributed agent class.""" + +from typing import Callable, Dict, Optional + +import acme +from acme import datasets +from acme import specs +from acme.adders import reverb as adders +from acme.agents.tf import actors +from acme.agents.tf.mog_mpo import learning +from acme.tf import networks +from acme.tf import savers as tf2_savers +from acme.tf import variable_utils as tf2_variable_utils +from acme.utils import counting +from acme.utils import loggers +from acme.utils import lp_utils +import dm_env +import launchpad as lp +import reverb +import sonnet as snt + + +class DistributedMoGMPO: + """Program definition for distributional (MoG) MPO.""" + + def __init__( + self, + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: Callable[[specs.EnvironmentSpec], Dict[str, snt.Module]], + num_actors: int = 1, + num_caches: int = 0, + environment_spec: Optional[specs.EnvironmentSpec] = None, + batch_size: int = 256, + prefetch_size: int = 4, + min_replay_size: int = 1_000, + max_replay_size: int = 1_000_000, + samples_per_insert: Optional[float] = 32.0, + n_step: int = 5, + num_samples: int = 20, + policy_evaluation_config: Optional[ + learning.PolicyEvaluationConfig] = None, + additional_discount: float = 0.99, + target_policy_update_period: int = 100, + target_critic_update_period: int = 100, + policy_loss_factory: Optional[Callable[[], snt.Module]] = None, + max_actor_steps: Optional[int] = None, + log_every: float = 10.0, + ): + + if environment_spec is None: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + self._environment_factory = environment_factory + self._network_factory = network_factory + self._policy_loss_factory = policy_loss_factory + self._environment_spec = environment_spec + self._num_actors = num_actors + self._num_caches = num_caches + self._batch_size = batch_size + self._prefetch_size = prefetch_size + self._min_replay_size = min_replay_size + self._max_replay_size = max_replay_size + self._samples_per_insert = samples_per_insert + self._n_step = n_step + self._additional_discount = additional_discount + self._num_samples = num_samples + self._policy_evaluation_config = policy_evaluation_config + self._target_policy_update_period = target_policy_update_period + self._target_critic_update_period = target_critic_update_period + self._max_actor_steps = max_actor_steps + self._log_every = log_every + + def replay(self): + """The replay storage.""" + if self._samples_per_insert is not None: + # Create enough of an error buffer to give a 10% tolerance in rate. + samples_per_insert_tolerance = 0.1 * self._samples_per_insert + error_buffer = self._min_replay_size * samples_per_insert_tolerance + + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._min_replay_size, + samples_per_insert=self._samples_per_insert, + error_buffer=error_buffer) + else: + limiter = reverb.rate_limiters.MinSize(self._min_replay_size) + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._max_replay_size, + rate_limiter=limiter, + signature=adders.NStepTransitionAdder.signature( + self._environment_spec)) + return [replay_table] + + def counter(self): + return tf2_savers.CheckpointingRunner( + counting.Counter(), time_delta_minutes=1, subdirectory='counter') + + def coordinator(self, counter: counting.Counter, max_actor_steps: int): + return lp_utils.StepsLimiter(counter, max_actor_steps) + + def learner( + self, + replay: reverb.Client, + counter: counting.Counter, + ): + """The Learning part of the agent.""" + + # Create online and target networks. + online_networks = self._network_factory(self._environment_spec) + target_networks = self._network_factory(self._environment_spec) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset( + server_address=replay.server_address, + batch_size=self._batch_size, + prefetch_size=self._prefetch_size, + ) + + counter = counting.Counter(counter, 'learner') + logger = loggers.make_default_logger('learner', time_delta=self._log_every) + + # Create policy loss module if a factory is passed. + if self._policy_loss_factory: + policy_loss_module = self._policy_loss_factory() + else: + policy_loss_module = None + + # Return the learning agent. + return learning.MoGMPOLearner( + policy_network=online_networks['policy'], + critic_network=online_networks['critic'], + observation_network=online_networks['observation'], + target_policy_network=target_networks['policy'], + target_critic_network=target_networks['critic'], + target_observation_network=target_networks['observation'], + discount=self._additional_discount, + num_samples=self._num_samples, + policy_evaluation_config=self._policy_evaluation_config, + target_policy_update_period=self._target_policy_update_period, + target_critic_update_period=self._target_critic_update_period, + policy_loss_module=policy_loss_module, + dataset=dataset, + counter=counter, + logger=logger) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + actor_id: int, + ) -> acme.EnvironmentLoop: + """The actor process.""" + + # Create environment and target networks to act with. + environment = self._environment_factory(False) + agent_networks = self._network_factory(self._environment_spec) + + # Create a stochastic behavior policy. + behavior_network = snt.Sequential([ + agent_networks['observation'], + agent_networks['policy'], + networks.StochasticSamplingHead(), + ]) + + # Ensure network variables are created. + policy_variables = {'policy': behavior_network.variables} + + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = tf2_variable_utils.VariableClient( + variable_source, policy_variables, update_period=1000) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + # Component to add things into replay. + adder = adders.NStepTransitionAdder( + client=replay, n_step=self._n_step, discount=self._additional_discount) + + # Create the agent. + actor = actors.FeedForwardActor( + policy_network=behavior_network, + adder=adder, + variable_client=variable_client) + + # Create logger and counter; actors will not spam bigtable. + save_data = actor_id == 0 + counter = counting.Counter(counter, 'actor') + logger = loggers.make_default_logger( + 'actor', save_data=save_data, time_delta=self._log_every) + + # Create the run loop and return it. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def evaluator( + self, + variable_source: acme.VariableSource, + counter: counting.Counter, + ): + """The evaluation process.""" + + # Create environment and target networks to act with. + environment = self._environment_factory(True) + agent_networks = self._network_factory(self._environment_spec) + + # Create a stochastic behavior policy. + evaluator_network = snt.Sequential([ + agent_networks['observation'], + agent_networks['policy'], + networks.StochasticMeanHead(), + ]) + + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = tf2_variable_utils.VariableClient( + variable_source, + variables={'policy': evaluator_network.variables}, + update_period=1000) + + # Make sure not to evaluate a random actor by assigning variables before + # running the environment loop. + variable_client.update_and_wait() + + # Create the agent. + evaluator = actors.FeedForwardActor( + policy_network=evaluator_network, variable_client=variable_client) + + # Create logger and counter. + counter = counting.Counter(counter, 'evaluator') + logger = loggers.make_default_logger( + 'evaluator', time_delta=self._log_every) + + # Create the run loop and return it. + return acme.EnvironmentLoop(environment, evaluator, counter, logger) + + def build(self, name='dmpo'): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group('replay'): + replay = program.add_node(lp.ReverbNode(self.replay)) + + with program.group('counter'): + counter = program.add_node(lp.CourierNode(self.counter)) + + if self._max_actor_steps: + _ = program.add_node( + lp.CourierNode(self.coordinator, counter, self._max_actor_steps)) + + with program.group('learner'): + learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) + + with program.group('evaluator'): + program.add_node(lp.CourierNode(self.evaluator, learner, counter)) + + if not self._num_caches: + # Use our learner as a single variable source. + sources = [learner] + else: + with program.group('cacher'): + # Create a set of learner caches. + sources = [] + for _ in range(self._num_caches): + cacher = program.add_node( + lp.CacherNode( + learner, refresh_interval_ms=2000, stale_after_ms=4000)) + sources.append(cacher) + + with program.group('actor'): + # Add actors which pull round-robin from our variable sources. + for actor_id in range(self._num_actors): + source = sources[actor_id % len(sources)] + program.add_node( + lp.CourierNode(self.actor, replay, source, counter, actor_id)) + + return program diff --git a/acme/acme/agents/tf/mog_mpo/learning.py b/acme/acme/agents/tf/mog_mpo/learning.py new file mode 100644 index 00000000..60b358af --- /dev/null +++ b/acme/acme/agents/tf/mog_mpo/learning.py @@ -0,0 +1,317 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Distributional MPO with MoG critic learner implementation.""" + +import dataclasses +import time +from typing import List, Optional + +import acme +from acme import types +from acme.tf import losses +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import numpy as np +import reverb +import sonnet as snt +import tensorflow as tf +import tensorflow_probability as tfp + +tfd = tfp.distributions + + +@dataclasses.dataclass +class PolicyEvaluationConfig: + evaluate_stochastic_policy: bool = True + num_value_samples: int = 128 + + +class MoGMPOLearner(acme.Learner): + """Distributional (MoG) MPO learner.""" + + def __init__( + self, + policy_network: snt.Module, + critic_network: snt.Module, + target_policy_network: snt.Module, + target_critic_network: snt.Module, + discount: float, + num_samples: int, + target_policy_update_period: int, + target_critic_update_period: int, + dataset: tf.data.Dataset, + observation_network: snt.Module, + target_observation_network: snt.Module, + policy_evaluation_config: Optional[PolicyEvaluationConfig] = None, + policy_loss_module: Optional[snt.Module] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + dual_optimizer: Optional[snt.Optimizer] = None, + clipping: bool = True, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + ): + + # Store online and target networks. + self._policy_network = policy_network + self._critic_network = critic_network + self._observation_network = observation_network + self._target_policy_network = target_policy_network + self._target_critic_network = target_critic_network + self._target_observation_network = target_observation_network + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger('learner') + + # Other learner parameters. + self._discount = discount + self._num_samples = num_samples + if policy_evaluation_config is None: + policy_evaluation_config = PolicyEvaluationConfig() + self._policy_evaluation_config = policy_evaluation_config + self._clipping = clipping + + # Necessary to track when to update target networks. + self._num_steps = tf.Variable(0, dtype=tf.int32) + self._target_policy_update_period = target_policy_update_period + self._target_critic_update_period = target_critic_update_period + + # Batch dataset and create iterator. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + + self._policy_loss_module = policy_loss_module or losses.MPO( + epsilon=1e-1, + epsilon_mean=3e-3, + epsilon_stddev=1e-6, + epsilon_penalty=1e-3, + init_log_temperature=10., + init_log_alpha_mean=10., + init_log_alpha_stddev=1000.) + + # Create the optimizers. + self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2) + + # Expose the variables. + policy_network_to_expose = snt.Sequential( + [self._target_observation_network, self._target_policy_network]) + self._variables = { + 'critic': self._target_critic_network.variables, + 'policy': policy_network_to_expose.variables, + } + + # Create a checkpointer and snapshotter object. + self._checkpointer = None + self._snapshotter = None + + if checkpoint: + self._checkpointer = tf2_savers.Checkpointer( + subdirectory='mog_mpo_learner', + objects_to_save={ + 'counter': self._counter, + 'policy': self._policy_network, + 'critic': self._critic_network, + 'observation': self._observation_network, + 'target_policy': self._target_policy_network, + 'target_critic': self._target_critic_network, + 'target_observation': self._target_observation_network, + 'policy_optimizer': self._policy_optimizer, + 'critic_optimizer': self._critic_optimizer, + 'dual_optimizer': self._dual_optimizer, + 'policy_loss_module': self._policy_loss_module, + 'num_steps': self._num_steps, + }) + + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={ + 'policy': + snt.Sequential([ + self._target_observation_network, + self._target_policy_network + ]), + }) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + @tf.function + def _step(self, inputs: reverb.ReplaySample) -> types.NestedTensor: + + # Get data from replay (dropping extras if any). Note there is no + # extra data here because we do not insert any into Reverb. + o_tm1, a_tm1, r_t, d_t, o_t = (inputs.data.observation, inputs.data.action, + inputs.data.reward, inputs.data.discount, + inputs.data.next_observation) + + # Cast the additional discount to match the environment discount dtype. + discount = tf.cast(self._discount, dtype=d_t.dtype) + + with tf.GradientTape(persistent=True) as tape: + # Maybe transform the observation before feeding into policy and critic. + # Transforming the observations this way at the start of the learning + # step effectively means that the policy and critic share observation + # network weights. + o_tm1 = self._observation_network(o_tm1) + # This stop_gradient prevents gradients to propagate into the target + # observation network. In addition, since the online policy network is + # evaluated at o_t, this also means the policy loss does not influence + # the observation network training. + o_t = tf.stop_gradient(self._target_observation_network(o_t)) + + # Get online and target action distributions from policy networks. + online_action_distribution = self._policy_network(o_t) + target_action_distribution = self._target_policy_network(o_t) + + # Sample actions to evaluate policy; of size [N, B, ...]. + sampled_actions = target_action_distribution.sample(self._num_samples) + + # Tile embedded observations to feed into the target critic network. + # Note: this is more efficient than tiling before the embedding layer. + tiled_o_t = tf2_utils.tile_tensor(o_t, self._num_samples) # [N, B, ...] + + # Compute target-estimated distributional value of sampled actions at o_t. + sampled_q_t_distributions = self._target_critic_network( + # Merge batch dimensions; to shape [N*B, ...]. + snt.merge_leading_dims(tiled_o_t, num_dims=2), + snt.merge_leading_dims(sampled_actions, num_dims=2)) + + # Compute online critic value distribution of a_tm1 in state o_tm1. + q_tm1_distribution = self._critic_network(o_tm1, a_tm1) # [B, ...] + + # Get the return distributions used in the policy evaluation bootstrap. + if self._policy_evaluation_config.evaluate_stochastic_policy: + z_distributions = sampled_q_t_distributions + num_joint_samples = self._num_samples + else: + z_distributions = self._target_critic_network( + o_t, target_action_distribution.mean()) + num_joint_samples = 1 + + num_value_samples = self._policy_evaluation_config.num_value_samples + num_joint_samples *= num_value_samples + z_samples = z_distributions.sample(num_value_samples) + z_samples = tf.reshape(z_samples, (num_joint_samples, -1, 1)) + + # Expand dims of reward and discount tensors. + reward = r_t[..., tf.newaxis] # [B, 1] + full_discount = discount * d_t[..., tf.newaxis] + target_q = reward + full_discount * z_samples # [N, B, 1] + target_q = tf.stop_gradient(target_q) + + # Compute sample-based cross-entropy. + log_probs_q = q_tm1_distribution.log_prob(target_q) # [N, B, 1] + critic_loss = -tf.reduce_mean(log_probs_q, axis=0) # [B, 1] + critic_loss = tf.reduce_mean(critic_loss) + + # Compute Q-values of sampled actions and reshape to [N, B]. + sampled_q_values = sampled_q_t_distributions.mean() + sampled_q_values = tf.reshape(sampled_q_values, (self._num_samples, -1)) + + # Compute MPO policy loss. + policy_loss, policy_stats = self._policy_loss_module( + online_action_distribution=online_action_distribution, + target_action_distribution=target_action_distribution, + actions=sampled_actions, + q_values=sampled_q_values) + policy_loss = tf.reduce_mean(policy_loss) + + # For clarity, explicitly define which variables are trained by which loss. + critic_trainable_variables = ( + # In this agent, the critic loss trains the observation network. + self._observation_network.trainable_variables + + self._critic_network.trainable_variables) + policy_trainable_variables = self._policy_network.trainable_variables + # The following are the MPO dual variables, stored in the loss module. + dual_trainable_variables = self._policy_loss_module.trainable_variables + + # Compute gradients. + critic_gradients = tape.gradient(critic_loss, critic_trainable_variables) + policy_gradients, dual_gradients = tape.gradient( + policy_loss, (policy_trainable_variables, dual_trainable_variables)) + + # Delete the tape manually because of the persistent=True flag. + del tape + + # Maybe clip gradients. + if self._clipping: + policy_gradients = tuple(tf.clip_by_global_norm(policy_gradients, 40.)[0]) + critic_gradients = tuple(tf.clip_by_global_norm(critic_gradients, 40.)[0]) + + # Apply gradients. + self._critic_optimizer.apply(critic_gradients, critic_trainable_variables) + self._policy_optimizer.apply(policy_gradients, policy_trainable_variables) + self._dual_optimizer.apply(dual_gradients, dual_trainable_variables) + + # Losses to track. + fetches = { + 'critic_loss': critic_loss, + 'policy_loss': policy_loss, + } + # Log MPO stats. + fetches.update(policy_stats) + + return fetches + + def step(self): + self._maybe_update_target_networks() + self._num_steps.assign_add(1) + + # Run the learning step. + fetches = self._step(next(self._iterator)) + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + fetches.update(counts) + + # Checkpoint and attempt to write the logs. + if self._checkpointer is not None: + self._checkpointer.save() + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(fetches) + + def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: + return [tf2_utils.to_numpy(self._variables[name]) for name in names] + + def _maybe_update_target_networks(self): + # Update target network. + online_policy_variables = self._policy_network.variables + target_policy_variables = self._target_policy_network.variables + online_critic_variables = (*self._observation_network.variables, + *self._critic_network.variables) + target_critic_variables = (*self._target_observation_network.variables, + *self._target_critic_network.variables) + + # Make online policy -> target policy network update ops. + if tf.math.mod(self._num_steps, self._target_policy_update_period) == 0: + for src, dest in zip(online_policy_variables, target_policy_variables): + dest.assign(src) + + # Make online critic -> target critic network update ops. + if tf.math.mod(self._num_steps, self._target_critic_update_period) == 0: + for src, dest in zip(online_critic_variables, target_critic_variables): + dest.assign(src) diff --git a/acme/acme/agents/tf/mog_mpo/networks.py b/acme/acme/agents/tf/mog_mpo/networks.py new file mode 100644 index 00000000..1d9b18b7 --- /dev/null +++ b/acme/acme/agents/tf/mog_mpo/networks.py @@ -0,0 +1,76 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared helpers for different experiment flavours.""" + +from typing import Mapping, Sequence + +from acme import specs +from acme.tf import networks +from acme.tf import utils as tf2_utils + +import numpy as np +import sonnet as snt + + +def make_default_networks( + environment_spec: specs.EnvironmentSpec, + *, + policy_layer_sizes: Sequence[int] = (256, 256, 256), + critic_layer_sizes: Sequence[int] = (512, 512, 256), + policy_init_scale: float = 0.7, + critic_init_scale: float = 1e-3, + critic_num_components: int = 5, +) -> Mapping[str, snt.Module]: + """Creates networks used by the agent.""" + + # Unpack the environment spec to get appropriate shapes, dtypes, etc. + act_spec = environment_spec.actions + obs_spec = environment_spec.observations + num_dimensions = np.prod(act_spec.shape, dtype=int) + + # Create the observation network and make sure it's a Sonnet module. + observation_network = tf2_utils.batch_concat + observation_network = tf2_utils.to_sonnet_module(observation_network) + + # Create the policy network. + policy_network = snt.Sequential([ + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + init_scale=policy_init_scale, + use_tfd_independent=True) + ]) + + # The multiplexer concatenates the (maybe transformed) observations/actions. + critic_network = snt.Sequential([ + networks.CriticMultiplexer(action_network=networks.ClipToSpec(act_spec)), + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.GaussianMixtureHead( + num_dimensions=1, + num_components=critic_num_components, + init_scale=critic_init_scale) + ]) + + # Create network variables. + # Get embedding spec by creating observation network variables. + emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) + tf2_utils.create_variables(policy_network, [emb_spec]) + tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) + + return { + 'policy': policy_network, + 'critic': critic_network, + 'observation': observation_network, + } diff --git a/acme/acme/agents/tf/mompo/README.md b/acme/acme/agents/tf/mompo/README.md new file mode 100644 index 00000000..3c39ff48 --- /dev/null +++ b/acme/acme/agents/tf/mompo/README.md @@ -0,0 +1,28 @@ +# Multi-Objective Maximum a posteriori Policy Optimization (MO-MPO) + +This folder contains an implementation of Multi-Objective Maximum a posteriori +Policy Optimization (MO-MPO), introduced in ([Abdolmaleki, Huang et al., 2020]). +This trains a policy that optimizes for multiple objectives, with the desired +preference across objectives encoded by the hyperparameters `epsilon`. + +As with our MPO agent, while this is a more general algorithm, the current +implementation targets the continuous control setting and is most readily +applied to the DeepMind control suite or similar control tasks. This +implementation also includes the options of: + +* per-dimension KL constraint satisfaction, and +* distributional (per-objective) critics, as used by the DMPO agent + +Detailed notes: + +* When using per-dimension KL constraint satisfaction, you may need to tune + the value of `epsilon_mean` (and `epsilon_stddev` if not fixed). A good rule + of thumb would be to divide it by the number of dimensions in the action + space. +* If using a distributional critic, the `vmin|vmax` hyperparameters of the + distributional critic may need tuning depending on your environment's + rewards. A good rule of thumb is to set `vmax` to the discounted sum of the + maximum instantaneous rewards for the maximum episode length; then set + `vmin` to `-vmax`. + +[Abdolmaleki, Huang et al., 2020]: https://arxiv.org/abs/2005.07513 diff --git a/acme/acme/agents/tf/mompo/__init__.py b/acme/acme/agents/tf/mompo/__init__.py new file mode 100644 index 00000000..cee0d99a --- /dev/null +++ b/acme/acme/agents/tf/mompo/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementations of a distributional MPO agent.""" + +from acme.agents.tf.mompo.agent import MultiObjectiveMPO +from acme.agents.tf.mompo.agent_distributed import DistributedMultiObjectiveMPO +from acme.agents.tf.mompo.learning import MultiObjectiveMPOLearner +from acme.agents.tf.mompo.learning import QValueObjective +from acme.agents.tf.mompo.learning import RewardObjective diff --git a/acme/acme/agents/tf/mompo/agent.py b/acme/acme/agents/tf/mompo/agent.py new file mode 100644 index 00000000..e597a559 --- /dev/null +++ b/acme/acme/agents/tf/mompo/agent.py @@ -0,0 +1,204 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-objective MPO agent implementation.""" + +import copy +from typing import Optional, Sequence + +from acme import datasets +from acme import specs +from acme import types +from acme.adders import reverb as adders +from acme.agents import agent +from acme.agents.tf import actors +from acme.agents.tf.mompo import learning +from acme.tf import losses +from acme.tf import networks +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import reverb +import sonnet as snt +import tensorflow as tf + + +class MultiObjectiveMPO(agent.Agent): + """Multi-objective MPO Agent. + + This implements a single-process multi-objective MPO agent. This is an + actor-critic algorithm that generates data via a behavior policy, inserts + N-step transitions into a replay buffer, and periodically updates the policy + (and as a result the behavior) by sampling uniformly from this buffer. + This agent distinguishes itself from the MPO agent in two ways: + - Allowing for one or more objectives (see `acme/agents/tf/mompo/learning.py` + for details on what form this sequence of objectives should take) + - Optionally using a distributional critic (state-action value approximator) + as in DMPO. In other words, the critic network can output either scalar + Q-values or a DiscreteValuedDistribution. + """ + + def __init__(self, + reward_objectives: Sequence[learning.RewardObjective], + qvalue_objectives: Sequence[learning.QValueObjective], + environment_spec: specs.EnvironmentSpec, + policy_network: snt.Module, + critic_network: snt.Module, + observation_network: types.TensorTransformation = tf.identity, + discount: float = 0.99, + batch_size: int = 512, + prefetch_size: int = 4, + target_policy_update_period: int = 200, + target_critic_update_period: int = 200, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: float = 16., + policy_loss_module: Optional[losses.MultiObjectiveMPO] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + n_step: int = 5, + num_samples: int = 20, + clipping: bool = True, + logger: Optional[loggers.Logger] = None, + counter: Optional[counting.Counter] = None, + checkpoint: bool = True, + replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE): + """Initialize the agent. + + Args: + reward_objectives: list of the objectives that the policy should optimize; + each objective is defined by its reward function + qvalue_objectives: list of the objectives that the policy should optimize; + each objective is defined by its Q-value function + environment_spec: description of the actions, observations, etc. + policy_network: the online (optimized) policy. + critic_network: the online critic. + observation_network: optional network to transform the observations before + they are fed into any network. + discount: discount to use for TD updates. + batch_size: batch size for updates. + prefetch_size: size to prefetch from replay. + target_policy_update_period: number of updates to perform before updating + the target policy network. + target_critic_update_period: number of updates to perform before updating + the target critic network. + min_replay_size: minimum replay size before updating. + max_replay_size: maximum replay size. + samples_per_insert: number of samples to take from replay for every insert + that is made. + policy_loss_module: configured MO-MPO loss function for the policy + optimization; defaults to sensible values on the control suite. + See `acme/tf/losses/mompo.py` for more details. + policy_optimizer: optimizer to be used on the policy. + critic_optimizer: optimizer to be used on the critic. + n_step: number of steps to squash into a single transition. + num_samples: number of actions to sample when doing a Monte Carlo + integration with respect to the policy. + clipping: whether to clip gradients by global norm. + logger: logging object used to write to logs. + counter: counter object used to keep track of steps. + checkpoint: boolean indicating whether to checkpoint the learner. + replay_table_name: string indicating what name to give the replay table. + """ + # Check that at least one objective's reward function is specified. + if not reward_objectives: + raise ValueError('Must specify at least one reward objective.') + + # Create a replay server to add data to. + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=max_replay_size, + rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), + signature=adders.NStepTransitionAdder.signature(environment_spec)) + self._server = reverb.Server([replay_table], port=None) + + # The adder is used to insert observations into replay. + address = f'localhost:{self._server.port}' + adder = adders.NStepTransitionAdder( + client=reverb.Client(address), + n_step=n_step, + discount=discount) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset( + table=replay_table_name, + server_address=address, + batch_size=batch_size, + prefetch_size=prefetch_size) + + # Make sure observation network is a Sonnet Module. + observation_network = tf2_utils.to_sonnet_module(observation_network) + + # Create target networks before creating online/target network variables. + target_policy_network = copy.deepcopy(policy_network) + target_critic_network = copy.deepcopy(critic_network) + target_observation_network = copy.deepcopy(observation_network) + + # Get observation and action specs. + act_spec = environment_spec.actions + obs_spec = environment_spec.observations + emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) + + # Create the behavior policy. + behavior_network = snt.Sequential([ + observation_network, + policy_network, + networks.StochasticSamplingHead(), + ]) + + # Create variables. + tf2_utils.create_variables(policy_network, [emb_spec]) + tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) + tf2_utils.create_variables(target_policy_network, [emb_spec]) + tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) + tf2_utils.create_variables(target_observation_network, [obs_spec]) + + # Create the actor which defines how we take actions. + actor = actors.FeedForwardActor( + policy_network=behavior_network, adder=adder) + + # Create optimizers. + policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + + # The learner updates the parameters (and initializes them). + learner = learning.MultiObjectiveMPOLearner( + reward_objectives=reward_objectives, + qvalue_objectives=qvalue_objectives, + policy_network=policy_network, + critic_network=critic_network, + observation_network=observation_network, + target_policy_network=target_policy_network, + target_critic_network=target_critic_network, + target_observation_network=target_observation_network, + policy_loss_module=policy_loss_module, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + clipping=clipping, + discount=discount, + num_samples=num_samples, + target_policy_update_period=target_policy_update_period, + target_critic_update_period=target_critic_update_period, + dataset=dataset, + logger=logger, + counter=counter, + checkpoint=checkpoint) + + super().__init__( + actor=actor, + learner=learner, + min_observations=max(batch_size, min_replay_size), + observations_per_step=float(batch_size) / samples_per_insert) diff --git a/acme/acme/agents/tf/mompo/agent_distributed.py b/acme/acme/agents/tf/mompo/agent_distributed.py new file mode 100644 index 00000000..7fc2a741 --- /dev/null +++ b/acme/acme/agents/tf/mompo/agent_distributed.py @@ -0,0 +1,361 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the multi-objective MPO distributed agent class.""" + +from typing import Callable, Dict, Optional, Sequence + +import acme +from acme import datasets +from acme import specs +from acme.adders import reverb as adders +from acme.agents.tf import actors +from acme.agents.tf.mompo import learning +from acme.tf import losses +from acme.tf import networks +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.tf import variable_utils as tf2_variable_utils +from acme.utils import counting +from acme.utils import loggers +from acme.utils import lp_utils +import dm_env +import launchpad as lp +import reverb +import sonnet as snt +import tensorflow as tf + +MultiObjectiveNetworkFactorySpec = Callable[ + [specs.BoundedArray, int], Dict[str, snt.Module]] +MultiObjectivePolicyLossFactorySpec = Callable[[], losses.MultiObjectiveMPO] + + +class DistributedMultiObjectiveMPO: + """Program definition for multi-objective MPO. + + This agent distinguishes itself from the distributed MPO agent in two ways: + - Allowing for one or more objectives (see `acme/agents/tf/mompo/learning.py` + for details on what form this sequence of objectives should take) + - Optionally using a distributional critic (state-action value approximator) + as in DMPO. In other words, the critic network can output either scalar + Q-values or a DiscreteValuedDistribution. + """ + + def __init__( + self, + reward_objectives: Sequence[learning.RewardObjective], + qvalue_objectives: Sequence[learning.QValueObjective], + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: MultiObjectiveNetworkFactorySpec, + num_actors: int = 1, + num_caches: int = 0, + environment_spec: Optional[specs.EnvironmentSpec] = None, + batch_size: int = 512, + prefetch_size: int = 4, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: Optional[float] = None, + n_step: int = 5, + max_in_flight_items: int = 5, + num_samples: int = 20, + additional_discount: float = 0.99, + target_policy_update_period: int = 200, + target_critic_update_period: int = 200, + policy_loss_factory: Optional[MultiObjectivePolicyLossFactorySpec] = None, + max_actor_steps: Optional[int] = None, + log_every: float = 10.0, + ): + + if environment_spec is None: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + self._environment_factory = environment_factory + self._network_factory = network_factory + self._policy_loss_factory = policy_loss_factory + self._environment_spec = environment_spec + self._num_actors = num_actors + self._num_caches = num_caches + self._batch_size = batch_size + self._prefetch_size = prefetch_size + self._min_replay_size = min_replay_size + self._max_replay_size = max_replay_size + self._samples_per_insert = samples_per_insert + self._n_step = n_step + self._max_in_flight_items = max_in_flight_items + self._additional_discount = additional_discount + self._num_samples = num_samples + self._target_policy_update_period = target_policy_update_period + self._target_critic_update_period = target_critic_update_period + self._max_actor_steps = max_actor_steps + self._log_every = log_every + + self._reward_objectives = reward_objectives + self._qvalue_objectives = qvalue_objectives + self._num_critic_heads = len(self._reward_objectives) + + if not self._reward_objectives: + raise ValueError('Must specify at least one reward objective.') + + def replay(self): + """The replay storage.""" + if self._samples_per_insert is not None: + # Create enough of an error buffer to give a 10% tolerance in rate. + samples_per_insert_tolerance = 0.1 * self._samples_per_insert + error_buffer = self._min_replay_size * samples_per_insert_tolerance + + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._min_replay_size, + samples_per_insert=self._samples_per_insert, + error_buffer=error_buffer) + else: + limiter = reverb.rate_limiters.MinSize(self._min_replay_size) + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._max_replay_size, + rate_limiter=limiter, + signature=adders.NStepTransitionAdder.signature( + self._environment_spec)) + return [replay_table] + + def counter(self): + return tf2_savers.CheckpointingRunner(counting.Counter(), + time_delta_minutes=1, + subdirectory='counter') + + def coordinator(self, counter: counting.Counter, max_actor_steps: int): + return lp_utils.StepsLimiter(counter, max_actor_steps) + + def learner( + self, + replay: reverb.Client, + counter: counting.Counter, + ): + """The Learning part of the agent.""" + + act_spec = self._environment_spec.actions + obs_spec = self._environment_spec.observations + + # Create online and target networks. + online_networks = self._network_factory(act_spec, self._num_critic_heads) + target_networks = self._network_factory(act_spec, self._num_critic_heads) + + # Make sure observation network is a Sonnet Module. + observation_network = online_networks.get('observation', tf.identity) + target_observation_network = target_networks.get('observation', tf.identity) + observation_network = tf2_utils.to_sonnet_module(observation_network) + target_observation_network = tf2_utils.to_sonnet_module( + target_observation_network) + + # Get embedding spec and create observation network variables. + emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) + + # Create variables. + tf2_utils.create_variables(online_networks['policy'], [emb_spec]) + tf2_utils.create_variables(online_networks['critic'], [emb_spec, act_spec]) + tf2_utils.create_variables(target_networks['policy'], [emb_spec]) + tf2_utils.create_variables(target_networks['critic'], [emb_spec, act_spec]) + tf2_utils.create_variables(target_observation_network, [obs_spec]) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset(server_address=replay.server_address) + dataset = dataset.batch(self._batch_size, drop_remainder=True) + dataset = dataset.prefetch(self._prefetch_size) + + counter = counting.Counter(counter, 'learner') + logger = loggers.make_default_logger( + 'learner', time_delta=self._log_every, steps_key='learner_steps') + + # Create policy loss module if a factory is passed. + if self._policy_loss_factory: + policy_loss_module = self._policy_loss_factory() + else: + policy_loss_module = None + + # Return the learning agent. + return learning.MultiObjectiveMPOLearner( + reward_objectives=self._reward_objectives, + qvalue_objectives=self._qvalue_objectives, + policy_network=online_networks['policy'], + critic_network=online_networks['critic'], + observation_network=observation_network, + target_policy_network=target_networks['policy'], + target_critic_network=target_networks['critic'], + target_observation_network=target_observation_network, + discount=self._additional_discount, + num_samples=self._num_samples, + target_policy_update_period=self._target_policy_update_period, + target_critic_update_period=self._target_critic_update_period, + policy_loss_module=policy_loss_module, + dataset=dataset, + counter=counter, + logger=logger) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + ) -> acme.EnvironmentLoop: + """The actor process.""" + + action_spec = self._environment_spec.actions + observation_spec = self._environment_spec.observations + + # Create environment and target networks to act with. + environment = self._environment_factory(False) + agent_networks = self._network_factory(action_spec, self._num_critic_heads) + + # Make sure observation network is defined. + observation_network = agent_networks.get('observation', tf.identity) + + # Create a stochastic behavior policy. + behavior_network = snt.Sequential([ + observation_network, + agent_networks['policy'], + networks.StochasticSamplingHead(), + ]) + + # Ensure network variables are created. + tf2_utils.create_variables(behavior_network, [observation_spec]) + policy_variables = {'policy': behavior_network.variables} + + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = tf2_variable_utils.VariableClient( + variable_source, policy_variables, update_period=1000) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + # Component to add things into replay. + adder = adders.NStepTransitionAdder( + client=replay, + n_step=self._n_step, + max_in_flight_items=self._max_in_flight_items, + discount=self._additional_discount) + + # Create the agent. + actor = actors.FeedForwardActor( + policy_network=behavior_network, + adder=adder, + variable_client=variable_client) + + # Create logger and counter; actors will not spam bigtable. + counter = counting.Counter(counter, 'actor') + logger = loggers.make_default_logger( + 'actor', + save_data=False, + time_delta=self._log_every, + steps_key='actor_steps') + + # Create the run loop and return it. + return acme.EnvironmentLoop( + environment, actor, counter, logger) + + def evaluator( + self, + variable_source: acme.VariableSource, + counter: counting.Counter, + ): + """The evaluation process.""" + + action_spec = self._environment_spec.actions + observation_spec = self._environment_spec.observations + + # Create environment and target networks to act with. + environment = self._environment_factory(True) + agent_networks = self._network_factory(action_spec, self._num_critic_heads) + + # Make sure observation network is defined. + observation_network = agent_networks.get('observation', tf.identity) + + # Create a deterministic behavior policy. + evaluator_modules = [ + observation_network, + agent_networks['policy'], + networks.StochasticMeanHead(), + ] + if isinstance(action_spec, specs.BoundedArray): + evaluator_modules += [networks.ClipToSpec(action_spec)] + evaluator_network = snt.Sequential(evaluator_modules) + + # Ensure network variables are created. + tf2_utils.create_variables(evaluator_network, [observation_spec]) + policy_variables = {'policy': evaluator_network.variables} + + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = tf2_variable_utils.VariableClient( + variable_source, policy_variables, update_period=1000) + + # Make sure not to evaluate a random actor by assigning variables before + # running the environment loop. + variable_client.update_and_wait() + + # Create the agent. + evaluator = actors.FeedForwardActor( + policy_network=evaluator_network, variable_client=variable_client) + + # Create logger and counter. + counter = counting.Counter(counter, 'evaluator') + logger = loggers.make_default_logger( + 'evaluator', time_delta=self._log_every, steps_key='evaluator_steps') + + # Create the run loop and return it. + return acme.EnvironmentLoop( + environment, evaluator, counter, logger) + + def build(self, name='mompo'): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group('replay'): + replay = program.add_node(lp.ReverbNode(self.replay)) + + with program.group('counter'): + counter = program.add_node(lp.CourierNode(self.counter)) + + if self._max_actor_steps: + _ = program.add_node( + lp.CourierNode(self.coordinator, counter, self._max_actor_steps)) + + with program.group('learner'): + learner = program.add_node( + lp.CourierNode(self.learner, replay, counter)) + + with program.group('evaluator'): + program.add_node( + lp.CourierNode(self.evaluator, learner, counter)) + + if not self._num_caches: + # Use our learner as a single variable source. + sources = [learner] + else: + with program.group('cacher'): + # Create a set of learner caches. + sources = [] + for _ in range(self._num_caches): + cacher = program.add_node( + lp.CacherNode( + learner, refresh_interval_ms=2000, stale_after_ms=4000)) + sources.append(cacher) + + with program.group('actor'): + # Add actors which pull round-robin from our variable sources. + for actor_id in range(self._num_actors): + source = sources[actor_id % len(sources)] + program.add_node(lp.CourierNode(self.actor, replay, source, counter)) + + return program diff --git a/acme/acme/agents/tf/mompo/agent_distributed_test.py b/acme/acme/agents/tf/mompo/agent_distributed_test.py new file mode 100644 index 00000000..a76b20d2 --- /dev/null +++ b/acme/acme/agents/tf/mompo/agent_distributed_test.py @@ -0,0 +1,163 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test for the distributed agent.""" + +from typing import Sequence, Tuple + +import acme +from acme import specs +from acme import wrappers +from acme.agents.tf import mompo +from acme.tf import networks +from acme.tf import utils as tf2_utils +from acme.utils import lp_utils +from dm_control import suite +import launchpad as lp +import numpy as np +import sonnet as snt +import tensorflow as tf + +from absl.testing import absltest +from absl.testing import parameterized + + +def make_networks( + action_spec: specs.BoundedArray, + num_critic_heads: int, + policy_layer_sizes: Sequence[int] = (50,), + critic_layer_sizes: Sequence[int] = (50,), + num_layers_shared: int = 1, + distributional_critic: bool = True, + vmin: float = -150., + vmax: float = 150., + num_atoms: int = 51, +): + """Creates networks used by the agent.""" + + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential([ + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=False, + init_scale=0.69) + ]) + + if not distributional_critic: + critic_layer_sizes = list(critic_layer_sizes) + [1] + + if not num_layers_shared: + # No layers are shared + critic_network_base = None + else: + critic_network_base = networks.LayerNormMLP( + critic_layer_sizes[:num_layers_shared], activate_final=True) + critic_network_heads = [ + snt.nets.MLP(critic_layer_sizes, activation=tf.nn.elu, + activate_final=False) + for _ in range(num_critic_heads)] + if distributional_critic: + critic_network_heads = [ + snt.Sequential([ + c, networks.DiscreteValuedHead(vmin, vmax, num_atoms) + ]) for c in critic_network_heads] + # The multiplexer concatenates the (maybe transformed) observations/actions. + critic_network = snt.Sequential([ + networks.CriticMultiplexer( + critic_network=critic_network_base, + action_network=networks.ClipToSpec(action_spec)), + networks.Multihead(network_heads=critic_network_heads), + ]) + + return { + 'policy': policy_network, + 'critic': critic_network, + 'observation': tf2_utils.batch_concat, + } + + +def make_environment(evaluation: bool = False): + del evaluation # Unused. + environment = suite.load('cartpole', 'balance') + wrapped = wrappers.SinglePrecisionWrapper(environment) + return wrapped + + +def compute_action_norm(target_pi_samples: tf.Tensor, + target_q_target_pi_samples: tf.Tensor) -> tf.Tensor: + """Compute Q-values for the action norm objective from action samples.""" + del target_q_target_pi_samples + action_norm = tf.norm(target_pi_samples, ord=2, axis=-1) + return tf.stop_gradient(-1 * action_norm) + + +def task_reward_fn(observation: tf.Tensor, + action: tf.Tensor, + reward: tf.Tensor) -> tf.Tensor: + del observation, action + return tf.stop_gradient(reward) + + +def make_objectives() -> Tuple[ + Sequence[mompo.RewardObjective], Sequence[mompo.QValueObjective]]: + """Define the multiple objectives for the policy to learn.""" + task_reward = mompo.RewardObjective( + name='task', + reward_fn=task_reward_fn) + action_norm = mompo.QValueObjective( + name='action_norm_q', + qvalue_fn=compute_action_norm) + return [task_reward], [action_norm] + + +class DistributedAgentTest(parameterized.TestCase): + """Simple integration/smoke test for the distributed agent.""" + + @parameterized.named_parameters( + ('distributional_critic', True), + ('vanilla_critic', False)) + def test_agent(self, distributional_critic): + # Create objectives. + reward_objectives, qvalue_objectives = make_objectives() + + network_factory = lp_utils.partial_kwargs( + make_networks, distributional_critic=distributional_critic) + + agent = mompo.DistributedMultiObjectiveMPO( + reward_objectives, + qvalue_objectives, + environment_factory=make_environment, + network_factory=network_factory, + num_actors=2, + batch_size=32, + min_replay_size=32, + max_replay_size=1000, + ) + program = agent.build() + + (learner_node,) = program.groups['learner'] + learner_node.disable_run() + + lp.launch(program, launch_type='test_mt') + + learner: acme.Learner = learner_node.create_handle().dereference() + + for _ in range(5): + learner.step() + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/mompo/agent_test.py b/acme/acme/agents/tf/mompo/agent_test.py new file mode 100644 index 00000000..c08a8b06 --- /dev/null +++ b/acme/acme/agents/tf/mompo/agent_test.py @@ -0,0 +1,149 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the multi-objective MPO agent.""" + +from typing import Dict, Sequence, Tuple + +import acme +from acme import specs +from acme.agents.tf import mompo +from acme.testing import fakes +from acme.tf import networks +import numpy as np +import sonnet as snt +import tensorflow as tf + +from absl.testing import absltest +from absl.testing import parameterized + + +def make_networks( + action_spec: specs.Array, + num_critic_heads: int, + policy_layer_sizes: Sequence[int] = (300, 200), + critic_layer_sizes: Sequence[int] = (400, 300), + num_layers_shared: int = 1, + distributional_critic: bool = True, + vmin: float = -150., + vmax: float = 150., + num_atoms: int = 51, +) -> Dict[str, snt.Module]: + """Creates networks used by the agent.""" + + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential([ + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=False, + init_scale=0.69) + ]) + + if not distributional_critic: + critic_layer_sizes = list(critic_layer_sizes) + [1] + + if not num_layers_shared: + # No layers are shared + critic_network_base = None + else: + critic_network_base = networks.LayerNormMLP( + critic_layer_sizes[:num_layers_shared], activate_final=True) + critic_network_heads = [ + snt.nets.MLP(critic_layer_sizes, activation=tf.nn.elu, + activate_final=False) + for _ in range(num_critic_heads)] + if distributional_critic: + critic_network_heads = [ + snt.Sequential([ + c, networks.DiscreteValuedHead(vmin, vmax, num_atoms) + ]) for c in critic_network_heads] + # The multiplexer concatenates the (maybe transformed) observations/actions. + critic_network = snt.Sequential([ + networks.CriticMultiplexer( + critic_network=critic_network_base), + networks.Multihead(network_heads=critic_network_heads), + ]) + return { + 'policy': policy_network, + 'critic': critic_network, + } + + +def compute_action_norm(target_pi_samples: tf.Tensor, + target_q_target_pi_samples: tf.Tensor) -> tf.Tensor: + """Compute Q-values for the action norm objective from action samples.""" + del target_q_target_pi_samples + action_norm = tf.norm(target_pi_samples, ord=2, axis=-1) + return tf.stop_gradient(-1 * action_norm) + + +def task_reward_fn(observation: tf.Tensor, + action: tf.Tensor, + reward: tf.Tensor) -> tf.Tensor: + del observation, action + return tf.stop_gradient(reward) + + +def make_objectives() -> Tuple[ + Sequence[mompo.RewardObjective], Sequence[mompo.QValueObjective]]: + """Define the multiple objectives for the policy to learn.""" + task_reward = mompo.RewardObjective( + name='task', + reward_fn=task_reward_fn) + action_norm = mompo.QValueObjective( + name='action_norm_q', + qvalue_fn=compute_action_norm) + return [task_reward], [action_norm] + + +class MOMPOTest(parameterized.TestCase): + + @parameterized.named_parameters( + ('distributional_critic', True), + ('vanilla_critic', False)) + def test_mompo(self, distributional_critic): + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment(episode_length=10) + spec = specs.make_environment_spec(environment) + + # Create objectives. + reward_objectives, qvalue_objectives = make_objectives() + num_critic_heads = len(reward_objectives) + + # Create networks. + agent_networks = make_networks( + spec.actions, num_critic_heads=num_critic_heads, + distributional_critic=distributional_critic) + + # Construct the agent. + agent = mompo.MultiObjectiveMPO( + reward_objectives, + qvalue_objectives, + spec, + policy_network=agent_networks['policy'], + critic_network=agent_networks['critic'], + batch_size=10, + samples_per_insert=2, + min_replay_size=10) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=2) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/mompo/learning.py b/acme/acme/agents/tf/mompo/learning.py new file mode 100644 index 00000000..095e253b --- /dev/null +++ b/acme/acme/agents/tf/mompo/learning.py @@ -0,0 +1,454 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-objective MPO learner implementation.""" + +import dataclasses +import time +from typing import Callable, List, Optional, Sequence + +import acme +from acme import types +from acme.tf import losses +from acme.tf import networks +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import numpy as np +import sonnet as snt +import tensorflow as tf +import trfl + +QValueFunctionSpec = Callable[[tf.Tensor, tf.Tensor], tf.Tensor] +RewardFunctionSpec = Callable[[tf.Tensor, tf.Tensor, tf.Tensor], tf.Tensor] + +_DEFAULT_EPSILON = 1e-1 +_DEFAULT_EPSILON_MEAN = 1e-3 +_DEFAULT_EPSILON_STDDEV = 1e-6 +_DEFAULT_INIT_LOG_TEMPERATURE = 1. +_DEFAULT_INIT_LOG_ALPHA_MEAN = 1. +_DEFAULT_INIT_LOG_ALPHA_STDDEV = 10. + + +@dataclasses.dataclass +class QValueObjective: + """Defines an objective by specifying its 'Q-values' directly.""" + + name: str + # This computes "Q-values" directly from the sampled actions and other Q's. + qvalue_fn: QValueFunctionSpec + + +@dataclasses.dataclass +class RewardObjective: + """Defines an objective by specifying its reward function.""" + + name: str + # This computes the reward from observations, actions, and environment task + # reward. In the learner, a head will automatically be added to the critic + # network, to learn Q-values for this objective. + reward_fn: RewardFunctionSpec + + +class MultiObjectiveMPOLearner(acme.Learner): + """Distributional MPO learner. + + This is the learning component of a multi-objective MPO (MO-MPO) agent. Two + sequences of objectives must be specified. Otherwise, the inputs are identical + to those of the MPO / DMPO learners. + + Each objective must be defined as either a RewardObjective or an + QValueObjective. These objectives are provided by the reward_objectives and + qvalue_objectives parameters, respectively. For each RewardObjective, a critic + will be trained to estimate Q-values for that objective. Whereas for each + QValueObjective, the Q-values are computed directly by its qvalue_fn. + + A RewardObjective's reward_fn takes the observation, action, and environment + reward as input, and returns the reward for that objective. For example, if + the environment reward is a scalar, then an objective corresponding to the = + task would simply return the environment reward. + + A QValueObjective's qvalue_fn takes the actions and reward-based objectives' + Q-values as input, and outputs the "Q-values" for that objective. For + instance, in the MO-MPO paper ([Abdolmaleki, Huang et al., 2020]), the action + norm objective in the Humanoid run task is defined by setting the qvalue_fn + to be the l2-norm of the actions. + + Note: If there is only one objective and that is the task reward, then this + algorithm becomes exactly the same as (D)MPO. + + (Abdolmaleki, Huang et al., 2020): https://arxiv.org/pdf/2005.07513.pdf + """ + + def __init__( + self, + reward_objectives: Sequence[RewardObjective], + qvalue_objectives: Sequence[QValueObjective], + policy_network: snt.Module, + critic_network: snt.Module, + target_policy_network: snt.Module, + target_critic_network: snt.Module, + discount: float, + num_samples: int, + target_policy_update_period: int, + target_critic_update_period: int, + dataset: tf.data.Dataset, + observation_network: types.TensorTransformation = tf.identity, + target_observation_network: types.TensorTransformation = tf.identity, + policy_loss_module: Optional[losses.MultiObjectiveMPO] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + dual_optimizer: Optional[snt.Optimizer] = None, + clipping: bool = True, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + ): + + # Store online and target networks. + self._policy_network = policy_network + self._critic_network = critic_network + self._target_policy_network = target_policy_network + self._target_critic_network = target_critic_network + + # Make sure observation networks are snt.Module's so they have variables. + self._observation_network = tf2_utils.to_sonnet_module(observation_network) + self._target_observation_network = tf2_utils.to_sonnet_module( + target_observation_network) + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger('learner') + + # Other learner parameters. + self._discount = discount + self._num_samples = num_samples + self._clipping = clipping + + # Necessary to track when to update target networks. + self._num_steps = tf.Variable(0, dtype=tf.int32) + self._target_policy_update_period = target_policy_update_period + self._target_critic_update_period = target_critic_update_period + + # Batch dataset and create iterator. + # TODO(b/155086959): Fix type stubs and remove. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + + # Store objectives + self._reward_objectives = reward_objectives + self._qvalue_objectives = qvalue_objectives + if self._qvalue_objectives is None: + self._qvalue_objectives = [] + self._num_critic_heads = len(self._reward_objectives) # C + self._objective_names = ( + [x.name for x in self._reward_objectives] + + [x.name for x in self._qvalue_objectives]) + + self._policy_loss_module = policy_loss_module or losses.MultiObjectiveMPO( + epsilons=[losses.KLConstraint(name, _DEFAULT_EPSILON) + for name in self._objective_names], + epsilon_mean=_DEFAULT_EPSILON_MEAN, + epsilon_stddev=_DEFAULT_EPSILON_STDDEV, + init_log_temperature=_DEFAULT_INIT_LOG_TEMPERATURE, + init_log_alpha_mean=_DEFAULT_INIT_LOG_ALPHA_MEAN, + init_log_alpha_stddev=_DEFAULT_INIT_LOG_ALPHA_STDDEV) + + # Check that ordering of objectives matches the policy_loss_module's + if self._objective_names != list(self._policy_loss_module.objective_names): + raise ValueError("Agent's ordering of objectives doesn't match " + "the policy loss module's ordering of epsilons.") + + # Create the optimizers. + self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2) + + # Expose the variables. + policy_network_to_expose = snt.Sequential( + [self._target_observation_network, self._target_policy_network]) + self._variables = { + 'critic': self._target_critic_network.variables, + 'policy': policy_network_to_expose.variables, + } + + # Create a checkpointer and snapshotter object. + self._checkpointer = None + self._snapshotter = None + + if checkpoint: + self._checkpointer = tf2_savers.Checkpointer( + subdirectory='mompo_learner', + objects_to_save={ + 'counter': self._counter, + 'policy': self._policy_network, + 'critic': self._critic_network, + 'observation': self._observation_network, + 'target_policy': self._target_policy_network, + 'target_critic': self._target_critic_network, + 'target_observation': self._target_observation_network, + 'policy_optimizer': self._policy_optimizer, + 'critic_optimizer': self._critic_optimizer, + 'dual_optimizer': self._dual_optimizer, + 'policy_loss_module': self._policy_loss_module, + 'num_steps': self._num_steps, + }) + + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={ + 'policy': + snt.Sequential([ + self._target_observation_network, + self._target_policy_network + ]), + }) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp: float = None + + @tf.function + def _step(self) -> types.NestedTensor: + # Update target network. + online_policy_variables = self._policy_network.variables + target_policy_variables = self._target_policy_network.variables + online_critic_variables = ( + *self._observation_network.variables, + *self._critic_network.variables, + ) + target_critic_variables = ( + *self._target_observation_network.variables, + *self._target_critic_network.variables, + ) + + # Make online policy -> target policy network update ops. + if tf.math.mod(self._num_steps, self._target_policy_update_period) == 0: + for src, dest in zip(online_policy_variables, target_policy_variables): + dest.assign(src) + # Make online critic -> target critic network update ops. + if tf.math.mod(self._num_steps, self._target_critic_update_period) == 0: + for src, dest in zip(online_critic_variables, target_critic_variables): + dest.assign(src) + + self._num_steps.assign_add(1) + + # Get data from replay (dropping extras if any). Note there is no + # extra data here because we do not insert any into Reverb. + inputs = next(self._iterator) + transitions: types.Transition = inputs.data + + with tf.GradientTape(persistent=True) as tape: + # Maybe transform the observation before feeding into policy and critic. + # Transforming the observations this way at the start of the learning + # step effectively means that the policy and critic share observation + # network weights. + o_tm1 = self._observation_network(transitions.observation) + # This stop_gradient prevents gradients to propagate into the target + # observation network. In addition, since the online policy network is + # evaluated at o_t, this also means the policy loss does not influence + # the observation network training. + o_t = tf.stop_gradient( + self._target_observation_network(transitions.next_observation)) + + # Get online and target action distributions from policy networks. + online_action_distribution = self._policy_network(o_t) + target_action_distribution = self._target_policy_network(o_t) + + # Sample actions to evaluate policy; of size [N, B, ...]. + sampled_actions = target_action_distribution.sample(self._num_samples) + + # Tile embedded observations to feed into the target critic network. + # Note: this is more efficient than tiling before the embedding layer. + tiled_o_t = tf2_utils.tile_tensor(o_t, self._num_samples) # [N, B, ...] + + # Compute target-estimated distributional value of sampled actions at o_t. + sampled_q_t_all = self._target_critic_network( + # Merge batch dimensions; to shape [N*B, ...]. + snt.merge_leading_dims(tiled_o_t, num_dims=2), + snt.merge_leading_dims(sampled_actions, num_dims=2)) + + # Compute online critic value distribution of a_tm1 in state o_tm1. + q_tm1_all = self._critic_network(o_tm1, transitions.action) + + # Compute rewards for objectives with defined reward_fn + reward_stats = {} + r_t_all = [] + for objective in self._reward_objectives: + r = objective.reward_fn(o_tm1, transitions.action, transitions.reward) + reward_stats['{}_reward'.format(objective.name)] = tf.reduce_mean(r) + r_t_all.append(r) + r_t_all = tf.stack(r_t_all, axis=-1) + r_t_all.get_shape().assert_has_rank(2) # [B, C] + + if isinstance(sampled_q_t_all, list): # Distributional critics + critic_loss, sampled_q_t = _compute_distributional_critic_loss( + sampled_q_t_all, q_tm1_all, r_t_all, transitions.discount, + self._discount, self._num_samples) + else: + critic_loss, sampled_q_t = _compute_critic_loss( + sampled_q_t_all, q_tm1_all, r_t_all, transitions.discount, + self._discount, self._num_samples, self._num_critic_heads) + + # Add sampled Q-values for objectives with defined qvalue_fn + sampled_q_t_k = [sampled_q_t] + for objective in self._qvalue_objectives: + sampled_q_t_k.append(tf.expand_dims(tf.stop_gradient( + objective.qvalue_fn(sampled_actions, sampled_q_t)), axis=-1)) + sampled_q_t_k = tf.concat(sampled_q_t_k, axis=-1) # [N, B, K] + + # Compute MPO policy loss. + policy_loss, policy_stats = self._policy_loss_module( + online_action_distribution=online_action_distribution, + target_action_distribution=target_action_distribution, + actions=sampled_actions, + q_values=sampled_q_t_k) + + # For clarity, explicitly define which variables are trained by which loss. + critic_trainable_variables = ( + # In this agent, the critic loss trains the observation network. + self._observation_network.trainable_variables + + self._critic_network.trainable_variables) + policy_trainable_variables = self._policy_network.trainable_variables + # The following are the MPO dual variables, stored in the loss module. + dual_trainable_variables = self._policy_loss_module.trainable_variables + + # Compute gradients. + critic_gradients = tape.gradient(critic_loss, critic_trainable_variables) + policy_gradients, dual_gradients = tape.gradient( + policy_loss, (policy_trainable_variables, dual_trainable_variables)) + + # Delete the tape manually because of the persistent=True flag. + del tape + + # Maybe clip gradients. + if self._clipping: + policy_gradients = tuple(tf.clip_by_global_norm(policy_gradients, 40.)[0]) + critic_gradients = tuple(tf.clip_by_global_norm(critic_gradients, 40.)[0]) + + # Apply gradients. + self._critic_optimizer.apply(critic_gradients, critic_trainable_variables) + self._policy_optimizer.apply(policy_gradients, policy_trainable_variables) + self._dual_optimizer.apply(dual_gradients, dual_trainable_variables) + + # Losses to track. + fetches = { + 'critic_loss': critic_loss, + 'policy_loss': policy_loss, + } + fetches.update(policy_stats) # Log MPO stats. + fetches.update(reward_stats) # Log reward stats. + + return fetches + + def step(self): + # Run the learning step. + fetches = self._step() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + fetches.update(counts) + + # Checkpoint and attempt to write the logs. + if self._checkpointer is not None: + self._checkpointer.save() + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(fetches) + + def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: + return [tf2_utils.to_numpy(self._variables[name]) for name in names] + + +def _compute_distributional_critic_loss( + sampled_q_t_all: List[tf.Tensor], + q_tm1_all: List[tf.Tensor], + r_t_all: tf.Tensor, + d_t: tf.Tensor, + discount: float, + num_samples: int): + """Compute loss and sampled Q-values for distributional critics.""" + # Compute average logits by first reshaping them and normalizing them + # across atoms. + batch_size = r_t_all.get_shape()[0] + # Cast the additional discount to match the environment discount dtype. + discount = tf.cast(discount, dtype=d_t.dtype) + critic_losses = [] + sampled_q_ts = [] + for idx, (sampled_q_t_distributions, q_tm1_distribution) in enumerate( + zip(sampled_q_t_all, q_tm1_all)): + # Compute loss for distributional critic for objective c + sampled_logits = tf.reshape( + sampled_q_t_distributions.logits, + [num_samples, batch_size, -1]) # [N, B, A] + sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1) + averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0) + + # Construct the expected distributional value for bootstrapping. + q_t_distribution = networks.DiscreteValuedDistribution( + values=sampled_q_t_distributions.values, logits=averaged_logits) + + # Compute critic distributional loss. + critic_loss = losses.categorical( + q_tm1_distribution, r_t_all[:, idx], discount * d_t, + q_t_distribution) + critic_losses.append(tf.reduce_mean(critic_loss)) + + # Compute Q-values of sampled actions and reshape to [N, B]. + sampled_q_ts.append(tf.reshape( + sampled_q_t_distributions.mean(), (num_samples, -1))) + + critic_loss = tf.reduce_mean(critic_losses) + sampled_q_t = tf.stack(sampled_q_ts, axis=-1) # [N, B, C] + return critic_loss, sampled_q_t + + +def _compute_critic_loss( + sampled_q_t_all: tf.Tensor, + q_tm1_all: tf.Tensor, + r_t_all: tf.Tensor, + d_t: tf.Tensor, + discount: float, + num_samples: int, + num_critic_heads: int): + """Compute loss and sampled Q-values for (non-distributional) critics.""" + # Reshape Q-value samples back to original batch dimensions and average + # them to compute the TD-learning bootstrap target. + batch_size = r_t_all.get_shape()[0] + sampled_q_t = tf.reshape( + sampled_q_t_all, + (num_samples, batch_size, num_critic_heads)) # [N,B,C] + q_t = tf.reduce_mean(sampled_q_t, axis=0) # [B, C] + + # Flatten q_t and q_tm1; necessary for trfl.td_learning + q_t = tf.reshape(q_t, [-1]) # [B*C] + q_tm1 = tf.reshape(q_tm1_all, [-1]) # [B*C] + + # Flatten r_t_all; necessary for trfl.td_learning + r_t_all = tf.reshape(r_t_all, [-1]) # [B*C] + + # Broadcast and then flatten d_t, to match shape of q_t and q_tm1 + d_t = tf.tile(d_t, [num_critic_heads]) # [B*C] + # Cast the additional discount to match the environment discount dtype. + discount = tf.cast(discount, dtype=d_t.dtype) + + # Critic loss. + critic_loss = trfl.td_learning(q_tm1, r_t_all, discount * d_t, q_t).loss + critic_loss = tf.reduce_mean(critic_loss) + return critic_loss, sampled_q_t diff --git a/acme/acme/agents/tf/mpo/README.md b/acme/acme/agents/tf/mpo/README.md new file mode 100644 index 00000000..94df5628 --- /dev/null +++ b/acme/acme/agents/tf/mpo/README.md @@ -0,0 +1,28 @@ +# Maximum a posteriori Policy Optimization (MPO) + +This folder contains an implementation of Maximum a posteriori Policy +Optimization (MPO) introduced in ([Abdolmaleki et al., 2018a], [2018b]). While +this is a more general algorithm, the current implementation targets the +continuous control setting and is most readily applied to the DeepMind control +suite or similar control tasks. + +This implementation includes a few important options such as: + +* per-dimension KL constraint satisfaction, and +* action penalization via the multi-objective MPO work of + [Abdolmaleki, Huang et al., 2020]. + +See the DMPO agent directory for a similar agent that uses a distributional +critic. See the MO-MPO agent directory for an agent that optimizes for multiple +objectives. + +Detailed notes: + +* When using per-dimension KL constraint satisfaction, you may need to tune + the value of `epsilon_mean` (and `epsilon_stddev` if not fixed). A good rule + of thumb would be to divide it by the number of dimensions in the action + space. + +[Abdolmaleki et al., 2018a]: https://arxiv.org/pdf/1806.06920.pdf +[2018b]: https://arxiv.org/pdf/1812.02256.pdf +[Abdolmaleki, Huang et al., 2020]: https://arxiv.org/pdf/2005.07513.pdf diff --git a/acme/acme/agents/tf/mpo/__init__.py b/acme/acme/agents/tf/mpo/__init__.py new file mode 100644 index 00000000..8e913ffa --- /dev/null +++ b/acme/acme/agents/tf/mpo/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementations of a MPO agent.""" + +from acme.agents.tf.mpo.agent import MPO +from acme.agents.tf.mpo.agent_distributed import DistributedMPO +from acme.agents.tf.mpo.learning import MPOLearner diff --git a/acme/acme/agents/tf/mpo/agent.py b/acme/acme/agents/tf/mpo/agent.py new file mode 100644 index 00000000..82f60392 --- /dev/null +++ b/acme/acme/agents/tf/mpo/agent.py @@ -0,0 +1,191 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MPO agent implementation.""" + +import copy +from typing import Optional + +from acme import datasets +from acme import specs +from acme import types +from acme.adders import reverb as adders +from acme.agents import agent +from acme.agents.tf import actors +from acme.agents.tf.mpo import learning +from acme.tf import networks +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import reverb +import sonnet as snt +import tensorflow as tf + + +class MPO(agent.Agent): + """MPO Agent. + + This implements a single-process MPO agent. This is an actor-critic algorithm + that generates data via a behavior policy, inserts N-step transitions into + a replay buffer, and periodically updates the policy (and as a result the + behavior) by sampling uniformly from this buffer. This agent distinguishes + itself from the DPG agent by using MPO to learn a stochastic policy. + """ + + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + policy_network: snt.Module, + critic_network: snt.Module, + observation_network: types.TensorTransformation = tf.identity, + discount: float = 0.99, + batch_size: int = 256, + prefetch_size: int = 4, + target_policy_update_period: int = 100, + target_critic_update_period: int = 100, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: float = 32.0, + policy_loss_module: Optional[snt.Module] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + n_step: int = 5, + num_samples: int = 20, + clipping: bool = True, + logger: Optional[loggers.Logger] = None, + counter: Optional[counting.Counter] = None, + checkpoint: bool = True, + save_directory: str = '~/acme', + replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, + ): + """Initialize the agent. + + Args: + environment_spec: description of the actions, observations, etc. + policy_network: the online (optimized) policy. + critic_network: the online critic. + observation_network: optional network to transform the observations before + they are fed into any network. + discount: discount to use for TD updates. + batch_size: batch size for updates. + prefetch_size: size to prefetch from replay. + target_policy_update_period: number of updates to perform before updating + the target policy network. + target_critic_update_period: number of updates to perform before updating + the target critic network. + min_replay_size: minimum replay size before updating. + max_replay_size: maximum replay size. + samples_per_insert: number of samples to take from replay for every insert + that is made. + policy_loss_module: configured MPO loss function for the policy + optimization; defaults to sensible values on the control suite. See + `acme/tf/losses/mpo.py` for more details. + policy_optimizer: optimizer to be used on the policy. + critic_optimizer: optimizer to be used on the critic. + n_step: number of steps to squash into a single transition. + num_samples: number of actions to sample when doing a Monte Carlo + integration with respect to the policy. + clipping: whether to clip gradients by global norm. + logger: logging object used to write to logs. + counter: counter object used to keep track of steps. + checkpoint: boolean indicating whether to checkpoint the learner. + save_directory: string indicating where the learner should save + checkpoints and snapshots. + replay_table_name: string indicating what name to give the replay table. + """ + + # Create a replay server to add data to. + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=max_replay_size, + rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), + signature=adders.NStepTransitionAdder.signature(environment_spec)) + self._server = reverb.Server([replay_table], port=None) + + # The adder is used to insert observations into replay. + address = f'localhost:{self._server.port}' + adder = adders.NStepTransitionAdder( + client=reverb.Client(address), n_step=n_step, discount=discount) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset( + table=replay_table_name, + server_address=address, + batch_size=batch_size, + prefetch_size=prefetch_size) + + # Make sure observation network is a Sonnet Module. + observation_network = tf2_utils.to_sonnet_module(observation_network) + + # Create target networks before creating online/target network variables. + target_policy_network = copy.deepcopy(policy_network) + target_critic_network = copy.deepcopy(critic_network) + target_observation_network = copy.deepcopy(observation_network) + + # Get observation and action specs. + act_spec = environment_spec.actions + obs_spec = environment_spec.observations + emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) + + # Create the behavior policy. + behavior_network = snt.Sequential([ + observation_network, + policy_network, + networks.StochasticSamplingHead(), + ]) + + # Create variables. + tf2_utils.create_variables(policy_network, [emb_spec]) + tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) + tf2_utils.create_variables(target_policy_network, [emb_spec]) + tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) + tf2_utils.create_variables(target_observation_network, [obs_spec]) + + # Create the actor which defines how we take actions. + actor = actors.FeedForwardActor( + policy_network=behavior_network, adder=adder) + + # Create optimizers. + policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + + # The learner updates the parameters (and initializes them). + learner = learning.MPOLearner( + policy_network=policy_network, + critic_network=critic_network, + observation_network=observation_network, + target_policy_network=target_policy_network, + target_critic_network=target_critic_network, + target_observation_network=target_observation_network, + policy_loss_module=policy_loss_module, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + clipping=clipping, + discount=discount, + num_samples=num_samples, + target_policy_update_period=target_policy_update_period, + target_critic_update_period=target_critic_update_period, + dataset=dataset, + logger=logger, + counter=counter, + checkpoint=checkpoint, + save_directory=save_directory) + + super().__init__( + actor=actor, + learner=learner, + min_observations=max(batch_size, min_replay_size), + observations_per_step=float(batch_size) / samples_per_insert) diff --git a/acme/acme/agents/tf/mpo/agent_distributed.py b/acme/acme/agents/tf/mpo/agent_distributed.py new file mode 100644 index 00000000..6e7799c7 --- /dev/null +++ b/acme/acme/agents/tf/mpo/agent_distributed.py @@ -0,0 +1,338 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the MPO distributed agent class.""" + +from typing import Callable, Dict, Optional + +import acme +from acme import datasets +from acme import specs +from acme.adders import reverb as adders +from acme.agents.tf import actors +from acme.agents.tf.mpo import learning +from acme.tf import networks +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.tf import variable_utils as tf2_variable_utils +from acme.utils import counting +from acme.utils import loggers +from acme.utils import lp_utils +import dm_env +import launchpad as lp +import reverb +import sonnet as snt +import tensorflow as tf + + +class DistributedMPO: + """Program definition for MPO.""" + + def __init__( + self, + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: Callable[[specs.BoundedArray], Dict[str, snt.Module]], + num_actors: int = 1, + num_caches: int = 0, + environment_spec: Optional[specs.EnvironmentSpec] = None, + batch_size: int = 256, + prefetch_size: int = 4, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: Optional[float] = 32.0, + n_step: int = 5, + num_samples: int = 20, + additional_discount: float = 0.99, + target_policy_update_period: int = 100, + target_critic_update_period: int = 100, + variable_update_period: int = 1000, + policy_loss_factory: Optional[Callable[[], snt.Module]] = None, + max_actor_steps: Optional[int] = None, + log_every: float = 10.0, + ): + + if environment_spec is None: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + self._environment_factory = environment_factory + self._network_factory = network_factory + self._policy_loss_factory = policy_loss_factory + self._environment_spec = environment_spec + self._num_actors = num_actors + self._num_caches = num_caches + self._batch_size = batch_size + self._prefetch_size = prefetch_size + self._min_replay_size = min_replay_size + self._max_replay_size = max_replay_size + self._samples_per_insert = samples_per_insert + self._n_step = n_step + self._additional_discount = additional_discount + self._num_samples = num_samples + self._target_policy_update_period = target_policy_update_period + self._target_critic_update_period = target_critic_update_period + self._variable_update_period = variable_update_period + self._max_actor_steps = max_actor_steps + self._log_every = log_every + + def replay(self): + """The replay storage.""" + if self._samples_per_insert is not None: + # Create enough of an error buffer to give a 10% tolerance in rate. + samples_per_insert_tolerance = 0.1 * self._samples_per_insert + error_buffer = self._min_replay_size * samples_per_insert_tolerance + + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._min_replay_size, + samples_per_insert=self._samples_per_insert, + error_buffer=error_buffer) + else: + limiter = reverb.rate_limiters.MinSize( + min_size_to_sample=self._min_replay_size) + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._max_replay_size, + rate_limiter=limiter, + signature=adders.NStepTransitionAdder.signature( + self._environment_spec)) + return [replay_table] + + def counter(self): + return tf2_savers.CheckpointingRunner(counting.Counter(), + time_delta_minutes=1, + subdirectory='counter') + + def coordinator(self, counter: counting.Counter, max_actor_steps: int): + return lp_utils.StepsLimiter(counter, max_actor_steps) + + def learner( + self, + replay: reverb.Client, + counter: counting.Counter, + ): + """The Learning part of the agent.""" + + act_spec = self._environment_spec.actions + obs_spec = self._environment_spec.observations + + # Create online and target networks. + online_networks = self._network_factory(act_spec) + target_networks = self._network_factory(act_spec) + + # Make sure observation networks are Sonnet Modules. + observation_network = online_networks.get('observation', tf.identity) + observation_network = tf2_utils.to_sonnet_module(observation_network) + online_networks['observation'] = observation_network + target_observation_network = target_networks.get('observation', tf.identity) + target_observation_network = tf2_utils.to_sonnet_module( + target_observation_network) + target_networks['observation'] = target_observation_network + + # Get embedding spec and create observation network variables. + emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) + + tf2_utils.create_variables(online_networks['policy'], [emb_spec]) + tf2_utils.create_variables(online_networks['critic'], [emb_spec, act_spec]) + tf2_utils.create_variables(target_networks['observation'], [obs_spec]) + tf2_utils.create_variables(target_networks['policy'], [emb_spec]) + tf2_utils.create_variables(target_networks['critic'], [emb_spec, act_spec]) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset( + server_address=replay.server_address) + dataset = dataset.batch(self._batch_size, drop_remainder=True) + dataset = dataset.prefetch(self._prefetch_size) + + # Create a counter and logger for bookkeeping steps and performance. + counter = counting.Counter(counter, 'learner') + logger = loggers.make_default_logger( + 'learner', time_delta=self._log_every, steps_key='learner_steps') + + # Create policy loss module if a factory is passed. + if self._policy_loss_factory: + policy_loss_module = self._policy_loss_factory() + else: + policy_loss_module = None + + # Return the learning agent. + return learning.MPOLearner( + policy_network=online_networks['policy'], + critic_network=online_networks['critic'], + observation_network=observation_network, + target_policy_network=target_networks['policy'], + target_critic_network=target_networks['critic'], + target_observation_network=target_observation_network, + discount=self._additional_discount, + num_samples=self._num_samples, + target_policy_update_period=self._target_policy_update_period, + target_critic_update_period=self._target_critic_update_period, + policy_loss_module=policy_loss_module, + dataset=dataset, + counter=counter, + logger=logger) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + ) -> acme.EnvironmentLoop: + """The actor process.""" + + action_spec = self._environment_spec.actions + observation_spec = self._environment_spec.observations + + # Create environment and target networks to act with. + environment = self._environment_factory(False) + agent_networks = self._network_factory(action_spec) + + # Create a stochastic behavior policy. + behavior_modules = [ + agent_networks.get('observation', tf.identity), + agent_networks.get('policy'), + networks.StochasticSamplingHead() + ] + behavior_network = snt.Sequential(behavior_modules) + + # Ensure network variables are created. + tf2_utils.create_variables(behavior_network, [observation_spec]) + policy_variables = {'policy': behavior_network.variables} + + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = tf2_variable_utils.VariableClient( + variable_source, + policy_variables, + update_period=self._variable_update_period) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + # Component to add things into replay. + adder = adders.NStepTransitionAdder( + client=replay, + n_step=self._n_step, + discount=self._additional_discount) + + # Create the agent. + actor = actors.FeedForwardActor( + policy_network=behavior_network, + adder=adder, + variable_client=variable_client) + + # Create logger and counter; actors will not spam bigtable. + counter = counting.Counter(counter, 'actor') + logger = loggers.make_default_logger( + 'actor', + save_data=False, + time_delta=self._log_every, + steps_key='actor_steps') + + # Create the run loop and return it. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def evaluator( + self, + variable_source: acme.VariableSource, + counter: counting.Counter, + ): + """The evaluation process.""" + + action_spec = self._environment_spec.actions + observation_spec = self._environment_spec.observations + + # Create environment and target networks to act with. + environment = self._environment_factory(True) + agent_networks = self._network_factory(action_spec) + + # Create a stochastic behavior policy. + evaluator_modules = [ + agent_networks.get('observation', tf.identity), + agent_networks.get('policy'), + networks.StochasticMeanHead(), + ] + + if isinstance(action_spec, specs.BoundedArray): + evaluator_modules += [networks.ClipToSpec(action_spec)] + evaluator_network = snt.Sequential(evaluator_modules) + + # Ensure network variables are created. + tf2_utils.create_variables(evaluator_network, [observation_spec]) + policy_variables = {'policy': evaluator_network.variables} + + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = tf2_variable_utils.VariableClient( + variable_source, + policy_variables, + update_period=self._variable_update_period) + + # Make sure not to evaluate a random actor by assigning variables before + # running the environment loop. + variable_client.update_and_wait() + + # Create the agent. + evaluator = actors.FeedForwardActor( + policy_network=evaluator_network, variable_client=variable_client) + + # Create logger and counter. + counter = counting.Counter(counter, 'evaluator') + logger = loggers.make_default_logger( + 'evaluator', time_delta=self._log_every, steps_key='evaluator_steps') + + # Create the run loop and return it. + return acme.EnvironmentLoop(environment, evaluator, counter, logger) + + def build(self, name='mpo'): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group('replay'): + replay = program.add_node(lp.ReverbNode(self.replay)) + + with program.group('counter'): + counter = program.add_node(lp.CourierNode(self.counter)) + + if self._max_actor_steps: + _ = program.add_node( + lp.CourierNode(self.coordinator, counter, self._max_actor_steps)) + + with program.group('learner'): + learner = program.add_node( + lp.CourierNode(self.learner, replay, counter)) + + with program.group('evaluator'): + program.add_node( + lp.CourierNode(self.evaluator, learner, counter)) + + if not self._num_caches: + # Use our learner as a single variable source. + sources = [learner] + else: + with program.group('cacher'): + # Create a set of learner caches. + sources = [] + for _ in range(self._num_caches): + cacher = program.add_node( + lp.CacherNode( + learner, refresh_interval_ms=2000, stale_after_ms=4000)) + sources.append(cacher) + + with program.group('actor'): + # Add actors which pull round-robin from our variable sources. + for actor_id in range(self._num_actors): + source = sources[actor_id % len(sources)] + program.add_node(lp.CourierNode(self.actor, replay, source, counter)) + + return program diff --git a/acme/acme/agents/tf/mpo/agent_distributed_test.py b/acme/acme/agents/tf/mpo/agent_distributed_test.py new file mode 100644 index 00000000..1bf1bda1 --- /dev/null +++ b/acme/acme/agents/tf/mpo/agent_distributed_test.py @@ -0,0 +1,100 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test for the distributed agent.""" + +from typing import Sequence + +import acme +from acme import specs +from acme.agents.tf import mpo +from acme.testing import fakes +from acme.tf import networks +from acme.tf import utils as tf2_utils +import launchpad as lp +import numpy as np +import sonnet as snt + +from absl.testing import absltest + + +def make_networks( + action_spec: specs.BoundedArray, + policy_layer_sizes: Sequence[int] = (50, 50), + critic_layer_sizes: Sequence[int] = (50, 50), +): + """Creates networks used by the agent.""" + + num_dimensions = np.prod(action_spec.shape, dtype=int) + + observation_network = tf2_utils.batch_concat + policy_network = snt.Sequential([ + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + init_scale=0.3, + fixed_scale=True, + use_tfd_independent=False) + ]) + evaluator_network = snt.Sequential([ + observation_network, + policy_network, + networks.StochasticMeanHead(), + ]) + # The multiplexer concatenates the (maybe transformed) observations/actions. + multiplexer = networks.CriticMultiplexer( + action_network=networks.ClipToSpec(action_spec)) + critic_network = snt.Sequential([ + multiplexer, + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(1), + ]) + + return { + 'policy': policy_network, + 'critic': critic_network, + 'observation': observation_network, + 'evaluator': evaluator_network, + } + + +class DistributedAgentTest(absltest.TestCase): + """Simple integration/smoke test for the distributed agent.""" + + def test_agent(self): + + agent = mpo.DistributedMPO( + environment_factory=lambda x: fakes.ContinuousEnvironment(bounded=True), + network_factory=make_networks, + num_actors=2, + batch_size=32, + min_replay_size=32, + max_replay_size=1000, + ) + program = agent.build() + + (learner_node,) = program.groups['learner'] + learner_node.disable_run() + + lp.launch(program, launch_type='test_mt') + + learner: acme.Learner = learner_node.create_handle().dereference() + + for _ in range(5): + learner.step() + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/mpo/agent_test.py b/acme/acme/agents/tf/mpo/agent_test.py new file mode 100644 index 00000000..9a763419 --- /dev/null +++ b/acme/acme/agents/tf/mpo/agent_test.py @@ -0,0 +1,77 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the MPO agent.""" + +import acme +from acme import specs +from acme.agents.tf import mpo +from acme.testing import fakes +from acme.tf import networks +import numpy as np +import sonnet as snt + +from absl.testing import absltest + + +def make_networks( + action_spec, + policy_layer_sizes=(10, 10), + critic_layer_sizes=(10, 10), +): + """Creates networks used by the agent.""" + + num_dimensions = np.prod(action_spec.shape, dtype=int) + critic_layer_sizes = list(critic_layer_sizes) + [1] + + policy_network = snt.Sequential([ + networks.LayerNormMLP(policy_layer_sizes), + networks.MultivariateNormalDiagHead(num_dimensions) + ]) + critic_network = networks.CriticMultiplexer( + critic_network=networks.LayerNormMLP(critic_layer_sizes)) + + return { + 'policy': policy_network, + 'critic': critic_network, + } + + +class MPOTest(absltest.TestCase): + + def test_mpo(self): + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment(episode_length=10, bounded=False) + spec = specs.make_environment_spec(environment) + + # Create networks. + agent_networks = make_networks(spec.actions) + + # Construct the agent. + agent = mpo.MPO( + spec, + policy_network=agent_networks['policy'], + critic_network=agent_networks['critic'], + batch_size=10, + samples_per_insert=2, + min_replay_size=10) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=2) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/mpo/learning.py b/acme/acme/agents/tf/mpo/learning.py new file mode 100644 index 00000000..b18ff459 --- /dev/null +++ b/acme/acme/agents/tf/mpo/learning.py @@ -0,0 +1,287 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MPO learner implementation.""" + +import time +from typing import List, Optional + +import acme +from acme import types +from acme.tf import losses +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import numpy as np +import sonnet as snt +import tensorflow as tf +import trfl + + +class MPOLearner(acme.Learner): + """MPO learner.""" + + def __init__( + self, + policy_network: snt.Module, + critic_network: snt.Module, + target_policy_network: snt.Module, + target_critic_network: snt.Module, + discount: float, + num_samples: int, + target_policy_update_period: int, + target_critic_update_period: int, + dataset: tf.data.Dataset, + observation_network: types.TensorTransformation = tf.identity, + target_observation_network: types.TensorTransformation = tf.identity, + policy_loss_module: Optional[snt.Module] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + dual_optimizer: Optional[snt.Optimizer] = None, + clipping: bool = True, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + save_directory: str = '~/acme', + ): + + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger('learner') + self._discount = discount + self._num_samples = num_samples + self._clipping = clipping + + # Necessary to track when to update target networks. + self._num_steps = tf.Variable(0, dtype=tf.int32) + self._target_policy_update_period = target_policy_update_period + self._target_critic_update_period = target_critic_update_period + + # Batch dataset and create iterator. + # TODO(b/155086959): Fix type stubs and remove. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + + # Store online and target networks. + self._policy_network = policy_network + self._critic_network = critic_network + self._target_policy_network = target_policy_network + self._target_critic_network = target_critic_network + + # Make sure observation networks are snt.Module's so they have variables. + self._observation_network = tf2_utils.to_sonnet_module(observation_network) + self._target_observation_network = tf2_utils.to_sonnet_module( + target_observation_network) + + self._policy_loss_module = policy_loss_module or losses.MPO( + epsilon=1e-1, + epsilon_penalty=1e-3, + epsilon_mean=2.5e-3, + epsilon_stddev=1e-6, + init_log_temperature=10., + init_log_alpha_mean=10., + init_log_alpha_stddev=1000.) + + # Create the optimizers. + self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2) + + # Expose the variables. + policy_network_to_expose = snt.Sequential( + [self._target_observation_network, self._target_policy_network]) + self._variables = { + 'critic': self._target_critic_network.variables, + 'policy': policy_network_to_expose.variables, + } + + # Create a checkpointer and snapshotter object. + self._checkpointer = None + self._snapshotter = None + + if checkpoint: + self._checkpointer = tf2_savers.Checkpointer( + directory=save_directory, + subdirectory='mpo_learner', + objects_to_save={ + 'counter': self._counter, + 'policy': self._policy_network, + 'critic': self._critic_network, + 'observation_network': self._observation_network, + 'target_policy': self._target_policy_network, + 'target_critic': self._target_critic_network, + 'target_observation_network': self._target_observation_network, + 'policy_optimizer': self._policy_optimizer, + 'critic_optimizer': self._critic_optimizer, + 'dual_optimizer': self._dual_optimizer, + 'policy_loss_module': self._policy_loss_module, + 'num_steps': self._num_steps, + }) + + self._snapshotter = tf2_savers.Snapshotter( + directory=save_directory, + objects_to_save={ + 'policy': + snt.Sequential([ + self._target_observation_network, + self._target_policy_network + ]), + }) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + @tf.function + def _step(self) -> types.Nest: + # Update target network. + online_policy_variables = self._policy_network.variables + target_policy_variables = self._target_policy_network.variables + online_critic_variables = ( + *self._observation_network.variables, + *self._critic_network.variables, + ) + target_critic_variables = ( + *self._target_observation_network.variables, + *self._target_critic_network.variables, + ) + + # Make online policy -> target policy network update ops. + if tf.math.mod(self._num_steps, self._target_policy_update_period) == 0: + for src, dest in zip(online_policy_variables, target_policy_variables): + dest.assign(src) + # Make online critic -> target critic network update ops. + if tf.math.mod(self._num_steps, self._target_critic_update_period) == 0: + for src, dest in zip(online_critic_variables, target_critic_variables): + dest.assign(src) + + # Increment number of learner steps for periodic update bookkeeping. + self._num_steps.assign_add(1) + + # Get next batch of data. + inputs = next(self._iterator) + + # Get data from replay (dropping extras if any). Note there is no + # extra data here because we do not insert any into Reverb. + transitions: types.Transition = inputs.data + + # Cast the additional discount to match the environment discount dtype. + discount = tf.cast(self._discount, dtype=transitions.discount.dtype) + + with tf.GradientTape(persistent=True) as tape: + # Maybe transform the observation before feeding into policy and critic. + # Transforming the observations this way at the start of the learning + # step effectively means that the policy and critic share observation + # network weights. + o_tm1 = self._observation_network(transitions.observation) + # This stop_gradient prevents gradients to propagate into the target + # observation network. In addition, since the online policy network is + # evaluated at o_t, this also means the policy loss does not influence + # the observation network training. + o_t = tf.stop_gradient( + self._target_observation_network(transitions.next_observation)) + + # Get action distributions from policy networks. + online_action_distribution = self._policy_network(o_t) + target_action_distribution = self._target_policy_network(o_t) + + # Get sampled actions to evaluate policy; of size [N, B, ...]. + sampled_actions = target_action_distribution.sample(self._num_samples) + tiled_o_t = tf2_utils.tile_tensor(o_t, self._num_samples) # [N, B, ...] + + # Compute the target critic's Q-value of the sampled actions in state o_t. + sampled_q_t = self._target_critic_network( + # Merge batch dimensions; to shape [N*B, ...]. + snt.merge_leading_dims(tiled_o_t, num_dims=2), + snt.merge_leading_dims(sampled_actions, num_dims=2)) + + # Reshape Q-value samples back to original batch dimensions and average + # them to compute the TD-learning bootstrap target. + sampled_q_t = tf.reshape(sampled_q_t, (self._num_samples, -1)) # [N, B] + q_t = tf.reduce_mean(sampled_q_t, axis=0) # [B] + + # Compute online critic value of a_tm1 in state o_tm1. + q_tm1 = self._critic_network(o_tm1, transitions.action) # [B, 1] + q_tm1 = tf.squeeze(q_tm1, axis=-1) # [B]; necessary for trfl.td_learning. + + # Critic loss. + critic_loss = trfl.td_learning(q_tm1, transitions.reward, + discount * transitions.discount, q_t).loss + critic_loss = tf.reduce_mean(critic_loss) + + # Actor learning. + policy_loss, policy_stats = self._policy_loss_module( + online_action_distribution=online_action_distribution, + target_action_distribution=target_action_distribution, + actions=sampled_actions, + q_values=sampled_q_t) + + # For clarity, explicitly define which variables are trained by which loss. + critic_trainable_variables = ( + # In this agent, the critic loss trains the observation network. + self._observation_network.trainable_variables + + self._critic_network.trainable_variables) + policy_trainable_variables = self._policy_network.trainable_variables + # The following are the MPO dual variables, stored in the loss module. + dual_trainable_variables = self._policy_loss_module.trainable_variables + + # Compute gradients. + critic_gradients = tape.gradient(critic_loss, critic_trainable_variables) + policy_gradients, dual_gradients = tape.gradient( + policy_loss, (policy_trainable_variables, dual_trainable_variables)) + + # Delete the tape manually because of the persistent=True flag. + del tape + + # Maybe clip gradients. + if self._clipping: + policy_gradients = tuple(tf.clip_by_global_norm(policy_gradients, 40.)[0]) + critic_gradients = tuple(tf.clip_by_global_norm(critic_gradients, 40.)[0]) + + # Apply gradients. + self._critic_optimizer.apply(critic_gradients, critic_trainable_variables) + self._policy_optimizer.apply(policy_gradients, policy_trainable_variables) + self._dual_optimizer.apply(dual_gradients, dual_trainable_variables) + + # Losses to track. + fetches = { + 'critic_loss': critic_loss, + 'policy_loss': policy_loss, + } + fetches.update(policy_stats) # Log MPO stats. + + return fetches + + def step(self): + # Run the learning step. + fetches = self._step() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + fetches.update(counts) + + # Checkpoint and attempt to write the logs. + if self._checkpointer is not None: + self._checkpointer.save() + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(fetches) + + def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: + return [tf2_utils.to_numpy(self._variables[name]) for name in names] diff --git a/acme/acme/agents/tf/r2d2/README.md b/acme/acme/agents/tf/r2d2/README.md new file mode 100644 index 00000000..a66fef18 --- /dev/null +++ b/acme/acme/agents/tf/r2d2/README.md @@ -0,0 +1,12 @@ +# R2D2 - Recurrent Experience Replay in Distributed Reinforcement Learning + +This folder contains an implementation of the R2D2 agent introduced in +([Kapturowski et al., 2019]). This work builds upon the DQN algorithm +([Mnih et al., 2013], [Mnih et al., 2015]) and Ape-X framework ([Horgan et al., +2018]), extending distributed Q-Learning to use recurrent neural networks. This +version is a synchronous version of the agent, and is therefore not distributed. + +[Kapturowski et al., 2019]: https://openreview.net/forum?id=r1lyTjAqYX +[Mnih et al., 2013]: https://arxiv.org/abs/1312.5602 +[Mnih et al., 2015]: https://www.nature.com/articles/nature14236 +[Horgan et al. 2018]: https://arxiv.org/pdf/1803.00933 diff --git a/acme/acme/agents/tf/r2d2/__init__.py b/acme/acme/agents/tf/r2d2/__init__.py new file mode 100644 index 00000000..fb6c7447 --- /dev/null +++ b/acme/acme/agents/tf/r2d2/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for Recurrent DQN (R2D2).""" + +from acme.agents.tf.r2d2.agent import R2D2 +from acme.agents.tf.r2d2.agent_distributed import DistributedR2D2 diff --git a/acme/acme/agents/tf/r2d2/agent.py b/acme/acme/agents/tf/r2d2/agent.py new file mode 100644 index 00000000..9acd5668 --- /dev/null +++ b/acme/acme/agents/tf/r2d2/agent.py @@ -0,0 +1,152 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Recurrent DQN (R2D2) agent implementation.""" + +import copy +from typing import Optional + +from acme import datasets +from acme import specs +from acme.adders import reverb as adders +from acme.agents import agent +from acme.agents.tf import actors +from acme.agents.tf.r2d2 import learning +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import reverb +import sonnet as snt +import tensorflow as tf +import trfl + + +class R2D2(agent.Agent): + """R2D2 Agent. + + This implements a single-process R2D2 agent. This is a Q-learning algorithm + that generates data via a (epislon-greedy) behavior policy, inserts + trajectories into a replay buffer, and periodically updates the policy (and + as a result the behavior) by sampling from this buffer. + """ + + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + network: snt.RNNCore, + burn_in_length: int, + trace_length: int, + replay_period: int, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + discount: float = 0.99, + batch_size: int = 32, + prefetch_size: int = tf.data.experimental.AUTOTUNE, + target_update_period: int = 100, + importance_sampling_exponent: float = 0.2, + priority_exponent: float = 0.6, + epsilon: float = 0.01, + learning_rate: float = 1e-3, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: float = 32.0, + store_lstm_state: bool = True, + max_priority_weight: float = 0.9, + checkpoint: bool = True, + ): + + if store_lstm_state: + extra_spec = { + 'core_state': tf2_utils.squeeze_batch_dim(network.initial_state(1)), + } + else: + extra_spec = () + + sequence_length = burn_in_length + trace_length + 1 + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Prioritized(priority_exponent), + remover=reverb.selectors.Fifo(), + max_size=max_replay_size, + rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), + signature=adders.SequenceAdder.signature( + environment_spec, extra_spec, sequence_length=sequence_length)) + self._server = reverb.Server([replay_table], port=None) + address = f'localhost:{self._server.port}' + + # Component to add things into replay. + adder = adders.SequenceAdder( + client=reverb.Client(address), + period=replay_period, + sequence_length=sequence_length, + ) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset( + server_address=address, + batch_size=batch_size, + prefetch_size=prefetch_size) + + target_network = copy.deepcopy(network) + tf2_utils.create_variables(network, [environment_spec.observations]) + tf2_utils.create_variables(target_network, [environment_spec.observations]) + + learner = learning.R2D2Learner( + environment_spec=environment_spec, + network=network, + target_network=target_network, + burn_in_length=burn_in_length, + sequence_length=sequence_length, + dataset=dataset, + reverb_client=reverb.TFClient(address), + counter=counter, + logger=logger, + discount=discount, + target_update_period=target_update_period, + importance_sampling_exponent=importance_sampling_exponent, + max_replay_size=max_replay_size, + learning_rate=learning_rate, + store_lstm_state=store_lstm_state, + max_priority_weight=max_priority_weight, + ) + + self._checkpointer = tf2_savers.Checkpointer( + subdirectory='r2d2_learner', + time_delta_minutes=60, + objects_to_save=learner.state, + enable_checkpointing=checkpoint, + ) + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={'network': network}, time_delta_minutes=60.) + + policy_network = snt.DeepRNN([ + network, + lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), + ]) + + actor = actors.RecurrentActor( + policy_network, adder, store_recurrent_state=store_lstm_state) + observations_per_step = ( + float(replay_period * batch_size) / samples_per_insert) + super().__init__( + actor=actor, + learner=learner, + min_observations=replay_period * max(batch_size, min_replay_size), + observations_per_step=observations_per_step) + + def update(self): + super().update() + self._snapshotter.save() + self._checkpointer.save() diff --git a/acme/acme/agents/tf/r2d2/agent_distributed.py b/acme/acme/agents/tf/r2d2/agent_distributed.py new file mode 100644 index 00000000..6be00251 --- /dev/null +++ b/acme/acme/agents/tf/r2d2/agent_distributed.py @@ -0,0 +1,276 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the Recurrent DQN Launchpad program.""" + +import copy +from typing import Callable, List, Optional + +import acme +from acme import datasets +from acme import specs +from acme.adders import reverb as adders +from acme.agents.tf import actors +from acme.agents.tf.r2d2 import learning +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.tf import variable_utils as tf2_variable_utils +from acme.utils import counting +from acme.utils import loggers +import dm_env +import launchpad as lp +import numpy as np +import reverb +import sonnet as snt +import tensorflow as tf +import trfl + + +class DistributedR2D2: + """Program definition for Recurrent Replay Distributed DQN (R2D2).""" + + def __init__(self, + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: Callable[[specs.DiscreteArray], snt.RNNCore], + num_actors: int, + burn_in_length: int, + trace_length: int, + replay_period: int, + environment_spec: Optional[specs.EnvironmentSpec] = None, + batch_size: int = 256, + prefetch_size: int = tf.data.experimental.AUTOTUNE, + min_replay_size: int = 1000, + max_replay_size: int = 100_000, + samples_per_insert: float = 32.0, + discount: float = 0.99, + priority_exponent: float = 0.6, + importance_sampling_exponent: float = 0.2, + variable_update_period: int = 1000, + learning_rate: float = 1e-3, + evaluator_epsilon: float = 0., + target_update_period: int = 100, + save_logs: bool = False): + + if environment_spec is None: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + self._environment_factory = environment_factory + self._network_factory = network_factory + self._environment_spec = environment_spec + self._num_actors = num_actors + self._batch_size = batch_size + self._prefetch_size = prefetch_size + self._min_replay_size = min_replay_size + self._max_replay_size = max_replay_size + self._samples_per_insert = samples_per_insert + self._burn_in_length = burn_in_length + self._trace_length = trace_length + self._replay_period = replay_period + self._discount = discount + self._target_update_period = target_update_period + self._variable_update_period = variable_update_period + self._save_logs = save_logs + self._priority_exponent = priority_exponent + self._learning_rate = learning_rate + self._evaluator_epsilon = evaluator_epsilon + self._importance_sampling_exponent = importance_sampling_exponent + + self._obs_spec = environment_spec.observations + + def replay(self) -> List[reverb.Table]: + """The replay storage.""" + network = self._network_factory(self._environment_spec.actions) + extra_spec = { + 'core_state': network.initial_state(1), + } + # Remove batch dimensions. + extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) + if self._samples_per_insert: + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._min_replay_size, + samples_per_insert=self._samples_per_insert, + error_buffer=self._batch_size) + else: + limiter = reverb.rate_limiters.MinSize(self._min_replay_size) + table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Prioritized(self._priority_exponent), + remover=reverb.selectors.Fifo(), + max_size=self._max_replay_size, + rate_limiter=limiter, + signature=adders.SequenceAdder.signature( + self._environment_spec, + extra_spec, + sequence_length=self._burn_in_length + self._trace_length + 1)) + + return [table] + + def counter(self): + """Creates the master counter process.""" + return tf2_savers.CheckpointingRunner( + counting.Counter(), time_delta_minutes=1, subdirectory='counter') + + def learner(self, replay: reverb.Client, counter: counting.Counter): + """The Learning part of the agent.""" + # Use architect and create the environment. + # Create the networks. + network = self._network_factory(self._environment_spec.actions) + target_network = copy.deepcopy(network) + + tf2_utils.create_variables(network, [self._obs_spec]) + tf2_utils.create_variables(target_network, [self._obs_spec]) + + # The dataset object to learn from. + reverb_client = reverb.TFClient(replay.server_address) + sequence_length = self._burn_in_length + self._trace_length + 1 + dataset = datasets.make_reverb_dataset( + server_address=replay.server_address, + batch_size=self._batch_size, + prefetch_size=self._prefetch_size) + + counter = counting.Counter(counter, 'learner') + logger = loggers.make_default_logger( + 'learner', save_data=True, steps_key='learner_steps') + # Return the learning agent. + learner = learning.R2D2Learner( + environment_spec=self._environment_spec, + network=network, + target_network=target_network, + burn_in_length=self._burn_in_length, + sequence_length=sequence_length, + dataset=dataset, + reverb_client=reverb_client, + counter=counter, + logger=logger, + discount=self._discount, + target_update_period=self._target_update_period, + importance_sampling_exponent=self._importance_sampling_exponent, + learning_rate=self._learning_rate, + max_replay_size=self._max_replay_size) + return tf2_savers.CheckpointingRunner( + wrapped=learner, time_delta_minutes=60, subdirectory='r2d2_learner') + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + epsilon: float, + ) -> acme.EnvironmentLoop: + """The actor process.""" + environment = self._environment_factory(False) + network = self._network_factory(self._environment_spec.actions) + + tf2_utils.create_variables(network, [self._obs_spec]) + + policy_network = snt.DeepRNN([ + network, + lambda qs: tf.cast(trfl.epsilon_greedy(qs, epsilon).sample(), tf.int32), + ]) + + # Component to add things into replay. + sequence_length = self._burn_in_length + self._trace_length + 1 + adder = adders.SequenceAdder( + client=replay, + period=self._replay_period, + sequence_length=sequence_length, + delta_encoded=True, + ) + + variable_client = tf2_variable_utils.VariableClient( + client=variable_source, + variables={'policy': policy_network.variables}, + update_period=self._variable_update_period) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + # Create the agent. + actor = actors.RecurrentActor( + policy_network=policy_network, + variable_client=variable_client, + adder=adder) + + counter = counting.Counter(counter, 'actor') + logger = loggers.make_default_logger( + 'actor', save_data=False, steps_key='actor_steps') + + # Create the loop to connect environment and agent. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def evaluator( + self, + variable_source: acme.VariableSource, + counter: counting.Counter, + ): + """The evaluation process.""" + environment = self._environment_factory(True) + network = self._network_factory(self._environment_spec.actions) + + tf2_utils.create_variables(network, [self._obs_spec]) + policy_network = snt.DeepRNN([ + network, + lambda qs: tf.cast(tf.argmax(qs, axis=-1), tf.int32), + ]) + + variable_client = tf2_variable_utils.VariableClient( + client=variable_source, + variables={'policy': policy_network.variables}, + update_period=self._variable_update_period) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + # Create the agent. + actor = actors.RecurrentActor( + policy_network=policy_network, variable_client=variable_client) + + # Create the run loop and return it. + logger = loggers.make_default_logger( + 'evaluator', save_data=True, steps_key='evaluator_steps') + counter = counting.Counter(counter, 'evaluator') + + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def build(self, name='r2d2'): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group('replay'): + replay = program.add_node(lp.ReverbNode(self.replay)) + + with program.group('counter'): + counter = program.add_node(lp.CourierNode(self.counter)) + + with program.group('learner'): + learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) + + with program.group('cacher'): + cacher = program.add_node( + lp.CacherNode(learner, refresh_interval_ms=2000, stale_after_ms=4000)) + + with program.group('evaluator'): + program.add_node(lp.CourierNode(self.evaluator, cacher, counter)) + + # Generate an epsilon for each actor. + epsilons = np.flip(np.logspace(1, 8, self._num_actors, base=0.4), axis=0) + + with program.group('actor'): + for epsilon in epsilons: + program.add_node( + lp.CourierNode(self.actor, replay, cacher, counter, epsilon)) + + return program diff --git a/acme/acme/agents/tf/r2d2/agent_distributed_test.py b/acme/acme/agents/tf/r2d2/agent_distributed_test.py new file mode 100644 index 00000000..81000f63 --- /dev/null +++ b/acme/acme/agents/tf/r2d2/agent_distributed_test.py @@ -0,0 +1,58 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test for the distributed agent.""" + +import acme +from acme.agents.tf import r2d2 +from acme.testing import fakes +from acme.tf import networks +import launchpad as lp + +from absl.testing import absltest + + +class DistributedAgentTest(absltest.TestCase): + """Simple integration/smoke test for the distributed agent.""" + + def test_agent(self): + env_factory = lambda x: fakes.fake_atari_wrapped(oar_wrapper=True) + net_factory = lambda spec: networks.R2D2AtariNetwork(spec.num_values) + + agent = r2d2.DistributedR2D2( + environment_factory=env_factory, + network_factory=net_factory, + num_actors=2, + batch_size=32, + min_replay_size=32, + max_replay_size=1000, + replay_period=1, + burn_in_length=1, + trace_length=10, + ) + program = agent.build() + + (learner_node,) = program.groups['learner'] + learner_node.disable_run() + + lp.launch(program, launch_type='test_mt') + + learner: acme.Learner = learner_node.create_handle().dereference() + + for _ in range(5): + learner.step() + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/r2d2/agent_test.py b/acme/acme/agents/tf/r2d2/agent_test.py new file mode 100644 index 00000000..03366782 --- /dev/null +++ b/acme/acme/agents/tf/r2d2/agent_test.py @@ -0,0 +1,84 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RDQN agent.""" + +import acme +from acme import specs +from acme.agents.tf import r2d2 +from acme.testing import fakes +from acme.tf import networks +import numpy as np +import sonnet as snt + +from absl.testing import absltest +from absl.testing import parameterized + + +class SimpleNetwork(networks.RNNCore): + + def __init__(self, action_spec: specs.DiscreteArray): + super().__init__(name='r2d2_test_network') + self._net = snt.DeepRNN([ + snt.Flatten(), + snt.LSTM(20), + snt.nets.MLP([50, 50, action_spec.num_values]) + ]) + + def __call__(self, inputs, state): + return self._net(inputs, state) + + def initial_state(self, batch_size: int, **kwargs): + return self._net.initial_state(batch_size) + + def unroll(self, inputs, state, sequence_length): + return snt.static_unroll(self._net, inputs, state, sequence_length) + + +class R2D2Test(parameterized.TestCase): + + @parameterized.parameters(True, False) + def test_r2d2(self, store_lstm_state: bool): + # Create a fake environment to test with. + # TODO(b/152596848): Allow R2D2 to deal with integer observations. + environment = fakes.DiscreteEnvironment( + num_actions=5, + num_observations=10, + obs_shape=(10, 4), + obs_dtype=np.float32, + episode_length=10) + spec = specs.make_environment_spec(environment) + + # Construct the agent. + agent = r2d2.R2D2( + environment_spec=spec, + network=SimpleNetwork(spec.actions), + batch_size=10, + samples_per_insert=2, + min_replay_size=10, + store_lstm_state=store_lstm_state, + burn_in_length=2, + trace_length=6, + replay_period=4, + checkpoint=False, + ) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=5) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/r2d2/learning.py b/acme/acme/agents/tf/r2d2/learning.py new file mode 100644 index 00000000..5d8f9d52 --- /dev/null +++ b/acme/acme/agents/tf/r2d2/learning.py @@ -0,0 +1,241 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Recurrent Replay Distributed DQN (R2D2) learner implementation.""" + +import functools +import time +from typing import Dict, Iterator, List, Mapping, Union, Optional + +import acme +from acme import specs +from acme.adders import reverb as adders +from acme.tf import losses +from acme.tf import networks +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import numpy as np +import reverb +import sonnet as snt +import tensorflow as tf +import tree + +Variables = List[np.ndarray] + + +class R2D2Learner(acme.Learner, tf2_savers.TFSaveable): + """R2D2 learner. + + This is the learning component of the R2D2 agent. It takes a dataset as input + and implements update functionality to learn from this dataset. + """ + + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + network: Union[networks.RNNCore, snt.RNNCore], + target_network: Union[networks.RNNCore, snt.RNNCore], + burn_in_length: int, + sequence_length: int, + dataset: tf.data.Dataset, + reverb_client: Optional[reverb.TFClient] = None, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + discount: float = 0.99, + target_update_period: int = 100, + importance_sampling_exponent: float = 0.2, + max_replay_size: int = 1_000_000, + learning_rate: float = 1e-3, + # TODO(sergomez): rename to use_core_state for consistency with JAX agent. + store_lstm_state: bool = True, + max_priority_weight: float = 0.9, + n_step: int = 5, + clip_grad_norm: Optional[float] = None, + ): + + if not isinstance(network, networks.RNNCore): + network.unroll = functools.partial(snt.static_unroll, network) + target_network.unroll = functools.partial(snt.static_unroll, + target_network) + + # Internalise agent components (replay buffer, networks, optimizer). + # TODO(b/155086959): Fix type stubs and remove. + self._iterator: Iterator[reverb.ReplaySample] = iter(dataset) # pytype: disable=wrong-arg-types + self._network = network + self._target_network = target_network + self._optimizer = snt.optimizers.Adam(learning_rate, epsilon=1e-3) + self._reverb_client = reverb_client + + # Internalise the hyperparameters. + self._store_lstm_state = store_lstm_state + self._burn_in_length = burn_in_length + self._discount = discount + self._max_replay_size = max_replay_size + self._importance_sampling_exponent = importance_sampling_exponent + self._max_priority_weight = max_priority_weight + self._target_update_period = target_update_period + self._num_actions = environment_spec.actions.num_values + self._sequence_length = sequence_length + self._n_step = n_step + self._clip_grad_norm = clip_grad_norm + + if burn_in_length: + self._burn_in = lambda o, s: self._network.unroll(o, s, burn_in_length) + else: + self._burn_in = lambda o, s: (o, s) # pylint: disable=unnecessary-lambda + + # Learner state. + self._variables = network.variables + self._num_steps = tf.Variable( + 0., dtype=tf.float32, trainable=False, name='step') + + # Internalise logging/counting objects. + self._counter = counting.Counter(counter, 'learner') + self._logger = logger or loggers.TerminalLogger('learner', time_delta=100.) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + @tf.function + def _step(self) -> Dict[str, tf.Tensor]: + + # Draw a batch of data from replay. + sample: reverb.ReplaySample = next(self._iterator) + + data = tf2_utils.batch_to_sequence(sample.data) + observations, actions, rewards, discounts, extra = (data.observation, + data.action, + data.reward, + data.discount, + data.extras) + unused_sequence_length, batch_size = actions.shape + + # Get initial state for the LSTM, either from replay or simply use zeros. + if self._store_lstm_state: + core_state = tree.map_structure(lambda x: x[0], extra['core_state']) + else: + core_state = self._network.initial_state(batch_size) + target_core_state = tree.map_structure(tf.identity, core_state) + + # Before training, optionally unroll the LSTM for a fixed warmup period. + burn_in_obs = tree.map_structure(lambda x: x[:self._burn_in_length], + observations) + _, core_state = self._burn_in(burn_in_obs, core_state) + _, target_core_state = self._burn_in(burn_in_obs, target_core_state) + + # Don't train on the warmup period. + observations, actions, rewards, discounts, extra = tree.map_structure( + lambda x: x[self._burn_in_length:], + (observations, actions, rewards, discounts, extra)) + + with tf.GradientTape() as tape: + # Unroll the online and target Q-networks on the sequences. + q_values, _ = self._network.unroll(observations, core_state, + self._sequence_length) + target_q_values, _ = self._target_network.unroll(observations, + target_core_state, + self._sequence_length) + + # Compute the target policy distribution (greedy). + greedy_actions = tf.argmax(q_values, output_type=tf.int32, axis=-1) + target_policy_probs = tf.one_hot( + greedy_actions, depth=self._num_actions, dtype=q_values.dtype) + + # Compute the transformed n-step loss. + rewards = tree.map_structure(lambda x: x[:-1], rewards) + discounts = tree.map_structure(lambda x: x[:-1], discounts) + loss, extra = losses.transformed_n_step_loss( + qs=q_values, + targnet_qs=target_q_values, + actions=actions, + rewards=rewards, + pcontinues=discounts * self._discount, + target_policy_probs=target_policy_probs, + bootstrap_n=self._n_step, + ) + + # Calculate importance weights and use them to scale the loss. + sample_info = sample.info + keys, probs = sample_info.key, sample_info.probability + importance_weights = 1. / (self._max_replay_size * probs) # [T, B] + importance_weights **= self._importance_sampling_exponent + importance_weights /= tf.reduce_max(importance_weights) + loss *= tf.cast(importance_weights, tf.float32) # [T, B] + loss = tf.reduce_mean(loss) # [] + + # Apply gradients via optimizer. + gradients = tape.gradient(loss, self._network.trainable_variables) + # Clip and apply gradients. + if self._clip_grad_norm is not None: + gradients, _ = tf.clip_by_global_norm(gradients, self._clip_grad_norm) + + self._optimizer.apply(gradients, self._network.trainable_variables) + + # Periodically update the target network. + if tf.math.mod(self._num_steps, self._target_update_period) == 0: + for src, dest in zip(self._network.variables, + self._target_network.variables): + dest.assign(src) + self._num_steps.assign_add(1) + + if self._reverb_client: + # Compute updated priorities. + priorities = compute_priority(extra.errors, self._max_priority_weight) + # Compute priorities and add an op to update them on the reverb side. + self._reverb_client.update_priorities( + table=adders.DEFAULT_PRIORITY_TABLE, + keys=keys, + priorities=tf.cast(priorities, tf.float64)) + + return {'loss': loss} + + def step(self): + # Run the learning step. + results = self._step() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + results.update(counts) + self._logger.write(results) + + def get_variables(self, names: List[str]) -> List[Variables]: + return [tf2_utils.to_numpy(self._variables)] + + @property + def state(self) -> Mapping[str, tf2_savers.Checkpointable]: + """Returns the stateful parts of the learner for checkpointing.""" + return { + 'network': self._network, + 'target_network': self._target_network, + 'optimizer': self._optimizer, + 'num_steps': self._num_steps, + } + + +def compute_priority(errors: tf.Tensor, alpha: float): + """Compute priority as mixture of max and mean sequence errors.""" + abs_errors = tf.abs(errors) + mean_priority = tf.reduce_mean(abs_errors, axis=0) + max_priority = tf.reduce_max(abs_errors, axis=0) + + return alpha * max_priority + (1 - alpha) * mean_priority diff --git a/acme/acme/agents/tf/r2d3/README.md b/acme/acme/agents/tf/r2d3/README.md new file mode 100644 index 00000000..ab6c008f --- /dev/null +++ b/acme/acme/agents/tf/r2d3/README.md @@ -0,0 +1,11 @@ +# R2D3 - R2D2 from Demonstrations + +This folder contains an implementation of the R2D3 agent introduced in +([Paine et al., 2019]). This work builds upon the R2D2 algorithm +([Kapturowski et al., 2019]). + +In this case a learner similar to the one used in R2D2 receives batches with a +fixed proportion of replay buffer and demonstration data. + +[Paine et al., 2019]: https://arxiv.org/abs/1909.01387 +[Kapturowski et al., 2019]: https://openreview.net/forum?id=r1lyTjAqYX diff --git a/acme/acme/agents/tf/r2d3/__init__.py b/acme/acme/agents/tf/r2d3/__init__.py new file mode 100644 index 00000000..55be53c6 --- /dev/null +++ b/acme/acme/agents/tf/r2d3/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for Recurrent DQfD (R2D3).""" + +from acme.agents.tf.r2d3.agent import R2D3 diff --git a/acme/acme/agents/tf/r2d3/agent.py b/acme/acme/agents/tf/r2d3/agent.py new file mode 100644 index 00000000..8bbe89d2 --- /dev/null +++ b/acme/acme/agents/tf/r2d3/agent.py @@ -0,0 +1,226 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Recurrent DQfD (R2D3) agent implementation.""" + +import functools +from typing import Optional + +from acme import datasets +from acme import specs +from acme import types as acme_types +from acme.adders import reverb as adders +from acme.agents import agent +from acme.agents.tf import actors +from acme.agents.tf.r2d2 import learning +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import reverb +import sonnet as snt +import tensorflow as tf +import tree +import trfl + + +class R2D3(agent.Agent): + """R2D3 Agent. + + This implements a single-process R2D2 agent that mixes demonstrations with + actor experience. + """ + + def __init__(self, + environment_spec: specs.EnvironmentSpec, + network: snt.RNNCore, + target_network: snt.RNNCore, + burn_in_length: int, + trace_length: int, + replay_period: int, + demonstration_dataset: tf.data.Dataset, + demonstration_ratio: float, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + discount: float = 0.99, + batch_size: int = 32, + target_update_period: int = 100, + importance_sampling_exponent: float = 0.2, + epsilon: float = 0.01, + learning_rate: float = 1e-3, + save_logs: bool = False, + log_name: str = 'agent', + checkpoint: bool = True, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: float = 32.0): + + sequence_length = burn_in_length + trace_length + 1 + extra_spec = { + 'core_state': network.initial_state(1), + } + # Remove batch dimensions. + extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=max_replay_size, + rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), + signature=adders.SequenceAdder.signature( + environment_spec, extra_spec, sequence_length=sequence_length)) + self._server = reverb.Server([replay_table], port=None) + address = f'localhost:{self._server.port}' + + # Component to add things into replay. + sequence_kwargs = dict( + period=replay_period, + sequence_length=sequence_length, + ) + adder = adders.SequenceAdder(client=reverb.Client(address), + **sequence_kwargs) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset( + server_address=address) + + # Combine with demonstration dataset. + transition = functools.partial(_sequence_from_episode, + extra_spec=extra_spec, + **sequence_kwargs) + dataset_demos = demonstration_dataset.map(transition) + dataset = tf.data.experimental.sample_from_datasets( + [dataset, dataset_demos], + [1 - demonstration_ratio, demonstration_ratio]) + + # Batch and prefetch. + dataset = dataset.batch(batch_size, drop_remainder=True) + dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) + + tf2_utils.create_variables(network, [environment_spec.observations]) + tf2_utils.create_variables(target_network, [environment_spec.observations]) + + learner = learning.R2D2Learner( + environment_spec=environment_spec, + network=network, + target_network=target_network, + burn_in_length=burn_in_length, + dataset=dataset, + reverb_client=reverb.TFClient(address), + counter=counter, + logger=logger, + sequence_length=sequence_length, + discount=discount, + target_update_period=target_update_period, + importance_sampling_exponent=importance_sampling_exponent, + max_replay_size=max_replay_size, + learning_rate=learning_rate, + store_lstm_state=False, + ) + + self._checkpointer = tf2_savers.Checkpointer( + subdirectory='r2d2_learner', + time_delta_minutes=60, + objects_to_save=learner.state, + enable_checkpointing=checkpoint, + ) + + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={'network': network}, time_delta_minutes=60.) + + policy_network = snt.DeepRNN([ + network, + lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), + ]) + + actor = actors.RecurrentActor(policy_network, adder) + observations_per_step = (float(replay_period * batch_size) / + samples_per_insert) + super().__init__( + actor=actor, + learner=learner, + min_observations=replay_period * max(batch_size, min_replay_size), + observations_per_step=observations_per_step) + + def update(self): + super().update() + self._snapshotter.save() + self._checkpointer.save() + + +def _sequence_from_episode(observations: acme_types.NestedTensor, + actions: tf.Tensor, + rewards: tf.Tensor, + discounts: tf.Tensor, + extra_spec: acme_types.NestedSpec, + period: int, + sequence_length: int): + """Produce Reverb-like sequence from a full episode. + + Observations, actions, rewards and discounts have the same length. This + function will ignore the first reward and discount and the last action. + + This function generates fake (all-zero) extras. + + See docs for reverb.SequenceAdder() for more details. + + Args: + observations: [L, ...] Tensor. + actions: [L, ...] Tensor. + rewards: [L] Tensor. + discounts: [L] Tensor. + extra_spec: A possibly nested structure of specs for extras. This function + will generate fake (all-zero) extras. + period: The period with which we add sequences. + sequence_length: The fixed length of sequences we wish to add. + + Returns: + (o_t, a_t, r_t, d_t, e_t) Tuple. + """ + + length = tf.shape(rewards)[0] + first = tf.random.uniform(shape=(), minval=0, maxval=length, dtype=tf.int32) + first = first // period * period # Get a multiple of `period`. + to = tf.minimum(first + sequence_length, length) + + def _slice_and_pad(x): + pad_length = sequence_length + first - to + padding_shape = tf.concat([[pad_length], tf.shape(x)[1:]], axis=0) + result = tf.concat([x[first:to], tf.zeros(padding_shape, x.dtype)], axis=0) + result.set_shape([sequence_length] + x.shape.as_list()[1:]) + return result + + o_t = tree.map_structure(_slice_and_pad, observations) + a_t = tree.map_structure(_slice_and_pad, actions) + r_t = _slice_and_pad(rewards) + d_t = _slice_and_pad(discounts) + start_of_episode = tf.equal(first, 0) + start_of_episode = tf.expand_dims(start_of_episode, axis=0) + start_of_episode = tf.tile(start_of_episode, [sequence_length]) + + def _sequence_zeros(spec): + return tf.zeros([sequence_length] + spec.shape, spec.dtype) + + e_t = tree.map_structure(_sequence_zeros, extra_spec) + info = tree.map_structure(lambda dtype: tf.ones([], dtype), + reverb.SampleInfo.tf_dtypes()) + return reverb.ReplaySample( + info=info, + data=adders.Step( + observation=o_t, + action=a_t, + reward=r_t, + discount=d_t, + start_of_episode=start_of_episode, + extras=e_t)) diff --git a/acme/acme/agents/tf/r2d3/agent_test.py b/acme/acme/agents/tf/r2d3/agent_test.py new file mode 100644 index 00000000..e5822166 --- /dev/null +++ b/acme/acme/agents/tf/r2d3/agent_test.py @@ -0,0 +1,94 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for R2D3 agent.""" + +import acme +from acme import specs +from acme.agents.tf import r2d3 +from acme.agents.tf.dqfd import bsuite_demonstrations +from acme.testing import fakes +from acme.tf import networks +import dm_env +import numpy as np +import sonnet as snt + +from absl.testing import absltest + + +class SimpleNetwork(networks.RNNCore): + + def __init__(self, action_spec: specs.DiscreteArray): + super().__init__(name='r2d2_test_network') + self._net = snt.DeepRNN([ + snt.Flatten(), + snt.LSTM(20), + snt.nets.MLP([50, 50, action_spec.num_values]) + ]) + + def __call__(self, inputs, state): + return self._net(inputs, state) + + def initial_state(self, batch_size: int, **kwargs): + return self._net.initial_state(batch_size) + + def unroll(self, inputs, state, sequence_length): + return snt.static_unroll(self._net, inputs, state, sequence_length) + + +class R2D3Test(absltest.TestCase): + + def test_r2d3(self): + # Create a fake environment to test with. + environment = fakes.DiscreteEnvironment( + num_actions=5, + num_observations=10, + obs_dtype=np.float32, + episode_length=10) + spec = specs.make_environment_spec(environment) + + # Build demonstrations. + dummy_action = np.zeros((), dtype=np.int32) + recorder = bsuite_demonstrations.DemonstrationRecorder() + timestep = environment.reset() + while timestep.step_type is not dm_env.StepType.LAST: + recorder.step(timestep, dummy_action) + timestep = environment.step(dummy_action) + recorder.step(timestep, dummy_action) + recorder.record_episode() + + # Construct the agent. + agent = r2d3.R2D3( + environment_spec=spec, + network=SimpleNetwork(spec.actions), + target_network=SimpleNetwork(spec.actions), + demonstration_dataset=recorder.make_tf_dataset(), + demonstration_ratio=0.5, + batch_size=10, + samples_per_insert=2, + min_replay_size=10, + burn_in_length=2, + trace_length=6, + replay_period=4, + checkpoint=False, + ) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=5) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/svg0_prior/README.md b/acme/acme/agents/tf/svg0_prior/README.md new file mode 100644 index 00000000..0a38a422 --- /dev/null +++ b/acme/acme/agents/tf/svg0_prior/README.md @@ -0,0 +1,24 @@ +# Stochastic Value Gradients (SVG) with Behavior Prior. + +This folder contains a version of the SVG-0 agent introduced in +([Heess et al., 2015]) that has been extended with an entropy bonus, RETRACE +([Munos et al., 2016]) for off-policy correction and code to learn behavior +priors ([Tirumala et al., 2019], [Galashov et al., 2019]). + +The base SVG-0 algorithm is similar to DPG and DDPG ([Silver et al., 2015], +[Lillicrap et al., 2015]) but uses the reparameterization trick to learn +stochastic and not deterministic policies. In addition, the RETRACE algorithm is +used to learn value functions using multiple timesteps of data with importance +sampling for off policy correction. + +In addition an optional Behavior Prior can be learnt using this setup with an +information asymmetry that has shown to boost performance in some domains. +Example code to run with and without behavior priors on the DeepMind Control +Suite and Locomotion tasks are provided in the `examples` folder. + +[Heess et al., 2015]: https://arxiv.org/abs/1510.09142 +[Munos et al., 2016]: https://arxiv.org/abs/1606.02647 +[Lillicrap et al., 2015]: https://arxiv.org/abs/1509.02971 +[Silver et al., 2014]: http://proceedings.mlr.press/v32/silver14 +[Tirumala et al., 2020]: https://arxiv.org/abs/2010.14274 +[Galashov et al., 2019]: https://arxiv.org/abs/1905.01240 diff --git a/acme/acme/agents/tf/svg0_prior/__init__.py b/acme/acme/agents/tf/svg0_prior/__init__.py new file mode 100644 index 00000000..b4218db2 --- /dev/null +++ b/acme/acme/agents/tf/svg0_prior/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementations of a SVG0 agent with prior.""" + +from acme.agents.tf.svg0_prior.agent import SVG0 +from acme.agents.tf.svg0_prior.agent_distributed import DistributedSVG0 +from acme.agents.tf.svg0_prior.learning import SVG0Learner +from acme.agents.tf.svg0_prior.networks import make_default_networks +from acme.agents.tf.svg0_prior.networks import make_network_with_prior diff --git a/acme/acme/agents/tf/svg0_prior/acting.py b/acme/acme/agents/tf/svg0_prior/acting.py new file mode 100644 index 00000000..f044a14e --- /dev/null +++ b/acme/acme/agents/tf/svg0_prior/acting.py @@ -0,0 +1,67 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SVG0 actor implementation.""" + +from typing import Optional + +from acme import adders +from acme import types + +from acme.agents.tf import actors +from acme.tf import utils as tf2_utils +from acme.tf import variable_utils as tf2_variable_utils + +import dm_env +import sonnet as snt + + +class SVG0Actor(actors.FeedForwardActor): + """An actor that also returns `log_prob`.""" + + def __init__( + self, + policy_network: snt.Module, + adder: Optional[adders.Adder] = None, + variable_client: Optional[tf2_variable_utils.VariableClient] = None, + deterministic_policy: Optional[bool] = False, + ): + super().__init__(policy_network, adder, variable_client) + self._log_prob = None + self._deterministic_policy = deterministic_policy + + def select_action(self, observation: types.NestedArray) -> types.NestedArray: + # Add a dummy batch dimension and as a side effect convert numpy to TF. + batched_observation = tf2_utils.add_batch_dim(observation) + + # Compute the policy, conditioned on the observation. + policy = self._policy_network(batched_observation) + if self._deterministic_policy: + action = policy.mean() + else: + action = policy.sample() + self._log_prob = policy.log_prob(action) + return tf2_utils.to_numpy_squeeze(action) + + def observe( + self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + ): + if not self._adder: + return + + extras = {'log_prob': self._log_prob} + extras = tf2_utils.to_numpy_squeeze(extras) + self._adder.add(action, next_timestep, extras) diff --git a/acme/acme/agents/tf/svg0_prior/agent.py b/acme/acme/agents/tf/svg0_prior/agent.py new file mode 100644 index 00000000..e9303c72 --- /dev/null +++ b/acme/acme/agents/tf/svg0_prior/agent.py @@ -0,0 +1,370 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SVG0 agent implementation.""" + +import copy +import dataclasses +from typing import Iterator, List, Optional, Tuple + +from acme import adders +from acme import core +from acme import datasets +from acme import specs +from acme.adders import reverb as reverb_adders +from acme.agents import agent +from acme.agents.tf.svg0_prior import acting +from acme.agents.tf.svg0_prior import learning +from acme.tf import utils +from acme.tf import variable_utils +from acme.utils import counting +from acme.utils import loggers +import reverb +import sonnet as snt +import tensorflow as tf + + +@dataclasses.dataclass +class SVG0Config: + """Configuration options for the agent.""" + + discount: float = 0.99 + batch_size: int = 256 + prefetch_size: int = 4 + target_update_period: int = 100 + policy_optimizer: Optional[snt.Optimizer] = None + critic_optimizer: Optional[snt.Optimizer] = None + prior_optimizer: Optional[snt.Optimizer] = None + min_replay_size: int = 1000 + max_replay_size: int = 1000000 + samples_per_insert: Optional[float] = 32.0 + sequence_length: int = 10 + sigma: float = 0.3 + replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE + distillation_cost: Optional[float] = 1e-3 + entropy_regularizer_cost: Optional[float] = 1e-3 + + +@dataclasses.dataclass +class SVG0Networks: + """Structure containing the networks for SVG0.""" + + policy_network: snt.Module + critic_network: snt.Module + prior_network: Optional[snt.Module] + + def __init__( + self, + policy_network: snt.Module, + critic_network: snt.Module, + prior_network: Optional[snt.Module] = None + ): + # This method is implemented (rather than added by the dataclass decorator) + # in order to allow observation network to be passed as an arbitrary tensor + # transformation rather than as a snt Module. + # TODO(mwhoffman): use Protocol rather than Module/TensorTransformation. + self.policy_network = policy_network + self.critic_network = critic_network + self.prior_network = prior_network + + def init(self, environment_spec: specs.EnvironmentSpec): + """Initialize the networks given an environment spec.""" + # Get observation and action specs. + act_spec = environment_spec.actions + obs_spec = environment_spec.observations + + # Create variables for the policy and critic nets. + _ = utils.create_variables(self.policy_network, [obs_spec]) + _ = utils.create_variables(self.critic_network, [obs_spec, act_spec]) + if self.prior_network is not None: + _ = utils.create_variables(self.prior_network, [obs_spec]) + + def make_policy( + self, + ) -> snt.Module: + """Create a single network which evaluates the policy.""" + return self.policy_network + + def make_prior( + self, + ) -> snt.Module: + """Create a single network which evaluates the prior.""" + behavior_prior = self.prior_network + return behavior_prior + + +class SVG0Builder: + """Builder for SVG0 which constructs individual components of the agent.""" + + def __init__(self, config: SVG0Config): + self._config = config + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + sequence_length: int, + ) -> List[reverb.Table]: + """Create tables to insert data into.""" + if self._config.samples_per_insert is None: + # We will take a samples_per_insert ratio of None to mean that there is + # no limit, i.e. this only implies a min size limit. + limiter = reverb.rate_limiters.MinSize(self._config.min_replay_size) + + else: + error_buffer = max(1, self._config.samples_per_insert) + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._config.min_replay_size, + samples_per_insert=self._config.samples_per_insert, + error_buffer=error_buffer) + + extras_spec = { + 'log_prob': tf.ones( + shape=(), dtype=tf.float32) + } + replay_table = reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_replay_size, + rate_limiter=limiter, + signature=reverb_adders.SequenceAdder.signature( + environment_spec, + extras_spec=extras_spec, + sequence_length=sequence_length + 1)) + + return [replay_table] + + def make_dataset_iterator( + self, + reverb_client: reverb.Client, + ) -> Iterator[reverb.ReplaySample]: + """Create a dataset iterator to use for learning/updating the agent.""" + # The dataset provides an interface to sample from replay. + dataset = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=reverb_client.server_address, + batch_size=self._config.batch_size, + prefetch_size=self._config.prefetch_size) + + # TODO(b/155086959): Fix type stubs and remove. + return iter(dataset) # pytype: disable=wrong-arg-types + + def make_adder( + self, + replay_client: reverb.Client, + ) -> adders.Adder: + """Create an adder which records data generated by the actor/environment.""" + return reverb_adders.SequenceAdder( + client=replay_client, + sequence_length=self._config.sequence_length+1, + priority_fns={self._config.replay_table_name: lambda x: 1.}, + period=self._config.sequence_length, + end_of_episode_behavior=reverb_adders.EndBehavior.CONTINUE, + ) + + def make_actor( + self, + policy_network: snt.Module, + adder: Optional[adders.Adder] = None, + variable_source: Optional[core.VariableSource] = None, + deterministic_policy: Optional[bool] = False, + ): + """Create an actor instance.""" + if variable_source: + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = variable_utils.VariableClient( + client=variable_source, + variables={'policy': policy_network.variables}, + update_period=1000, + ) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + else: + variable_client = None + + # Create the actor which defines how we take actions. + return acting.SVG0Actor( + policy_network=policy_network, + adder=adder, + variable_client=variable_client, + deterministic_policy=deterministic_policy + ) + + def make_learner( + self, + networks: Tuple[SVG0Networks, SVG0Networks], + dataset: Iterator[reverb.ReplaySample], + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = False, + ): + """Creates an instance of the learner.""" + online_networks, target_networks = networks + + # The learner updates the parameters (and initializes them). + return learning.SVG0Learner( + policy_network=online_networks.policy_network, + critic_network=online_networks.critic_network, + target_policy_network=target_networks.policy_network, + target_critic_network=target_networks.critic_network, + prior_network=online_networks.prior_network, + target_prior_network=target_networks.prior_network, + policy_optimizer=self._config.policy_optimizer, + critic_optimizer=self._config.critic_optimizer, + prior_optimizer=self._config.prior_optimizer, + distillation_cost=self._config.distillation_cost, + entropy_regularizer_cost=self._config.entropy_regularizer_cost, + discount=self._config.discount, + target_update_period=self._config.target_update_period, + dataset_iterator=dataset, + counter=counter, + logger=logger, + checkpoint=checkpoint, + ) + + +class SVG0(agent.Agent): + """SVG0 Agent with prior. + + This implements a single-process SVG0 agent. This is an actor-critic algorithm + that generates data via a behavior policy, inserts N-step transitions into + a replay buffer, and periodically updates the policy (and as a result the + behavior) by sampling uniformly from this buffer. + """ + + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + policy_network: snt.Module, + critic_network: snt.Module, + discount: float = 0.99, + batch_size: int = 256, + prefetch_size: int = 4, + target_update_period: int = 100, + prior_network: Optional[snt.Module] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + prior_optimizer: Optional[snt.Optimizer] = None, + distillation_cost: Optional[float] = 1e-3, + entropy_regularizer_cost: Optional[float] = 1e-3, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: float = 32.0, + sequence_length: int = 10, + sigma: float = 0.3, + replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + ): + """Initialize the agent. + + Args: + environment_spec: description of the actions, observations, etc. + policy_network: the online (optimized) policy. + critic_network: the online critic. + discount: discount to use for TD updates. + batch_size: batch size for updates. + prefetch_size: size to prefetch from replay. + target_update_period: number of learner steps to perform before updating + the target networks. + prior_network: an optional `behavior prior` to regularize against. + policy_optimizer: optimizer for the policy network updates. + critic_optimizer: optimizer for the critic network updates. + prior_optimizer: optimizer for the prior network updates. + distillation_cost: a multiplier to be used when adding distillation + against the prior to the losses. + entropy_regularizer_cost: a multiplier used for per state sample based + entropy added to the actor loss. + min_replay_size: minimum replay size before updating. + max_replay_size: maximum replay size. + samples_per_insert: number of samples to take from replay for every insert + that is made. + sequence_length: number of timesteps to store for each trajectory. + sigma: standard deviation of zero-mean, Gaussian exploration noise. + replay_table_name: string indicating what name to give the replay table. + counter: counter object used to keep track of steps. + logger: logger object to be used by learner. + checkpoint: boolean indicating whether to checkpoint the learner. + """ + # Create the Builder object which will internally create agent components. + builder = SVG0Builder( + # TODO(mwhoffman): pass the config dataclass in directly. + # TODO(mwhoffman): use the limiter rather than the workaround below. + # Right now this modifies min_replay_size and samples_per_insert so that + # they are not controlled by a limiter and are instead handled by the + # Agent base class (the above TODO directly references this behavior). + SVG0Config( + discount=discount, + batch_size=batch_size, + prefetch_size=prefetch_size, + target_update_period=target_update_period, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + prior_optimizer=prior_optimizer, + distillation_cost=distillation_cost, + entropy_regularizer_cost=entropy_regularizer_cost, + min_replay_size=1, # Let the Agent class handle this. + max_replay_size=max_replay_size, + samples_per_insert=None, # Let the Agent class handle this. + sequence_length=sequence_length, + sigma=sigma, + replay_table_name=replay_table_name, + )) + + # TODO(mwhoffman): pass the network dataclass in directly. + online_networks = SVG0Networks(policy_network=policy_network, + critic_network=critic_network, + prior_network=prior_network,) + + # Target networks are just a copy of the online networks. + target_networks = copy.deepcopy(online_networks) + + # Initialize the networks. + online_networks.init(environment_spec) + target_networks.init(environment_spec) + + # TODO(mwhoffman): either make this Dataclass or pass only one struct. + # The network struct passed to make_learner is just a tuple for the + # time-being (for backwards compatibility). + networks = (online_networks, target_networks) + + # Create the behavior policy. + policy_network = online_networks.make_policy() + + # Create the replay server and grab its address. + replay_tables = builder.make_replay_tables(environment_spec, + sequence_length) + replay_server = reverb.Server(replay_tables, port=None) + replay_client = reverb.Client(f'localhost:{replay_server.port}') + + # Create actor, dataset, and learner for generating, storing, and consuming + # data respectively. + adder = builder.make_adder(replay_client) + actor = builder.make_actor(policy_network, adder) + dataset = builder.make_dataset_iterator(replay_client) + learner = builder.make_learner(networks, dataset, counter, logger, + checkpoint) + + super().__init__( + actor=actor, + learner=learner, + min_observations=max(batch_size, min_replay_size), + observations_per_step=float(batch_size) / samples_per_insert) + + # Save the replay so we don't garbage collect it. + self._replay_server = replay_server diff --git a/acme/acme/agents/tf/svg0_prior/agent_distributed.py b/acme/acme/agents/tf/svg0_prior/agent_distributed.py new file mode 100644 index 00000000..8bf0bebc --- /dev/null +++ b/acme/acme/agents/tf/svg0_prior/agent_distributed.py @@ -0,0 +1,251 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the SVG0 agent class.""" + +import copy +from typing import Callable, Dict, Optional + +import acme +from acme import specs +from acme.agents.tf.svg0_prior import agent +from acme.tf import savers as tf2_savers +from acme.utils import counting +from acme.utils import loggers +from acme.utils import lp_utils +import dm_env +import launchpad as lp +import reverb +import sonnet as snt + + +class DistributedSVG0: + """Program definition for SVG0.""" + + def __init__( + self, + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: Callable[[specs.BoundedArray], Dict[str, snt.Module]], + num_actors: int = 1, + num_caches: int = 0, + environment_spec: Optional[specs.EnvironmentSpec] = None, + batch_size: int = 256, + prefetch_size: int = 4, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: Optional[float] = 32.0, + sequence_length: int = 10, + sigma: float = 0.3, + discount: float = 0.99, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + prior_optimizer: Optional[snt.Optimizer] = None, + distillation_cost: Optional[float] = 1e-3, + entropy_regularizer_cost: Optional[float] = 1e-3, + target_update_period: int = 100, + max_actor_steps: Optional[int] = None, + log_every: float = 10.0, + ): + + if not environment_spec: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + # TODO(mwhoffman): Make network_factory directly return the struct. + # TODO(mwhoffman): Make the factory take the entire spec. + def wrapped_network_factory(action_spec): + networks_dict = network_factory(action_spec) + networks = agent.SVG0Networks( + policy_network=networks_dict.get('policy'), + critic_network=networks_dict.get('critic'), + prior_network=networks_dict.get('prior', None),) + return networks + + self._environment_factory = environment_factory + self._network_factory = wrapped_network_factory + self._environment_spec = environment_spec + self._sigma = sigma + self._num_actors = num_actors + self._num_caches = num_caches + self._max_actor_steps = max_actor_steps + self._log_every = log_every + self._sequence_length = sequence_length + + self._builder = agent.SVG0Builder( + # TODO(mwhoffman): pass the config dataclass in directly. + # TODO(mwhoffman): use the limiter rather than the workaround below. + agent.SVG0Config( + discount=discount, + batch_size=batch_size, + prefetch_size=prefetch_size, + target_update_period=target_update_period, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + prior_optimizer=prior_optimizer, + min_replay_size=min_replay_size, + max_replay_size=max_replay_size, + samples_per_insert=samples_per_insert, + sequence_length=sequence_length, + sigma=sigma, + distillation_cost=distillation_cost, + entropy_regularizer_cost=entropy_regularizer_cost, + )) + + def replay(self): + """The replay storage.""" + return self._builder.make_replay_tables(self._environment_spec, + self._sequence_length) + + def counter(self): + return tf2_savers.CheckpointingRunner(counting.Counter(), + time_delta_minutes=1, + subdirectory='counter') + + def coordinator(self, counter: counting.Counter): + return lp_utils.StepsLimiter(counter, self._max_actor_steps) + + def learner( + self, + replay: reverb.Client, + counter: counting.Counter, + ): + """The Learning part of the agent.""" + + # Create the networks to optimize (online) and target networks. + online_networks = self._network_factory(self._environment_spec.actions) + target_networks = copy.deepcopy(online_networks) + + # Initialize the networks. + online_networks.init(self._environment_spec) + target_networks.init(self._environment_spec) + + dataset = self._builder.make_dataset_iterator(replay) + counter = counting.Counter(counter, 'learner') + logger = loggers.make_default_logger( + 'learner', time_delta=self._log_every, steps_key='learner_steps') + + return self._builder.make_learner( + networks=(online_networks, target_networks), + dataset=dataset, + counter=counter, + logger=logger, + ) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + ) -> acme.EnvironmentLoop: + """The actor process.""" + + # Create the behavior policy. + networks = self._network_factory(self._environment_spec.actions) + networks.init(self._environment_spec) + policy_network = networks.make_policy() + + # Create the agent. + actor = self._builder.make_actor( + policy_network=policy_network, + adder=self._builder.make_adder(replay), + variable_source=variable_source, + ) + + # Create the environment. + environment = self._environment_factory(False) + + # Create logger and counter; actors will not spam bigtable. + counter = counting.Counter(counter, 'actor') + logger = loggers.make_default_logger( + 'actor', + save_data=False, + time_delta=self._log_every, + steps_key='actor_steps') + + # Create the loop to connect environment and agent. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def evaluator( + self, + variable_source: acme.VariableSource, + counter: counting.Counter, + logger: Optional[loggers.Logger] = None, + ): + """The evaluation process.""" + + # Create the behavior policy. + networks = self._network_factory(self._environment_spec.actions) + networks.init(self._environment_spec) + policy_network = networks.make_policy() + + # Create the agent. + actor = self._builder.make_actor( + policy_network=policy_network, + variable_source=variable_source, + deterministic_policy=True, + ) + + # Make the environment. + environment = self._environment_factory(True) + + # Create logger and counter. + counter = counting.Counter(counter, 'evaluator') + logger = logger or loggers.make_default_logger( + 'evaluator', + time_delta=self._log_every, + steps_key='evaluator_steps', + ) + + # Create the run loop and return it. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def build(self, name='svg0'): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group('replay'): + replay = program.add_node(lp.ReverbNode(self.replay)) + + with program.group('counter'): + counter = program.add_node(lp.CourierNode(self.counter)) + + if self._max_actor_steps: + with program.group('coordinator'): + _ = program.add_node(lp.CourierNode(self.coordinator, counter)) + + with program.group('learner'): + learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) + + with program.group('evaluator'): + program.add_node(lp.CourierNode(self.evaluator, learner, counter)) + + if not self._num_caches: + # Use our learner as a single variable source. + sources = [learner] + else: + with program.group('cacher'): + # Create a set of learner caches. + sources = [] + for _ in range(self._num_caches): + cacher = program.add_node( + lp.CacherNode( + learner, refresh_interval_ms=2000, stale_after_ms=4000)) + sources.append(cacher) + + with program.group('actor'): + # Add actors which pull round-robin from our variable sources. + for actor_id in range(self._num_actors): + source = sources[actor_id % len(sources)] + program.add_node(lp.CourierNode(self.actor, replay, source, counter)) + + return program diff --git a/acme/acme/agents/tf/svg0_prior/agent_distributed_test.py b/acme/acme/agents/tf/svg0_prior/agent_distributed_test.py new file mode 100644 index 00000000..070231ab --- /dev/null +++ b/acme/acme/agents/tf/svg0_prior/agent_distributed_test.py @@ -0,0 +1,95 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test for the distributed agent.""" + +from typing import Sequence + +import acme +from acme import specs +from acme.agents.tf import svg0_prior +from acme.testing import fakes +from acme.tf import networks +from acme.tf import utils as tf2_utils +import launchpad as lp +import numpy as np +import sonnet as snt + +from absl.testing import absltest + + +def make_networks( + action_spec: specs.BoundedArray, + policy_layer_sizes: Sequence[int] = (10, 10), + critic_layer_sizes: Sequence[int] = (10, 10), +): + """Simple networks for testing..""" + + # Get total number of action dimensions from action spec. + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential([ + tf2_utils.batch_concat, + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + min_scale=0.3, + init_scale=0.7, + fixed_scale=False, + use_tfd_independent=False) + ]) + # The multiplexer concatenates the (maybe transformed) observations/actions. + multiplexer = networks.CriticMultiplexer() + critic_network = snt.Sequential([ + multiplexer, + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(1), + ]) + + return { + 'policy': policy_network, + 'critic': critic_network, + } + + +class DistributedAgentTest(absltest.TestCase): + """Simple integration/smoke test for the distributed agent.""" + + def test_control_suite(self): + """Tests that the agent can run on the control suite without crashing.""" + + agent = svg0_prior.DistributedSVG0( + environment_factory=lambda x: fakes.ContinuousEnvironment(), + network_factory=make_networks, + num_actors=2, + batch_size=32, + min_replay_size=32, + max_replay_size=1000, + ) + program = agent.build() + + (learner_node,) = program.groups['learner'] + learner_node.disable_run() + + lp.launch(program, launch_type='test_mt') + + learner: acme.Learner = learner_node.create_handle().dereference() + + for _ in range(5): + learner.step() + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/svg0_prior/agent_test.py b/acme/acme/agents/tf/svg0_prior/agent_test.py new file mode 100644 index 00000000..c8f0b03c --- /dev/null +++ b/acme/acme/agents/tf/svg0_prior/agent_test.py @@ -0,0 +1,96 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the SVG agent.""" + +import sys +from typing import Dict, Sequence + +import acme +from acme import specs +from acme import types +from acme.agents.tf import svg0_prior +from acme.testing import fakes +from acme.tf import networks +from acme.tf import utils as tf2_utils +import numpy as np +import sonnet as snt + +from absl.testing import absltest + + +def make_networks( + action_spec: types.NestedSpec, + policy_layer_sizes: Sequence[int] = (10, 10), + critic_layer_sizes: Sequence[int] = (10, 10), +) -> Dict[str, snt.Module]: + """Creates networks used by the agent.""" + # Get total number of action dimensions from action spec. + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential([ + tf2_utils.batch_concat, + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + min_scale=0.3, + init_scale=0.7, + fixed_scale=False, + use_tfd_independent=False) + ]) + # The multiplexer concatenates the (maybe transformed) observations/actions. + multiplexer = networks.CriticMultiplexer() + critic_network = snt.Sequential([ + multiplexer, + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(1), + ]) + + return { + 'policy': policy_network, + 'critic': critic_network, + } + + +class SVG0Test(absltest.TestCase): + + def test_svg0(self): + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment(episode_length=10) + spec = specs.make_environment_spec(environment) + + # Create the networks. + agent_networks = make_networks(spec.actions) + + # Construct the agent. + agent = svg0_prior.SVG0( + environment_spec=spec, + policy_network=agent_networks['policy'], + critic_network=agent_networks['critic'], + batch_size=10, + samples_per_insert=2, + min_replay_size=10, + ) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=2) + + # Imports check + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/agents/tf/svg0_prior/learning.py b/acme/acme/agents/tf/svg0_prior/learning.py new file mode 100644 index 00000000..297228e1 --- /dev/null +++ b/acme/acme/agents/tf/svg0_prior/learning.py @@ -0,0 +1,386 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SVG learner implementation.""" + +import time +from typing import Dict, Iterator, List, Optional + +import acme +from acme.agents.tf.svg0_prior import utils as svg0_utils +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import numpy as np +import reverb +import sonnet as snt +import tensorflow as tf +from trfl import continuous_retrace_ops + +_MIN_LOG_VAL = 1e-20 + + +class SVG0Learner(acme.Learner): + """SVG0 learner with optional prior. + + This is the learning component of an SVG0 agent. IE it takes a dataset as + input and implements update functionality to learn from this dataset. + """ + + def __init__( + self, + policy_network: snt.Module, + critic_network: snt.Module, + target_policy_network: snt.Module, + target_critic_network: snt.Module, + discount: float, + target_update_period: int, + dataset_iterator: Iterator[reverb.ReplaySample], + prior_network: Optional[snt.Module] = None, + target_prior_network: Optional[snt.Module] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + prior_optimizer: Optional[snt.Optimizer] = None, + distillation_cost: Optional[float] = 1e-3, + entropy_regularizer_cost: Optional[float] = 1e-3, + num_action_samples: int = 10, + lambda_: float = 1.0, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + ): + """Initializes the learner. + + Args: + policy_network: the online (optimized) policy. + critic_network: the online critic. + target_policy_network: the target policy (which lags behind the online + policy). + target_critic_network: the target critic. + discount: discount to use for TD updates. + target_update_period: number of learner steps to perform before updating + the target networks. + dataset_iterator: dataset to learn from, whether fixed or from a replay + buffer (see `acme.datasets.reverb.make_reverb_dataset` documentation). + prior_network: the online (optimized) prior. + target_prior_network: the target prior (which lags behind the online + prior). + policy_optimizer: the optimizer to be applied to the SVG-0 (policy) loss. + critic_optimizer: the optimizer to be applied to the distributional + Bellman loss. + prior_optimizer: the optimizer to be applied to the prior (distillation) + loss. + distillation_cost: a multiplier to be used when adding distillation + against the prior to the losses. + entropy_regularizer_cost: a multiplier used for per state sample based + entropy added to the actor loss. + num_action_samples: the number of action samples to use for estimating the + value function and sample based entropy. + lambda_: the `lambda` value to be used with retrace. + counter: counter object used to keep track of steps. + logger: logger object to be used by learner. + checkpoint: boolean indicating whether to checkpoint the learner. + """ + + # Store online and target networks. + self._policy_network = policy_network + self._critic_network = critic_network + self._target_policy_network = target_policy_network + self._target_critic_network = target_critic_network + + self._prior_network = prior_network + self._target_prior_network = target_prior_network + + self._lambda = lambda_ + self._num_action_samples = num_action_samples + self._distillation_cost = distillation_cost + self._entropy_regularizer_cost = entropy_regularizer_cost + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger('learner') + + # Other learner parameters. + self._discount = discount + + # Necessary to track when to update target networks. + self._num_steps = tf.Variable(0, dtype=tf.int32) + self._target_update_period = target_update_period + + # Batch dataset and create iterator. + self._iterator = dataset_iterator + + # Create optimizers if they aren't given. + self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + self._prior_optimizer = prior_optimizer or snt.optimizers.Adam(1e-4) + + # Expose the variables. + self._variables = { + 'critic': self._critic_network.variables, + 'policy': self._policy_network.variables, + } + if self._prior_network is not None: + self._variables['prior'] = self._prior_network.variables + + # Create a checkpointer and snapshotter objects. + self._checkpointer = None + self._snapshotter = None + + if checkpoint: + objects_to_save = { + 'counter': self._counter, + 'policy': self._policy_network, + 'critic': self._critic_network, + 'target_policy': self._target_policy_network, + 'target_critic': self._target_critic_network, + 'policy_optimizer': self._policy_optimizer, + 'critic_optimizer': self._critic_optimizer, + 'num_steps': self._num_steps, + } + if self._prior_network is not None: + objects_to_save['prior'] = self._prior_network + objects_to_save['target_prior'] = self._target_prior_network + objects_to_save['prior_optimizer'] = self._prior_optimizer + + self._checkpointer = tf2_savers.Checkpointer( + subdirectory='svg0_learner', + objects_to_save=objects_to_save) + objects_to_snapshot = { + 'policy': self._policy_network, + 'critic': self._critic_network, + } + if self._prior_network is not None: + objects_to_snapshot['prior'] = self._prior_network + + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save=objects_to_snapshot) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + @tf.function + def _step(self) -> Dict[str, tf.Tensor]: + # Update target network + online_variables = [ + *self._critic_network.variables, + *self._policy_network.variables, + ] + if self._prior_network is not None: + online_variables += [*self._prior_network.variables] + online_variables = tuple(online_variables) + + target_variables = [ + *self._target_critic_network.variables, + *self._target_policy_network.variables, + ] + if self._prior_network is not None: + target_variables += [*self._target_prior_network.variables] + target_variables = tuple(target_variables) + + # Make online -> target network update ops. + if tf.math.mod(self._num_steps, self._target_update_period) == 0: + for src, dest in zip(online_variables, target_variables): + dest.assign(src) + self._num_steps.assign_add(1) + + # Get data from replay (dropping extras if any) and flip to `[T, B, ...]`. + sample: reverb.ReplaySample = next(self._iterator) + data = tf2_utils.batch_to_sequence(sample.data) + observations, actions, rewards, discounts, extra = (data.observation, + data.action, + data.reward, + data.discount, + data.extras) + online_target_pi_q = svg0_utils.OnlineTargetPiQ( + online_pi=self._policy_network, + online_q=self._critic_network, + target_pi=self._target_policy_network, + target_q=self._target_critic_network, + num_samples=self._num_action_samples, + online_prior=self._prior_network, + target_prior=self._target_prior_network, + ) + with tf.GradientTape(persistent=True) as tape: + step_outputs = svg0_utils.static_rnn( + core=online_target_pi_q, + inputs=(observations, actions), + unroll_length=rewards.shape[0]) + + # Flip target samples to have shape [S, T+1, B, ...] where 'S' is the + # number of action samples taken. + target_pi_samples = tf2_utils.batch_to_sequence( + step_outputs.target_samples) + # Tile observations to have shape [S, T+1, B,..]. + tiled_observations = tf2_utils.tile_nested(observations, + self._num_action_samples) + + # Finally compute target Q values on the new action samples. + # Shape: [S, T+1, B, 1] + target_q_target_pi_samples = snt.BatchApply(self._target_critic_network, + 3)(tiled_observations, + target_pi_samples) + # Compute the value estimate by averaging over the action dimension. + # Shape: [T+1, B, 1]. + target_v_target_pi = tf.reduce_mean(target_q_target_pi_samples, axis=0) + + # Split the target V's into the target for learning + # `value_function_target` and the bootstrap value. Shape: [T, B]. + value_function_target = tf.squeeze(target_v_target_pi[:-1], axis=-1) + # Shape: [B]. + bootstrap_value = tf.squeeze(target_v_target_pi[-1], axis=-1) + + # When learning with a prior, add entropy terms to value targets. + if self._prior_network is not None: + value_function_target -= self._distillation_cost * tf.stop_gradient( + step_outputs.analytic_kl_to_target[:-1] + ) + bootstrap_value -= self._distillation_cost * tf.stop_gradient( + step_outputs.analytic_kl_to_target[-1]) + + # Get target log probs and behavior log probs from rollout. + # Shape: [T+1, B]. + target_log_probs_behavior_actions = ( + step_outputs.target_log_probs_behavior_actions) + behavior_log_probs = extra['log_prob'] + # Calculate importance weights. Shape: [T+1, B]. + rhos = tf.exp(target_log_probs_behavior_actions - behavior_log_probs) + + # Filter the importance weights to mask out episode restarts. Ignore the + # last action and consider the step type of the next step for masking. + # Shape: [T, B]. + episode_start_mask = tf2_utils.batch_to_sequence( + sample.data.start_of_episode)[1:] + + rhos = svg0_utils.mask_out_restarting(rhos[:-1], episode_start_mask) + + # rhos = rhos[:-1] + # Compute the log importance weights with a small value added for + # stability. + # Shape: [T, B] + log_rhos = tf.math.log(rhos + _MIN_LOG_VAL) + + # Retrieve the target and online Q values and throw away the last action. + # Shape: [T, B]. + target_q_values = tf.squeeze(step_outputs.target_q[:-1], -1) + online_q_values = tf.squeeze(step_outputs.online_q[:-1], -1) + + # Flip target samples to have shape [S, T+1, B, ...] where 'S' is the + # number of action samples taken. + online_pi_samples = tf2_utils.batch_to_sequence( + step_outputs.online_samples) + target_q_online_pi_samples = snt.BatchApply(self._target_critic_network, + 3)(tiled_observations, + online_pi_samples) + expected_q = tf.reduce_mean( + tf.squeeze(target_q_online_pi_samples, -1), axis=0) + + # Flip online_log_probs to be of shape [S, T+1, B] and then compute + # entropy by averaging over num samples. Final shape: [T+1, B]. + online_log_probs = tf2_utils.batch_to_sequence( + step_outputs.online_log_probs) + sample_based_entropy = tf.reduce_mean(-online_log_probs, axis=0) + retrace_outputs = continuous_retrace_ops.retrace_from_importance_weights( + log_rhos=log_rhos, + discounts=self._discount * discounts[:-1], + rewards=rewards[:-1], + q_values=target_q_values, + values=value_function_target, + bootstrap_value=bootstrap_value, + lambda_=self._lambda, + ) + + # Critic loss. Shape: [T, B]. + critic_loss = 0.5 * tf.math.squared_difference( + tf.stop_gradient(retrace_outputs.qs), online_q_values) + + # Policy loss- SVG0 with sample based entropy. Shape: [T, B] + policy_loss = -( + expected_q + self._entropy_regularizer_cost * sample_based_entropy) + policy_loss = policy_loss[:-1] + + if self._prior_network is not None: + # When training the prior, also add the per-timestep KL cost. + policy_loss += ( + self._distillation_cost * step_outputs.analytic_kl_to_target[:-1]) + + # Ensure episode restarts are masked out when computing the losses. + critic_loss = svg0_utils.mask_out_restarting(critic_loss, + episode_start_mask) + critic_loss = tf.reduce_mean(critic_loss) + + policy_loss = svg0_utils.mask_out_restarting(policy_loss, + episode_start_mask) + policy_loss = tf.reduce_mean(policy_loss) + + if self._prior_network is not None: + prior_loss = step_outputs.analytic_kl_divergence[:-1] + prior_loss = svg0_utils.mask_out_restarting(prior_loss, + episode_start_mask) + prior_loss = tf.reduce_mean(prior_loss) + + # Get trainable variables. + policy_variables = self._policy_network.trainable_variables + critic_variables = self._critic_network.trainable_variables + + # Compute gradients. + policy_gradients = tape.gradient(policy_loss, policy_variables) + critic_gradients = tape.gradient(critic_loss, critic_variables) + if self._prior_network is not None: + prior_variables = self._prior_network.trainable_variables + prior_gradients = tape.gradient(prior_loss, prior_variables) + + # Delete the tape manually because of the persistent=True flag. + del tape + + # Apply gradients. + self._policy_optimizer.apply(policy_gradients, policy_variables) + self._critic_optimizer.apply(critic_gradients, critic_variables) + losses = { + 'critic_loss': critic_loss, + 'policy_loss': policy_loss, + } + + if self._prior_network is not None: + self._prior_optimizer.apply(prior_gradients, prior_variables) + losses['prior_loss'] = prior_loss + + # Losses to track. + return losses + + def step(self): + # Run the learning step. + fetches = self._step() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + fetches.update(counts) + + # Checkpoint and attempt to write the logs. + if self._checkpointer is not None: + self._checkpointer.save() + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(fetches) + + def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: + return [tf2_utils.to_numpy(self._variables[name]) for name in names] diff --git a/acme/acme/agents/tf/svg0_prior/networks.py b/acme/acme/agents/tf/svg0_prior/networks.py new file mode 100644 index 00000000..945a2003 --- /dev/null +++ b/acme/acme/agents/tf/svg0_prior/networks.py @@ -0,0 +1,118 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared helpers for different experiment flavours.""" + +import functools +from typing import Mapping, Sequence, Optional + +from acme import specs +from acme import types +from acme.agents.tf.svg0_prior import utils as svg0_utils +from acme.tf import networks +from acme.tf import utils as tf2_utils + +import numpy as np +import sonnet as snt + + +def make_default_networks( + action_spec: specs.BoundedArray, + policy_layer_sizes: Sequence[int] = (256, 256, 256), + critic_layer_sizes: Sequence[int] = (512, 512, 256), +) -> Mapping[str, types.TensorTransformation]: + """Creates networks used by the agent.""" + + # Get total number of action dimensions from action spec. + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential([ + tf2_utils.batch_concat, + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + min_scale=0.3, + init_scale=0.7, + fixed_scale=False, + use_tfd_independent=False) + ]) + # The multiplexer concatenates the (maybe transformed) observations/actions. + multiplexer = networks.CriticMultiplexer( + action_network=networks.ClipToSpec(action_spec)) + critic_network = snt.Sequential([ + multiplexer, + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(1), + ]) + + return { + "policy": policy_network, + "critic": critic_network, + } + + +def make_network_with_prior( + action_spec: specs.BoundedArray, + policy_layer_sizes: Sequence[int] = (200, 100), + critic_layer_sizes: Sequence[int] = (400, 300), + prior_layer_sizes: Sequence[int] = (200, 100), + policy_keys: Optional[Sequence[str]] = None, + prior_keys: Optional[Sequence[str]] = None, +) -> Mapping[str, types.TensorTransformation]: + """Creates networks used by the agent.""" + + # Get total number of action dimensions from action spec. + num_dimensions = np.prod(action_spec.shape, dtype=int) + flatten_concat_policy = functools.partial( + svg0_utils.batch_concat_selection, concat_keys=policy_keys) + flatten_concat_prior = functools.partial( + svg0_utils.batch_concat_selection, concat_keys=prior_keys) + + policy_network = snt.Sequential([ + flatten_concat_policy, + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + min_scale=0.1, + init_scale=0.7, + fixed_scale=False, + use_tfd_independent=False) + ]) + # The multiplexer concatenates the (maybe transformed) observations/actions. + multiplexer = networks.CriticMultiplexer( + observation_network=flatten_concat_policy, + action_network=networks.ClipToSpec(action_spec)) + critic_network = snt.Sequential([ + multiplexer, + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(1), + ]) + prior_network = snt.Sequential([ + flatten_concat_prior, + networks.LayerNormMLP(prior_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + min_scale=0.1, + init_scale=0.7, + fixed_scale=False, + use_tfd_independent=False) + ]) + return { + "policy": policy_network, + "critic": critic_network, + "prior": prior_network, + } diff --git a/acme/acme/agents/tf/svg0_prior/utils.py b/acme/acme/agents/tf/svg0_prior/utils.py new file mode 100644 index 00000000..8474fea6 --- /dev/null +++ b/acme/acme/agents/tf/svg0_prior/utils.py @@ -0,0 +1,157 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for SVG0 algorithm with priors.""" + +import collections +from typing import Tuple, Optional, Dict, Iterable + +from acme import types +from acme.tf import utils as tf2_utils + +import sonnet as snt +import tensorflow as tf +import tree + + +class OnlineTargetPiQ(snt.Module): + """Core to unroll online and target policies and Q functions at once. + + A core that runs online and target policies and Q functions. This can be more + efficient if the core needs to be unrolled across time and called many times. + """ + + def __init__(self, + online_pi: snt.Module, + online_q: snt.Module, + target_pi: snt.Module, + target_q: snt.Module, + num_samples: int, + online_prior: Optional[snt.Module] = None, + target_prior: Optional[snt.Module] = None, + name='OnlineTargetPiQ'): + super().__init__(name) + + self._online_pi = online_pi + self._target_pi = target_pi + self._online_q = online_q + self._target_q = target_q + self._online_prior = online_prior + self._target_prior = target_prior + + self._num_samples = num_samples + output_list = [ + 'online_samples', 'target_samples', 'target_log_probs_behavior_actions', + 'online_log_probs', 'online_q', 'target_q' + ] + if online_prior is not None: + output_list += ['analytic_kl_divergence', 'analytic_kl_to_target'] + self._output_tuple = collections.namedtuple( + 'OnlineTargetPiQ', output_list) + + def __call__(self, input_obs_and_action: Tuple[tf.Tensor, tf.Tensor]): + (obs, action) = input_obs_and_action + online_pi_dist = self._online_pi(obs) + target_pi_dist = self._target_pi(obs) + + online_samples = online_pi_dist.sample(self._num_samples) + target_samples = target_pi_dist.sample(self._num_samples) + target_log_probs_behavior_actions = target_pi_dist.log_prob(action) + + online_log_probs = online_pi_dist.log_prob(tf.stop_gradient(online_samples)) + + online_q_out = self._online_q(obs, action) + target_q_out = self._target_q(obs, action) + + output_list = [ + online_samples, target_samples, target_log_probs_behavior_actions, + online_log_probs, online_q_out, target_q_out + ] + + if self._online_prior is not None: + prior_dist = self._online_prior(obs) + target_prior_dist = self._target_prior(obs) + analytic_kl_divergence = online_pi_dist.kl_divergence(prior_dist) + analytic_kl_to_target = online_pi_dist.kl_divergence(target_prior_dist) + + output_list += [analytic_kl_divergence, analytic_kl_to_target] + output = self._output_tuple(*output_list) + return output + + +def static_rnn(core: snt.Module, inputs: types.NestedTensor, + unroll_length: int): + """Unroll core along inputs for unroll_length steps. + + Note: for time-major input tensors whose leading dimension is less than + unroll_length, `None` would be provided instead. + + Args: + core: an instance of snt.Module. + inputs: a `nest` of time-major input tensors. + unroll_length: number of time steps to unroll. + + Returns: + step_outputs: a `nest` of time-major stacked output tensors of length + `unroll_length`. + """ + step_outputs = [] + for time_dim in range(unroll_length): + inputs_t = tree.map_structure( + lambda t, i_=time_dim: t[i_] if i_ < t.shape[0] else None, inputs) + step_output = core(inputs_t) + step_outputs.append(step_output) + + step_outputs = _nest_stack(step_outputs) + return step_outputs + + +def mask_out_restarting(tensor: tf.Tensor, start_of_episode: tf.Tensor): + """Mask out `tensor` taken on the step that resets the environment. + + Args: + tensor: a time-major 2-D `Tensor` of shape [T, B]. + start_of_episode: a 2-D `Tensor` of shape [T, B] that contains the points + where the episode restarts. + + Returns: + tensor of shape [T, B] with elements are masked out according to step_types, + restarting weights of shape [T, B] + """ + tensor.get_shape().assert_has_rank(2) + start_of_episode.get_shape().assert_has_rank(2) + weights = tf.cast(~start_of_episode, dtype=tf.float32) + masked_tensor = tensor * weights + return masked_tensor + + +def batch_concat_selection(observation_dict: Dict[str, types.NestedTensor], + concat_keys: Optional[Iterable[str]] = None, + output_dtype=tf.float32) -> tf.Tensor: + """Concatenate a dict of observations into 2-D tensors.""" + concat_keys = concat_keys or sorted(observation_dict.keys()) + to_concat = [] + for obs in concat_keys: + if obs not in observation_dict: + raise KeyError( + 'Missing observation. Requested: {} (available: {})'.format( + obs, list(observation_dict.keys()))) + to_concat.append(tf.cast(observation_dict[obs], output_dtype)) + + return tf2_utils.batch_concat(to_concat) + + +def _nest_stack(list_of_nests, axis=0): + """Convert a list of nests to a nest of stacked lists.""" + return tree.map_structure(lambda *ts: tf.stack(ts, axis=axis), *list_of_nests) diff --git a/acme/acme/core.py b/acme/acme/core.py new file mode 100644 index 00000000..edd7ad6f --- /dev/null +++ b/acme/acme/core.py @@ -0,0 +1,176 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Core Acme interfaces. + +This file specifies and documents the notions of `Actor` and `Learner`. +""" + +import abc +import itertools +from typing import Generic, Iterator, List, Optional, Sequence, TypeVar + +from acme import types +from acme.utils import metrics +import dm_env + +T = TypeVar('T') + + +@metrics.record_class_usage +class Actor(abc.ABC): + """Interface for an agent that can act. + + This interface defines an API for an Actor to interact with an EnvironmentLoop + (see acme.environment_loop), e.g. a simple RL loop where each step is of the + form: + + # Make the first observation. + timestep = env.reset() + actor.observe_first(timestep) + + # Take a step and observe. + action = actor.select_action(timestep.observation) + next_timestep = env.step(action) + actor.observe(action, next_timestep) + + # Update the actor policy/parameters. + actor.update() + """ + + @abc.abstractmethod + def select_action(self, observation: types.NestedArray) -> types.NestedArray: + """Samples from the policy and returns an action.""" + + @abc.abstractmethod + def observe_first(self, timestep: dm_env.TimeStep): + """Make a first observation from the environment. + + Note that this need not be an initial state, it is merely beginning the + recording of a trajectory. + + Args: + timestep: first timestep. + """ + + @abc.abstractmethod + def observe( + self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + ): + """Make an observation of timestep data from the environment. + + Args: + action: action taken in the environment. + next_timestep: timestep produced by the environment given the action. + """ + + @abc.abstractmethod + def update(self, wait: bool = False): + """Perform an update of the actor parameters from past observations. + + Args: + wait: if True, the update will be blocking. + """ + + +class VariableSource(abc.ABC): + """Abstract source of variables. + + Objects which implement this interface provide a source of variables, returned + as a collection of (nested) numpy arrays. Generally this will be used to + provide variables to some learned policy/etc. + """ + + @abc.abstractmethod + def get_variables(self, names: Sequence[str]) -> List[types.NestedArray]: + """Return the named variables as a collection of (nested) numpy arrays. + + Args: + names: args where each name is a string identifying a predefined subset of + the variables. + + Returns: + A list of (nested) numpy arrays `variables` such that `variables[i]` + corresponds to the collection named by `names[i]`. + """ + + +@metrics.record_class_usage +class Worker(abc.ABC): + """An interface for (potentially) distributed workers.""" + + @abc.abstractmethod + def run(self): + """Runs the worker.""" + + +class Saveable(abc.ABC, Generic[T]): + """An interface for saveable objects.""" + + @abc.abstractmethod + def save(self) -> T: + """Returns the state from the object to be saved.""" + + @abc.abstractmethod + def restore(self, state: T): + """Given the state, restores the object.""" + + +class Learner(VariableSource, Worker, Saveable): + """Abstract learner object. + + This corresponds to an object which implements a learning loop. A single step + of learning should be implemented via the `step` method and this step + is generally interacted with via the `run` method which runs update + continuously. + + All objects implementing this interface should also be able to take in an + external dataset (see acme.datasets) and run updates using data from this + dataset. This can be accomplished by explicitly running `learner.step()` + inside a for/while loop or by using the `learner.run()` convenience function. + Data will be read from this dataset asynchronously and this is primarily + useful when the dataset is filled by an external process. + """ + + @abc.abstractmethod + def step(self): + """Perform an update step of the learner's parameters.""" + + def run(self, num_steps: Optional[int] = None) -> None: + """Run the update loop; typically an infinite loop which calls step.""" + + iterator = range(num_steps) if num_steps is not None else itertools.count() + + for _ in iterator: + self.step() + + def save(self): + raise NotImplementedError('Method "save" is not implemented.') + + def restore(self, state): + raise NotImplementedError('Method "restore" is not implemented.') + + +class PrefetchingIterator(Iterator[T], abc.ABC): + """Abstract iterator object which supports `ready` method.""" + + @abc.abstractmethod + def ready(self) -> bool: + """Is there any data waiting for processing.""" + + @abc.abstractmethod + def retrieved_elements(self) -> int: + """How many elements were retrieved from the iterator.""" diff --git a/acme/acme/core_test.py b/acme/acme/core_test.py new file mode 100644 index 00000000..a7f2db55 --- /dev/null +++ b/acme/acme/core_test.py @@ -0,0 +1,57 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for core.py.""" + +from typing import List + +from acme import core +from acme import types + +from absl.testing import absltest + + +class StepCountingLearner(core.Learner): + """A learner which counts `num_steps` and then raises `StopIteration`.""" + + def __init__(self, num_steps: int): + self.step_count = 0 + self.num_steps = num_steps + + def step(self): + self.step_count += 1 + if self.step_count >= self.num_steps: + raise StopIteration() + + def get_variables(self, unused: List[str]) -> List[types.NestedArray]: + del unused + return [] + + +class CoreTest(absltest.TestCase): + + def test_learner_run_with_limit(self): + learner = StepCountingLearner(100) + learner.run(7) + self.assertEqual(learner.step_count, 7) + + def test_learner_run_no_limit(self): + learner = StepCountingLearner(100) + with self.assertRaises(StopIteration): + learner.run() + self.assertEqual(learner.step_count, 100) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/datasets/__init__.py b/acme/acme/datasets/__init__.py new file mode 100644 index 00000000..6dcfae02 --- /dev/null +++ b/acme/acme/datasets/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset interfaces.""" + +from acme.datasets.numpy_iterator import NumpyIterator +from acme.datasets.reverb import make_reverb_dataset +# from acme.datasets.reverb import make_reverb_dataset_trajectory diff --git a/acme/acme/datasets/image_augmentation.py b/acme/acme/datasets/image_augmentation.py new file mode 100644 index 00000000..bcf4f070 --- /dev/null +++ b/acme/acme/datasets/image_augmentation.py @@ -0,0 +1,120 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformations to be applied to replay datasets for augmentation purposes.""" + +import enum + +from acme import types +from acme.datasets import reverb as reverb_dataset +import reverb +import tensorflow as tf + + +class CropType(enum.Enum): + """Types of cropping supported by the image aumentation transforms. + + BILINEAR: Continuously randomly located then bilinearly interpolated. + ALIGNED: Aligned with input image's pixel grid. + """ + BILINEAR = 'bilinear' + ALIGNED = 'aligned' + + +def pad_and_crop(img: tf.Tensor, + pad_size: int = 4, + method: CropType = CropType.ALIGNED) -> tf.Tensor: + """Pad and crop image to mimic a random translation with mirroring at edges. + + This implements the image augmentation from section 3.1 in (Kostrikov et al.) + https://arxiv.org/abs/2004.13649. + + Args: + img: The image to pad and crop. Its dimensions are [..., H, W, C] where ... + are batch dimensions (if it has any). + pad_size: The amount of padding to apply to the image before cropping it. + method: The method to use for cropping the image, see `CropType` for + details. + + Returns: + The image after having been padded and cropped. + """ + num_batch_dims = img.shape[:-3].rank + + if img.shape.is_fully_defined(): + img_shape = img.shape.as_list() + else: + img_shape = tf.shape(img) + + # Set paddings for height and width only, batches and channels set to [0, 0]. + paddings = [[0, 0]] * num_batch_dims # Do not pad batch dims. + paddings.extend([[pad_size, pad_size], [pad_size, pad_size], [0, 0]]) + + # Pad using symmetric padding. + padded_img = tf.pad(img, paddings=paddings, mode='SYMMETRIC') + + # Crop padded image using requested method. + if method == CropType.ALIGNED: + cropped_img = tf.image.random_crop(padded_img, img_shape) + elif method == CropType.BILINEAR: + height, width = img_shape[-3:-1] + padded_height, padded_width = height + 2 * pad_size, width + 2 * pad_size + + # Pick a top-left point uniformly at random. + top_left = tf.random.uniform( + shape=(2,), maxval=2 * pad_size + 1, dtype=tf.int32) + + # This single box is applied to the entire batch if a batch is passed. + batch_size = tf.shape(padded_img)[0] + box = tf.cast( + tf.tile( + tf.expand_dims([ + top_left[0] / padded_height, + top_left[1] / padded_width, + (top_left[0] + height) / padded_height, + (top_left[1] + width) / padded_width, + ], axis=0), [batch_size, 1]), + tf.float32) # Shape [batch_size, 2]. + + # Crop and resize according to `box` then reshape back to input shape. + cropped_img = tf.image.crop_and_resize( + padded_img, + box, + tf.range(batch_size), + (height, width), + method='bilinear') + cropped_img = tf.reshape(cropped_img, img_shape) + + return cropped_img + + +def make_transform( + observation_transform: types.TensorTransformation, + transform_next_observation: bool = True, +) -> reverb_dataset.Transform: + """Creates the appropriate dataset transform for the given signature.""" + + if transform_next_observation: + def transform(x: reverb.ReplaySample) -> reverb.ReplaySample: + return x._replace( + data=x.data._replace( + observation=observation_transform(x.data.observation), + next_observation=observation_transform(x.data.next_observation))) + else: + def transform(x: reverb.ReplaySample) -> reverb.ReplaySample: + return x._replace( + data=x.data._replace( + observation=observation_transform(x.data.observation))) + + return transform diff --git a/acme/acme/datasets/numpy_iterator.py b/acme/acme/datasets/numpy_iterator.py new file mode 100644 index 00000000..8d673369 --- /dev/null +++ b/acme/acme/datasets/numpy_iterator.py @@ -0,0 +1,48 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A iterator that does zero-copy conversion of `tf.Tensor`s into `np.ndarray`s.""" + +from typing import Iterator + +from acme import types +import numpy as np +import tree + + +class NumpyIterator(Iterator[types.NestedArray]): + """Iterator over a dataset with elements converted to numpy. + + Note: This iterator returns read-only numpy arrays. + + This iterator (compared to `tf.data.Dataset.as_numpy_iterator()`) does not + copy the data when comverting `tf.Tensor`s to `np.ndarray`s. + + TODO(b/178684359): Remove this when it is upstreamed into `tf.data`. + """ + + __slots__ = ['_iterator'] + + def __init__(self, dataset): + self._iterator: Iterator[types.NestedTensor] = iter(dataset) + + def __iter__(self) -> 'NumpyIterator': + return self + + def __next__(self) -> types.NestedArray: + return tree.map_structure(lambda t: np.asarray(memoryview(t)), + next(self._iterator)) + + def next(self): + return self.__next__() diff --git a/acme/acme/datasets/numpy_iterator_test.py b/acme/acme/datasets/numpy_iterator_test.py new file mode 100644 index 00000000..500a4c3f --- /dev/null +++ b/acme/acme/datasets/numpy_iterator_test.py @@ -0,0 +1,49 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for acme.datasets.numpy_iterator.""" + +import collections + +from acme.datasets import numpy_iterator +import tensorflow as tf + +from absl.testing import absltest + + +class NumpyIteratorTest(absltest.TestCase): + + def testBasic(self): + ds = tf.data.Dataset.range(3) + self.assertEqual([0, 1, 2], list(numpy_iterator.NumpyIterator(ds))) + + def testNestedStructure(self): + point = collections.namedtuple('Point', ['x', 'y']) + ds = tf.data.Dataset.from_tensor_slices({ + 'a': ([1, 2], [3, 4]), + 'b': [5, 6], + 'c': point([7, 8], [9, 10]) + }) + self.assertEqual([{ + 'a': (1, 3), + 'b': 5, + 'c': point(7, 9) + }, { + 'a': (2, 4), + 'b': 6, + 'c': point(8, 10) + }], list(numpy_iterator.NumpyIterator(ds))) + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/datasets/reverb.py b/acme/acme/datasets/reverb.py new file mode 100644 index 00000000..77515660 --- /dev/null +++ b/acme/acme/datasets/reverb.py @@ -0,0 +1,157 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for making TensorFlow datasets for sampling from Reverb replay.""" + +import collections +import os +from typing import Callable, Mapping, Optional, Union + +from acme import specs +from acme import types +from acme.adders import reverb as adders +import reverb +import tensorflow as tf + +Transform = Callable[[reverb.ReplaySample], reverb.ReplaySample] + + +def make_reverb_dataset( + server_address: str, + batch_size: Optional[int] = None, + prefetch_size: Optional[int] = None, + table: Union[str, Mapping[str, float]] = adders.DEFAULT_PRIORITY_TABLE, + num_parallel_calls: Optional[int] = 12, + max_in_flight_samples_per_worker: Optional[int] = None, + postprocess: Optional[Transform] = None, + # Deprecated kwargs. + environment_spec: Optional[specs.EnvironmentSpec] = None, + extra_spec: Optional[types.NestedSpec] = None, + transition_adder: bool = False, + convert_zero_size_to_none: bool = False, + using_deprecated_adder: bool = False, + sequence_length: Optional[int] = None, +) -> tf.data.Dataset: + """Make a TensorFlow dataset backed by a Reverb trajectory replay service. + + Arguments: + server_address: Address of the Reverb server. + batch_size: Batch size of the returned dataset. + prefetch_size: The number of elements to prefetch from the original dataset. + Note that Reverb may do some internal prefetching in addition to this. + table: The name of the Reverb table to use, or a mapping of (table_name, + float_weight) for mixing multiple tables in the input (e.g. mixing online + and offline experiences). + num_parallel_calls: The parralelism to use. Setting it to `tf.data.AUTOTUNE` + will allow `tf.data` to automatically find a reasonable value. + max_in_flight_samples_per_worker: see reverb.TrajectoryDataset for details. + postprocess: User-specified transformation to be applied to the dataset (as + `ds.map(postprocess)`). + environment_spec: DEPRECATED! Do not use. + extra_spec: DEPRECATED! Do not use. + transition_adder: DEPRECATED! Do not use. + convert_zero_size_to_none: DEPRECATED! Do not use. + using_deprecated_adder: DEPRECATED! Do not use. + sequence_length: DEPRECATED! Do not use. + + Returns: + A `tf.data.Dataset` iterating over the contents of the Reverb table. + + Raises: + ValueError if `environment_spec` or `extra_spec` are set, or `table` is a + mapping with no positive weight values. + """ + + if environment_spec or extra_spec: + raise ValueError( + 'The make_reverb_dataset factory function no longer requires specs as' + ' as they should be passed as a signature to the reverb.Table when it' + ' is created. Consider either updating your code or falling back to the' + ' deprecated dataset factory in acme/datasets/deprecated.') + + # These are no longer used and are only kept in the call signature for + # backward compatibility. + del environment_spec + del extra_spec + del transition_adder + del convert_zero_size_to_none + del using_deprecated_adder + del sequence_length + + # This is the default that used to be set by reverb.TFClient.dataset(). + if max_in_flight_samples_per_worker is None and batch_size is None: + max_in_flight_samples_per_worker = 100 + elif max_in_flight_samples_per_worker is None: + max_in_flight_samples_per_worker = 2 * batch_size + + # Create mapping from tables to non-zero weights. + if isinstance(table, str): + tables = collections.OrderedDict([(table, 1.)]) + else: + tables = collections.OrderedDict([ + (name, weight) for name, weight in table.items() if weight > 0. + ]) + if len(tables) <= 0: + raise ValueError(f'No positive weights in input tables {tables}') + + # Normalize weights. + total_weight = sum(tables.values()) + tables = collections.OrderedDict([ + (name, weight / total_weight) for name, weight in tables.items() + ]) + + def _make_dataset(unused_idx: tf.Tensor) -> tf.data.Dataset: + datasets = () + for table_name, weight in tables.items(): + max_in_flight_samples = max( + 1, int(max_in_flight_samples_per_worker * weight)) + dataset = reverb.TrajectoryDataset.from_table_signature( + server_address=server_address, + table=table_name, + max_in_flight_samples_per_worker=max_in_flight_samples) + datasets += (dataset,) + if len(datasets) > 1: + dataset = tf.data.Dataset.sample_from_datasets( + datasets, weights=tables.values()) + else: + dataset = datasets[0] + + # Post-process each element if a post-processing function is passed, e.g. + # observation-stacking or data augmenting transformations. + if postprocess: + dataset = dataset.map(postprocess) + + if batch_size: + dataset = dataset.batch(batch_size, drop_remainder=True) + + return dataset + + if num_parallel_calls is not None: + # Create a datasets and interleaves it to create `num_parallel_calls` + # `TrajectoryDataset`s. + num_datasets_to_interleave = ( + os.cpu_count() + if num_parallel_calls == tf.data.AUTOTUNE else num_parallel_calls) + dataset = tf.data.Dataset.range(num_datasets_to_interleave).interleave( + map_func=_make_dataset, + cycle_length=num_parallel_calls, + num_parallel_calls=num_parallel_calls, + deterministic=False) + else: + dataset = _make_dataset(tf.constant(0)) + + if prefetch_size: + dataset = dataset.prefetch(prefetch_size) + + return dataset diff --git a/acme/acme/datasets/reverb_benchmark.py b/acme/acme/datasets/reverb_benchmark.py new file mode 100644 index 00000000..65af786f --- /dev/null +++ b/acme/acme/datasets/reverb_benchmark.py @@ -0,0 +1,97 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reverb dataset benchmark. + +Note: this a no-GRPC layer setup. +""" + +import time +from typing import Sequence + +from absl import app +from absl import logging +from acme import adders +from acme import specs +from acme.adders import reverb as adders_reverb +from acme.datasets import reverb as datasets +from acme.testing import fakes +import numpy as np +import reverb +from reverb import rate_limiters + + +def make_replay_tables(environment_spec: specs.EnvironmentSpec + ) -> Sequence[reverb.Table]: + """Create tables to insert data into.""" + return [ + reverb.Table( + name='default', + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=1000000, + rate_limiter=rate_limiters.MinSize(1), + signature=adders_reverb.NStepTransitionAdder.signature( + environment_spec)) + ] + + +def make_adder(replay_client: reverb.Client) -> adders.Adder: + return adders_reverb.NStepTransitionAdder( + priority_fns={'default': None}, + client=replay_client, + n_step=1, + discount=1) + + +def main(_): + environment = fakes.ContinuousEnvironment(action_dim=8, + observation_dim=87, + episode_length=10000000) + spec = specs.make_environment_spec(environment) + replay_tables = make_replay_tables(spec) + replay_server = reverb.Server(replay_tables, port=None) + replay_client = reverb.Client(f'localhost:{replay_server.port}') + adder = make_adder(replay_client) + + timestep = environment.reset() + adder.add_first(timestep) + # TODO(raveman): Consider also filling the table to say 1M (too slow). + for steps in range(10000): + if steps % 1000 == 0: + logging.info('Processed %s steps', steps) + action = np.asarray(np.random.uniform(-1, 1, (8,)), dtype=np.float32) + next_timestep = environment.step(action) + adder.add(action, next_timestep, extras=()) + + for batch_size in [256, 256 * 8, 256 * 64]: + for prefetch_size in [0, 1, 4]: + print(f'Processing batch_size={batch_size} prefetch_size={prefetch_size}') + ds = datasets.make_reverb_dataset( + table='default', + server_address=replay_client.server_address, + batch_size=batch_size, + prefetch_size=prefetch_size, + ) + it = ds.as_numpy_iterator() + + for iteration in range(3): + t = time.time() + for _ in range(1000): + _ = next(it) + print(f'Iteration {iteration} finished in {time.time() - t}s') + + +if __name__ == '__main__': + app.run(main) diff --git a/acme/acme/datasets/tfds.py b/acme/acme/datasets/tfds.py new file mode 100644 index 00000000..80c79e24 --- /dev/null +++ b/acme/acme/datasets/tfds.py @@ -0,0 +1,209 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities related to loading TFDS datasets.""" + +import logging +from typing import Any, Iterator, Optional, Tuple, Sequence + +from acme import specs +from acme import types +from flax import jax_utils +import jax +import jax.numpy as jnp +import numpy as np +import rlds +import tensorflow as tf +import tensorflow_datasets as tfds + + +def _batched_step_to_transition(step: rlds.BatchedStep) -> types.Transition: + return types.Transition( + observation=tf.nest.map_structure(lambda x: x[0], step[rlds.OBSERVATION]), + action=tf.nest.map_structure(lambda x: x[0], step[rlds.ACTION]), + reward=tf.nest.map_structure(lambda x: x[0], step[rlds.REWARD]), + discount=1.0 - tf.cast(step[rlds.IS_TERMINAL][1], dtype=tf.float32), + # If next step is terminal, then the observation may be arbitrary. + next_observation=tf.nest.map_structure( + lambda x: x[1], step[rlds.OBSERVATION]) + ) + + +def _batch_steps(episode: rlds.Episode) -> tf.data.Dataset: + return rlds.transformations.batch( + episode[rlds.STEPS], size=2, shift=1, drop_remainder=True) + + +def _dataset_size_upperbound(dataset: tf.data.Dataset) -> int: + if dataset.cardinality() != tf.data.experimental.UNKNOWN_CARDINALITY: + return dataset.cardinality() + return tf.cast( + dataset.batch(1000).reduce(0, lambda x, step: x + 1000), tf.int64) + + +def load_tfds_dataset( + dataset_name: str, + num_episodes: Optional[int] = None, + env_spec: Optional[specs.EnvironmentSpec] = None) -> tf.data.Dataset: + """Returns a TFDS dataset with the given name.""" + # Used only in tests. + del env_spec + + dataset = tfds.load(dataset_name)['train'] + if num_episodes: + dataset = dataset.take(num_episodes) + return dataset + + +# TODO(sinopalnikov): replace get_ftds_dataset with a pair of load/transform. +def get_tfds_dataset( + dataset_name: str, + num_episodes: Optional[int] = None, + env_spec: Optional[specs.EnvironmentSpec] = None) -> tf.data.Dataset: + """Returns a TFDS dataset transformed to a dataset of transitions.""" + dataset = load_tfds_dataset(dataset_name, num_episodes, env_spec) + batched_steps = dataset.flat_map(_batch_steps) + return rlds.transformations.map_steps(batched_steps, + _batched_step_to_transition) + + +# In order to avoid excessive copying on TPU one needs to make the last +# dimension a multiple of this number. +_BEST_DIVISOR = 128 + + +def _pad(x: jnp.ndarray) -> jnp.ndarray: + if len(x.shape) != 2: + return x + # Find a more scientific way to find this threshold (30). Depending on various + # conditions for low enough sizes the excessive copying is not triggered. + if x.shape[-1] % _BEST_DIVISOR != 0 and x.shape[-1] > 30: + n = _BEST_DIVISOR - (x.shape[-1] % _BEST_DIVISOR) + x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(0, n)], 'constant') + return x + + +# Undo the padding. +def _unpad(x: jnp.ndarray, shape: Sequence[int]) -> jnp.ndarray: + if len(shape) == 2 and x.shape[-1] != shape[-1]: + return x[..., :shape[-1]] + return x + + +_PMAP_AXIS_NAME = 'data' + + +class JaxInMemoryRandomSampleIterator(Iterator[Any]): + """In memory random sample iterator implemented in JAX. + + Loads the whole dataset in memory and performs random sampling with + replacement of batches of `batch_size`. + This class provides much faster sampling functionality compared to using + an iterator on tf.data.Dataset. + """ + + def __init__(self, + dataset: tf.data.Dataset, + key: jnp.ndarray, + batch_size: int, + shard_dataset_across_devices: bool = False): + """Creates an iterator. + + Args: + dataset: underlying tf Dataset + key: a key to be used for random number generation + batch_size: batch size + shard_dataset_across_devices: whether to use all available devices + for storing the underlying dataset. The upside is a larger + dataset capacity that fits into memory. Downsides are: + - execution of pmapped functions is usually slower than jitted + - few last elements in the dataset might be dropped (if not multiple) + - sampling is not 100% uniform, since each core will be doing sampling + only within its data chunk + The number of available devices must divide the batch_size evenly. + """ + # Read the whole dataset. We use artificially large batch_size to make sure + # we capture the whole dataset. + size = _dataset_size_upperbound(dataset) + data = next(dataset.batch(size).as_numpy_iterator()) + self._dataset_size = jax.tree_flatten( + jax.tree_map(lambda x: x.shape[0], data))[0][0] + device = jax_utils._pmap_device_order() + if not shard_dataset_across_devices: + device = device[:1] + should_pmap = len(device) > 1 + assert batch_size % len(device) == 0 + self._dataset_size = self._dataset_size - self._dataset_size % len(device) + # len(device) needs to divide self._dataset_size evenly. + assert self._dataset_size % len(device) == 0 + logging.info('Trying to load %s elements to %s', self._dataset_size, device) + logging.info('Dataset %s %s', + ('before padding' if should_pmap else ''), + jax.tree_map(lambda x: x.shape, data)) + if should_pmap: + shapes = jax.tree_map(lambda x: x.shape, data) + # Padding to a multiple of 128 is needed to avoid excessive copying on TPU + data = jax.tree_map(_pad, data) + logging.info('Dataset after padding %s', + jax.tree_map(lambda x: x.shape, data)) + def split_and_put(x: jnp.ndarray) -> jnp.ndarray: + return jax.device_put_sharded( + np.split(x[:self._dataset_size], len(device)), devices=device) + self._jax_dataset = jax.tree_map(split_and_put, data) + else: + self._jax_dataset = jax.tree_map(jax.device_put, data) + + self._key = (jnp.stack(jax.random.split(key, len(device))) + if should_pmap else key) + + def sample_per_shard(data: Any, + key: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: + key1, key2 = jax.random.split(key) + indices = jax.random.randint( + key1, (batch_size // len(device),), + minval=0, + maxval=self._dataset_size // len(device)) + data_sample = jax.tree_map(lambda d: jnp.take(d, indices, axis=0), data) + return data_sample, key2 + + if should_pmap: + def sample(data, key): + data_sample, key = sample_per_shard(data, key) + # Gathering data on TPUs is much more efficient that doing so on a host + # since it avoids Host - Device communications. + data_sample = jax.lax.all_gather( + data_sample, axis_name=_PMAP_AXIS_NAME, axis=0, tiled=True) + data_sample = jax.tree_map(_unpad, data_sample, shapes) + return data_sample, key + + pmapped_sample = jax.pmap(sample, axis_name=_PMAP_AXIS_NAME) + + def sample_and_postprocess(key: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: + data, key = pmapped_sample(self._jax_dataset, key) + # All pmapped devices return the same data, so we just take the one from + # the first device. + return jax.tree_map(lambda x: x[0], data), key + self._sample = sample_and_postprocess + else: + self._sample = jax.jit( + lambda key: sample_per_shard(self._jax_dataset, key)) + + def __next__(self) -> Any: + data, self._key = self._sample(self._key) + return data + + @property + def dataset_size(self) -> int: + """An integer of the dataset cardinality.""" + return self._dataset_size diff --git a/acme/acme/environment_loop.py b/acme/acme/environment_loop.py new file mode 100644 index 00000000..b9a016e8 --- /dev/null +++ b/acme/acme/environment_loop.py @@ -0,0 +1,189 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A simple agent-environment training loop.""" + +import operator +import time +from typing import Optional, Sequence + +from acme import core +from acme.utils import counting +from acme.utils import loggers +from acme.utils import observers as observers_lib +from acme.utils import signals + +import dm_env +from dm_env import specs +import numpy as np +import tree + + +class EnvironmentLoop(core.Worker): + """A simple RL environment loop. + + This takes `Environment` and `Actor` instances and coordinates their + interaction. Agent is updated if `should_update=True`. This can be used as: + + loop = EnvironmentLoop(environment, actor) + loop.run(num_episodes) + + A `Counter` instance can optionally be given in order to maintain counts + between different Acme components. If not given a local Counter will be + created to maintain counts between calls to the `run` method. + + A `Logger` instance can also be passed in order to control the output of the + loop. If not given a platform-specific default logger will be used as defined + by utils.loggers.make_default_logger. A string `label` can be passed to easily + change the label associated with the default logger; this is ignored if a + `Logger` instance is given. + + A list of 'Observer' instances can be specified to generate additional metrics + to be logged by the logger. They have access to the 'Environment' instance, + the current timestep datastruct and the current action. + """ + + def __init__( + self, + environment: dm_env.Environment, + actor: core.Actor, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + should_update: bool = True, + label: str = 'environment_loop', + observers: Sequence[observers_lib.EnvLoopObserver] = (), + ): + # Internalize agent and environment. + self._environment = environment + self._actor = actor + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + label, steps_key=self._counter.get_steps_key()) + self._should_update = should_update + self._observers = observers + + def run_episode(self) -> loggers.LoggingData: + """Run one episode. + + Each episode is a loop which interacts first with the environment to get an + observation and then give that observation to the agent in order to retrieve + an action. + + Returns: + An instance of `loggers.LoggingData`. + """ + # Reset any counts and start the environment. + start_time = time.time() + episode_steps = 0 + + # For evaluation, this keeps track of the total undiscounted reward + # accumulated during the episode. + episode_return = tree.map_structure(_generate_zeros_from_spec, + self._environment.reward_spec()) + timestep = self._environment.reset() + # Make the first observation. + self._actor.observe_first(timestep) + for observer in self._observers: + # Initialize the observer with the current state of the env after reset + # and the initial timestep. + observer.observe_first(self._environment, timestep) + + # Run an episode. + while not timestep.last(): + # Generate an action from the agent's policy and step the environment. + action = self._actor.select_action(timestep.observation) + timestep = self._environment.step(action) + + # Have the agent observe the timestep and let the actor update itself. + self._actor.observe(action, next_timestep=timestep) + for observer in self._observers: + # One environment step was completed. Observe the current state of the + # environment, the current timestep and the action. + observer.observe(self._environment, timestep, action) + if self._should_update: + self._actor.update() + + # Book-keeping. + episode_steps += 1 + + # Equivalent to: episode_return += timestep.reward + # We capture the return value because if timestep.reward is a JAX + # DeviceArray, episode_return will not be mutated in-place. (In all other + # cases, the returned episode_return will be the same object as the + # argument episode_return.) + episode_return = tree.map_structure(operator.iadd, + episode_return, + timestep.reward) + + # Record counts. + counts = self._counter.increment(episodes=1, steps=episode_steps) + + # Collect the results and combine with counts. + steps_per_second = episode_steps / (time.time() - start_time) + result = { + 'episode_length': episode_steps, + 'episode_return': episode_return, + 'steps_per_second': steps_per_second, + } + result.update(counts) + for observer in self._observers: + result.update(observer.get_metrics()) + return result + + def run(self, + num_episodes: Optional[int] = None, + num_steps: Optional[int] = None): + """Perform the run loop. + + Run the environment loop either for `num_episodes` episodes or for at + least `num_steps` steps (the last episode is always run until completion, + so the total number of steps may be slightly more than `num_steps`). + At least one of these two arguments has to be None. + + Upon termination of an episode a new episode will be started. If the number + of episodes and the number of steps are not given then this will interact + with the environment infinitely. + + Args: + num_episodes: number of episodes to run the loop for. + num_steps: minimal number of steps to run the loop for. + + Returns: + Actual number of steps the loop executed. + + Raises: + ValueError: If both 'num_episodes' and 'num_steps' are not None. + """ + + if not (num_episodes is None or num_steps is None): + raise ValueError('Either "num_episodes" or "num_steps" should be None.') + + def should_terminate(episode_count: int, step_count: int) -> bool: + return ((num_episodes is not None and episode_count >= num_episodes) or + (num_steps is not None and step_count >= num_steps)) + + episode_count, step_count = 0, 0 + with signals.runtime_terminator(): + while not should_terminate(episode_count, step_count): + result = self.run_episode() + episode_count += 1 + step_count += result['episode_length'] + # Log the given episode results. + self._logger.write(result) + + return step_count + + +def _generate_zeros_from_spec(spec: specs.Array) -> np.ndarray: + return np.zeros(spec.shape, spec.dtype) diff --git a/acme/acme/environment_loop_test.py b/acme/acme/environment_loop_test.py new file mode 100644 index 00000000..677ddb45 --- /dev/null +++ b/acme/acme/environment_loop_test.py @@ -0,0 +1,102 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the environment loop.""" + +from typing import Optional + +from acme import environment_loop +from acme import specs +from acme import types +from acme.testing import fakes +import numpy as np + +from absl.testing import absltest +from absl.testing import parameterized + +EPISODE_LENGTH = 10 + +# Discount specs +F32_2_MIN_0_MAX_1 = specs.BoundedArray( + dtype=np.float32, shape=(2,), minimum=0.0, maximum=1.0) +F32_2x1_MIN_0_MAX_1 = specs.BoundedArray( + dtype=np.float32, shape=(2, 1), minimum=0.0, maximum=1.0) +TREE_MIN_0_MAX_1 = {'a': F32_2_MIN_0_MAX_1, 'b': F32_2x1_MIN_0_MAX_1} + +# Reward specs +F32 = specs.Array(dtype=np.float32, shape=()) +F32_1x3 = specs.Array(dtype=np.float32, shape=(1, 3)) +TREE = {'a': F32, 'b': F32_1x3} + +TEST_CASES = ( + ('scalar_discount_scalar_reward', None, None), + ('vector_discount_scalar_reward', F32_2_MIN_0_MAX_1, F32), + ('matrix_discount_matrix_reward', F32_2x1_MIN_0_MAX_1, F32_1x3), + ('tree_discount_tree_reward', TREE_MIN_0_MAX_1, TREE), + ) + + +class EnvironmentLoopTest(parameterized.TestCase): + + @parameterized.named_parameters(*TEST_CASES) + def test_one_episode(self, discount_spec, reward_spec): + _, loop = _parameterized_setup(discount_spec, reward_spec) + result = loop.run_episode() + self.assertIn('episode_length', result) + self.assertEqual(EPISODE_LENGTH, result['episode_length']) + self.assertIn('episode_return', result) + self.assertIn('steps_per_second', result) + + @parameterized.named_parameters(*TEST_CASES) + def test_run_episodes(self, discount_spec, reward_spec): + actor, loop = _parameterized_setup(discount_spec, reward_spec) + + # Run the loop. There should be EPISODE_LENGTH update calls per episode. + loop.run(num_episodes=10) + self.assertEqual(actor.num_updates, 10 * EPISODE_LENGTH) + + @parameterized.named_parameters(*TEST_CASES) + def test_run_steps(self, discount_spec, reward_spec): + actor, loop = _parameterized_setup(discount_spec, reward_spec) + + # Run the loop. This will run 2 episodes so that total number of steps is + # at least 15. + loop.run(num_steps=EPISODE_LENGTH + 5) + self.assertEqual(actor.num_updates, 2 * EPISODE_LENGTH) + + +def _parameterized_setup(discount_spec: Optional[types.NestedSpec] = None, + reward_spec: Optional[types.NestedSpec] = None): + """Common setup code that, unlike self.setUp, takes arguments. + + Args: + discount_spec: None, or a (nested) specs.BoundedArray. + reward_spec: None, or a (nested) specs.Array. + Returns: + environment, actor, loop + """ + env_kwargs = {'episode_length': EPISODE_LENGTH} + if discount_spec: + env_kwargs['discount_spec'] = discount_spec + if reward_spec: + env_kwargs['reward_spec'] = reward_spec + + environment = fakes.DiscreteEnvironment(**env_kwargs) + actor = fakes.Actor(specs.make_environment_spec(environment)) + loop = environment_loop.EnvironmentLoop(environment, actor) + return actor, loop + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/environment_loops/__init__.py b/acme/acme/environment_loops/__init__.py new file mode 100644 index 00000000..32a4e752 --- /dev/null +++ b/acme/acme/environment_loops/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Specialized environment loops.""" + +try: + # pylint: disable=g-import-not-at-top + from acme.environment_loops.open_spiel_environment_loop import OpenSpielEnvironmentLoop +except ImportError: + pass diff --git a/acme/acme/environment_loops/open_spiel_environment_loop.py b/acme/acme/environment_loops/open_spiel_environment_loop.py new file mode 100644 index 00000000..4e9c81a3 --- /dev/null +++ b/acme/acme/environment_loops/open_spiel_environment_loop.py @@ -0,0 +1,227 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An OpenSpiel multi-agent/environment training loop.""" + +import operator +import time +from typing import Optional, Sequence + +from acme import core +from acme.utils import counting +from acme.utils import loggers +from acme.wrappers import open_spiel_wrapper +import dm_env +from dm_env import specs +import numpy as np +import tree + +# pytype: disable=import-error +import pyspiel +# pytype: enable=import-error + + +class OpenSpielEnvironmentLoop(core.Worker): + """An OpenSpiel RL environment loop. + + This takes `Environment` and list of `Actor` instances and coordinates their + interaction. Agents are updated if `should_update=True`. This can be used as: + + loop = EnvironmentLoop(environment, actors) + loop.run(num_episodes) + + A `Counter` instance can optionally be given in order to maintain counts + between different Acme components. If not given a local Counter will be + created to maintain counts between calls to the `run` method. + + A `Logger` instance can also be passed in order to control the output of the + loop. If not given a platform-specific default logger will be used as defined + by utils.loggers.make_default_logger. A string `label` can be passed to easily + change the label associated with the default logger; this is ignored if a + `Logger` instance is given. + """ + + def __init__( + self, + environment: open_spiel_wrapper.OpenSpielWrapper, + actors: Sequence[core.Actor], + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + should_update: bool = True, + label: str = 'open_spiel_environment_loop', + ): + # Internalize agent and environment. + self._environment = environment + self._actors = actors + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger(label) + self._should_update = should_update + + # Track information necessary to coordinate updates among multiple actors. + self._observed_first = [False] * len(self._actors) + self._prev_actions = [pyspiel.INVALID_ACTION] * len(self._actors) + + def _send_observation(self, timestep: dm_env.TimeStep, player: int): + # If terminal all actors must update + if player == pyspiel.PlayerId.TERMINAL: + for player_id in range(len(self._actors)): + # Note: we must account for situations where the first observation + # is a terminal state, e.g. if an opponent folds in poker before we get + # to act. + if self._observed_first[player_id]: + player_timestep = self._get_player_timestep(timestep, player_id) + self._actors[player_id].observe(self._prev_actions[player_id], + player_timestep) + if self._should_update: + self._actors[player_id].update() + self._observed_first = [False] * len(self._actors) + self._prev_actions = [pyspiel.INVALID_ACTION] * len(self._actors) + else: + if not self._observed_first[player]: + player_timestep = dm_env.TimeStep( + observation=timestep.observation[player], + reward=None, + discount=None, + step_type=dm_env.StepType.FIRST) + self._actors[player].observe_first(player_timestep) + self._observed_first[player] = True + else: + player_timestep = self._get_player_timestep(timestep, player) + self._actors[player].observe(self._prev_actions[player], + player_timestep) + if self._should_update: + self._actors[player].update() + + def _get_action(self, timestep: dm_env.TimeStep, player: int) -> int: + self._prev_actions[player] = self._actors[player].select_action( + timestep.observation[player]) + return self._prev_actions[player] + + def _get_player_timestep(self, timestep: dm_env.TimeStep, + player: int) -> dm_env.TimeStep: + return dm_env.TimeStep(observation=timestep.observation[player], + reward=timestep.reward[player], + discount=timestep.discount[player], + step_type=timestep.step_type) + + def run_episode(self) -> loggers.LoggingData: + """Run one episode. + + Each episode is a loop which interacts first with the environment to get an + observation and then give that observation to the agent in order to retrieve + an action. + + Returns: + An instance of `loggers.LoggingData`. + """ + # Reset any counts and start the environment. + start_time = time.time() + episode_steps = 0 + + # For evaluation, this keeps track of the total undiscounted reward + # for each player accumulated during the episode. + multiplayer_reward_spec = specs.BoundedArray( + (self._environment.game.num_players(),), + np.float32, + minimum=self._environment.game.min_utility(), + maximum=self._environment.game.max_utility()) + episode_return = tree.map_structure(_generate_zeros_from_spec, + multiplayer_reward_spec) + + timestep = self._environment.reset() + + # Make the first observation. + self._send_observation(timestep, self._environment.current_player) + + # Run an episode. + while not timestep.last(): + # Generate an action from the agent's policy and step the environment. + if self._environment.is_turn_based: + action_list = [ + self._get_action(timestep, self._environment.current_player) + ] + else: + # FIXME: Support simultaneous move games. + raise ValueError('Currently only supports sequential games.') + + timestep = self._environment.step(action_list) + + # Have the agent observe the timestep and let the actor update itself. + self._send_observation(timestep, self._environment.current_player) + + # Book-keeping. + episode_steps += 1 + + # Equivalent to: episode_return += timestep.reward + # We capture the return value because if timestep.reward is a JAX + # DeviceArray, episode_return will not be mutated in-place. (In all other + # cases, the returned episode_return will be the same object as the + # argument episode_return.) + episode_return = tree.map_structure(operator.iadd, + episode_return, + timestep.reward) + + # Record counts. + counts = self._counter.increment(episodes=1, steps=episode_steps) + + # Collect the results and combine with counts. + steps_per_second = episode_steps / (time.time() - start_time) + result = { + 'episode_length': episode_steps, + 'episode_return': episode_return, + 'steps_per_second': steps_per_second, + } + result.update(counts) + return result + + def run(self, + num_episodes: Optional[int] = None, + num_steps: Optional[int] = None): + """Perform the run loop. + + Run the environment loop either for `num_episodes` episodes or for at + least `num_steps` steps (the last episode is always run until completion, + so the total number of steps may be slightly more than `num_steps`). + At least one of these two arguments has to be None. + + Upon termination of an episode a new episode will be started. If the number + of episodes and the number of steps are not given then this will interact + with the environment infinitely. + + Args: + num_episodes: number of episodes to run the loop for. + num_steps: minimal number of steps to run the loop for. + + Raises: + ValueError: If both 'num_episodes' and 'num_steps' are not None. + """ + + if not (num_episodes is None or num_steps is None): + raise ValueError('Either "num_episodes" or "num_steps" should be None.') + + def should_terminate(episode_count: int, step_count: int) -> bool: + return ((num_episodes is not None and episode_count >= num_episodes) or + (num_steps is not None and step_count >= num_steps)) + + episode_count, step_count = 0, 0 + while not should_terminate(episode_count, step_count): + result = self.run_episode() + episode_count += 1 + step_count += result['episode_length'] + # Log the given results. + self._logger.write(result) + + +def _generate_zeros_from_spec(spec: specs.Array) -> np.ndarray: + return np.zeros(spec.shape, spec.dtype) diff --git a/acme/acme/environment_loops/open_spiel_environment_loop_test.py b/acme/acme/environment_loops/open_spiel_environment_loop_test.py new file mode 100644 index 00000000..e09d3d49 --- /dev/null +++ b/acme/acme/environment_loops/open_spiel_environment_loop_test.py @@ -0,0 +1,101 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for OpenSpiel environment loop.""" + +import unittest + +import acme +from acme import core +from acme import specs +from acme import types +from acme import wrappers +import dm_env +import numpy as np +import tree + +from absl.testing import absltest +from absl.testing import parameterized + +SKIP_OPEN_SPIEL_TESTS = False +SKIP_OPEN_SPIEL_MESSAGE = 'open_spiel not installed.' + +try: + # pylint: disable=g-import-not-at-top + # pytype: disable=import-error + from acme.environment_loops import open_spiel_environment_loop + from acme.wrappers import open_spiel_wrapper + from open_spiel.python import rl_environment + # pytype: disable=import-error + + class RandomActor(core.Actor): + """Fake actor which generates random actions and validates specs.""" + + def __init__(self, spec: specs.EnvironmentSpec): + self._spec = spec + self.num_updates = 0 + + def select_action(self, observation: open_spiel_wrapper.OLT) -> int: + _validate_spec(self._spec.observations, observation) + legals = np.array(np.nonzero(observation.legal_actions), dtype=np.int32) + return np.random.choice(legals[0]) + + def observe_first(self, timestep: dm_env.TimeStep): + _validate_spec(self._spec.observations, timestep.observation) + + def observe(self, action: types.NestedArray, + next_timestep: dm_env.TimeStep): + _validate_spec(self._spec.actions, action) + _validate_spec(self._spec.rewards, next_timestep.reward) + _validate_spec(self._spec.discounts, next_timestep.discount) + _validate_spec(self._spec.observations, next_timestep.observation) + + def update(self, wait: bool = False): + self.num_updates += 1 + +except ModuleNotFoundError: + SKIP_OPEN_SPIEL_TESTS = True + + +def _validate_spec(spec: types.NestedSpec, value: types.NestedArray): + """Validate a value from a potentially nested spec.""" + tree.assert_same_structure(value, spec) + tree.map_structure(lambda s, v: s.validate(v), spec, value) + + +@unittest.skipIf(SKIP_OPEN_SPIEL_TESTS, SKIP_OPEN_SPIEL_MESSAGE) +class OpenSpielEnvironmentLoopTest(parameterized.TestCase): + + def test_loop_run(self): + raw_env = rl_environment.Environment('tic_tac_toe') + env = open_spiel_wrapper.OpenSpielWrapper(raw_env) + env = wrappers.SinglePrecisionWrapper(env) + environment_spec = acme.make_environment_spec(env) + + actors = [] + for _ in range(env.num_players): + actors.append(RandomActor(environment_spec)) + + loop = open_spiel_environment_loop.OpenSpielEnvironmentLoop(env, actors) + result = loop.run_episode() + self.assertIn('episode_length', result) + self.assertIn('episode_return', result) + self.assertIn('steps_per_second', result) + + loop.run(num_episodes=10) + loop.run(num_steps=100) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/jax/__init__.py b/acme/acme/jax/__init__.py new file mode 100644 index 00000000..240cb715 --- /dev/null +++ b/acme/acme/jax/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/acme/acme/jax/experiments/__init__.py b/acme/acme/jax/experiments/__init__.py new file mode 100644 index 00000000..9a1170f0 --- /dev/null +++ b/acme/acme/jax/experiments/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""JAX experiment utils.""" + +from acme.jax.experiments.config import CheckpointingConfig +from acme.jax.experiments.config import default_evaluator_factory +from acme.jax.experiments.config import DeprecatedPolicyFactory +from acme.jax.experiments.config import EvaluatorFactory +from acme.jax.experiments.config import ExperimentConfig +from acme.jax.experiments.config import make_policy +from acme.jax.experiments.config import MakeActorFn +from acme.jax.experiments.config import NetworkFactory +from acme.jax.experiments.config import OfflineExperimentConfig +from acme.jax.experiments.config import SnapshotModelFactory +from acme.jax.experiments.make_distributed_experiment import make_distributed_experiment +#from acme.jax.experiments.make_distributed_offline_experiment import make_distributed_offline_experiment +from acme.jax.experiments.run_experiment import run_experiment +from acme.jax.experiments.run_offline_experiment import run_offline_experiment \ No newline at end of file diff --git a/acme/acme/jax/experiments/config.py b/acme/acme/jax/experiments/config.py new file mode 100644 index 00000000..3723d396 --- /dev/null +++ b/acme/acme/jax/experiments/config.py @@ -0,0 +1,309 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""JAX experiment config.""" + +import dataclasses +from typing import Any, Callable, Dict, Generic, Iterator, Optional, Protocol, Sequence + +from acme import core +from acme import environment_loop +from acme import specs +from acme.agents.jax import builders +from acme.jax import types +from acme.jax import utils +from acme.utils import counting +from acme.utils import loggers +from acme.utils import observers as observers_lib +from acme.utils import experiment_utils +import jax + +AgentNetwork = Any +PolicyNetwork = Any + +class MakeActorFn(Protocol, Generic[builders.Policy]): + + def __call__(self, random_key: types.PRNGKey, policy: builders.Policy, + environment_spec: specs.EnvironmentSpec, + variable_source: core.VariableSource) -> core.Actor: + ... + + +class NetworkFactory(Protocol, Generic[builders.Networks]): + + def __call__(self, + environment_spec: specs.EnvironmentSpec) -> builders.Networks: + ... + + +class DeprecatedPolicyFactory(Protocol, Generic[builders.Networks, + builders.Policy]): + + def __call__(self, networks: builders.Networks) -> builders.Policy: + ... + + +class PolicyFactory(Protocol, Generic[builders.Networks, builders.Policy]): + + def __call__(self, networks: builders.Networks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool) -> builders.Policy: + ... + + +class EvaluatorFactory(Protocol, Generic[builders.Policy]): + + def __call__(self, random_key: types.PRNGKey, + variable_source: core.VariableSource, counter: counting.Counter, + make_actor_fn: MakeActorFn[builders.Policy]) -> core.Worker: + ... + + +class SnapshotModelFactory(Protocol, Generic[builders.Networks]): + + def __call__( + self, networks: builders.Networks, environment_spec: specs.EnvironmentSpec + ) -> Dict[str, Callable[[core.VariableSource], types.ModelToSnapshot]]: + ... + + +@dataclasses.dataclass(frozen=True) +class CheckpointingConfig: + """Configuration options for checkpointing. + Attributes: + max_to_keep: Maximum number of checkpoints to keep. Does not apply to replay + checkpointing. + directory: Where to store the checkpoints. + add_uid: Whether or not to add a unique identifier, see + `paths.get_unique_id()` for how it is generated. + replay_checkpointing_time_delta_minutes: How frequently to write replay + checkpoints; defaults to None, which disables periodic checkpointing. + Warning! These are written asynchronously so as not to interrupt other + replay duties, however this does pose a risk of OOM since items that would + otherwise be removed are temporarily kept alive for checkpointing + purposes. + Note: Since replay buffers tend to be quite large O(100GiB), writing can + take up to 10 minutes so keep that in mind when setting this frequency. + time_delta_minutes: How often to save the checkpoint, in minutes. + """ + max_to_keep: int = 1 + directory: str = '~/acme' + add_uid: bool = True + replay_checkpointing_time_delta_minutes: Optional[int] = None + time_delta_minutes: int = 5 + + +@dataclasses.dataclass(frozen=True) +class ExperimentConfig(Generic[builders.Networks, builders.Policy, + builders.Sample]): + """Config which defines aspects of constructing an experiment. + Attributes: + builder: Builds components of an RL agent (Learner, Actor...). + network_factory: Builds networks used by the agent. + environment_factory: Returns an instance of an environment. + max_num_actor_steps: How many environment steps to perform. + seed: Seed used for agent initialization. + policy_network_factory: Policy network factory which is used actors to + perform inference. + evaluator_factories: Factories of policy evaluators. When not specified the + default evaluators are constructed using eval_policy_network_factory. Set + to an empty list to disable evaluators. + eval_policy_network_factory: Policy network factory used by evaluators. + Should be specified to use the default evaluators (when + evaluator_factories is not provided). + environment_spec: Specification of the environment. Can be specified to + reduce the number of times environment_factory is invoked (for performance + or resource usage reasons). + observers: Observers used for extending logs with custom information. + logger_factory: Loggers factory used to construct loggers for learner, + actors and evaluators. + checkpointing: Configuration options for checkpointing. If None, + checkpointing and snapshotting is disabled. + """ + # Below fields must be explicitly specified for any Agent. + builder: builders.ActorLearnerBuilder[builders.Networks, builders.Policy, + builders.Sample] + network_factory: NetworkFactory[builders.Networks] + environment_factory: types.EnvironmentFactory + max_num_actor_steps: int + seed: int + # policy_network_factory is deprecated. Use builder.make_policy to + # create the policy. + policy_network_factory: Optional[DeprecatedPolicyFactory[ + builders.Networks, builders.Policy]] = None + # Fields below are optional. If you just started with Acme do not worry about + # them. You might need them later when you want to customize your RL agent. + # TODO(stanczyk): Introduce a marker for the default value (instead of None). + evaluator_factories: Optional[Sequence[EvaluatorFactory[ + builders.Policy]]] = None + # eval_policy_network_factory is deprecated. Use builder.make_policy to + # create the policy. + eval_policy_network_factory: Optional[DeprecatedPolicyFactory[ + builders.Networks, builders.Policy]] = None + environment_spec: Optional[specs.EnvironmentSpec] = None + observers: Sequence[observers_lib.EnvLoopObserver] = () + logger_factory: loggers.LoggerFactory = experiment_utils.make_experiment_logger + checkpointing: Optional[CheckpointingConfig] = CheckpointingConfig() + + # TODO(stanczyk): Make get_evaluator_factories a standalone function. + def get_evaluator_factories(self): + """Constructs the evaluator factories.""" + if self.evaluator_factories is not None: + return self.evaluator_factories + + def eval_policy_factory(networks: builders.Networks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool) -> builders.Policy: + del evaluation + # The config factory has precedence until all agents are migrated to use + # builder.make_policy + if self.eval_policy_network_factory is not None: + return self.eval_policy_network_factory(networks) + else: + return self.builder.make_policy( + networks=networks, + environment_spec=environment_spec, + evaluation=True) + + return [ + default_evaluator_factory( + environment_factory=self.environment_factory, + network_factory=self.network_factory, + policy_factory=eval_policy_factory, + logger_factory=self.logger_factory, + observers=self.observers) + ] + + +@dataclasses.dataclass +class OfflineExperimentConfig(Generic[builders.Networks, builders.Policy, + builders.Sample]): + """Config which defines aspects of constructing an offline RL experiment. + This class is similar to the ExperimentConfig, but is tailored to offline RL + setting, so it excludes attributes related to training via interaction with + the environment (max_num_actor_steps, policy_network_factory) and instead + includes attributes specific to learning from demonstration. + Attributes: + builder: Builds components of an offline RL agent (Learner and Evaluator). + network_factory: Builds networks used by the agent. + demonstration_dataset_factory: Function that returns an iterator over + demonstrations. + environment_spec: Specification of the environment. + max_num_learner_steps: How many learner steps to perform. + seed: Seed used for agent initialization. + evaluator_factories: Factories of policy evaluators. When not specified the + default evaluators are constructed using eval_policy_network_factory. Set + to an empty list to disable evaluators. + eval_policy_factory: Policy factory used by evaluators. Should be specified + to use the default evaluators (when evaluator_factories is not provided). + environment_factory: Returns an instance of an environment to be used for + evaluation. Should be specified to use the default evaluators (when + evaluator_factories is not provided). + observers: Observers used for extending logs with custom information. + logger_factory: Loggers factory used to construct loggers for learner, + actors and evaluators. + checkpointing: Configuration options for checkpointing. If None, + checkpointing and snapshotting is disabled. + """ + # Below fields must be explicitly specified for any Agent. + builder: builders.OfflineBuilder[builders.Networks, builders.Policy, + builders.Sample] + network_factory: Callable[[specs.EnvironmentSpec], builders.Networks] + demonstration_dataset_factory: Callable[[types.PRNGKey], + Iterator[builders.Sample]] + environment_factory: types.EnvironmentFactory + max_num_learner_steps: int + seed: int + # Fields below are optional. If you just started with Acme do not worry about + # them. You might need them later when you want to customize your RL agent. + # TODO(stanczyk): Introduce a marker for the default value (instead of None). + evaluator_factories: Optional[Sequence[EvaluatorFactory]] = None + environment_spec: Optional[specs.EnvironmentSpec] = None + observers: Sequence[observers_lib.EnvLoopObserver] = () + logger_factory: loggers.LoggerFactory = experiment_utils.make_experiment_logger + checkpointing: Optional[CheckpointingConfig] = CheckpointingConfig() + + # TODO(stanczyk): Make get_evaluator_factories a standalone function. + def get_evaluator_factories(self): + """Constructs the evaluator factories.""" + if self.evaluator_factories is not None: + return self.evaluator_factories + if self.environment_factory is None: + raise ValueError( + 'You need to set `environment_factory` in `OfflineExperimentConfig` ' + 'when `evaluator_factories` are not specified. To disable evaluation ' + 'altogether just set `evaluator_factories = []`') + + return [ + default_evaluator_factory( + environment_factory=self.environment_factory, + network_factory=self.network_factory, + policy_factory=self.builder.make_policy, + logger_factory=self.logger_factory, + observers=self.observers) + ] + + +def default_evaluator_factory( + environment_factory: types.EnvironmentFactory, + network_factory: NetworkFactory[builders.Networks], + policy_factory: PolicyFactory[builders.Networks, builders.Policy], + logger_factory: loggers.LoggerFactory, + observers: Sequence[observers_lib.EnvLoopObserver] = (), +) -> EvaluatorFactory[builders.Policy]: + """Returns a default evaluator process.""" + + def evaluator( + random_key: types.PRNGKey, + variable_source: core.VariableSource, + counter: counting.Counter, + make_actor: MakeActorFn[builders.Policy], + ): + """The evaluation process.""" + + # Create environment and evaluator networks + environment_key, actor_key = jax.random.split(random_key) + # Environments normally require uint32 as a seed. + environment = environment_factory(utils.sample_uint32(environment_key)) + environment_spec = specs.make_environment_spec(environment) + networks = network_factory(environment_spec) + policy = policy_factory(networks, environment_spec, True) + actor = make_actor(actor_key, policy, environment_spec, variable_source) + + # Create logger and counter. + counter = counting.Counter(counter, 'evaluator') + logger = logger_factory('evaluator', 'actor_steps', 0) + + # Create the run loop and return it. + return environment_loop.EnvironmentLoop( + environment, actor, counter, logger, observers=observers) + + return evaluator + + +def make_policy(experiment: ExperimentConfig[builders.Networks, builders.Policy, + Any], networks: builders.Networks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool) -> builders.Policy: + """Constructs a policy. It is only meant to be used internally.""" + # TODO(sabela): remove and update callers once all agents use + # builder.make_policy + if not evaluation and experiment.policy_network_factory: + return experiment.policy_network_factory(networks) + if evaluation and experiment.eval_policy_network_factory: + return experiment.eval_policy_network_factory(networks) + return experiment.builder.make_policy( + networks=networks, + environment_spec=environment_spec, + evaluation=evaluation) \ No newline at end of file diff --git a/acme/acme/jax/experiments/make_distributed_experiment.py b/acme/acme/jax/experiments/make_distributed_experiment.py new file mode 100644 index 00000000..356fb7a7 --- /dev/null +++ b/acme/acme/jax/experiments/make_distributed_experiment.py @@ -0,0 +1,286 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Program definition for a distributed layout based on a builder.""" + +import itertools +from typing import Any, Optional + +from acme import core +from acme import environment_loop +from acme import specs +from acme.agents.jax import builders +from acme.jax import networks as networks_lib +from acme.jax import savers +from acme.jax import utils +from acme.jax.experiments import config +from acme.jax import snapshotter +from acme.utils import counting +from acme.utils import lp_utils +import jax +import launchpad as lp +import reverb + +ActorId = int + + +def make_distributed_experiment( + experiment: config.ExperimentConfig[builders.Networks, Any, Any], + num_actors: int, + *, + num_learner_nodes: int = 1, + num_actors_per_node: int = 1, + multithreading_colocate_learner_and_reverb: bool = False, + make_snapshot_models: Optional[config.SnapshotModelFactory[ + builders.Networks]] = None, + name: str = 'agent', + program: Optional[lp.Program] = None) -> lp.Program: + """Builds a Launchpad program for running the experiment. + + Args: + experiment: configuration of the experiment. + num_actors: number of actors to run. + num_learner_nodes: number of learner nodes to run. When using multiple + learner nodes, make sure the learner class does the appropriate pmap/pmean + operations on the loss/gradients, respectively. + num_actors_per_node: number of actors per one program node. Actors within + one node are colocated in one process. + multithreading_colocate_learner_and_reverb: whether to colocate the learner + and reverb nodes in one process. Not supported if the learner is spread + across multiple nodes (num_learner_nodes > 1). False by default, which + means no colocation. + make_snapshot_models: a factory that defines what is saved in snapshots. + name: name of the constructed program. Ignored if an existing program is + passed. + program: a program where agent nodes are added to. If None, a new program is + created. + + Returns: + The Launchpad program with all the nodes needed for running the experiment. + """ + + if multithreading_colocate_learner_and_reverb and num_learner_nodes > 1: + raise ValueError( + 'Replay and learner colocation is not yet supported when the learner is' + ' spread across multiple nodes (num_learner_nodes > 1). Please contact' + ' Acme devs if this is a feature you want. Got:' + '\tmultithreading_colocate_learner_and_reverb=' + f'{multithreading_colocate_learner_and_reverb}' + f'\tnum_learner_nodes={num_learner_nodes}.') + + def build_replay(): + """The replay storage.""" + dummy_seed = 1 + spec = ( + experiment.environment_spec or + specs.make_environment_spec(experiment.environment_factory(dummy_seed))) + network = experiment.network_factory(spec) + policy = config.make_policy( + experiment=experiment, + networks=network, + environment_spec=spec, + evaluation=False) + return experiment.builder.make_replay_tables(spec, policy) + + def build_model_saver(variable_source: core.VariableSource): + assert experiment.checkpointing + environment = experiment.environment_factory(0) + spec = specs.make_environment_spec(environment) + networks = experiment.network_factory(spec) + models = make_snapshot_models(networks, spec) + # TODO(raveman): Decouple checkpointing and snapshotting configs. + return snapshotter.JAXSnapshotter( + variable_source=variable_source, + models=models, + path=experiment.checkpointing.directory, + subdirectory='snapshots', + add_uid=experiment.checkpointing.add_uid) + + def build_counter(): + counter = counting.Counter() + if experiment.checkpointing: + counter = savers.CheckpointingRunner( + counter, + key='counter', + subdirectory='counter', + time_delta_minutes=experiment.checkpointing.time_delta_minutes, + directory=experiment.checkpointing.directory, + add_uid=experiment.checkpointing.add_uid, + max_to_keep=experiment.checkpointing.max_to_keep) + return counter + + def build_learner( + random_key: networks_lib.PRNGKey, + replay: reverb.Client, + counter: Optional[counting.Counter] = None, + primary_learner: Optional[core.Learner] = None, + ): + """The Learning part of the agent.""" + + dummy_seed = 1 + spec = ( + experiment.environment_spec or + specs.make_environment_spec(experiment.environment_factory(dummy_seed))) + + # Creates the networks to optimize (online) and target networks. + networks = experiment.network_factory(spec) + + iterator = experiment.builder.make_dataset_iterator(replay) + # make_dataset_iterator is responsible for putting data onto appropriate + # training devices, so here we apply prefetch, so that data is copied over + # in the background. + iterator = utils.prefetch(iterable=iterator, buffer_size=1) + counter = counting.Counter(counter, 'learner') + learner = experiment.builder.make_learner(random_key, networks, iterator, + experiment.logger_factory, spec, + replay, counter) + + if experiment.checkpointing: + if primary_learner is None: + learner = savers.CheckpointingRunner( + learner, + key='learner', + subdirectory='learner', + time_delta_minutes=5, + directory=experiment.checkpointing.directory, + add_uid=experiment.checkpointing.add_uid, + max_to_keep=experiment.checkpointing.max_to_keep) + else: + learner.restore(primary_learner.save()) + # NOTE: This initially synchronizes secondary learner states with the + # primary one. Further synchronization should be handled by the learner + # properly doing a pmap/pmean on the loss/gradients, respectively. + + return learner + + def build_actor( + random_key: networks_lib.PRNGKey, + replay: reverb.Client, + variable_source: core.VariableSource, + counter: counting.Counter, + actor_id: ActorId, + ) -> environment_loop.EnvironmentLoop: + """The actor process.""" + environment_key, actor_key = jax.random.split(random_key) + # Create environment and policy core. + + # Environments normally require uint32 as a seed. + environment = experiment.environment_factory( + utils.sample_uint32(environment_key)) + environment_spec = specs.make_environment_spec(environment) + + networks = experiment.network_factory(environment_spec) + policy_network = config.make_policy( + experiment=experiment, + networks=networks, + environment_spec=environment_spec, + evaluation=False) + adder = experiment.builder.make_adder(replay, environment_spec, + policy_network) + actor = experiment.builder.make_actor(actor_key, policy_network, + environment_spec, variable_source, + adder) + + # Create logger and counter. + counter = counting.Counter(counter, 'actor') + logger = experiment.logger_factory('actor', counter.get_steps_key(), + actor_id) + # Create the loop to connect environment and agent. + return environment_loop.EnvironmentLoop( + environment, actor, counter, logger, observers=experiment.observers) + + if not program: + program = lp.Program(name=name) + + key = jax.random.PRNGKey(experiment.seed) + + checkpoint_time_delta_minutes: Optional[int] = ( + experiment.checkpointing.replay_checkpointing_time_delta_minutes + if experiment.checkpointing else None) + replay_node = lp.ReverbNode( + build_replay, checkpoint_time_delta_minutes=checkpoint_time_delta_minutes) + replay = replay_node.create_handle() + + counter = program.add_node(lp.CourierNode(build_counter), label='counter') + + if experiment.max_num_actor_steps is not None: + program.add_node( + lp.CourierNode(lp_utils.StepsLimiter, counter, + experiment.max_num_actor_steps), + label='counter') + + learner_key, key = jax.random.split(key) + learner_node = lp.CourierNode(build_learner, learner_key, replay, counter) + learner = learner_node.create_handle() + variable_sources = [learner] + + if multithreading_colocate_learner_and_reverb: + program.add_node( + lp.MultiThreadingColocation([learner_node, replay_node]), + label='learner') + else: + program.add_node(replay_node, label='replay') + + with program.group('learner'): + program.add_node(learner_node) + + # Maybe create secondary learners, necessary when using multi-host + # accelerators. + # Warning! If you set num_learner_nodes > 1, make sure the learner class + # does the appropriate pmap/pmean operations on the loss/gradients, + # respectively. + for _ in range(1, num_learner_nodes): + learner_key, key = jax.random.split(key) + variable_sources.append( + program.add_node( + lp.CourierNode( + build_learner, learner_key, replay, + primary_learner=learner))) + # NOTE: Secondary learners are used to load-balance get_variables calls, + # which is why they get added to the list of available variable sources. + # NOTE: Only the primary learner checkpoints. + # NOTE: Do not pass the counter to the secondary learners to avoid + # double counting of learner steps. + + with program.group('actor'): + # Create all actor threads. + *actor_keys, key = jax.random.split(key, num_actors + 1) + variable_sources = itertools.cycle(variable_sources) + actor_nodes = [ + lp.CourierNode(build_actor, akey, replay, vsource, counter, aid) + for aid, (akey, + vsource) in enumerate(zip(actor_keys, variable_sources)) + ] + + # Create (maybe colocated) actor nodes. + if num_actors_per_node == 1: + for actor_node in actor_nodes: + program.add_node(actor_node) + else: + for i in range(0, num_actors, num_actors_per_node): + program.add_node( + lp.MultiThreadingColocation(actor_nodes[i:i + num_actors_per_node])) + + for evaluator in experiment.get_evaluator_factories(): + evaluator_key, key = jax.random.split(key) + program.add_node( + lp.CourierNode(evaluator, evaluator_key, learner, counter, + experiment.builder.make_actor), + label='evaluator') + + if make_snapshot_models and experiment.checkpointing: + program.add_node( + lp.CourierNode(build_model_saver, learner), label='model_saver') + + return program \ No newline at end of file diff --git a/acme/acme/jax/experiments/make_distributed_offline_experiment.py b/acme/acme/jax/experiments/make_distributed_offline_experiment.py new file mode 100644 index 00000000..dea486cf --- /dev/null +++ b/acme/acme/jax/experiments/make_distributed_offline_experiment.py @@ -0,0 +1,145 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Program definition for a distributed layout for an offline RL experiment.""" + +from typing import Callable, Dict, Optional + +from acme import core +from acme import specs +from acme.agents.jax import builders +from acme.jax import networks as networks_lib +from acme.jax import savers +from acme.jax import types +from acme.jax import utils +from acme.jax.experiments import config +from acme.jax import snapshotter +from acme.utils import counting +from acme.utils import lp_utils +import jax +import launchpad as lp + + +SnapshotModelFactory = Callable[[builders.Networks, specs.EnvironmentSpec], + Dict[str, Callable[[core.VariableSource], + types.ModelToSnapshot]]] + + +def make_distributed_offline_experiment( + experiment: config.OfflineExperimentConfig, + *, + checkpointing_config: Optional[config.CheckpointingConfig] = None, + make_snapshot_models: Optional[SnapshotModelFactory] = None, + name='agent', + program: Optional[lp.Program] = None): + """Builds distributed agent based on a builder.""" + + if checkpointing_config is None: + checkpointing_config = config.CheckpointingConfig() + + def build_model_saver(variable_source: core.VariableSource): + environment = experiment.environment_factory(0) + spec = specs.make_environment_spec(environment) + networks = experiment.network_factory(spec) + models = make_snapshot_models(networks, spec) + # TODO(raveman): Decouple checkpointing and snahpshotting configs. + return snapshotter.JAXSnapshotter( + variable_source=variable_source, + models=models, + path=checkpointing_config.directory, + add_uid=checkpointing_config.add_uid) + + def build_counter(): + return savers.CheckpointingRunner( + counting.Counter(), + key='counter', + subdirectory='counter', + time_delta_minutes=5, + directory=checkpointing_config.directory, + add_uid=checkpointing_config.add_uid, + max_to_keep=checkpointing_config.max_to_keep) + + def build_learner( + random_key: networks_lib.PRNGKey, + counter: Optional[counting.Counter] = None, + ): + """The Learning part of the agent.""" + + dummy_seed = 1 + spec = ( + experiment.environment_spec or + specs.make_environment_spec(experiment.environment_factory(dummy_seed))) + + # Creates the networks to optimize (online) and target networks. + networks = experiment.network_factory(spec) + + dataset_key, random_key = jax.random.split(random_key) + iterator = experiment.demonstration_dataset_factory(dataset_key) + # make_demonstrations is responsible for putting data onto appropriate + # training devices, so here we apply prefetch, so that data is copied over + # in the background. + iterator = utils.prefetch(iterable=iterator, buffer_size=1) + counter = counting.Counter(counter, 'learner') + learner = experiment.builder.make_learner( + random_key=random_key, + networks=networks, + dataset=iterator, + logger_fn=experiment.logger_factory, + environment_spec=spec, + counter=counter) + + learner = savers.CheckpointingRunner( + learner, + key='learner', + subdirectory='learner', + time_delta_minutes=5, + directory=checkpointing_config.directory, + add_uid=checkpointing_config.add_uid, + max_to_keep=checkpointing_config.max_to_keep) + + return learner + + if not program: + program = lp.Program(name=name) + + key = jax.random.PRNGKey(experiment.seed) + + counter = program.add_node(lp.CourierNode(build_counter), label='counter') + + if experiment.max_num_learner_steps is not None: + program.add_node( + lp.CourierNode( + lp_utils.StepsLimiter, + counter, + experiment.max_num_learner_steps, + steps_key='learner_steps'), + label='counter') + + learner_key, key = jax.random.split(key) + learner_node = lp.CourierNode(build_learner, learner_key, counter) + learner = learner_node.create_handle() + program.add_node(learner_node, label='learner') + + for evaluator in experiment.get_evaluator_factories(): + evaluator_key, key = jax.random.split(key) + program.add_node( + lp.CourierNode(evaluator, evaluator_key, learner, counter, + experiment.builder.make_actor), + label='evaluator') + + if make_snapshot_models and checkpointing_config: + program.add_node(lp.CourierNode(build_model_saver, learner), + label='model_saver') + + return program diff --git a/acme/acme/jax/experiments/run_experiment.py b/acme/acme/jax/experiments/run_experiment.py new file mode 100644 index 00000000..9f7457ad --- /dev/null +++ b/acme/acme/jax/experiments/run_experiment.py @@ -0,0 +1,276 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runners used for executing local agents.""" + +import sys +import time +from typing import Optional, Sequence, Tuple + +import acme +from acme import core +from acme import specs +from acme import types +from acme.jax import utils +from acme.jax.experiments import config +from acme.tf import savers +from acme.utils import counting +import dm_env +import jax +import reverb + + +def run_experiment(experiment: config.ExperimentConfig, + eval_every: int = 100, + num_eval_episodes: int = 1): + """Runs a simple, single-threaded training loop using the default evaluators. + + It targets simplicity of the code and so only the basic features of the + ExperimentConfig are supported. + + Arguments: + experiment: Definition and configuration of the agent to run. + eval_every: After how many actor steps to perform evaluation. + num_eval_episodes: How many evaluation episodes to execute at each + evaluation step. + """ + + key = jax.random.PRNGKey(experiment.seed) + + # Create the environment and get its spec. + environment = experiment.environment_factory(experiment.seed) + environment_spec = experiment.environment_spec or specs.make_environment_spec( + environment) + + # Create the networks and policy. + networks = experiment.network_factory(environment_spec) + policy = config.make_policy( + experiment=experiment, + networks=networks, + environment_spec=environment_spec, + evaluation=False) + + # Create the replay server and grab its address. + replay_tables = experiment.builder.make_replay_tables(environment_spec, + policy) + + # Disable blocking of inserts by tables' rate limiters, as this function + # executes learning (sampling from the table) and data generation + # (inserting into the table) sequentially from the same thread + # which could result in blocked insert making the algorithm hang. + replay_tables, rate_limiters_max_diff = _disable_insert_blocking( + replay_tables) + + replay_server = reverb.Server(replay_tables, port=None) + replay_client = reverb.Client(f'localhost:{replay_server.port}') + + # Parent counter allows to share step counts between train and eval loops and + # the learner, so that it is possible to plot for example evaluator's return + # value as a function of the number of training episodes. + parent_counter = counting.Counter(time_delta=0.) + + dataset = experiment.builder.make_dataset_iterator(replay_client) + # We always use prefetch as it provides an iterator with an additional + # 'ready' method. + dataset = utils.prefetch(dataset, buffer_size=1) + + # Create actor, adder, and learner for generating, storing, and consuming + # data respectively. + # NOTE: These are created in reverse order as the actor needs to be given the + # adder and the learner (as a source of variables). + learner_key, key = jax.random.split(key) + learner = experiment.builder.make_learner( + random_key=learner_key, + networks=networks, + dataset=dataset, + logger_fn=experiment.logger_factory, + environment_spec=environment_spec, + replay_client=replay_client, + counter=counting.Counter(parent_counter, prefix='learner', time_delta=0.)) + + adder = experiment.builder.make_adder(replay_client, environment_spec, policy) + + actor_key, key = jax.random.split(key) + actor = experiment.builder.make_actor( + actor_key, policy, environment_spec, variable_source=learner, adder=adder) + + # Create the environment loop used for training. + train_counter = counting.Counter( + parent_counter, prefix='actor', time_delta=0.) + train_logger = experiment.logger_factory('actor', + train_counter.get_steps_key(), 0) + + checkpointer = None + if experiment.checkpointing is not None: + checkpointer = savers.Checkpointer( + objects_to_save={ + 'learner': learner, + 'counter': parent_counter + }, + time_delta_minutes=experiment.checkpointing.time_delta_minutes, + directory=experiment.checkpointing.directory, + add_uid=experiment.checkpointing.add_uid, + max_to_keep=experiment.checkpointing.max_to_keep) + + # Replace the actor with a LearningActor. This makes sure that every time + # that `update` is called on the actor it checks to see whether there is + # any new data to learn from and if so it runs a learner step. The rate + # at which new data is released is controlled by the replay table's + # rate_limiter which is created by the builder.make_replay_tables call above. + actor = _LearningActor(actor, learner, dataset, replay_tables, + rate_limiters_max_diff, checkpointer) + + train_loop = acme.EnvironmentLoop( + environment, + actor, + counter=train_counter, + logger=train_logger, + observers=experiment.observers) + + max_num_actor_steps = ( + experiment.max_num_actor_steps - + parent_counter.get_counts().get(train_counter.get_steps_key(), 0)) + + if num_eval_episodes == 0: + # No evaluation. Just run the training loop. + train_loop.run(num_steps=max_num_actor_steps) + return + + # Create the evaluation actor and loop. + eval_counter = counting.Counter( + parent_counter, prefix='evaluator', time_delta=0.) + eval_logger = experiment.logger_factory('evaluator', + eval_counter.get_steps_key(), 0) + eval_policy = config.make_policy( + experiment=experiment, + networks=networks, + environment_spec=environment_spec, + evaluation=True) + eval_actor = experiment.builder.make_actor( + random_key=jax.random.PRNGKey(experiment.seed), + policy=eval_policy, + environment_spec=environment_spec, + variable_source=learner) + eval_loop = acme.EnvironmentLoop( + environment, + eval_actor, + counter=eval_counter, + logger=eval_logger, + observers=experiment.observers) + + steps = 0 + while steps < max_num_actor_steps: + eval_loop.run(num_episodes=num_eval_episodes) + steps += train_loop.run(num_steps=eval_every) + eval_loop.run(num_episodes=num_eval_episodes) + + +class _LearningActor(core.Actor): + """Actor which learns (updates its parameters) when `update` is called. + + This combines a base actor and a learner. Whenever `update` is called + on the wrapping actor the learner will take a step (e.g. one step of gradient + descent) as long as there is data available for training + (provided iterator and replay_tables are used to check for that). + Selecting actions and making observations are handled by the base actor. + Intended to be used by the `run_experiment` only. + """ + + def __init__(self, actor: core.Actor, learner: core.Learner, + iterator: core.PrefetchingIterator, + replay_tables: Sequence[reverb.Table], + sample_sizes: Sequence[int], + checkpointer: Optional[savers.Checkpointer]): + """Initializes _LearningActor. + + Args: + actor: Actor to be wrapped. + learner: Learner on which step() is to be called when there is data. + iterator: Iterator used by the Learner to fetch training data. + replay_tables: Collection of tables from which Learner fetches data + through the iterator. + sample_sizes: For each table from `replay_tables`, how many elements the + table should have available for sampling to wait for the `iterator` to + prefetch a batch of data. Otherwise more experience needs to be + collected by the actor. + checkpointer: Checkpointer to save the state on update. + """ + self._actor = actor + self._learner = learner + self._iterator = iterator + self._replay_tables = replay_tables + self._sample_sizes = sample_sizes + self._learner_steps = 0 + self._checkpointer = checkpointer + + def select_action(self, observation: types.NestedArray) -> types.NestedArray: + return self._actor.select_action(observation) + + def observe_first(self, timestep: dm_env.TimeStep): + self._actor.observe_first(timestep) + + def observe(self, action: types.NestedArray, next_timestep: dm_env.TimeStep): + self._actor.observe(action, next_timestep) + + def _maybe_train(self): + trained = False + while True: + if self._iterator.ready(): + self._learner.step() + batches = self._iterator.retrieved_elements() - self._learner_steps + self._learner_steps += 1 + assert batches == 1, ( + 'Learner step must retrieve exactly one element from the iterator' + f' (retrieved {batches}). Otherwise agent can deadlock. Example ' + 'cause is that your chosen agent' + 's Builder has a `make_learner` ' + 'factory that prefetches the data but it shouldn' + 't.') + trained = True + else: + # Wait for the iterator to fetch more data from the table(s) only + # if there plenty of data to sample from each table. + for table, sample_size in zip(self._replay_tables, self._sample_sizes): + if not table.can_sample(sample_size): + return trained + # Let iterator's prefetching thread get data from the table(s). + time.sleep(0.001) + + def update(self): + if self._maybe_train(): + # Update the actor weights only when learner was updated. + self._actor.update() + if self._checkpointer: + self._checkpointer.save() + + +def _disable_insert_blocking( + tables: Sequence[reverb.Table] +) -> Tuple[Sequence[reverb.Table], Sequence[int]]: + """Disables blocking of insert operations for a given collection of tables.""" + modified_tables = [] + sample_sizes = [] + for table in tables: + rate_limiter_info = table.info.rate_limiter_info + rate_limiter = reverb.rate_limiters.RateLimiter( + samples_per_insert=rate_limiter_info.samples_per_insert, + min_size_to_sample=rate_limiter_info.min_size_to_sample, + min_diff=rate_limiter_info.min_diff, + max_diff=sys.float_info.max) + modified_tables.append(table.replace(rate_limiter=rate_limiter)) + # Target the middle of the rate limiter's insert-sample balance window. + sample_sizes.append( + max(1, int( + (rate_limiter_info.max_diff - rate_limiter_info.min_diff) / 2))) + return modified_tables, sample_sizes \ No newline at end of file diff --git a/acme/acme/jax/experiments/run_offline_experiment.py b/acme/acme/jax/experiments/run_offline_experiment.py new file mode 100644 index 00000000..79a06616 --- /dev/null +++ b/acme/acme/jax/experiments/run_offline_experiment.py @@ -0,0 +1,96 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runner used for executing local offline RL agents.""" + +import acme +from acme import specs +from acme.jax.experiments import config +from acme.utils import counting +import jax + + +def run_offline_experiment(experiment: config.OfflineExperimentConfig, + eval_every: int = 100, + num_eval_episodes: int = 1): + """Runs a simple, single-threaded training loop using the default evaluators. + + It targets simplicity of the code and so only the basic features of the + OfflineExperimentConfig are supported. + + Arguments: + experiment: Definition and configuration of the agent to run. + eval_every: After how many learner steps to perform evaluation. + num_eval_episodes: How many evaluation episodes to execute at each + evaluation step. + """ + + key = jax.random.PRNGKey(experiment.seed) + + # Create the environment and get its spec. + environment = experiment.environment_factory(experiment.seed) + environment_spec = experiment.environment_spec or specs.make_environment_spec( + environment) + + # Create the networks and policy. + networks = experiment.network_factory(environment_spec) + + # Parent counter allows to share step counts between train and eval loops and + # the learner, so that it is possible to plot for example evaluator's return + # value as a function of the number of training episodes. + parent_counter = counting.Counter(time_delta=0.) + + # Create the demonstrations dataset. + dataset_key, key = jax.random.split(key) + dataset = experiment.demonstration_dataset_factory(dataset_key) + + # Create the learner. + learner_key, key = jax.random.split(key) + learner = experiment.builder.make_learner( + random_key=learner_key, + networks=networks, + dataset=dataset, + logger_fn=experiment.logger_factory, + environment_spec=environment_spec, + counter=counting.Counter(parent_counter, prefix='learner', time_delta=0.)) + + # Define the evaluation loop. + eval_loop = None + if num_eval_episodes > 0: + # Create the evaluation actor and loop. + eval_logger = experiment.logger_factory('eval', 'eval_steps', 0) + eval_key, key = jax.random.split(key) + eval_actor = experiment.builder.make_actor( + random_key=eval_key, + policy=experiment.builder.make_policy(networks, environment_spec, True), + environment_spec=environment_spec, + variable_source=learner) + eval_loop = acme.EnvironmentLoop( + environment, + eval_actor, + counter=counting.Counter(parent_counter, prefix='eval', time_delta=0.), + logger=eval_logger, + observers=experiment.observers) + + # Run the training loop. + if eval_loop: + eval_loop.run(num_eval_episodes) + steps = 0 + while steps < experiment.max_num_learner_steps: + learner_steps = min(eval_every, experiment.max_num_learner_steps - steps) + for _ in range(learner_steps): + learner.step() + if eval_loop: + eval_loop.run(num_eval_episodes) + steps += learner_steps diff --git a/acme/acme/jax/imitation_learning_types.py b/acme/acme/jax/imitation_learning_types.py new file mode 100644 index 00000000..996e6e26 --- /dev/null +++ b/acme/acme/jax/imitation_learning_types.py @@ -0,0 +1,22 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""JAX type definitions for imitation and apprenticeship learning algorithms.""" + +from typing import TypeVar + +# Common TypeVars that correspond to various aspects of the direct RL algorithm. +DirectPolicyNetwork = TypeVar('DirectPolicyNetwork') +DirectRLNetworks = TypeVar('DirectRLNetworks') +DirectRLTrainingState = TypeVar('DirectRLTrainingState') diff --git a/acme/acme/jax/layouts/__init__.py b/acme/acme/jax/layouts/__init__.py new file mode 100644 index 00000000..240cb715 --- /dev/null +++ b/acme/acme/jax/layouts/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/acme/acme/jax/layouts/distributed_layout.py b/acme/acme/jax/layouts/distributed_layout.py new file mode 100644 index 00000000..ab9b9d3d --- /dev/null +++ b/acme/acme/jax/layouts/distributed_layout.py @@ -0,0 +1,193 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Program definition for a distributed layout based on a builder.""" + +from typing import Callable, Dict, Optional, Sequence + +from acme import core +from acme import environment_loop +from acme import specs +from acme.agents.jax import builders +from acme.jax import experiments +from acme.jax import types +from acme.jax import utils +from acme.utils import counting +from acme.utils import loggers +from acme.utils import observers as observers_lib +import jax +import launchpad as lp + +# TODO(stanczyk): Remove when use cases are ported to the new location. +EvaluatorFactory = experiments.config.EvaluatorFactory +AgentNetwork = experiments.config.AgentNetwork +PolicyNetwork = experiments.config.PolicyNetwork +NetworkFactory = experiments.config.NetworkFactory +PolicyFactory = experiments.config.DeprecatedPolicyFactory +MakeActorFn = experiments.config.MakeActorFn +LoggerLabel = loggers.LoggerLabel +LoggerStepsKey = loggers.LoggerStepsKey +LoggerFn = Callable[[LoggerLabel, LoggerStepsKey], loggers.Logger] +EvaluatorFactory = experiments.config.EvaluatorFactory + +ActorId = int + +SnapshotModelFactory = Callable[ + [experiments.config.AgentNetwork, specs.EnvironmentSpec], + Dict[str, Callable[[core.VariableSource], types.ModelToSnapshot]]] + +CheckpointingConfig = experiments.CheckpointingConfig + + +def default_evaluator_factory( + environment_factory: types.EnvironmentFactory, + network_factory: NetworkFactory, + policy_factory: PolicyFactory, + observers: Sequence[observers_lib.EnvLoopObserver] = (), + save_logs: bool = False, + logger_fn: Optional[LoggerFn] = None) -> EvaluatorFactory: + """Returns a default evaluator process.""" + + def evaluator( + random_key: types.PRNGKey, + variable_source: core.VariableSource, + counter: counting.Counter, + make_actor: MakeActorFn, + ): + """The evaluation process.""" + + # Create environment and evaluator networks + environment_key, actor_key = jax.random.split(random_key) + # Environments normally require uint32 as a seed. + environment = environment_factory(utils.sample_uint32(environment_key)) + environment_spec = specs.make_environment_spec(environment) + policy = policy_factory(network_factory(environment_spec)) + + actor = make_actor(actor_key, policy, environment_spec, variable_source) + + # Create logger and counter. + counter = counting.Counter(counter, 'evaluator') + if logger_fn is not None: + logger = logger_fn('evaluator', 'actor_steps') + else: + logger = loggers.make_default_logger( + 'evaluator', save_logs, steps_key='actor_steps') + + # Create the run loop and return it. + return environment_loop.EnvironmentLoop( + environment, actor, counter, logger, observers=observers) + + return evaluator + + +def get_default_logger_fn( + save_logs: bool = False, + log_every: float = 10) -> Callable[[ActorId], loggers.Logger]: + """Creates an actor logger.""" + + def create_logger(actor_id: ActorId): + return loggers.make_default_logger( + 'actor', + save_data=(save_logs and actor_id == 0), + time_delta=log_every, + steps_key='actor_steps') + + return create_logger + + +def logger_factory( + learner_logger_fn: Optional[Callable[[], loggers.Logger]] = None, + actor_logger_fn: Optional[Callable[[ActorId], loggers.Logger]] = None, + save_logs: bool = True, + log_every: float = 10.0) -> Callable[[str, str, int], loggers.Logger]: + """Builds a logger factory used by the experiments.config.""" + + def factory(label: str, + steps_key: Optional[str] = None, + task_id: Optional[int] = None): + if task_id is None: + task_id = 0 + if steps_key is None: + steps_key = f'{label}_steps' + if label == 'learner' and learner_logger_fn: + return learner_logger_fn() + if label == 'actor': + if actor_logger_fn: + return actor_logger_fn(task_id) + else: + return get_default_logger_fn(save_logs)(task_id) + if label == 'evaluator': + return loggers.make_default_logger( + label, save_logs, time_delta=log_every, steps_key=steps_key) + return None + + return factory + + +class DistributedLayout: + """Program definition for a distributed agent based on a builder. + + DEPRECATED: Use make_distributed_experiment directly. + """ + + def __init__( + self, + seed: int, + environment_factory: types.EnvironmentFactory, + network_factory: experiments.config.NetworkFactory, + builder: builders.ActorLearnerBuilder, + policy_network: experiments.config.DeprecatedPolicyFactory, + num_actors: int, + environment_spec: Optional[specs.EnvironmentSpec] = None, + learner_logger_fn: Optional[Callable[[], loggers.Logger]] = None, + actor_logger_fn: Optional[Callable[[ActorId], loggers.Logger]] = None, + evaluator_factories: Sequence[experiments.config.EvaluatorFactory] = (), + prefetch_size: int = 1, + save_logs: bool = False, + max_number_of_steps: Optional[int] = None, + observers: Sequence[observers_lib.EnvLoopObserver] = (), + multithreading_colocate_learner_and_reverb: bool = False, + checkpointing_config: Optional[CheckpointingConfig] = None, + make_snapshot_models: Optional[SnapshotModelFactory] = None): + del prefetch_size + self._experiment_config = experiments.config.ExperimentConfig( + builder=builder, + environment_factory=environment_factory, + environment_spec=environment_spec, + network_factory=network_factory, + policy_network_factory=policy_network, + evaluator_factories=evaluator_factories, + observers=observers, + seed=seed, + max_num_actor_steps=max_number_of_steps, + logger_factory=logger_factory(learner_logger_fn, actor_logger_fn, + save_logs)) + self._num_actors = num_actors + self._multithreading_colocate_learner_and_reverb = ( + multithreading_colocate_learner_and_reverb) + self._checkpointing_config = checkpointing_config + self._make_snapshot_models = make_snapshot_models + + def build(self, name='agent', program: Optional[lp.Program] = None): + """Build the distributed agent topology.""" + + return experiments.make_distributed_experiment( + self._experiment_config, + self._num_actors, + multithreading_colocate_learner_and_reverb=self + ._multithreading_colocate_learner_and_reverb, + checkpointing_config=self._checkpointing_config, + make_snapshot_models=self._make_snapshot_models, + name=name, + program=program) diff --git a/acme/acme/jax/layouts/local_layout.py b/acme/acme/jax/layouts/local_layout.py new file mode 100644 index 00000000..b288cea8 --- /dev/null +++ b/acme/acme/jax/layouts/local_layout.py @@ -0,0 +1,153 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Local agent based on builders.""" + +import sys +from typing import Any, Optional + +from acme import specs +from acme.agents import agent +from acme.agents.jax import builders +from acme.jax import utils +from acme.tf import savers +from acme.utils import counting +from acme.utils import loggers +import jax +import reverb + + +class LocalLayout(agent.Agent): + """An Agent that runs an algorithm defined by 'builder' on a single machine. + """ + + def __init__( + self, + seed: int, + environment_spec: specs.EnvironmentSpec, + builder: builders.ActorLearnerBuilder, + networks: Any, + policy_network: Any, + learner_logger: Optional[loggers.Logger] = None, + workdir: Optional[str] = '~/acme', + batch_size: int = 256, + num_sgd_steps_per_step: int = 1, + prefetch_size: int = 1, + counter: Optional[counting.Counter] = None, + checkpoint: bool = True, + ): + """Initialize the agent. + + Args: + seed: A random seed to use for this layout instance. + environment_spec: description of the actions, observations, etc. + builder: builder defining an RL algorithm to train. + networks: network objects to be passed to the learner. + policy_network: function that given an observation returns actions. + learner_logger: logger used by the learner. + workdir: if provided saves the state of the learner and the counter + (if the counter is not None) into workdir. + batch_size: batch size for updates. + num_sgd_steps_per_step: how many sgd steps a learner does per 'step' call. + For performance reasons (especially to reduce TPU host-device transfer + times) it is performance-beneficial to do multiple sgd updates at once, + provided that it does not hurt the training, which needs to be verified + empirically for each environment. + prefetch_size: whether to prefetch iterator. + counter: counter object used to keep track of steps. + checkpoint: boolean indicating whether to checkpoint the learner + and the counter (if the counter is not None). + """ + if prefetch_size < 0: + raise ValueError(f'Prefetch size={prefetch_size} should be non negative') + + key = jax.random.PRNGKey(seed) + + # Create the replay server and grab its address. + replay_tables = builder.make_replay_tables(environment_spec, policy_network) + + # Disable blocking of inserts by tables' rate limiters, as LocalLayout + # agents run inserts and sampling from the same thread and blocked insert + # would result in a hang. + new_tables = [] + for table in replay_tables: + rl_info = table.info.rate_limiter_info + rate_limiter = reverb.rate_limiters.RateLimiter( + samples_per_insert=rl_info.samples_per_insert, + min_size_to_sample=rl_info.min_size_to_sample, + min_diff=rl_info.min_diff, + max_diff=sys.float_info.max) + new_tables.append(table.replace(rate_limiter=rate_limiter)) + replay_tables = new_tables + + replay_server = reverb.Server(replay_tables, port=None) + replay_client = reverb.Client(f'localhost:{replay_server.port}') + + # Create actor, dataset, and learner for generating, storing, and consuming + # data respectively. + adder = builder.make_adder(replay_client, environment_spec, policy_network) + + dataset = builder.make_dataset_iterator(replay_client) + # We always use prefetch, as it provides an iterator with additional + # 'ready' method. + dataset = utils.prefetch(dataset, buffer_size=prefetch_size) + learner_key, key = jax.random.split(key) + learner = builder.make_learner( + random_key=learner_key, + networks=networks, + dataset=dataset, + logger_fn=( + lambda label, steps_key=None, task_instance=None: learner_logger), + environment_spec=environment_spec, + replay_client=replay_client, + counter=counter) + if not checkpoint or workdir is None: + self._checkpointer = None + else: + objects_to_save = {'learner': learner} + if counter is not None: + objects_to_save.update({'counter': counter}) + self._checkpointer = savers.Checkpointer( + objects_to_save, + time_delta_minutes=30, + subdirectory='learner', + directory=workdir, + add_uid=(workdir == '~/acme')) + + actor_key, key = jax.random.split(key) + actor = builder.make_actor( + actor_key, + policy_network, + environment_spec, + variable_source=learner, + adder=adder) + + super().__init__( + actor=actor, + learner=learner, + iterator=dataset, + replay_tables=replay_tables) + + # Save the replay so we don't garbage collect it. + self._replay_server = replay_server + + def update(self): + super().update() + if self._checkpointer: + self._checkpointer.save() + + def save(self): + """Checkpoint the state of the agent.""" + if self._checkpointer: + self._checkpointer.save(force=True) diff --git a/acme/acme/jax/layouts/offline_distributed_layout.py b/acme/acme/jax/layouts/offline_distributed_layout.py new file mode 100644 index 00000000..d6922731 --- /dev/null +++ b/acme/acme/jax/layouts/offline_distributed_layout.py @@ -0,0 +1,128 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Program definition for a distributed layout based on a builder.""" + +from typing import Any, Callable, Dict, Optional, Union, Sequence + +from acme import core +from acme.jax import networks as networks_lib +from acme.jax import savers +from acme.jax import types +from acme.jax import utils +from acme.utils import counting +from acme.utils import loggers +from acme.utils import lp_utils +import jax +import launchpad as lp + +AgentNetwork = Any +NetworkFactory = Callable[[], AgentNetwork] +# It will be treated as Dict[str, Any]. Proper support is tracked b/109648354. +NestedLogger = Union[loggers.Logger, Dict[str, 'NestedLogger']] # pytype: disable=not-supported-yet +LearnerFactory = Callable[[ + types.PRNGKey, + AgentNetwork, + Optional[counting.Counter], + Optional[NestedLogger], +], core.Learner] +EvaluatorFactory = Callable[ + [types.PRNGKey, core.VariableSource, counting.Counter], core.Worker] + + +class OfflineDistributedLayout: + """Program definition for an offline distributed agent based on a builder. + + It is distributed in the sense that evaluators run on different machines than + learner. + """ + + def __init__( + self, + seed: int, + network_factory: NetworkFactory, + make_learner: LearnerFactory, + evaluator_factories: Sequence[EvaluatorFactory] = (), + save_logs: bool = False, + log_every: float = 10.0, + max_number_of_steps: Optional[int] = None, + workdir: str = '~/acme', + ): + + self._seed = seed + self._make_learner = make_learner + self._evaluator_factories = evaluator_factories + self._network_factory = network_factory + self._save_logs = save_logs + self._log_every = log_every + self._max_number_of_steps = max_number_of_steps + self._workdir = workdir + + def counter(self): + kwargs = {'directory': self._workdir, 'add_uid': self._workdir == '~/acme'} + return savers.CheckpointingRunner( + counting.Counter(), subdirectory='counter', time_delta_minutes=5, + **kwargs) + + def learner( + self, + random_key: networks_lib.PRNGKey, + counter: counting.Counter, + ): + """The Learning part of the agent.""" + # Counter and logger. + counter = counting.Counter(counter, 'learner') + logger = loggers.make_default_logger( + 'learner', self._save_logs, time_delta=self._log_every, + asynchronous=True, serialize_fn=utils.fetch_devicearray, + steps_key='learner_steps') + + # Create the learner. + networks = self._network_factory() + learner = self._make_learner(random_key, networks, counter, logger) + + kwargs = {'directory': self._workdir, 'add_uid': self._workdir == '~/acme'} + # Return the learning agent. + return savers.CheckpointingRunner( + learner, subdirectory='learner', time_delta_minutes=5, **kwargs) + + def coordinator(self, counter: counting.Counter, max_learner_steps: int): + return lp_utils.StepsLimiter(counter, max_steps=max_learner_steps, + steps_key='learner_steps') + + def build(self, name='agent'): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + key = jax.random.PRNGKey(self._seed) + + with program.group('counter'): + counter = program.add_node(lp.CourierNode(self.counter)) + if self._max_number_of_steps is not None: + _ = program.add_node( + lp.CourierNode(self.coordinator, counter, + self._max_number_of_steps)) + + learner_key, key = jax.random.split(key) + with program.group('learner'): + learner = program.add_node( + lp.CourierNode(self.learner, learner_key, counter)) + + with program.group('evaluator'): + for evaluator in self._evaluator_factories: + evaluator_key, key = jax.random.split(key) + program.add_node( + lp.CourierNode(evaluator, evaluator_key, learner, counter)) + + return program diff --git a/acme/acme/jax/losses/__init__.py b/acme/acme/jax/losses/__init__.py new file mode 100644 index 00000000..9a1f9d95 --- /dev/null +++ b/acme/acme/jax/losses/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common loss functions.""" + +from acme.jax.losses.impala import impala_loss +from acme.jax.losses.mpo import MPO +from acme.jax.losses.mpo import MPOParams +from acme.jax.losses.mpo import MPOStats diff --git a/acme/acme/jax/losses/impala.py b/acme/acme/jax/losses/impala.py new file mode 100644 index 00000000..457f8aa7 --- /dev/null +++ b/acme/acme/jax/losses/impala.py @@ -0,0 +1,111 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Loss function for IMPALA (Espeholt et al., 2018) [1]. + +[1] https://arxiv.org/abs/1802.01561 +""" + +from typing import Callable + +from acme.agents.jax.impala import types +from acme.jax import utils +import haiku as hk +import jax.numpy as jnp +import numpy as np +import reverb +import rlax +import tree + + +def impala_loss( + unroll_fn: types.PolicyValueFn, + *, + discount: float, + max_abs_reward: float = np.inf, + baseline_cost: float = 1., + entropy_cost: float = 0., +) -> Callable[[hk.Params, reverb.ReplaySample], jnp.DeviceArray]: + """Builds the standard entropy-regularised IMPALA loss function. + + Args: + unroll_fn: A `hk.Transformed` object containing a callable which maps + (params, observations_sequence, initial_state) -> ((logits, value), state) + discount: The standard geometric discount rate to apply. + max_abs_reward: Optional symmetric reward clipping to apply. + baseline_cost: Weighting of the critic loss relative to the policy loss. + entropy_cost: Weighting of the entropy regulariser relative to policy loss. + + Returns: + A loss function with signature (params, data) -> loss_scalar. + """ + + def loss_fn(params: hk.Params, + sample: reverb.ReplaySample) -> jnp.DeviceArray: + """Batched, entropy-regularised actor-critic loss with V-trace.""" + + # Extract the data. + data = sample.data + observations, actions, rewards, discounts, extra = (data.observation, + data.action, + data.reward, + data.discount, + data.extras) + initial_state = tree.map_structure(lambda s: s[0], extra['core_state']) + behaviour_logits = extra['logits'] + + # Apply reward clipping. + rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward) + + # Unroll current policy over observations. + (logits, values), _ = unroll_fn(params, observations, initial_state) + + # Compute importance sampling weights: current policy / behavior policy. + rhos = rlax.categorical_importance_sampling_ratios(logits[:-1], + behaviour_logits[:-1], + actions[:-1]) + + # Critic loss. + vtrace_returns = rlax.vtrace_td_error_and_advantage( + v_tm1=values[:-1], + v_t=values[1:], + r_t=rewards[:-1], + discount_t=discounts[:-1] * discount, + rho_tm1=rhos) + critic_loss = jnp.square(vtrace_returns.errors) + + # Policy gradient loss. + policy_gradient_loss = rlax.policy_gradient_loss( + logits_t=logits[:-1], + a_t=actions[:-1], + adv_t=vtrace_returns.pg_advantage, + w_t=jnp.ones_like(rewards[:-1])) + + # Entropy regulariser. + entropy_loss = rlax.entropy_loss(logits[:-1], jnp.ones_like(rewards[:-1])) + + # Combine weighted sum of actor & critic losses, averaged over the sequence. + mean_loss = jnp.mean(policy_gradient_loss + baseline_cost * critic_loss + + entropy_cost * entropy_loss) # [] + + metrics = { + 'policy_loss': jnp.mean(policy_gradient_loss), + 'critic_loss': jnp.mean(baseline_cost * critic_loss), + 'entropy_loss': jnp.mean(entropy_cost * entropy_loss), + 'entropy': jnp.mean(entropy_loss), + } + + return mean_loss, metrics + + return utils.mapreduce(loss_fn, in_axes=(None, 0)) diff --git a/acme/acme/jax/losses/impala_test.py b/acme/acme/jax/losses/impala_test.py new file mode 100644 index 00000000..b8f27396 --- /dev/null +++ b/acme/acme/jax/losses/impala_test.py @@ -0,0 +1,101 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the IMPALA loss function.""" + +from acme.adders import reverb as adders +from acme.jax.losses import impala +from acme.utils.tree_utils import tree_map +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +import reverb + +from absl.testing import absltest + + +class ImpalaTest(absltest.TestCase): + + def test_shapes(self): + + # + batch_size = 2 + sequence_len = 3 + num_actions = 5 + hidden_size = 7 + + # Define a trivial recurrent actor-critic network. + @hk.without_apply_rng + @hk.transform + def unroll_fn_transformed(observations, state): + lstm = hk.LSTM(hidden_size) + embedding, state = hk.dynamic_unroll(lstm, observations, state) + logits = hk.Linear(num_actions)(embedding) + values = jnp.squeeze(hk.Linear(1)(embedding), axis=-1) + + return (logits, values), state + + @hk.without_apply_rng + @hk.transform + def initial_state_fn(): + return hk.LSTM(hidden_size).initial_state(None) + + # Initial recurrent network state. + initial_state = initial_state_fn.apply(None) + + # Make some fake data. + observations = np.ones(shape=(sequence_len, 50)) + actions = np.random.randint(num_actions, size=sequence_len) + rewards = np.random.rand(sequence_len) + discounts = np.ones(shape=(sequence_len,)) + + batch_tile = tree_map(lambda x: np.tile(x, [batch_size, *([1] * x.ndim)])) + seq_tile = tree_map(lambda x: np.tile(x, [sequence_len, *([1] * x.ndim)])) + + extras = { + 'logits': np.random.rand(sequence_len, num_actions), + 'core_state': seq_tile(initial_state), + } + + # Package up the data into a ReverbSample. + data = adders.Step( + observations, + actions, + rewards, + discounts, + extras=extras, + start_of_episode=()) + data = batch_tile(data) + sample = reverb.ReplaySample(info=None, data=data) + + # Initialise parameters. + rng = hk.PRNGSequence(1) + params = unroll_fn_transformed.init(next(rng), observations, initial_state) + + # Make loss function. + loss_fn = impala.impala_loss( + unroll_fn_transformed.apply, discount=0.99) + + # Return value should be scalar. + loss, metrics = loss_fn(params, sample) + loss = jax.device_get(loss) + self.assertEqual(loss.shape, ()) + for value in metrics.values(): + value = jax.device_get(value) + self.assertEqual(value.shape, ()) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/jax/losses/mpo.py b/acme/acme/jax/losses/mpo.py new file mode 100644 index 00000000..6902c00c --- /dev/null +++ b/acme/acme/jax/losses/mpo.py @@ -0,0 +1,452 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implements the MPO loss. + +The MPO loss uses MPOParams, which can be initialized using init_params, +to track the temperature and the dual variables. + +Tensor shapes are annotated, where helpful, as follow: + B: batch size, + N: number of sampled actions, see MPO paper for more details, + D: dimensionality of the action space. +""" + +from typing import NamedTuple, Optional, Tuple, Union + +import jax +import jax.numpy as jnp +import tensorflow_probability + +tfp = tensorflow_probability.substrates.jax +tfd = tensorflow_probability.substrates.jax.distributions + +_MPO_FLOAT_EPSILON = 1e-8 +_MIN_LOG_TEMPERATURE = -18.0 +_MIN_LOG_ALPHA = -18.0 + +Shape = Tuple[int] +DType = type(jnp.float32) # _ScalarMeta, a private type. + + +class MPOParams(NamedTuple): + """NamedTuple to store trainable loss parameters.""" + log_temperature: jnp.ndarray + log_alpha_mean: jnp.ndarray + log_alpha_stddev: jnp.ndarray + log_penalty_temperature: Optional[jnp.ndarray] = None + + +class MPOStats(NamedTuple): + """NamedTuple to store loss statistics.""" + dual_alpha_mean: float + dual_alpha_stddev: float + dual_temperature: float + + loss_policy: float + loss_alpha: float + loss_temperature: float + kl_q_rel: float + + kl_mean_rel: float + kl_stddev_rel: float + + q_min: float + q_max: float + + pi_stddev_min: float + pi_stddev_max: float + pi_stddev_cond: float + + penalty_kl_q_rel: Optional[float] = None + + +class MPO: + """MPO loss with decoupled KL constraints as in (Abdolmaleki et al., 2018). + + This implementation of the MPO loss includes the following features, as + options: + - Satisfying the KL-constraint on a per-dimension basis (on by default); + - Penalizing actions that fall outside of [-1, 1] (on by default) as a + special case of multi-objective MPO (MO-MPO; Abdolmaleki et al., 2020). + For best results on the control suite, keep both of these on. + + (Abdolmaleki et al., 2018): https://arxiv.org/pdf/1812.02256.pdf + (Abdolmaleki et al., 2020): https://arxiv.org/pdf/2005.07513.pdf + """ + + def __init__(self, + epsilon: float, + epsilon_mean: float, + epsilon_stddev: float, + init_log_temperature: float, + init_log_alpha_mean: float, + init_log_alpha_stddev: float, + per_dim_constraining: bool = True, + action_penalization: bool = True, + epsilon_penalty: float = 0.001): + """Initialize and configure the MPO loss. + + Args: + epsilon: KL constraint on the non-parametric auxiliary policy, the one + associated with the dual variable called temperature. + epsilon_mean: KL constraint on the mean of the Gaussian policy, the one + associated with the dual variable called alpha_mean. + epsilon_stddev: KL constraint on the stddev of the Gaussian policy, the + one associated with the dual variable called alpha_mean. + init_log_temperature: initial value for the temperature in log-space, note + a softplus (rather than an exp) will be used to transform this. + init_log_alpha_mean: initial value for the alpha_mean in log-space, note a + softplus (rather than an exp) will be used to transform this. + init_log_alpha_stddev: initial value for the alpha_stddev in log-space, + note a softplus (rather than an exp) will be used to transform this. + per_dim_constraining: whether to enforce the KL constraint on each + dimension independently; this is the default. Otherwise the overall KL + is constrained, which allows some dimensions to change more at the + expense of others staying put. + action_penalization: whether to use a KL constraint to penalize actions + via the MO-MPO algorithm. + epsilon_penalty: KL constraint on the probability of violating the action + constraint. + """ + + # MPO constrain thresholds. + self._epsilon = epsilon + self._epsilon_mean = epsilon_mean + self._epsilon_stddev = epsilon_stddev + + # Initial values for the constraints' dual variables. + self._init_log_temperature = init_log_temperature + self._init_log_alpha_mean = init_log_alpha_mean + self._init_log_alpha_stddev = init_log_alpha_stddev + + # Whether to penalize out-of-bound actions via MO-MPO and its corresponding + # constraint threshold. + self._action_penalization = action_penalization + self._epsilon_penalty = epsilon_penalty + + # Whether to ensure per-dimension KL constraint satisfication. + self._per_dim_constraining = per_dim_constraining + + @property + def per_dim_constraining(self): + return self._per_dim_constraining + + def init_params(self, action_dim: int, dtype: DType = jnp.float32): + """Creates an initial set of parameters.""" + + if self._per_dim_constraining: + dual_variable_shape = [action_dim] + else: + dual_variable_shape = [1] + + log_temperature = jnp.full([1], self._init_log_temperature, dtype=dtype) + + log_alpha_mean = jnp.full( + dual_variable_shape, self._init_log_alpha_mean, dtype=dtype) + + log_alpha_stddev = jnp.full( + dual_variable_shape, self._init_log_alpha_stddev, dtype=dtype) + + if self._action_penalization: + log_penalty_temperature = jnp.full([1], + self._init_log_temperature, + dtype=dtype) + else: + log_penalty_temperature = None + + return MPOParams( + log_temperature=log_temperature, + log_alpha_mean=log_alpha_mean, + log_alpha_stddev=log_alpha_stddev, + log_penalty_temperature=log_penalty_temperature) + + def __call__( + self, + params: MPOParams, + online_action_distribution: Union[tfd.MultivariateNormalDiag, + tfd.Independent], + target_action_distribution: Union[tfd.MultivariateNormalDiag, + tfd.Independent], + actions: jnp.ndarray, # Shape [N, B, D]. + q_values: jnp.ndarray, # Shape [N, B]. + ) -> Tuple[jnp.ndarray, MPOStats]: + """Computes the decoupled MPO loss. + + Args: + params: parameters tracking the temperature and the dual variables. + online_action_distribution: online distribution returned by the online + policy network; expects batch_dims of [B] and event_dims of [D]. + target_action_distribution: target distribution returned by the target + policy network; expects same shapes as online distribution. + actions: actions sampled from the target policy; expects shape [N, B, D]. + q_values: Q-values associated with each action; expects shape [N, B]. + + Returns: + Loss, combining the policy loss, KL penalty, and dual losses required to + adapt the dual variables. + Stats, for diagnostics and tracking performance. + """ + + # Cast `MultivariateNormalDiag`s to Independent Normals. + # The latter allows us to satisfy KL constraints per-dimension. + if isinstance(target_action_distribution, tfd.MultivariateNormalDiag): + target_action_distribution = tfd.Independent( + tfd.Normal(target_action_distribution.mean(), + target_action_distribution.stddev())) + online_action_distribution = tfd.Independent( + tfd.Normal(online_action_distribution.mean(), + online_action_distribution.stddev())) + + # Transform dual variables from log-space. + # Note: using softplus instead of exponential for numerical stability. + temperature = jax.nn.softplus(params.log_temperature) + _MPO_FLOAT_EPSILON + alpha_mean = jax.nn.softplus(params.log_alpha_mean) + _MPO_FLOAT_EPSILON + alpha_stddev = jax.nn.softplus(params.log_alpha_stddev) + _MPO_FLOAT_EPSILON + + # Get online and target means and stddevs in preparation for decomposition. + online_mean = online_action_distribution.distribution.mean() + online_scale = online_action_distribution.distribution.stddev() + target_mean = target_action_distribution.distribution.mean() + target_scale = target_action_distribution.distribution.stddev() + + # Compute normalized importance weights, used to compute expectations with + # respect to the non-parametric policy; and the temperature loss, used to + # adapt the tempering of Q-values. + normalized_weights, loss_temperature = compute_weights_and_temperature_loss( + q_values, self._epsilon, temperature) + + # Only needed for diagnostics: Compute estimated actualized KL between the + # non-parametric and current target policies. + kl_nonparametric = compute_nonparametric_kl_from_normalized_weights( + normalized_weights) + + if self._action_penalization: + # Transform action penalization temperature. + penalty_temperature = jax.nn.softplus( + params.log_penalty_temperature) + _MPO_FLOAT_EPSILON + + # Compute action penalization cost. + # Note: the cost is zero in [-1, 1] and quadratic beyond. + diff_out_of_bound = actions - jnp.clip(actions, -1.0, 1.0) + cost_out_of_bound = -jnp.linalg.norm(diff_out_of_bound, axis=-1) + + penalty_normalized_weights, loss_penalty_temperature = compute_weights_and_temperature_loss( + cost_out_of_bound, self._epsilon_penalty, penalty_temperature) + + # Only needed for diagnostics: Compute estimated actualized KL between the + # non-parametric and current target policies. + penalty_kl_nonparametric = compute_nonparametric_kl_from_normalized_weights( + penalty_normalized_weights) + + # Combine normalized weights. + normalized_weights += penalty_normalized_weights + loss_temperature += loss_penalty_temperature + + # Decompose the online policy into fixed-mean & fixed-stddev distributions. + # This has been documented as having better performance in bandit settings, + # see e.g. https://arxiv.org/pdf/1812.02256.pdf. + fixed_stddev_distribution = tfd.Independent( + tfd.Normal(loc=online_mean, scale=target_scale)) + fixed_mean_distribution = tfd.Independent( + tfd.Normal(loc=target_mean, scale=online_scale)) + + # Compute the decomposed policy losses. + loss_policy_mean = compute_cross_entropy_loss(actions, normalized_weights, + fixed_stddev_distribution) + loss_policy_stddev = compute_cross_entropy_loss(actions, normalized_weights, + fixed_mean_distribution) + + # Compute the decomposed KL between the target and online policies. + if self._per_dim_constraining: + kl_mean = target_action_distribution.distribution.kl_divergence( + fixed_stddev_distribution.distribution) # Shape [B, D]. + kl_stddev = target_action_distribution.distribution.kl_divergence( + fixed_mean_distribution.distribution) # Shape [B, D]. + else: + kl_mean = target_action_distribution.kl_divergence( + fixed_stddev_distribution) # Shape [B]. + kl_stddev = target_action_distribution.kl_divergence( + fixed_mean_distribution) # Shape [B]. + + # Compute the alpha-weighted KL-penalty and dual losses to adapt the alphas. + loss_kl_mean, loss_alpha_mean = compute_parametric_kl_penalty_and_dual_loss( + kl_mean, alpha_mean, self._epsilon_mean) + loss_kl_stddev, loss_alpha_stddev = compute_parametric_kl_penalty_and_dual_loss( + kl_stddev, alpha_stddev, self._epsilon_stddev) + + # Combine losses. + loss_policy = loss_policy_mean + loss_policy_stddev + loss_kl_penalty = loss_kl_mean + loss_kl_stddev + loss_dual = loss_alpha_mean + loss_alpha_stddev + loss_temperature + loss = loss_policy + loss_kl_penalty + loss_dual + + # Create statistics. + pi_stddev = online_action_distribution.distribution.stddev() + stats = MPOStats( + # Dual Variables. + dual_alpha_mean=jnp.mean(alpha_mean), + dual_alpha_stddev=jnp.mean(alpha_stddev), + dual_temperature=jnp.mean(temperature), + # Losses. + loss_policy=jnp.mean(loss), + loss_alpha=jnp.mean(loss_alpha_mean + loss_alpha_stddev), + loss_temperature=jnp.mean(loss_temperature), + # KL measurements. + kl_q_rel=jnp.mean(kl_nonparametric) / self._epsilon, + penalty_kl_q_rel=((jnp.mean(penalty_kl_nonparametric) / + self._epsilon_penalty) + if self._action_penalization else None), + kl_mean_rel=jnp.mean(kl_mean, axis=0) / self._epsilon_mean, + kl_stddev_rel=jnp.mean(kl_stddev, axis=0) / self._epsilon_stddev, + # Q measurements. + q_min=jnp.mean(jnp.min(q_values, axis=0)), + q_max=jnp.mean(jnp.max(q_values, axis=0)), + # If the policy has stddev, log summary stats for this as well. + pi_stddev_min=jnp.mean(jnp.min(pi_stddev, axis=-1)), + pi_stddev_max=jnp.mean(jnp.max(pi_stddev, axis=-1)), + # Condition number of the diagonal covariance (actually, stddev) matrix. + pi_stddev_cond=jnp.mean( + jnp.max(pi_stddev, axis=-1) / jnp.min(pi_stddev, axis=-1)), + ) + + return loss, stats + + +def compute_weights_and_temperature_loss( + q_values: jnp.ndarray, + epsilon: float, + temperature: jnp.ndarray, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Computes normalized importance weights for the policy optimization. + + Args: + q_values: Q-values associated with the actions sampled from the target + policy; expected shape [N, B]. + epsilon: Desired constraint on the KL between the target and non-parametric + policies. + temperature: Scalar used to temper the Q-values before computing normalized + importance weights from them. This is really the Lagrange dual variable in + the constrained optimization problem, the solution of which is the + non-parametric policy targeted by the policy loss. + + Returns: + Normalized importance weights, used for policy optimization. + Temperature loss, used to adapt the temperature. + """ + + # Temper the given Q-values using the current temperature. + tempered_q_values = jax.lax.stop_gradient(q_values) / temperature + + # Compute the normalized importance weights used to compute expectations with + # respect to the non-parametric policy. + normalized_weights = jax.nn.softmax(tempered_q_values, axis=0) + normalized_weights = jax.lax.stop_gradient(normalized_weights) + + # Compute the temperature loss (dual of the E-step optimization problem). + q_logsumexp = jax.scipy.special.logsumexp(tempered_q_values, axis=0) + log_num_actions = jnp.log(q_values.shape[0] / 1.) + loss_temperature = epsilon + jnp.mean(q_logsumexp) - log_num_actions + loss_temperature = temperature * loss_temperature + + return normalized_weights, loss_temperature + + +def compute_nonparametric_kl_from_normalized_weights( + normalized_weights: jnp.ndarray) -> jnp.ndarray: + """Estimate the actualized KL between the non-parametric and target policies.""" + + # Compute integrand. + num_action_samples = normalized_weights.shape[0] / 1. + integrand = jnp.log(num_action_samples * normalized_weights + 1e-8) + + # Return the expectation with respect to the non-parametric policy. + return jnp.sum(normalized_weights * integrand, axis=0) + + +def compute_cross_entropy_loss( + sampled_actions: jnp.ndarray, + normalized_weights: jnp.ndarray, + online_action_distribution: tfd.Distribution, +) -> jnp.ndarray: + """Compute cross-entropy online and the reweighted target policy. + + Args: + sampled_actions: samples used in the Monte Carlo integration in the policy + loss. Expected shape is [N, B, ...], where N is the number of sampled + actions and B is the number of sampled states. + normalized_weights: target policy multiplied by the exponentiated Q values + and normalized; expected shape is [N, B]. + online_action_distribution: policy to be optimized. + + Returns: + loss_policy_gradient: the cross-entropy loss that, when differentiated, + produces the policy gradient. + """ + + # Compute the M-step loss. + log_prob = online_action_distribution.log_prob(sampled_actions) + + # Compute the weighted average log-prob using the normalized weights. + loss_policy_gradient = -jnp.sum(log_prob * normalized_weights, axis=0) + + # Return the mean loss over the batch of states. + return jnp.mean(loss_policy_gradient, axis=0) + + +def compute_parametric_kl_penalty_and_dual_loss( + kl: jnp.ndarray, + alpha: jnp.ndarray, + epsilon: float, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Computes the KL cost to be added to the Lagragian and its dual loss. + + The KL cost is simply the alpha-weighted KL divergence and it is added as a + regularizer to the policy loss. The dual variable alpha itself has a loss that + can be minimized to adapt the strength of the regularizer to keep the KL + between consecutive updates at the desired target value of epsilon. + + Args: + kl: KL divergence between the target and online policies. + alpha: Lagrange multipliers (dual variables) for the KL constraints. + epsilon: Desired value for the KL. + + Returns: + loss_kl: alpha-weighted KL regularization to be added to the policy loss. + loss_alpha: The Lagrange dual loss minimized to adapt alpha. + """ + + # Compute the mean KL over the batch. + mean_kl = jnp.mean(kl, axis=0) + + # Compute the regularization. + loss_kl = jnp.sum(jax.lax.stop_gradient(alpha) * mean_kl) + + # Compute the dual loss. + loss_alpha = jnp.sum(alpha * (epsilon - jax.lax.stop_gradient(mean_kl))) + + return loss_kl, loss_alpha + + +def clip_mpo_params(params: MPOParams, per_dim_constraining: bool) -> MPOParams: + clipped_params = MPOParams( + log_temperature=jnp.maximum(_MIN_LOG_TEMPERATURE, params.log_temperature), + log_alpha_mean=jnp.maximum(_MIN_LOG_ALPHA, params.log_alpha_mean), + log_alpha_stddev=jnp.maximum(_MIN_LOG_ALPHA, params.log_alpha_stddev)) + if not per_dim_constraining: + return clipped_params + else: + return clipped_params._replace( + log_penalty_temperature=jnp.maximum(_MIN_LOG_TEMPERATURE, + params.log_penalty_temperature)) diff --git a/acme/acme/jax/networks/__init__.py b/acme/acme/jax/networks/__init__.py new file mode 100644 index 00000000..eab07a00 --- /dev/null +++ b/acme/acme/jax/networks/__init__.py @@ -0,0 +1,52 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""JAX networks implemented with Haiku.""" + +from acme.jax.networks.atari import AtariTorso +from acme.jax.networks.atari import DeepIMPALAAtariNetwork +from acme.jax.networks.atari import dqn_atari_network +from acme.jax.networks.atari import R2D2AtariNetwork +from acme.jax.networks.base import Action +from acme.jax.networks.base import FeedForwardNetwork +from acme.jax.networks.base import Logits +from acme.jax.networks.base import LogProb +from acme.jax.networks.base import LogProbFn +from acme.jax.networks.base import LSTMOutputs +from acme.jax.networks.base import NetworkOutput +from acme.jax.networks.base import Observation +from acme.jax.networks.base import Params +from acme.jax.networks.base import PolicyValueRNN +from acme.jax.networks.base import PRNGKey +from acme.jax.networks.base import QNetwork +from acme.jax.networks.base import RecurrentQNetwork +from acme.jax.networks.base import SampleFn +from acme.jax.networks.base import Value +from acme.jax.networks.continuous import LayerNormMLP +from acme.jax.networks.continuous import NearZeroInitializedLinear +from acme.jax.networks.distributional import CategoricalHead +from acme.jax.networks.distributional import CategoricalValueHead +from acme.jax.networks.distributional import DiscreteValued +from acme.jax.networks.distributional import GaussianMixture +from acme.jax.networks.distributional import MultivariateNormalDiagHead +from acme.jax.networks.distributional import NormalTanhDistribution +from acme.jax.networks.distributional import TanhTransformedDistribution +from acme.jax.networks.duelling import DuellingMLP +from acme.jax.networks.multiplexers import CriticMultiplexer +from acme.jax.networks.policy_value import PolicyValueHead +from acme.jax.networks.rescaling import ClipToSpec +from acme.jax.networks.rescaling import TanhToSpec +from acme.jax.networks.resnet import DownsamplingStrategy +from acme.jax.networks.resnet import ResidualBlock +from acme.jax.networks.resnet import ResNetTorso diff --git a/acme/acme/jax/networks/atari.py b/acme/acme/jax/networks/atari.py new file mode 100644 index 00000000..4b8ccccf --- /dev/null +++ b/acme/acme/jax/networks/atari.py @@ -0,0 +1,182 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common networks for Atari. + +Glossary of shapes: +- T: Sequence length. +- B: Batch size. +- A: Number of actions. +- D: Embedding size. +- X?: X is optional (e.g. optional batch/sequence dimension). + +""" +from typing import Optional, Tuple, Sequence + +from acme.jax.networks import base +from acme.jax.networks import duelling +from acme.jax.networks import embedding +from acme.jax.networks import policy_value +from acme.jax.networks import resnet +from acme.wrappers import observation_action_reward +import haiku as hk +import jax +import jax.numpy as jnp + +# Useful type aliases. +Images = jnp.ndarray + + +class AtariTorso(hk.Module): + """Simple convolutional stack commonly used for Atari.""" + + def __init__(self): + super().__init__(name='atari_torso') + self._network = hk.Sequential([ + hk.Conv2D(32, [8, 8], 4), jax.nn.relu, + hk.Conv2D(64, [4, 4], 2), jax.nn.relu, + hk.Conv2D(64, [3, 3], 1), jax.nn.relu + ]) + + def __call__(self, inputs: Images) -> jnp.ndarray: + inputs_rank = jnp.ndim(inputs) + batched_inputs = inputs_rank == 4 + if inputs_rank < 3 or inputs_rank > 4: + raise ValueError('Expected input BHWC or HWC. Got rank %d' % inputs_rank) + + outputs = self._network(inputs) + + if batched_inputs: + return jnp.reshape(outputs, [outputs.shape[0], -1]) # [B, D] + return jnp.reshape(outputs, [-1]) # [D] + + +def dqn_atari_network(num_actions: int) -> base.QNetwork: + """A feed-forward network for use with Ape-X DQN.""" + + def network(inputs: Images) -> base.QValues: + model = hk.Sequential([ + AtariTorso(), + duelling.DuellingMLP(num_actions, hidden_sizes=[512]), + ]) + return model(inputs) + + return network + + +class DeepAtariTorso(hk.Module): + """Deep torso for Atari, from the IMPALA paper.""" + + def __init__( + self, + channels_per_group: Sequence[int] = (16, 32, 32), + blocks_per_group: Sequence[int] = (2, 2, 2), + downsampling_strategies: Sequence[resnet.DownsamplingStrategy] = ( + resnet.DownsamplingStrategy.CONV_MAX,) * 3, + hidden_sizes: Sequence[int] = (256,), + use_layer_norm: bool = False, + name: str = 'deep_atari_torso'): + super().__init__(name=name) + self._use_layer_norm = use_layer_norm + self.resnet = resnet.ResNetTorso( + channels_per_group=channels_per_group, + blocks_per_group=blocks_per_group, + downsampling_strategies=downsampling_strategies, + use_layer_norm=use_layer_norm) + # Make sure to activate the last layer as this torso is expected to feed + # into the rest of a bigger network. + self.mlp_head = hk.nets.MLP(output_sizes=hidden_sizes, activate_final=True) + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + output = self.resnet(x) + output = jax.nn.relu(output) + output = hk.Flatten(preserve_dims=-3)(output) + output = self.mlp_head(output) + return output + + +class DeepIMPALAAtariNetwork(hk.RNNCore): + """A recurrent network for use with IMPALA. + + See https://arxiv.org/pdf/1802.01561.pdf for more information. + """ + + def __init__(self, num_actions: int): + super().__init__(name='impala_atari_network') + self._embed = embedding.OAREmbedding( + DeepAtariTorso(use_layer_norm=True), num_actions) + self._core = hk.GRU(256) + self._head = policy_value.PolicyValueHead(num_actions) + self._num_actions = num_actions + + def __call__(self, inputs: observation_action_reward.OAR, + state: hk.LSTMState) -> base.LSTMOutputs: + + embeddings = self._embed(inputs) # [B?, D+A+1] + embeddings, new_state = self._core(embeddings, state) + logits, value = self._head(embeddings) # logits: [B?, A], value: [B?, 1] + + return (logits, value), new_state + + def initial_state(self, batch_size: Optional[int], + **unused_kwargs) -> hk.LSTMState: + return self._core.initial_state(batch_size) + + def unroll(self, inputs: observation_action_reward.OAR, + state: hk.LSTMState) -> base.LSTMOutputs: + """Efficient unroll that applies embeddings, MLP, & convnet in one pass.""" + embeddings = self._embed(inputs) + embeddings, new_states = hk.static_unroll(self._core, embeddings, state) + logits, values = self._head(embeddings) + + return (logits, values), new_states + + +class R2D2AtariNetwork(hk.RNNCore): + """A duelling recurrent network for use with Atari observations as seen in R2D2. + + See https://openreview.net/forum?id=r1lyTjAqYX for more information. + """ + + def __init__(self, num_actions: int): + super().__init__(name='r2d2_atari_network') + self._embed = embedding.OAREmbedding(DeepAtariTorso(), num_actions) + self._core = hk.LSTM(512) + self._duelling_head = duelling.DuellingMLP(num_actions, hidden_sizes=[512]) + self._num_actions = num_actions + + def __call__( + self, + inputs: observation_action_reward.OAR, # [B, ...] + state: hk.LSTMState # [B, ...] + ) -> Tuple[base.QValues, hk.LSTMState]: + embeddings = self._embed(inputs) # [B, D+A+1] + core_outputs, new_state = self._core(embeddings, state) + q_values = self._duelling_head(core_outputs) + return q_values, new_state + + def initial_state(self, batch_size: Optional[int], + **unused_kwargs) -> hk.LSTMState: + return self._core.initial_state(batch_size) + + def unroll( + self, + inputs: observation_action_reward.OAR, # [T, B, ...] + state: hk.LSTMState # [T, ...] + ) -> Tuple[base.QValues, hk.LSTMState]: + """Efficient unroll that applies torso, core, and duelling mlp in one pass.""" + embeddings = hk.BatchApply(self._embed)(inputs) # [T, B, D+A+1] + core_outputs, new_states = hk.static_unroll(self._core, embeddings, state) + q_values = hk.BatchApply(self._duelling_head)(core_outputs) # [T, B, A] + return q_values, new_states diff --git a/acme/acme/jax/networks/base.py b/acme/acme/jax/networks/base.py new file mode 100644 index 00000000..67d3210b --- /dev/null +++ b/acme/acme/jax/networks/base.py @@ -0,0 +1,60 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base interfaces for networks.""" + +import dataclasses +from typing import Callable, Tuple + +from acme import types +from acme.jax import types as jax_types +import haiku as hk +import jax.numpy as jnp + +# This definition is deprecated. Use jax_types.PRNGKey directly instead. +# TODO(sinopalnikov): migrate all users and remove this definition. +PRNGKey = jax_types.PRNGKey + +# Commonly-used types. +Observation = types.NestedArray +Action = types.NestedArray +Params = types.NestedArray +NetworkOutput = types.NestedArray +QValues = jnp.ndarray +Logits = jnp.ndarray +LogProb = jnp.ndarray +Value = jnp.ndarray + +# Commonly-used function/network signatures. +QNetwork = Callable[[Observation], QValues] +LSTMOutputs = Tuple[Tuple[Logits, Value], hk.LSTMState] +PolicyValueRNN = Callable[[Observation, hk.LSTMState], LSTMOutputs] +RecurrentQNetwork = Callable[[Observation, hk.LSTMState], + Tuple[QValues, hk.LSTMState]] +SampleFn = Callable[[NetworkOutput, PRNGKey], Action] +LogProbFn = Callable[[NetworkOutput, Action], LogProb] + + +@dataclasses.dataclass +class FeedForwardNetwork: + """Holds a pair of pure functions defining a feed-forward network. + + Attributes: + init: A pure function: ``params = init(rng, *a, **k)`` + apply: A pure function: ``out = apply(params, rng, *a, **k)`` + """ + # Initializes and returns the networks parameters. + init: Callable[..., Params] + # Computes and returns the outputs of a forward pass. + apply: Callable[..., NetworkOutput] diff --git a/acme/acme/jax/networks/continuous.py b/acme/acme/jax/networks/continuous.py new file mode 100644 index 00000000..71f88fa5 --- /dev/null +++ b/acme/acme/jax/networks/continuous.py @@ -0,0 +1,76 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Networks used in continuous control.""" + +from typing import Callable, Sequence + +import haiku as hk +import jax +import jax.numpy as jnp + +uniform_initializer = hk.initializers.UniformScaling(scale=0.333) + + +class NearZeroInitializedLinear(hk.Linear): + """Simple linear layer, initialized at near zero weights and zero biases.""" + + def __init__(self, output_size: int, scale: float = 1e-4): + super().__init__(output_size, w_init=hk.initializers.VarianceScaling(scale)) + + +class LayerNormMLP(hk.Module): + """Simple feedforward MLP torso with initial layer-norm. + + This MLP's first linear layer is followed by a LayerNorm layer and a tanh + non-linearity; subsequent layers use `activation`, which defaults to elu. + + Note! The default activation differs from the usual MLP default of ReLU for + legacy reasons. + """ + + def __init__(self, + layer_sizes: Sequence[int], + w_init: hk.initializers.Initializer = uniform_initializer, + activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.elu, + activate_final: bool = False, + name: str = 'feedforward_mlp_torso'): + """Construct the MLP. + + Args: + layer_sizes: a sequence of ints specifying the size of each layer. + w_init: initializer for Linear layers. + activation: nonlinearity to use in the MLP, defaults to elu. + Note! The default activation differs from the usual MLP default of ReLU + for legacy reasons. + activate_final: whether or not to use the activation function on the final + layer of the neural network. + name: a name for the module. + """ + super().__init__(name=name) + + self._network = hk.Sequential([ + hk.Linear(layer_sizes[0], w_init=w_init), + hk.LayerNorm(axis=-1, create_scale=True, create_offset=True), + jax.lax.tanh, + hk.nets.MLP( + layer_sizes[1:], + w_init=w_init, + activation=activation, + activate_final=activate_final), + ]) + + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + """Forwards the policy network.""" + return self._network(inputs) diff --git a/acme/acme/jax/networks/distributional.py b/acme/acme/jax/networks/distributional.py new file mode 100644 index 00000000..06d7fc72 --- /dev/null +++ b/acme/acme/jax/networks/distributional.py @@ -0,0 +1,330 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Haiku modules that output tfd.Distributions.""" + +from typing import Any, List, Optional, Union + +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +import tensorflow_probability +hk_init = hk.initializers +tfp = tensorflow_probability.substrates.jax +tfd = tfp.distributions + +_MIN_SCALE = 1e-4 +Initializer = hk.initializers.Initializer + + +class CategoricalHead(hk.Module): + """Module that produces a categorical distribution with the given number of values.""" + + def __init__( + self, + num_values: Union[int, List[int]], + dtype: Optional[Any] = jnp.int32, + w_init: Optional[Initializer] = None, + name: Optional[str] = None, + ): + super().__init__(name=name) + self._dtype = dtype + self._logit_shape = num_values + self._linear = hk.Linear(np.prod(num_values), w_init=w_init) + + def __call__(self, inputs: jnp.ndarray) -> tfd.Distribution: + logits = self._linear(inputs) + if not isinstance(self._logit_shape, int): + logits = hk.Reshape(self._logit_shape)(logits) + return tfd.Categorical(logits=logits, dtype=self._dtype) + + +class GaussianMixture(hk.Module): + """Module that outputs a Gaussian Mixture Distribution.""" + + def __init__(self, + num_dimensions: int, + num_components: int, + multivariate: bool, + init_scale: Optional[float] = None, + append_singleton_event_dim: bool = False, + reinterpreted_batch_ndims: Optional[int] = None, + name: str = 'GaussianMixture'): + """Initialization. + + Args: + num_dimensions: dimensionality of the output distribution + num_components: number of mixture components. + multivariate: whether the resulting distribution is multivariate or not. + init_scale: the initial scale for the Gaussian mixture components. + append_singleton_event_dim: (univariate only) Whether to add an extra + singleton dimension to the event shape. + reinterpreted_batch_ndims: (univariate only) Number of batch dimensions to + reinterpret as event dimensions. + name: name of the module passed to snt.Module parent class. + """ + super().__init__(name=name) + + self._num_dimensions = num_dimensions + self._num_components = num_components + self._multivariate = multivariate + self._append_singleton_event_dim = append_singleton_event_dim + self._reinterpreted_batch_ndims = reinterpreted_batch_ndims + + if init_scale is not None: + self._scale_factor = init_scale / jax.nn.softplus(0.) + else: + self._scale_factor = 1.0 # Corresponds to init_scale = softplus(0). + + def __call__(self, + inputs: jnp.ndarray, + low_noise_policy: bool = False) -> tfd.Distribution: + """Run the networks through inputs. + + Args: + inputs: hidden activations of the policy network body. + low_noise_policy: whether to set vanishingly small scales for each + component. If this flag is set to True, the policy is effectively run + without Gaussian noise. + + Returns: + Mixture Gaussian distribution. + """ + + # Define the weight initializer. + w_init = hk.initializers.VarianceScaling(scale=1e-5) + + # Create a layer that outputs the unnormalized log-weights. + if self._multivariate: + logits_size = self._num_components + else: + logits_size = self._num_dimensions * self._num_components + logit_layer = hk.Linear(logits_size, w_init=w_init) + + # Create two layers that outputs a location and a scale, respectively, for + # each dimension and each component. + loc_layer = hk.Linear( + self._num_dimensions * self._num_components, w_init=w_init) + scale_layer = hk.Linear( + self._num_dimensions * self._num_components, w_init=w_init) + + # Compute logits, locs, and scales if necessary. + logits = logit_layer(inputs) + locs = loc_layer(inputs) + + # When a low_noise_policy is requested, set the scales to its minimum value. + if low_noise_policy: + scales = jnp.full(locs.shape, _MIN_SCALE) + else: + scales = scale_layer(inputs) + scales = self._scale_factor * jax.nn.softplus(scales) + _MIN_SCALE + + if self._multivariate: + components_class = tfd.MultivariateNormalDiag + shape = [-1, self._num_components, self._num_dimensions] # [B, C, D] + # In this case, no need to reshape logits as they are in the correct shape + # already, namely [batch_size, num_components]. + else: + components_class = tfd.Normal + shape = [-1, self._num_dimensions, self._num_components] # [B, D, C] + if self._append_singleton_event_dim: + shape.insert(2, 1) # [B, D, 1, C] + logits = logits.reshape(shape) + + # Reshape the mixture's location and scale parameters appropriately. + locs = locs.reshape(shape) + scales = scales.reshape(shape) + + # Create the mixture distribution. + distribution = tfd.MixtureSameFamily( + mixture_distribution=tfd.Categorical(logits=logits), + components_distribution=components_class(loc=locs, scale=scales)) + + if not self._multivariate: + distribution = tfd.Independent( + distribution, + reinterpreted_batch_ndims=self._reinterpreted_batch_ndims) + + return distribution + + +class TanhTransformedDistribution(tfd.TransformedDistribution): + """Distribution followed by tanh.""" + + def __init__(self, distribution, threshold=.999, validate_args=False): + """Initialize the distribution. + + Args: + distribution: The distribution to transform. + threshold: Clipping value of the action when computing the logprob. + validate_args: Passed to super class. + """ + super().__init__( + distribution=distribution, + bijector=tfp.bijectors.Tanh(), + validate_args=validate_args) + # Computes the log of the average probability distribution outside the + # clipping range, i.e. on the interval [-inf, -atanh(threshold)] for + # log_prob_left and [atanh(threshold), inf] for log_prob_right. + self._threshold = threshold + inverse_threshold = self.bijector.inverse(threshold) + # average(pdf) = p/epsilon + # So log(average(pdf)) = log(p) - log(epsilon) + log_epsilon = jnp.log(1. - threshold) + # Those 2 values are differentiable w.r.t. model parameters, such that the + # gradient is defined everywhere. + self._log_prob_left = self.distribution.log_cdf( + -inverse_threshold) - log_epsilon + self._log_prob_right = self.distribution.log_survival_function( + inverse_threshold) - log_epsilon + + def log_prob(self, event): + # Without this clip there would be NaNs in the inner tf.where and that + # causes issues for some reasons. + event = jnp.clip(event, -self._threshold, self._threshold) + # The inverse image of {threshold} is the interval [atanh(threshold), inf] + # which has a probability of "log_prob_right" under the given distribution. + return jnp.where( + event <= -self._threshold, self._log_prob_left, + jnp.where(event >= self._threshold, self._log_prob_right, + super().log_prob(event))) + + def mode(self): + return self.bijector.forward(self.distribution.mode()) + + def entropy(self, seed=None): + # We return an estimation using a single sample of the log_det_jacobian. + # We can still do some backpropagation with this estimate. + return self.distribution.entropy() + self.bijector.forward_log_det_jacobian( + self.distribution.sample(seed=seed), event_ndims=0) + + @classmethod + def _parameter_properties(cls, dtype: Optional[Any], num_classes=None): + td_properties = super()._parameter_properties(dtype, + num_classes=num_classes) + del td_properties['bijector'] + return td_properties + + +class NormalTanhDistribution(hk.Module): + """Module that produces a TanhTransformedDistribution distribution.""" + + def __init__(self, + num_dimensions: int, + min_scale: float = 1e-3, + w_init: hk_init.Initializer = hk_init.VarianceScaling( + 1.0, 'fan_in', 'uniform'), + b_init: hk_init.Initializer = hk_init.Constant(0.)): + """Initialization. + + Args: + num_dimensions: Number of dimensions of a distribution. + min_scale: Minimum standard deviation. + w_init: Initialization for linear layer weights. + b_init: Initialization for linear layer biases. + """ + super().__init__(name='Normal') + self._min_scale = min_scale + self._loc_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) + self._scale_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) + + def __call__(self, inputs: jnp.ndarray) -> tfd.Distribution: + loc = self._loc_layer(inputs) + scale = self._scale_layer(inputs) + scale = jax.nn.softplus(scale) + self._min_scale + distribution = tfd.Normal(loc=loc, scale=scale) + return tfd.Independent( + TanhTransformedDistribution(distribution), reinterpreted_batch_ndims=1) + + +class MultivariateNormalDiagHead(hk.Module): + """Module that produces a tfd.MultivariateNormalDiag distribution.""" + + def __init__(self, + num_dimensions: int, + init_scale: float = 0.3, + min_scale: float = 1e-6, + w_init: hk_init.Initializer = hk_init.VarianceScaling(1e-4), + b_init: hk_init.Initializer = hk_init.Constant(0.)): + """Initialization. + + Args: + num_dimensions: Number of dimensions of MVN distribution. + init_scale: Initial standard deviation. + min_scale: Minimum standard deviation. + w_init: Initialization for linear layer weights. + b_init: Initialization for linear layer biases. + """ + super().__init__(name='MultivariateNormalDiagHead') + self._min_scale = min_scale + self._init_scale = init_scale + self._loc_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) + self._scale_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) + + def __call__(self, inputs: jnp.ndarray) -> tfd.Distribution: + loc = self._loc_layer(inputs) + scale = jax.nn.softplus(self._scale_layer(inputs)) + scale *= self._init_scale / jax.nn.softplus(0.) + scale += self._min_scale + return tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) + + +class CategoricalValueHead(hk.Module): + """Network head that produces a categorical distribution and value.""" + + def __init__( + self, + num_values: int, + name: Optional[str] = None, + ): + super().__init__(name=name) + self._logit_layer = hk.Linear(num_values) + self._value_layer = hk.Linear(1) + + def __call__(self, inputs: jnp.ndarray): + logits = self._logit_layer(inputs) + value = jnp.squeeze(self._value_layer(inputs), axis=-1) + return (tfd.Categorical(logits=logits), value) + + +class DiscreteValued(hk.Module): + """C51-style head. + + For each action, it produces the logits for a discrete distribution over + atoms. Therefore, the returned logits represents several distributions, one + for each action. + """ + + def __init__( + self, + num_actions: int, + head_units: int = 512, + num_atoms: int = 51, + v_min: float = -1.0, + v_max: float = 1.0, + ): + super().__init__('DiscreteValued') + self._num_actions = num_actions + self._num_atoms = num_atoms + self._atoms = jnp.linspace(v_min, v_max, self._num_atoms) + self._network = hk.nets.MLP([head_units, num_actions * num_atoms]) + + def __call__(self, inputs: jnp.ndarray): + q_logits = self._network(inputs) + q_logits = jnp.reshape(q_logits, (-1, self._num_actions, self._num_atoms)) + q_dist = jax.nn.softmax(q_logits) + q_values = jnp.sum(q_dist * self._atoms, axis=2) + q_values = jax.lax.stop_gradient(q_values) + return q_values, q_logits, self._atoms diff --git a/acme/acme/jax/networks/duelling.py b/acme/acme/jax/networks/duelling.py new file mode 100644 index 00000000..8db6d173 --- /dev/null +++ b/acme/acme/jax/networks/duelling.py @@ -0,0 +1,60 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A duelling network architecture, as described in [0]. + +[0] https://arxiv.org/abs/1511.06581 +""" + +from typing import Sequence, Optional + +import haiku as hk +import jax.numpy as jnp + + +class DuellingMLP(hk.Module): + """A Duelling MLP Q-network.""" + + def __init__( + self, + num_actions: int, + hidden_sizes: Sequence[int], + w_init: Optional[hk.initializers.Initializer] = None, + ): + super().__init__(name='duelling_q_network') + + self._value_mlp = hk.nets.MLP([*hidden_sizes, 1], w_init=w_init) + self._advantage_mlp = hk.nets.MLP([*hidden_sizes, num_actions], + w_init=w_init) + + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + """Forward pass of the duelling network. + + Args: + inputs: 2-D tensor of shape [batch_size, embedding_size]. + + Returns: + q_values: 2-D tensor of action values of shape [batch_size, num_actions] + """ + + # Compute value & advantage for duelling. + value = self._value_mlp(inputs) # [B, 1] + advantages = self._advantage_mlp(inputs) # [B, A] + + # Advantages have zero mean. + advantages -= jnp.mean(advantages, axis=-1, keepdims=True) # [B, A] + + q_values = value + advantages # [B, A] + + return q_values diff --git a/acme/acme/jax/networks/embedding.py b/acme/acme/jax/networks/embedding.py new file mode 100644 index 00000000..2510c6a6 --- /dev/null +++ b/acme/acme/jax/networks/embedding.py @@ -0,0 +1,61 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Modules for computing custom embeddings.""" + +import dataclasses + +from acme.wrappers import observation_action_reward +import haiku as hk +import jax +import jax.numpy as jnp + + +@dataclasses.dataclass +class OAREmbedding(hk.Module): + """Module for embedding (observation, action, reward) inputs together.""" + + torso: hk.SupportsCall + num_actions: int + + def __call__(self, inputs: observation_action_reward.OAR) -> jnp.ndarray: + """Embed each of the (observation, action, reward) inputs & concatenate.""" + + # Add dummy batch dimension to observation if necessary. + # This is needed because Conv2D assumes a leading batch dimension, i.e. + # that inputs are in [B, H, W, C] format. + expand_obs = len(inputs.observation.shape) == 3 + if expand_obs: + inputs = inputs._replace( + observation=jnp.expand_dims(inputs.observation, axis=0)) + features = self.torso(inputs.observation) # [T?, B, D] + if expand_obs: + features = jnp.squeeze(features, axis=0) + + # Do a one-hot embedding of the actions. + action = jax.nn.one_hot( + inputs.action, num_classes=self.num_actions) # [T?, B, A] + + # Map rewards -> [-1, 1]. + reward = jnp.tanh(inputs.reward) + + # Add dummy trailing dimensions to rewards if necessary. + while reward.ndim < action.ndim: + reward = jnp.expand_dims(reward, axis=-1) + + # Concatenate on final dimension. + embedding = jnp.concatenate( + [features, action, reward], axis=-1) # [T?, B, D+A+1] + + return embedding diff --git a/acme/acme/jax/networks/multiplexers.py b/acme/acme/jax/networks/multiplexers.py new file mode 100644 index 00000000..632f0b73 --- /dev/null +++ b/acme/acme/jax/networks/multiplexers.py @@ -0,0 +1,71 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multiplexers are networks that take multiple inputs.""" + +from typing import Callable, Optional, Union + +from acme.jax import utils +import haiku as hk +import jax.numpy as jnp +import tensorflow_probability + +tfd = tensorflow_probability.substrates.jax.distributions +ModuleOrArrayTransform = Union[hk.Module, Callable[[jnp.ndarray], jnp.ndarray]] + + +class CriticMultiplexer(hk.Module): + """Module connecting a critic torso to (transformed) observations/actions. + + This takes as input a `critic_network`, an `observation_network`, and an + `action_network` and returns another network whose outputs are given by + `critic_network(observation_network(o), action_network(a))`. + + The observations and actions passed to this module are assumed to have a batch + dimension that match. + + Notes: + - Either the `observation_` or `action_network` can be `None`, in which case + the observation or action, resp., are passed to the critic network as is. + - If all `critic_`, `observation_` and `action_network` are `None`, this + module reduces to a simple `tf2_utils.batch_concat()`. + """ + + def __init__(self, + critic_network: Optional[ModuleOrArrayTransform] = None, + observation_network: Optional[ModuleOrArrayTransform] = None, + action_network: Optional[ModuleOrArrayTransform] = None): + self._critic_network = critic_network + self._observation_network = observation_network + self._action_network = action_network + super().__init__(name='critic_multiplexer') + + def __call__(self, + observation: jnp.ndarray, + action: jnp.ndarray) -> jnp.ndarray: + + # Maybe transform observations and actions before feeding them on. + if self._observation_network: + observation = self._observation_network(observation) + if self._action_network: + action = self._action_network(action) + + # Concat observations and actions, with one batch dimension. + outputs = utils.batch_concat([observation, action]) + + # Maybe transform output before returning. + if self._critic_network: + outputs = self._critic_network(outputs) + + return outputs diff --git a/acme/acme/jax/networks/policy_value.py b/acme/acme/jax/networks/policy_value.py new file mode 100644 index 00000000..509d3fb1 --- /dev/null +++ b/acme/acme/jax/networks/policy_value.py @@ -0,0 +1,36 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Policy-value network head for actor-critic algorithms.""" + +from typing import Tuple + +import haiku as hk +import jax.numpy as jnp + + +class PolicyValueHead(hk.Module): + """A network with two linear layers, for policy and value respectively.""" + + def __init__(self, num_actions: int): + super().__init__(name='policy_value_network') + self._policy_layer = hk.Linear(num_actions) + self._value_layer = hk.Linear(1) + + def __call__(self, inputs: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Returns a (Logits, Value) tuple.""" + logits = self._policy_layer(inputs) # [B, A] + value = jnp.squeeze(self._value_layer(inputs), axis=-1) # [B] + + return logits, value diff --git a/acme/acme/jax/networks/rescaling.py b/acme/acme/jax/networks/rescaling.py new file mode 100644 index 00000000..4d63c845 --- /dev/null +++ b/acme/acme/jax/networks/rescaling.py @@ -0,0 +1,57 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Rescaling layers (e.g. to match action specs).""" + +import dataclasses + +from acme import specs +from jax import lax +import jax.numpy as jnp + + +@dataclasses.dataclass +class ClipToSpec: + """Clips inputs to within a BoundedArraySpec.""" + spec: specs.BoundedArray + + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + return jnp.clip(inputs, self.spec.minimum, self.spec.maximum) + + +@dataclasses.dataclass +class RescaleToSpec: + """Rescales inputs in [-1, 1] to match a BoundedArraySpec.""" + spec: specs.BoundedArray + + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + scale = self.spec.maximum - self.spec.minimum + offset = self.spec.minimum + inputs = 0.5 * (inputs + 1.0) # [0, 1] + output = inputs * scale + offset # [minimum, maximum] + return output + + +@dataclasses.dataclass +class TanhToSpec: + """Squashes real-valued inputs to match a BoundedArraySpec.""" + spec: specs.BoundedArray + + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + scale = self.spec.maximum - self.spec.minimum + offset = self.spec.minimum + inputs = lax.tanh(inputs) # [-1, 1] + inputs = 0.5 * (inputs + 1.0) # [0, 1] + output = inputs * scale + offset # [minimum, maximum] + return output diff --git a/acme/acme/jax/networks/resnet.py b/acme/acme/jax/networks/resnet.py new file mode 100644 index 00000000..a4e28ef0 --- /dev/null +++ b/acme/acme/jax/networks/resnet.py @@ -0,0 +1,159 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ResNet Modules.""" + +import enum +import functools +from typing import Callable, Sequence, Union +import haiku as hk +import jax +import jax.numpy as jnp + +InnerOp = Union[hk.Module, Callable[..., jnp.ndarray]] +MakeInnerOp = Callable[..., InnerOp] +NonLinearity = Callable[[jnp.ndarray], jnp.ndarray] + + +class ResidualBlock(hk.Module): + """Residual block of operations, e.g. convolutional or MLP.""" + + def __init__(self, + make_inner_op: MakeInnerOp, + non_linearity: NonLinearity = jax.nn.relu, + use_layer_norm: bool = False, + name: str = 'residual_block'): + super().__init__(name=name) + self.inner_op1 = make_inner_op() + self.inner_op2 = make_inner_op() + self.non_linearity = non_linearity + self.use_layer_norm = use_layer_norm + + if use_layer_norm: + self.layernorm1 = hk.LayerNorm( + axis=(1, 2, 3), create_scale=True, create_offset=True, eps=1e-6) + self.layernorm2 = hk.LayerNorm( + axis=(1, 2, 3), create_scale=True, create_offset=True, eps=1e-6) + + def __call__(self, x: jnp.ndarray): + output = x + + # First layer in residual block. + if self.use_layer_norm: + output = self.layernorm1(output) + output = self.non_linearity(output) + output = self.inner_op1(output) + + # Second layer in residual block. + if self.use_layer_norm: + output = self.layernorm2(output) + output = self.non_linearity(output) + output = self.inner_op2(output) + return x + output + + +# TODO(nikola): Remove this enum and configure downsampling with a layer factory +# instead. +class DownsamplingStrategy(enum.Enum): + AVG_POOL = 'avg_pool' + CONV_MAX = 'conv+max' # Used in IMPALA + LAYERNORM_RELU_CONV = 'layernorm+relu+conv' # Used in MuZero + CONV = 'conv' + + +def make_downsampling_layer( + strategy: Union[str, DownsamplingStrategy], + output_channels: int, +) -> hk.SupportsCall: + """Returns a sequence of modules corresponding to the desired downsampling.""" + strategy = DownsamplingStrategy(strategy) + + if strategy is DownsamplingStrategy.AVG_POOL: + return hk.AvgPool(window_shape=(3, 3, 1), strides=(2, 2, 1), padding='SAME') + + elif strategy is DownsamplingStrategy.CONV: + return hk.Sequential([ + hk.Conv2D( + output_channels, + kernel_shape=3, + stride=2, + w_init=hk.initializers.TruncatedNormal(1e-2)), + ]) + + elif strategy is DownsamplingStrategy.LAYERNORM_RELU_CONV: + return hk.Sequential([ + hk.LayerNorm( + axis=(1, 2, 3), create_scale=True, create_offset=True, eps=1e-6), + jax.nn.relu, + hk.Conv2D( + output_channels, + kernel_shape=3, + stride=2, + w_init=hk.initializers.TruncatedNormal(1e-2)), + ]) + + elif strategy is DownsamplingStrategy.CONV_MAX: + return hk.Sequential([ + hk.Conv2D(output_channels, kernel_shape=3, stride=1), + hk.MaxPool(window_shape=(3, 3, 1), strides=(2, 2, 1), padding='SAME') + ]) + else: + raise ValueError('Unrecognized downsampling strategy. Expected one of' + f' {[strategy.value for strategy in DownsamplingStrategy]}' + f' but received {strategy}.') + + +class ResNetTorso(hk.Module): + """ResNetTorso for visual inputs, inspired by the IMPALA paper.""" + + def __init__(self, + channels_per_group: Sequence[int] = (16, 32, 32), + blocks_per_group: Sequence[int] = (2, 2, 2), + downsampling_strategies: Sequence[DownsamplingStrategy] = ( + DownsamplingStrategy.CONV_MAX,) * 3, + use_layer_norm: bool = False, + name: str = 'resnet_torso'): + super().__init__(name=name) + self._channels_per_group = channels_per_group + self._blocks_per_group = blocks_per_group + self._downsampling_strategies = downsampling_strategies + self._use_layer_norm = use_layer_norm + + if (len(channels_per_group) != len(blocks_per_group) or + len(channels_per_group) != len(downsampling_strategies)): + raise ValueError('Length of channels_per_group, blocks_per_group, and ' + 'downsampling_strategies must be equal. ' + f'Got channels_per_group={channels_per_group}, ' + f'blocks_per_group={blocks_per_group}, and' + f'downsampling_strategies={downsampling_strategies}.') + + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + output = inputs + channels_blocks_strategies = zip(self._channels_per_group, + self._blocks_per_group, + self._downsampling_strategies) + + for i, (num_channels, num_blocks, + strategy) in enumerate(channels_blocks_strategies): + output = make_downsampling_layer(strategy, num_channels)(output) + + for j in range(num_blocks): + output = ResidualBlock( + make_inner_op=functools.partial( + hk.Conv2D, output_channels=num_channels, kernel_shape=3), + use_layer_norm=self._use_layer_norm, + name=f'residual_{i}_{j}')( + output) + + return output diff --git a/acme/acme/jax/running_statistics.py b/acme/acme/jax/running_statistics.py new file mode 100644 index 00000000..ad8b4511 --- /dev/null +++ b/acme/acme/jax/running_statistics.py @@ -0,0 +1,355 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions to compute running statistics.""" + +import dataclasses +from typing import Any, Optional, Tuple, Union + +from acme import types +from acme.utils import tree_utils +import chex +import jax +import jax.numpy as jnp +import numpy as np +import tree + + +Path = Tuple[Any, ...] +"""Path in a nested structure. + + A path is a tuple of indices (normally strings for maps and integers for + arrays and tuples) that uniquely identifies a subtree in the nested structure. + See + https://tree.readthedocs.io/en/latest/api.html#tree.map_structure_with_path + for more details. +""" + + +def _is_prefix(a: Path, b: Path) -> bool: + """Returns whether `a` is a prefix of `b`.""" + return b[:len(a)] == a + + +def _zeros_like(nest: types.Nest, dtype=None) -> types.NestedArray: + return jax.tree_map(lambda x: jnp.zeros(x.shape, dtype or x.dtype), nest) + + +def _ones_like(nest: types.Nest, dtype=None) -> types.NestedArray: + return jax.tree_map(lambda x: jnp.ones(x.shape, dtype or x.dtype), nest) + + +@chex.dataclass(frozen=True) +class NestedMeanStd: + """A container for running statistics (mean, std) of possibly nested data.""" + mean: types.NestedArray + std: types.NestedArray + + +@chex.dataclass(frozen=True) +class RunningStatisticsState(NestedMeanStd): + """Full state of running statistics computation.""" + count: Union[int, jnp.ndarray] + summed_variance: types.NestedArray + + +@dataclasses.dataclass(frozen=True) +class NestStatisticsConfig: + """Specifies how to compute statistics for Nests with the same structure. + + Attributes: + paths: A sequence of Nest paths to compute statistics for. If there is a + collision between paths (one is a prefix of the other), the shorter path + takes precedence. + """ + paths: Tuple[Path, ...] = ((),) + + +def _is_path_included(config: NestStatisticsConfig, path: Path) -> bool: + """Returns whether the path is included in the config.""" + # A path is included in the config if it corresponds to a tree node that + # belongs to a subtree rooted at the node corresponding to some path in + # the config. + return any(_is_prefix(config_path, path) for config_path in config.paths) + + +def init_state(nest: types.Nest) -> RunningStatisticsState: + """Initializes the running statistics for the given nested structure.""" + dtype = jnp.float64 if jax.config.jax_enable_x64 else jnp.float32 + + return RunningStatisticsState( + count=0., + mean=_zeros_like(nest, dtype=dtype), + summed_variance=_zeros_like(nest, dtype=dtype), + # Initialize with ones to make sure normalization works correctly + # in the initial state. + std=_ones_like(nest, dtype=dtype)) + + +def _validate_batch_shapes(batch: types.NestedArray, + reference_sample: types.NestedArray, + batch_dims: Tuple[int, ...]) -> None: + """Verifies shapes of the batch leaves against the reference sample. + + Checks that batch dimensions are the same in all leaves in the batch. + Checks that non-batch dimensions for all leaves in the batch are the same + as in the reference sample. + + Arguments: + batch: the nested batch of data to be verified. + reference_sample: the nested array to check non-batch dimensions. + batch_dims: a Tuple of indices of batch dimensions in the batch shape. + + Returns: + None. + """ + def validate_node_shape(reference_sample: jnp.ndarray, + batch: jnp.ndarray) -> None: + expected_shape = batch_dims + reference_sample.shape + assert batch.shape == expected_shape, f'{batch.shape} != {expected_shape}' + + tree_utils.fast_map_structure(validate_node_shape, reference_sample, batch) + + +def update(state: RunningStatisticsState, + batch: types.NestedArray, + *, + config: NestStatisticsConfig = NestStatisticsConfig(), + weights: Optional[jnp.ndarray] = None, + std_min_value: float = 1e-6, + std_max_value: float = 1e6, + pmap_axis_name: Optional[str] = None, + validate_shapes: bool = True) -> RunningStatisticsState: + """Updates the running statistics with the given batch of data. + + Note: data batch and state elements (mean, etc.) must have the same structure. + + Note: by default will use int32 for counts and float32 for accumulated + variance. This results in an integer overflow after 2^31 data points and + degrading precision after 2^24 batch updates or even earlier if variance + updates have large dynamic range. + To improve precision, consider setting jax_enable_x64 to True, see + https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision + + Arguments: + state: The running statistics before the update. + batch: The data to be used to update the running statistics. + config: The config that specifies which leaves of the nested structure + should the running statistics be computed for. + weights: Weights of the batch data. Should match the batch dimensions. + Passing a weight of 2. should be equivalent to updating on the + corresponding data point twice. + std_min_value: Minimum value for the standard deviation. + std_max_value: Maximum value for the standard deviation. + pmap_axis_name: Name of the pmapped axis, if any. + validate_shapes: If true, the shapes of all leaves of the batch will be + validated. Enabled by default. Doesn't impact performance when jitted. + + Returns: + Updated running statistics. + """ + # We require exactly the same structure to avoid issues when flattened + # batch and state have different order of elements. + tree.assert_same_structure(batch, state.mean) + batch_shape = tree.flatten(batch)[0].shape + # We assume the batch dimensions always go first. + batch_dims = batch_shape[:len(batch_shape) - tree.flatten(state.mean)[0].ndim] + batch_axis = range(len(batch_dims)) + if weights is None: + step_increment = np.prod(batch_dims) + else: + step_increment = jnp.sum(weights) + if pmap_axis_name is not None: + step_increment = jax.lax.psum(step_increment, axis_name=pmap_axis_name) + count = state.count + step_increment + + # Validation is important. If the shapes don't match exactly, but are + # compatible, arrays will be silently broadcasted resulting in incorrect + # statistics. + if validate_shapes: + if weights is not None: + if weights.shape != batch_dims: + raise ValueError(f'{weights.shape} != {batch_dims}') + _validate_batch_shapes(batch, state.mean, batch_dims) + + def _compute_node_statistics( + path: Path, mean: jnp.ndarray, summed_variance: jnp.ndarray, + batch: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + assert isinstance(mean, jnp.ndarray), type(mean) + assert isinstance(summed_variance, jnp.ndarray), type(summed_variance) + if not _is_path_included(config, path): + # Return unchanged. + return mean, summed_variance + # The mean and the sum of past variances are updated with Welford's + # algorithm using batches (see https://stackoverflow.com/q/56402955). + diff_to_old_mean = batch - mean + if weights is not None: + expanded_weights = jnp.reshape( + weights, + list(weights.shape) + [1] * (batch.ndim - weights.ndim)) + diff_to_old_mean = diff_to_old_mean * expanded_weights + mean_update = jnp.sum(diff_to_old_mean, axis=batch_axis) / count + if pmap_axis_name is not None: + mean_update = jax.lax.psum( + mean_update, axis_name=pmap_axis_name) + mean = mean + mean_update + + diff_to_new_mean = batch - mean + variance_update = diff_to_old_mean * diff_to_new_mean + variance_update = jnp.sum(variance_update, axis=batch_axis) + if pmap_axis_name is not None: + variance_update = jax.lax.psum(variance_update, axis_name=pmap_axis_name) + summed_variance = summed_variance + variance_update + return mean, summed_variance + + updated_stats = tree_utils.fast_map_structure_with_path( + _compute_node_statistics, state.mean, state.summed_variance, batch) + # map_structure_up_to is slow, so shortcut if we know the input is not + # structured. + if isinstance(state.mean, jnp.ndarray): + mean, summed_variance = updated_stats + else: + # Reshape the updated stats from `nest(mean, summed_variance)` to + # `nest(mean), nest(summed_variance)`. + mean, summed_variance = [ + tree.map_structure_up_to( + state.mean, lambda s, i=idx: s[i], updated_stats) + for idx in range(2) + ] + + def compute_std(path: Path, summed_variance: jnp.ndarray, + std: jnp.ndarray) -> jnp.ndarray: + assert isinstance(summed_variance, jnp.ndarray) + if not _is_path_included(config, path): + return std + # Summed variance can get negative due to rounding errors. + summed_variance = jnp.maximum(summed_variance, 0) + std = jnp.sqrt(summed_variance / count) + std = jnp.clip(std, std_min_value, std_max_value) + return std + + std = tree_utils.fast_map_structure_with_path(compute_std, summed_variance, + state.std) + + return RunningStatisticsState( + count=count, mean=mean, summed_variance=summed_variance, std=std) + + +def normalize(batch: types.NestedArray, + mean_std: NestedMeanStd, + max_abs_value: Optional[float] = None) -> types.NestedArray: + """Normalizes data using running statistics.""" + + def normalize_leaf(data: jnp.ndarray, mean: jnp.ndarray, + std: jnp.ndarray) -> jnp.ndarray: + # Only normalize inexact types. + if not jnp.issubdtype(data.dtype, jnp.inexact): + return data + data = (data - mean) / std + if max_abs_value is not None: + # TODO(b/124318564): remove pylint directive + data = jnp.clip(data, -max_abs_value, +max_abs_value) # pylint: disable=invalid-unary-operand-type + return data + + return tree_utils.fast_map_structure(normalize_leaf, batch, mean_std.mean, + mean_std.std) + + +def denormalize(batch: types.NestedArray, + mean_std: NestedMeanStd) -> types.NestedArray: + """Denormalizes values in a nested structure using the given mean/std. + + Only values of inexact types are denormalized. + See https://numpy.org/doc/stable/_images/dtype-hierarchy.png for Numpy type + hierarchy. + + Args: + batch: a nested structure containing batch of data. + mean_std: mean and standard deviation used for denormalization. + + Returns: + Nested structure with denormalized values. + """ + + def denormalize_leaf(data: jnp.ndarray, mean: jnp.ndarray, + std: jnp.ndarray) -> jnp.ndarray: + # Only denormalize inexact types. + if not np.issubdtype(data.dtype, np.inexact): + return data + return data * std + mean + + return tree_utils.fast_map_structure(denormalize_leaf, batch, mean_std.mean, + mean_std.std) + + +@dataclasses.dataclass(frozen=True) +class NestClippingConfig: + """Specifies how to clip Nests with the same structure. + + Attributes: + path_map: A map that specifies how to clip values in Nests with the same + structure. Keys correspond to paths in the nest. Values are maximum + absolute values to use for clipping. If there is a collision between paths + (one path is a prefix of the other), the behavior is undefined. + """ + path_map: Tuple[Tuple[Path, float], ...] = () + + +def get_clip_config_for_path(config: NestClippingConfig, + path: Path) -> NestClippingConfig: + """Returns the config for a subtree from the leaf defined by the path.""" + # Start with an empty config. + path_map = [] + for map_path, max_abs_value in config.path_map: + if _is_prefix(map_path, path): + return NestClippingConfig(path_map=(((), max_abs_value),)) + if _is_prefix(path, map_path): + path_map.append((map_path[len(path):], max_abs_value)) + return NestClippingConfig(path_map=tuple(path_map)) + + +def clip(batch: types.NestedArray, + clipping_config: NestClippingConfig) -> types.NestedArray: + """Clips the batch.""" + + def max_abs_value_for_path(path: Path, x: jnp.ndarray) -> Optional[float]: + del x # Unused, needed by interface. + return next((max_abs_value + for clipping_path, max_abs_value in clipping_config.path_map + if _is_prefix(clipping_path, path)), None) + + max_abs_values = tree_utils.fast_map_structure_with_path( + max_abs_value_for_path, batch) + + def clip_leaf(data: jnp.ndarray, + max_abs_value: Optional[float]) -> jnp.ndarray: + if max_abs_value is not None: + # TODO(b/124318564): remove pylint directive + data = jnp.clip(data, -max_abs_value, +max_abs_value) # pylint: disable=invalid-unary-operand-type + return data + + return tree_utils.fast_map_structure(clip_leaf, batch, max_abs_values) + + +@dataclasses.dataclass(frozen=True) +class NestNormalizationConfig: + """Specifies how to normalize Nests with the same structure. + + Attributes: + stats_config: A config that defines how to compute running statistics to be + used for normalization. + clip_config: A config that defines how to clip normalized values. + """ + stats_config: NestStatisticsConfig = NestStatisticsConfig() + clip_config: NestClippingConfig = NestClippingConfig() diff --git a/acme/acme/jax/running_statistics_test.py b/acme/acme/jax/running_statistics_test.py new file mode 100644 index 00000000..21195157 --- /dev/null +++ b/acme/acme/jax/running_statistics_test.py @@ -0,0 +1,305 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for running statistics utilities.""" + +import functools +import math +from typing import NamedTuple + +from acme import specs +from acme.jax import running_statistics +import jax +from jax.config import config as jax_config +import jax.numpy as jnp +import numpy as np +import tree + +from absl.testing import absltest + +update_and_validate = functools.partial( + running_statistics.update, validate_shapes=True) + + +class TestNestedSpec(NamedTuple): + # Note: the fields are intentionally in reverse order to test ordering. + a: specs.Array + b: specs.Array + + +class RunningStatisticsTest(absltest.TestCase): + + def setUp(self): + super().setUp() + jax_config.update('jax_enable_x64', False) + + def assert_allclose(self, + actual: jnp.ndarray, + desired: jnp.ndarray, + err_msg: str = '') -> None: + np.testing.assert_allclose( + actual, desired, atol=1e-5, rtol=1e-5, err_msg=err_msg) + + def test_normalize(self): + state = running_statistics.init_state(specs.Array((5,), jnp.float32)) + + x = jnp.arange(200, dtype=jnp.float32).reshape(20, 2, 5) + x1, x2, x3, x4 = jnp.split(x, 4, axis=0) + + state = update_and_validate(state, x1) + state = update_and_validate(state, x2) + state = update_and_validate(state, x3) + state = update_and_validate(state, x4) + normalized = running_statistics.normalize(x, state) + + mean = jnp.mean(normalized) + std = jnp.std(normalized) + self.assert_allclose(mean, jnp.zeros_like(mean)) + self.assert_allclose(std, jnp.ones_like(std)) + + def test_init_normalize(self): + state = running_statistics.init_state(specs.Array((5,), jnp.float32)) + + x = jnp.arange(200, dtype=jnp.float32).reshape(20, 2, 5) + normalized = running_statistics.normalize(x, state) + + self.assert_allclose(normalized, x) + + def test_one_batch_dim(self): + state = running_statistics.init_state(specs.Array((5,), jnp.float32)) + + x = jnp.arange(10, dtype=jnp.float32).reshape(2, 5) + + state = update_and_validate(state, x) + normalized = running_statistics.normalize(x, state) + + mean = jnp.mean(normalized, axis=0) + std = jnp.std(normalized, axis=0) + self.assert_allclose(mean, jnp.zeros_like(mean)) + self.assert_allclose(std, jnp.ones_like(std)) + + def test_clip(self): + state = running_statistics.init_state(specs.Array((), jnp.float32)) + + x = jnp.arange(5, dtype=jnp.float32) + + state = update_and_validate(state, x) + normalized = running_statistics.normalize(x, state, max_abs_value=1.0) + + mean = jnp.mean(normalized) + std = jnp.std(normalized) + self.assert_allclose(mean, jnp.zeros_like(mean)) + self.assert_allclose(std, jnp.ones_like(std) * math.sqrt(0.6)) + + def test_nested_normalize(self): + state = running_statistics.init_state({ + 'a': specs.Array((5,), jnp.float32), + 'b': specs.Array((2,), jnp.float32) + }) + + x1 = { + 'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5), + 'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2) + } + x2 = { + 'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5) + 20, + 'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2) + 8 + } + x3 = { + 'a': jnp.arange(40, dtype=jnp.float32).reshape(4, 2, 5), + 'b': jnp.arange(16, dtype=jnp.float32).reshape(4, 2, 2) + } + + state = update_and_validate(state, x1) + state = update_and_validate(state, x2) + state = update_and_validate(state, x3) + normalized = running_statistics.normalize(x3, state) + + mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized) + std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized) + tree.map_structure( + lambda x: self.assert_allclose(x, jnp.zeros_like(x)), + mean) + tree.map_structure( + lambda x: self.assert_allclose(x, jnp.ones_like(x)), + std) + + def test_validation(self): + state = running_statistics.init_state(specs.Array((1, 2, 3), jnp.float32)) + + x = jnp.arange(12, dtype=jnp.float32).reshape(2, 2, 3) + with self.assertRaises(AssertionError): + update_and_validate(state, x) + + x = jnp.arange(3, dtype=jnp.float32).reshape(1, 1, 3) + with self.assertRaises(AssertionError): + update_and_validate(state, x) + + def test_int_not_normalized(self): + state = running_statistics.init_state(specs.Array((), jnp.int32)) + + x = jnp.arange(5, dtype=jnp.int32) + + state = update_and_validate(state, x) + normalized = running_statistics.normalize(x, state) + + np.testing.assert_array_equal(normalized, x) + + def test_pmap_update_nested(self): + local_device_count = jax.local_device_count() + state = running_statistics.init_state({ + 'a': specs.Array((5,), jnp.float32), + 'b': specs.Array((2,), jnp.float32) + }) + + x = { + 'a': (jnp.arange(15 * local_device_count, + dtype=jnp.float32)).reshape(local_device_count, 3, 5), + 'b': (jnp.arange(6 * local_device_count, + dtype=jnp.float32)).reshape(local_device_count, 3, 2), + } + + devices = jax.local_devices() + state = jax.device_put_replicated(state, devices) + pmap_axis_name = 'i' + state = jax.pmap( + functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name), + pmap_axis_name)(state, x) + state = jax.pmap( + functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name), + pmap_axis_name)(state, x) + normalized = jax.pmap(running_statistics.normalize)(x, state) + + mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized) + std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized) + tree.map_structure( + lambda x: self.assert_allclose(x, jnp.zeros_like(x)), mean) + tree.map_structure( + lambda x: self.assert_allclose(x, jnp.ones_like(x)), std) + + def test_different_structure_normalize(self): + spec = TestNestedSpec( + a=specs.Array((5,), jnp.float32), b=specs.Array((2,), jnp.float32)) + state = running_statistics.init_state(spec) + + x = { + 'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5), + 'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2) + } + + with self.assertRaises(TypeError): + state = update_and_validate(state, x) + + def test_weights(self): + state = running_statistics.init_state(specs.Array((), jnp.float32)) + + x = jnp.arange(5, dtype=jnp.float32) + x_weights = jnp.ones_like(x) + y = 2 * x + 5 + y_weights = 2 * x_weights + z = jnp.concatenate([x, y]) + weights = jnp.concatenate([x_weights, y_weights]) + + state = update_and_validate(state, z, weights=weights) + + self.assertEqual(state.mean, (jnp.mean(x) + 2 * jnp.mean(y)) / 3) + big_z = jnp.concatenate([x, y, y]) + normalized = running_statistics.normalize(big_z, state) + self.assertAlmostEqual(jnp.mean(normalized), 0., places=6) + self.assertAlmostEqual(jnp.std(normalized), 1., places=6) + + def test_normalize_config(self): + x = jnp.arange(200, dtype=jnp.float32).reshape(20, 2, 5) + x_split = jnp.split(x, 5, axis=0) + + y = jnp.arange(160, dtype=jnp.float32).reshape(20, 2, 4) + y_split = jnp.split(y, 5, axis=0) + + z = {'a': x, 'b': y} + + z_split = [{'a': xx, 'b': yy} for xx, yy in zip(x_split, y_split)] + + update = jax.jit(running_statistics.update, static_argnames=('config',)) + + config = running_statistics.NestStatisticsConfig((('a',),)) + state = running_statistics.init_state({ + 'a': specs.Array((5,), jnp.float32), + 'b': specs.Array((4,), jnp.float32) + }) + # Test initialization from the first element. + state = update(state, z_split[0], config=config) + state = update(state, z_split[1], config=config) + state = update(state, z_split[2], config=config) + state = update(state, z_split[3], config=config) + state = update(state, z_split[4], config=config) + + normalize = jax.jit(running_statistics.normalize) + normalized = normalize(z, state) + + for key in normalized: + mean = jnp.mean(normalized[key], axis=(0, 1)) + std = jnp.std(normalized[key], axis=(0, 1)) + if key == 'a': + self.assert_allclose( + mean, + jnp.zeros_like(mean), + err_msg=f'key:{key} mean:{mean} normalized:{normalized[key]}') + self.assert_allclose( + std, + jnp.ones_like(std), + err_msg=f'key:{key} std:{std} normalized:{normalized[key]}') + else: + assert key == 'b' + np.testing.assert_array_equal( + normalized[key], + z[key], + err_msg=f'z:{z[key]} normalized:{normalized[key]}') + + def test_clip_config(self): + x = jnp.arange(10, dtype=jnp.float32) - 5 + y = jnp.arange(8, dtype=jnp.float32) - 4 + + z = {'x': x, 'y': y} + + max_abs_x = 2 + config = running_statistics.NestClippingConfig(((('x',), max_abs_x),)) + + clipped_z = running_statistics.clip(z, config) + + clipped_x = jnp.clip(a=x, a_min=-max_abs_x, a_max=max_abs_x) + np.testing.assert_array_equal(clipped_z['x'], clipped_x) + + np.testing.assert_array_equal(clipped_z['y'], z['y']) + + def test_denormalize(self): + state = running_statistics.init_state(specs.Array((5,), jnp.float32)) + + x = jnp.arange(100, dtype=jnp.float32).reshape(10, 2, 5) + x1, x2 = jnp.split(x, 2, axis=0) + + state = update_and_validate(state, x1) + state = update_and_validate(state, x2) + normalized = running_statistics.normalize(x, state) + + mean = jnp.mean(normalized) + std = jnp.std(normalized) + self.assert_allclose(mean, jnp.zeros_like(mean)) + self.assert_allclose(std, jnp.ones_like(std)) + + denormalized = running_statistics.denormalize(normalized, state) + self.assert_allclose(denormalized, x) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/jax/savers.py b/acme/acme/jax/savers.py new file mode 100644 index 00000000..5c844205 --- /dev/null +++ b/acme/acme/jax/savers.py @@ -0,0 +1,96 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility classes for saving model checkpoints.""" + +import datetime +import os +import pickle +from typing import Any + +from absl import logging +from acme import core +from acme.tf import savers as tf_savers +import jax.numpy as jnp +import numpy as np +import tree + +# Internal imports. + +CheckpointState = Any + +_DEFAULT_CHECKPOINT_TTL = int(datetime.timedelta(days=5).total_seconds()) +_ARRAY_NAME = 'array_nest' +_EXEMPLAR_NAME = 'nest_exemplar' + + +def restore_from_path(ckpt_dir: str) -> CheckpointState: + """Restore the state stored in ckpt_dir.""" + array_path = os.path.join(ckpt_dir, _ARRAY_NAME) + exemplar_path = os.path.join(ckpt_dir, _EXEMPLAR_NAME) + + with open(exemplar_path, 'rb') as f: + exemplar = pickle.load(f) + + with open(array_path, 'rb') as f: + files = np.load(f, allow_pickle=True) + flat_state = [files[key] for key in files.files] + unflattened_tree = tree.unflatten_as(exemplar, flat_state) + + def maybe_convert_to_python(value, numpy): + return value if numpy else value.item() + + return tree.map_structure(maybe_convert_to_python, unflattened_tree, exemplar) + + +def save_to_path(ckpt_dir: str, state: CheckpointState): + """Save the state in ckpt_dir.""" + + if not os.path.exists(ckpt_dir): + os.makedirs(ckpt_dir) + + is_numpy = lambda x: isinstance(x, (np.ndarray, jnp.DeviceArray)) + flat_state = tree.flatten(state) + nest_exemplar = tree.map_structure(is_numpy, state) + + array_path = os.path.join(ckpt_dir, _ARRAY_NAME) + logging.info('Saving flattened array nest to %s', array_path) + def _disabled_seek(*_): + raise AttributeError('seek() is disabled on this object.') + with open(array_path, 'wb') as f: + setattr(f, 'seek', _disabled_seek) + np.savez(f, *flat_state) + + exemplar_path = os.path.join(ckpt_dir, _EXEMPLAR_NAME) + logging.info('Saving nest exemplar to %s', exemplar_path) + with open(exemplar_path, 'wb') as f: + pickle.dump(nest_exemplar, f) + + +# Use TF checkpointer. +class Checkpointer(tf_savers.Checkpointer): + + def __init__( + self, + object_to_save: core.Saveable, + directory: str = '~/acme', + subdirectory: str = 'default', + **tf_checkpointer_kwargs): + super().__init__(dict(saveable=object_to_save), + directory=directory, + subdirectory=subdirectory, + **tf_checkpointer_kwargs) + + +CheckpointingRunner = tf_savers.CheckpointingRunner diff --git a/acme/acme/jax/savers_test.py b/acme/acme/jax/savers_test.py new file mode 100644 index 00000000..d67403dd --- /dev/null +++ b/acme/acme/jax/savers_test.py @@ -0,0 +1,89 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for savers.""" + +from unittest import mock + +from acme import core +from acme.jax import savers +from acme.testing import test_utils +from acme.utils import paths +import jax.numpy as jnp +import numpy as np +import tree + +from absl.testing import absltest + + +class DummySaveable(core.Saveable): + + def __init__(self, state): + self.state = state + + def save(self): + return self.state + + def restore(self, state): + self.state = state + + +def nest_assert_equal(a, b): + tree.map_structure(np.testing.assert_array_equal, a, b) + + +class SaverTest(test_utils.TestCase): + + def setUp(self): + super().setUp() + self._test_state = { + 'foo': jnp.ones(shape=(8, 4), dtype=jnp.float32), + 'bar': [jnp.zeros(shape=(3, 2), dtype=jnp.int32)], + 'baz': 3, + } + + def test_save_restore(self): + """Checks that we can save and restore state.""" + directory = self.get_tempdir() + savers.save_to_path(directory, self._test_state) + result = savers.restore_from_path(directory) + nest_assert_equal(result, self._test_state) + + def test_checkpointer(self): + """Checks that the Checkpointer class saves and restores as expected.""" + + with mock.patch.object(paths, 'get_unique_id') as mock_unique_id: + mock_unique_id.return_value = ('test',) + + # Given a path and some stateful object... + directory = self.get_tempdir() + x = DummySaveable(self._test_state) + + # If we checkpoint it... + checkpointer = savers.Checkpointer(x, directory, time_delta_minutes=0) + checkpointer.save() + + # The checkpointer should restore the object's state. + x.state = None + checkpointer.restore() + nest_assert_equal(x.state, self._test_state) + + # Checkpointers should also attempt a restore at construction time. + x.state = None + savers.Checkpointer(x, directory, time_delta_minutes=0) + nest_assert_equal(x.state, self._test_state) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/jax/snapshotter.py b/acme/acme/jax/snapshotter.py new file mode 100644 index 00000000..81ca4784 --- /dev/null +++ b/acme/acme/jax/snapshotter.py @@ -0,0 +1,116 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility classes for snapshotting models.""" + +import os +import time +from typing import Callable, Dict, List, Optional, Sequence, Tuple + +from absl import logging +from acme import core +from acme.jax import types +from acme.utils import signals +from acme.utils import paths +from jax.experimental import jax2tf +import tensorflow as tf + +# Internal imports. + + +class JAXSnapshotter(core.Worker): + """Periodically fetches new version of params and stores tf.saved_models.""" + + # NOTE: External contributor please refrain from modifying the high level of + # the API defined here. + + def __init__(self, + variable_source: core.VariableSource, + models: Dict[str, Callable[[core.VariableSource], + types.ModelToSnapshot]], + path: str, + subdirectory: Optional[str] = None, + max_to_keep: Optional[int] = None, + add_uid: bool = False): + self._variable_source = variable_source + self._models = models + if subdirectory is not None: + self._path = paths.process_path(path, subdirectory, add_uid=add_uid) + else: + self._path = paths.process_path(path, add_uid=add_uid) + self._max_to_keep = max_to_keep + self._snapshot_paths: Optional[List[str]] = None + + # Handle preemption signal. Note that this must happen in the main thread. + def _signal_handler(self): + logging.info('Caught SIGTERM: forcing models save.') + self._save() + + def _save(self): + if not self._snapshot_paths: + # Lazy discovery of already existing snapshots. + self._snapshot_paths = os.listdir(self._path) + self._snapshot_paths.sort(reverse=True) + + snapshot_location = os.path.join(self._path, time.strftime('%Y%m%d-%H%M%S')) + if self._snapshot_paths and self._snapshot_paths[0] == snapshot_location: + logging.info('Snapshot for the current time already exists.') + return + + # To make sure models are captured as close as possible from the same time + # we gather all the `ModelToSnapshot` in a 1st loop. We then convert/saved + # them in another loop as this operation can be slow. + models_and_paths = self._get_models_and_paths(path=snapshot_location) + self._snapshot_paths.insert(0, snapshot_location) + + for model, saving_path in models_and_paths: + self._snapshot_model(model=model, saving_path=saving_path) + + # Delete any excess snapshots. + while self._max_to_keep and len(self._snapshot_paths) > self._max_to_keep: + paths.rmdir(os.path.join(self._path, self._snapshot_paths.pop())) + + def _get_models_and_paths( + self, path: str) -> Sequence[Tuple[types.ModelToSnapshot, str]]: + """Gets the models to save asssociated with their saving path.""" + models_and_paths = [] + for name, model_fn in self._models.items(): + model = model_fn(self._variable_source) + model_path = os.path.join(path, name) + models_and_paths.append((model, model_path)) + return models_and_paths + + def _snapshot_model(self, model: types.ModelToSnapshot, + saving_path: str) -> None: + module = model_to_tf_module(model) + tf.saved_model.save(module, saving_path) + + def run(self): + """Runs the saver.""" + with signals.runtime_terminator(self._signal_handler): + while True: + self._save() + time.sleep(5 * 60) + + +def model_to_tf_module(model: types.ModelToSnapshot) -> tf.Module: + + def jax_fn_to_save(**kwargs): + return model.model(model.params, **kwargs) + + module = tf.Module() + module.f = tf.function(jax2tf.convert(jax_fn_to_save), autograph=False) + # Traces input to ensure the model has the correct shapes. + module.f(**model.dummy_kwargs) + return module diff --git a/acme/acme/jax/snapshotter_test.py b/acme/acme/jax/snapshotter_test.py new file mode 100644 index 00000000..e833437d --- /dev/null +++ b/acme/acme/jax/snapshotter_test.py @@ -0,0 +1,139 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for snapshotter.""" + +import os +import time +from typing import Any, Sequence + +from acme import core +from acme.jax import snapshotter +from acme.jax import types +from acme.testing import test_utils +import jax.numpy as jnp + +from absl.testing import absltest + + +def _model0(params, x1, x2): + return params['w0'] * jnp.sin(x1) + params['w1'] * jnp.cos(x2) + + +def _model1(params, x): + return params['p0'] * jnp.log(x) + + +class _DummyVariableSource(core.VariableSource): + + def __init__(self): + self._params_model0 = { + 'w0': jnp.ones([2, 3], dtype=jnp.float32), + 'w1': 2 * jnp.ones([2, 3], dtype=jnp.float32), + } + + self._params_model1 = { + 'p0': jnp.ones([3, 1], dtype=jnp.float32), + } + + def get_variables(self, names: Sequence[str]) -> Sequence[Any]: + variables = [] + for n in names: + if n == 'params_model0': + variables.append(self._params_model0) + elif n == 'params_model1': + variables.append(self._params_model1) + else: + raise ValueError('Unknow variable name: {n}') + return variables + + +def _get_model0(variable_source: core.VariableSource) -> types.ModelToSnapshot: + return types.ModelToSnapshot( + model=_model0, + params=variable_source.get_variables(['params_model0'])[0], + dummy_kwargs={ + 'x1': jnp.ones([2, 3], dtype=jnp.float32), + 'x2': jnp.ones([2, 3], dtype=jnp.float32), + }, + ) + + +def _get_model1(variable_source: core.VariableSource) -> types.ModelToSnapshot: + return types.ModelToSnapshot( + model=_model1, + params=variable_source.get_variables(['params_model1'])[0], + dummy_kwargs={ + 'x': jnp.ones([3, 1], dtype=jnp.float32), + }, + ) + + +class SnapshotterTest(test_utils.TestCase): + + def setUp(self): + super().setUp() + self._test_models = {'model0': _get_model0, 'model1': _get_model1} + + def _check_snapshot(self, directory: str, name: str): + self.assertTrue(os.path.exists(os.path.join(directory, name, 'model0'))) + self.assertTrue(os.path.exists(os.path.join(directory, name, 'model1'))) + + def test_snapshotter(self): + """Checks that the Snapshotter class saves as expected.""" + directory = self.get_tempdir() + + models_snapshotter = snapshotter.JAXSnapshotter( + variable_source=_DummyVariableSource(), + models=self._test_models, + path=directory, + max_to_keep=2, + add_uid=False, + ) + models_snapshotter._save() + + # The snapshots are written in a folder of the form: + # PATH/{time.strftime}/MODEL_NAME + first_snapshots = os.listdir(directory) + self.assertEqual(len(first_snapshots), 1) + self._check_snapshot(directory, first_snapshots[0]) + # Make sure that the second snapshot is constructed. + time.sleep(1.1) + models_snapshotter._save() + snapshots = os.listdir(directory) + self.assertEqual(len(snapshots), 2) + self._check_snapshot(directory, snapshots[0]) + self._check_snapshot(directory, snapshots[1]) + + # Make sure that new snapshotter deletes the oldest snapshot upon _save(). + time.sleep(1.1) + models_snapshotter2 = snapshotter.JAXSnapshotter( + variable_source=_DummyVariableSource(), + models=self._test_models, + path=directory, + max_to_keep=2, + add_uid=False, + ) + self.assertEqual(snapshots, os.listdir(directory)) + time.sleep(1.1) + models_snapshotter2._save() + snapshots = os.listdir(directory) + self.assertNotIn(first_snapshots[0], snapshots) + self.assertEqual(len(snapshots), 2) + self._check_snapshot(directory, snapshots[0]) + self._check_snapshot(directory, snapshots[1]) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/jax/types.py b/acme/acme/jax/types.py new file mode 100644 index 00000000..6e979d69 --- /dev/null +++ b/acme/acme/jax/types.py @@ -0,0 +1,68 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common JAX type definitions.""" + +import dataclasses +from typing import Any, Callable, Dict, Generic, Mapping, TypeVar + +from acme import types +import chex +import dm_env +import jax +import jax.numpy as jnp + +PRNGKey = jax.random.KeyArray +Networks = TypeVar('Networks') +PolicyNetwork = TypeVar('PolicyNetwork') +Sample = TypeVar('Sample') +TrainingState = TypeVar('TrainingState') + +TrainingMetrics = Mapping[str, jnp.ndarray] +"""Metrics returned by the training step. + +Typically these are logged, so the values are expected to be scalars. +""" + +Variables = Mapping[str, types.NestedArray] +"""Mapping of variable collections. + +A mapping of variable collections, as defined by Learner.get_variables. +The keys are the collection names, the values are nested arrays representing +the values of the corresponding collection variables. +""" + + +@chex.dataclass(frozen=True, mappable_dataclass=False) +class TrainingStepOutput(Generic[TrainingState]): + state: TrainingState + metrics: TrainingMetrics + + +Seed = int +EnvironmentFactory = Callable[[Seed], dm_env.Environment] + + +@dataclasses.dataclass +class ModelToSnapshot: + """Stores all necessary info to be able to save a model. + + Attributes: + model: a jax function to be saved. + params: fixed params to be passed to the function. + dummy_kwargs: arguments to be passed to the function. + """ + model: Any # Callable[params, **dummy_kwargs] + params: Any + dummy_kwargs: Dict[str, Any] diff --git a/acme/acme/jax/utils.py b/acme/acme/jax/utils.py new file mode 100644 index 00000000..d3176217 --- /dev/null +++ b/acme/acme/jax/utils.py @@ -0,0 +1,580 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for JAX.""" + +import functools +import itertools +import queue +import threading +from typing import Callable, Iterable, Iterator, NamedTuple, Optional, Sequence, Tuple, TypeVar + +from absl import logging +from acme import core +from acme import types +from acme.jax import types as jax_types +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +import tree + + +F = TypeVar('F', bound=Callable) +N = TypeVar('N', bound=types.NestedArray) +T = TypeVar('T') + + +NUM_PREFETCH_THREADS = 1 + + +def add_batch_dim(values: types.Nest) -> types.NestedArray: + return jax.tree_map(lambda x: jnp.expand_dims(x, axis=0), values) + + +def _flatten(x: jnp.ndarray, num_batch_dims: int) -> jnp.ndarray: + """Flattens the input, preserving the first ``num_batch_dims`` dimensions. + + If the input has fewer than ``num_batch_dims`` dimensions, it is returned + unchanged. + If the input has exactly ``num_batch_dims`` dimensions, an extra dimension + is added. This is needed to handle batched scalars. + + Arguments: + x: the input array to flatten. + num_batch_dims: number of dimensions to preserve. + + Returns: + flattened input. + """ + # TODO(b/173492429): consider throwing an error instead. + if x.ndim < num_batch_dims: + return x + return jnp.reshape(x, list(x.shape[:num_batch_dims]) + [-1]) + + +def batch_concat( + values: types.NestedArray, + num_batch_dims: int = 1, +) -> jnp.ndarray: + """Flatten and concatenate nested array structure, keeping batch dims.""" + flatten_fn = lambda x: _flatten(x, num_batch_dims) + flat_leaves = tree.map_structure(flatten_fn, values) + return jnp.concatenate(tree.flatten(flat_leaves), axis=-1) + + +def zeros_like(nest: types.Nest, dtype=None) -> types.NestedArray: + return jax.tree_map(lambda x: jnp.zeros(x.shape, dtype or x.dtype), nest) + + +def ones_like(nest: types.Nest, dtype=None) -> types.NestedArray: + return jax.tree_map(lambda x: jnp.ones(x.shape, dtype or x.dtype), nest) + + +def squeeze_batch_dim(nest: types.Nest) -> types.NestedArray: + return jax.tree_map(lambda x: jnp.squeeze(x, axis=0), nest) + + +def to_numpy_squeeze(values: types.Nest) -> types.NestedArray: + """Converts to numpy and squeezes out dummy batch dimension.""" + return jax.tree_map(lambda x: np.asarray(x).squeeze(axis=0), values) + + +def to_numpy(values: types.Nest) -> types.NestedArray: + return jax.tree_map(np.asarray, values) + + +def fetch_devicearray(values: types.Nest) -> types.Nest: + """Fetches and converts any DeviceArrays to np.ndarrays.""" + return tree.map_structure(_fetch_devicearray, values) + + +def _fetch_devicearray(x): + if isinstance(x, jax.xla.DeviceArray): + return np.asarray(x) + return x + + +def batch_to_sequence(values: types.Nest) -> types.NestedArray: + return jax.tree_map( + lambda x: jnp.transpose(x, axes=(1, 0, *range(2, len(x.shape)))), values) + + +def tile_array(array: jnp.ndarray, multiple: int) -> jnp.ndarray: + """Tiles `multiple` copies of `array` along a new leading axis.""" + return jnp.stack([array] * multiple) + + +def tile_nested(inputs: types.Nest, multiple: int) -> types.Nest: + """Tiles tensors in a nested structure along a new leading axis.""" + tile = functools.partial(tile_array, multiple=multiple) + return jax.tree_map(tile, inputs) + + +def maybe_recover_lstm_type(state: types.NestedArray) -> types.NestedArray: + """Recovers the type hk.LSTMState if LSTMState is in the type name. + + When the recurrent state of recurrent neural networks (RNN) is deserialized, + for example when it is sampled from replay, it is sometimes repacked in a type + that is identical to the source type but not the correct type itself. When + using this state as the initial state in an hk.dynamic_unroll, this will + cause hk.dynamic_unroll to raise an error as it requires its input and output + states to be identical. + + Args: + state: a nested structure of arrays representing the state of an RNN. + + Returns: + Either the state unchanged if it is anything but an LSTMState, otherwise + returns the state arrays properly contained in an hk.LSTMState. + """ + return hk.LSTMState(*state) if type(state).__name__ == 'LSTMState' else state + + +def prefetch( + iterable: Iterable[T], + buffer_size: int = 5, + device: Optional[jax.xla.Device] = None, + num_threads: int = NUM_PREFETCH_THREADS, +) -> core.PrefetchingIterator[T]: + """Returns prefetching iterator with additional 'ready' method.""" + + return PrefetchIterator(iterable, buffer_size, device, num_threads) + + +class PrefetchingSplit(NamedTuple): + host: types.NestedArray + device: types.NestedArray + + +_SplitFunction = Callable[[types.NestedArray], PrefetchingSplit] + + +def device_put( + iterable: Iterable[types.NestedArray], + device: jax.xla.Device, + split_fn: Optional[_SplitFunction] = None, +): + """Returns iterator that samples an item and places it on the device.""" + + return PutToDevicesIterable( + iterable=iterable, + pmapped_user=False, + devices=[device], + split_fn=split_fn) + + +def multi_device_put( + iterable: Iterable[types.NestedArray], + devices: Sequence[jax.xla.Device], + split_fn: Optional[_SplitFunction] = None, +): + """Returns iterator that, per device, samples an item and places on device.""" + + return PutToDevicesIterable( + iterable=iterable, pmapped_user=True, devices=devices, split_fn=split_fn) + + +class PutToDevicesIterable(Iterable[types.NestedArray]): + """Per device, samples an item from iterator and places on device. + + if pmapped_user: + Items from the resulting generator are intended to be used in a pmapped + function. Every element is a ShardedDeviceArray or (nested) Python container + thereof. A single next() call to this iterator results in len(devices) + calls to the underlying iterator. The returned items are put one on each + device. + if not pmapped_user: + Places a sample from the iterator on the given device. + + Yields: + If no split_fn is specified: + DeviceArray/ShardedDeviceArray or (nested) Python container thereof + representing the elements of shards stacked together, with each shard + backed by physical device memory specified by the corresponding entry in + devices. + + If split_fn is specified: + PrefetchingSplit where the .host element is a stacked numpy array or + (nested) Python contained thereof. The .device element is a + DeviceArray/ShardedDeviceArray or (nested) Python container thereof. + + Raises: + StopIteration: if there are not enough items left in the iterator to place + one sample on each device. + Any error thrown by the iterable_function. Note this is not raised inside + the producer, but after it finishes executing. + """ + + def __init__( + self, + iterable: Iterable[types.NestedArray], + pmapped_user: bool, + devices: Sequence[jax.xla.Device], + split_fn: Optional[_SplitFunction] = None, + ): + """Constructs PutToDevicesIterable. + + Args: + iterable: A python iterable. This is used to build the python prefetcher. + Note that each iterable should only be passed to this function once as + iterables aren't thread safe. + pmapped_user: whether the user of data from this iterator is implemented + using pmapping. + devices: Devices used for prefecthing. + split_fn: Optional function applied to every element from the iterable to + split the parts of it that will be kept in the host and the parts that + will sent to the device. + + Raises: + ValueError: If devices list is empty, or if pmapped_use=False and more + than 1 device is provided. + """ + self.num_devices = len(devices) + if self.num_devices == 0: + raise ValueError('At least one device must be specified.') + if (not pmapped_user) and (self.num_devices != 1): + raise ValueError('User is not implemented with pmapping but len(devices) ' + f'= {len(devices)} is not equal to 1! Devices given are:' + f'\n{devices}') + + self.iterable = iterable + self.pmapped_user = pmapped_user + self.split_fn = split_fn + self.devices = devices + self.iterator = iter(self.iterable) + + def __iter__(self) -> Iterator[types.NestedArray]: + # It is important to structure the Iterable like this, because in + # JustPrefetchIterator we must build a new iterable for each thread. + # This is crucial if working with tensorflow datasets because tf.Graph + # objects are thread local. + self.iterator = iter(self.iterable) + return self + + def __next__(self) -> types.NestedArray: + try: + if not self.pmapped_user: + item = next(self.iterator) + if self.split_fn is None: + return jax.device_put(item, self.devices[0]) + item_split = self.split_fn(item) + return PrefetchingSplit( + host=item_split.host, + device=jax.device_put(item_split.device, self.devices[0])) + + items = itertools.islice(self.iterator, self.num_devices) + items = tuple(items) + if len(items) < self.num_devices: + raise StopIteration + if self.split_fn is None: + return jax.device_put_sharded(tuple(items), self.devices) + else: + # ((host: x1, device: y1), ..., (host: xN, device: yN)). + items_split = (self.split_fn(item) for item in items) + # (host: (x1, ..., xN), device: (y1, ..., yN)). + split = tree.map_structure_up_to( + PrefetchingSplit(None, None), lambda *x: x, *items_split) + + return PrefetchingSplit( + host=np.stack(split.host), + device=jax.device_put_sharded(split.device, self.devices)) + + except StopIteration: + raise + + except Exception: # pylint: disable=broad-except + logging.exception('Error for %s', self.iterable) + raise + + +def sharded_prefetch( + iterable: Iterable[types.NestedArray], + buffer_size: int = 5, + num_threads: int = 1, + split_fn: Optional[_SplitFunction] = None, + devices: Optional[Sequence[jax.xla.Device]] = None, +) -> core.PrefetchingIterator: + """Performs sharded prefetching from an iterable in separate threads. + + Elements from the resulting generator are intended to be used in a jax.pmap + call. Every element is a sharded prefetched array with an additional replica + dimension and corresponds to jax.local_device_count() elements from the + original iterable. + + Args: + iterable: A python iterable. This is used to build the python prefetcher. + Note that each iterable should only be passed to this function once as + iterables aren't thread safe. + buffer_size (int): Number of elements to keep in the prefetch buffer. + num_threads (int): Number of threads. + split_fn: Optional function applied to every element from the iterable to + split the parts of it that will be kept in the host and the parts that + will sent to the device. + devices: Devices used for prefecthing. Optional, jax.local_devices() by + default. + + Returns: + Prefetched elements from the original iterable with additional replica + dimension. + Raises: + ValueError if the buffer_size <= 1. + Any error thrown by the iterable_function. Note this is not raised inside + the producer, but after it finishes executing. + """ + + devices = devices or jax.local_devices() + + iterable = PutToDevicesIterable( + iterable=iterable, pmapped_user=True, devices=devices, split_fn=split_fn) + + return prefetch(iterable, buffer_size, device=None, num_threads=num_threads) + + +def replicate_in_all_devices(nest: N, + devices: Optional[Sequence[jax.xla.Device]] = None + ) -> N: + """Replicate array nest in all available devices.""" + devices = devices or jax.local_devices() + return jax.device_put_sharded([nest] * len(devices), devices) + + +def get_from_first_device(nest: N, as_numpy: bool = True) -> N: + """Gets the first array of a nest of `jax.pxla.ShardedDeviceArray`s. + + Args: + nest: A nest of `jax.pxla.ShardedDeviceArray`s. + as_numpy: If `True` then each `DeviceArray` that is retrieved is transformed + (and copied if not on the host machine) into a `np.ndarray`. + + Returns: + The first array of a nest of `jax.pxla.ShardedDeviceArray`s. Note that if + `as_numpy=False` then the array will be a `DeviceArray` (which will live on + the same device as the sharded device array). If `as_numpy=True` then the + array will be copied to the host machine and converted into a `np.ndarray`. + """ + zeroth_nest = jax.tree_map(lambda x: x[0], nest) + return jax.device_get(zeroth_nest) if as_numpy else zeroth_nest + + +def mapreduce( + f: F, + reduce_fn: Optional[Callable[[jnp.DeviceArray], jnp.DeviceArray]] = None, + **vmap_kwargs, +) -> F: + """A simple decorator that transforms `f` into (`reduce_fn` o vmap o f). + + By default, we vmap over axis 0, and the `reduce_fn` is jnp.mean over axis 0. + Note that the call signature of `f` is invariant under this transformation. + + If, for example, f has shape signature [H, W] -> [N], then mapreduce(f) + (with the default arguments) will have shape signature [B, H, W] -> [N]. + + Args: + f: A pure function over examples. + reduce_fn: A pure function that reduces DeviceArrays -> DeviceArrays. + **vmap_kwargs: Keyword arguments to forward to `jax.vmap`. + + Returns: + g: A pure function over batches of examples. + """ + + if reduce_fn is None: + reduce_fn = lambda x: jnp.mean(x, axis=0) + + vmapped_f = jax.vmap(f, **vmap_kwargs) + + def g(*args, **kwargs): + return jax.tree_map(reduce_fn, vmapped_f(*args, **kwargs)) + + return g + + +_TrainingState = TypeVar('_TrainingState') +_TrainingData = TypeVar('_TrainingData') +_TrainingAux = TypeVar('_TrainingAux') + + +# TODO(b/192806089): migrate all callers to process_many_batches and remove this +# method. +def process_multiple_batches( + process_one_batch: Callable[[_TrainingState, _TrainingData], + Tuple[_TrainingState, _TrainingAux]], + num_batches: int, + postprocess_aux: Optional[Callable[[_TrainingAux], _TrainingAux]] = None +) -> Callable[[_TrainingState, _TrainingData], Tuple[_TrainingState, + _TrainingAux]]: + """Makes 'process_one_batch' process multiple batches at once. + + Args: + process_one_batch: a function that takes 'state' and 'data', and returns + 'new_state' and 'aux' (for example 'metrics'). + num_batches: how many batches to process at once + postprocess_aux: how to merge the extra information, defaults to taking the + mean. + + Returns: + A function with the same interface as 'process_one_batch' which processes + multiple batches at once. + """ + assert num_batches >= 1 + if num_batches == 1: + if not postprocess_aux: + return process_one_batch + def _process_one_batch(state, data): + state, aux = process_one_batch(state, data) + return state, postprocess_aux(aux) + return _process_one_batch + + if postprocess_aux is None: + postprocess_aux = lambda x: jax.tree_map(jnp.mean, x) + + def _process_multiple_batches(state, data): + data = jax.tree_map( + lambda a: jnp.reshape(a, (num_batches, -1, *a.shape[1:])), data) + + state, aux = jax.lax.scan( + process_one_batch, state, data, length=num_batches) + return state, postprocess_aux(aux) + + return _process_multiple_batches + + +def process_many_batches( + process_one_batch: Callable[[_TrainingState, _TrainingData], + jax_types.TrainingStepOutput[_TrainingState]], + num_batches: int, + postprocess_aux: Optional[Callable[[jax_types.TrainingMetrics], + jax_types.TrainingMetrics]] = None +) -> Callable[[_TrainingState, _TrainingData], + jax_types.TrainingStepOutput[_TrainingState]]: + """The version of 'process_multiple_batches' with stronger typing.""" + + def _process_one_batch( + state: _TrainingState, + data: _TrainingData) -> Tuple[_TrainingState, jax_types.TrainingMetrics]: + result = process_one_batch(state, data) + return result.state, result.metrics + + func = process_multiple_batches(_process_one_batch, num_batches, + postprocess_aux) + + def _process_many_batches( + state: _TrainingState, + data: _TrainingData) -> jax_types.TrainingStepOutput[_TrainingState]: + state, aux = func(state, data) + return jax_types.TrainingStepOutput(state, aux) + + return _process_many_batches + + +def weighted_softmax(x: jnp.ndarray, weights: jnp.ndarray, axis: int = 0): + x = x - jnp.max(x, axis=axis) + return weights * jnp.exp(x) / jnp.sum(weights * jnp.exp(x), + axis=axis, keepdims=True) + + +def sample_uint32(random_key: jax_types.PRNGKey) -> int: + """Returns an integer uniformly distributed in 0..2^32-1.""" + iinfo = jnp.iinfo(jnp.int32) + # randint only accepts int32 values as min and max. + jax_random = jax.random.randint( + random_key, shape=(), minval=iinfo.min, maxval=iinfo.max, dtype=jnp.int32) + return np.uint32(jax_random).item() + + +class PrefetchIterator(core.PrefetchingIterator): + """Performs prefetching from an iterable in separate threads. + + Its interface is additionally extended with `ready` method which tells whether + there is any data waiting for processing and a `retrieved_elements` method + specifying number of elements retrieved from the iterator. + + Yields: + Prefetched elements from the original iterable. + + Raises: + ValueError: if the buffer_size < 1. + StopIteration: If the iterable contains no more items. + Any error thrown by the iterable_function. Note this is not raised inside + the producer, but after it finishes executing. + """ + + def __init__( + self, + iterable: Iterable[types.NestedArray], + buffer_size: int = 5, + device: Optional[jax.xla.Device] = None, + num_threads: int = NUM_PREFETCH_THREADS, + ): + """Constructs PrefetchIterator. + + Args: + iterable: A python iterable. This is used to build the python prefetcher. + Note that each iterable should only be passed to this function once as + iterables aren't thread safe. + buffer_size (int): Number of elements to keep in the prefetch buffer. + device (deprecated): Optionally place items from the iterable on the given + device. If None, the items are returns as given by the iterable. This + argument is deprecated and the recommended usage is to wrap the + iterables using utils.device_put or utils.multi_device_put before using + utils.prefetch. + num_threads (int): Number of threads. + """ + + if buffer_size < 1: + raise ValueError('the buffer_size should be >= 1') + self.buffer = queue.Queue(maxsize=buffer_size) + self.producer_error = [] + self.end = object() + self.iterable = iterable + self.device = device + self.count = 0 + + # Start producer threads. + for _ in range(num_threads): + threading.Thread(target=self.producer, daemon=True).start() + + def producer(self): + """Enqueues items from `iterable` on a given thread.""" + try: + # Build a new iterable for each thread. This is crucial if working with + # tensorflow datasets because tf.Graph objects are thread local. + for item in self.iterable: + if self.device: + jax.device_put(item, self.device) + self.buffer.put(item) + except Exception as e: # pylint: disable=broad-except + logging.exception('Error in producer thread for %s', self.iterable) + self.producer_error.append(e) + finally: + self.buffer.put(self.end) + + def __iter__(self): + return self + + def ready(self): + return not self.buffer.empty() + + def retrieved_elements(self): + return self.count + + def __next__(self): + value = self.buffer.get() + if value is self.end: + if self.producer_error: + raise self.producer_error[0] from self.producer_error[0] + raise StopIteration + self.count += 1 + return value diff --git a/acme/acme/jax/utils_test.py b/acme/acme/jax/utils_test.py new file mode 100644 index 00000000..775bfaed --- /dev/null +++ b/acme/acme/jax/utils_test.py @@ -0,0 +1,87 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for utils.""" + +from acme.jax import utils +import chex +import jax +import jax.numpy as jnp +import numpy as np + +from absl.testing import absltest + +chex.set_n_cpu_devices(4) + + +class JaxUtilsTest(absltest.TestCase): + + def test_batch_concat(self): + batch_size = 32 + inputs = [ + jnp.zeros(shape=(batch_size, 2)), + { + 'foo': jnp.zeros(shape=(batch_size, 5, 3)) + }, + [jnp.zeros(shape=(batch_size, 1))], + jnp.zeros(shape=(batch_size,)), + ] + + output_shape = utils.batch_concat(inputs).shape + expected_shape = [batch_size, 2 + 5 * 3 + 1 + 1] + self.assertSequenceEqual(output_shape, expected_shape) + + def test_mapreduce(self): + + @utils.mapreduce + def f(y, x): + return jnp.square(x + y) + + z = f(jnp.ones(shape=(32,)), jnp.ones(shape=(32,))) + z = jax.device_get(z) + self.assertEqual(z, 4) + + def test_get_from_first_device(self): + sharded = { + 'a': + jax.device_put_sharded( + list(jnp.arange(16).reshape([jax.local_device_count(), 4])), + jax.local_devices()), + 'b': + jax.device_put_sharded( + list(jnp.arange(8).reshape([jax.local_device_count(), 2])), + jax.local_devices(), + ), + } + + want = { + 'a': jnp.arange(4), + 'b': jnp.arange(2), + } + + # Get zeroth device content as DeviceArray. + device_arrays = utils.get_from_first_device(sharded, as_numpy=False) + jax.tree_map( + lambda x: self.assertIsInstance(x, jax.xla.DeviceArray), + device_arrays) + jax.tree_map(np.testing.assert_array_equal, want, device_arrays) + + # Get the zeroth device content as numpy arrays. + numpy_arrays = utils.get_from_first_device(sharded, as_numpy=True) + jax.tree_map(lambda x: self.assertIsInstance(x, np.ndarray), numpy_arrays) + jax.tree_map(np.testing.assert_array_equal, want, numpy_arrays) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/jax/variable_utils.py b/acme/acme/jax/variable_utils.py new file mode 100644 index 00000000..46cbd51b --- /dev/null +++ b/acme/acme/jax/variable_utils.py @@ -0,0 +1,152 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Variable utilities for JAX.""" + +from concurrent import futures +import datetime +import time +from typing import List, NamedTuple, Optional, Sequence, Union + +from acme import core +from acme.jax import networks as network_types +import jax + + +class VariableReference(NamedTuple): + variable_name: str + + +class ReferenceVariableSource(core.VariableSource): + """Variable source which returns references instead of values. + + This is passed to each actor when using a centralized inference server. The + actor uses this special variable source to get references rather than values. + These references are then passed to calls to the inference server, which will + dereference them to obtain the value of the corresponding variables at + inference time. This avoids passing around copies of variables from each + actor to the inference server. + """ + + def get_variables(self, names: Sequence[str]) -> List[VariableReference]: + return [VariableReference(name) for name in names] + + +class VariableClient: + """A variable client for updating variables from a remote source.""" + + def __init__(self, + client: core.VariableSource, + key: Union[str, Sequence[str]], + update_period: Union[int, datetime.timedelta] = 1, + device: Optional[Union[str, jax.xla.Device]] = None): + """Initializes the variable client. + + Args: + client: A variable source from which we fetch variables. + key: Which variables to request. When multiple keys are used, params + property will return a list of params. + update_period: Interval between fetches, specified as either (int) a + number of calls to update() between actual fetches or (timedelta) a time + interval that has to pass since the last fetch. + device: The name of a JAX device to put variables on. If None (default), + VariableClient won't put params on any device. + """ + self._update_period = update_period + self._call_counter = 0 + self._last_call = time.time() + self._client = client + self._params: Sequence[network_types.Params] = None + + self._device = device + if isinstance(self._device, str): + self._device = jax.devices(device)[0] + + self._executor = futures.ThreadPoolExecutor(max_workers=1) + + if isinstance(key, str): + key = [key] + + self._key = key + self._request = lambda k=key: client.get_variables(k) + self._future: Optional[futures.Future] = None # pylint: disable=g-bare-generic + self._async_request = lambda: self._executor.submit(self._request) + + def update(self, wait: bool = False) -> None: + """Periodically updates the variables with the latest copy from the source. + + If wait is True, a blocking request is executed. Any active request will be + cancelled. + If wait is False, this method makes an asynchronous request for variables. + + Args: + wait: Whether to execute asynchronous (False) or blocking updates (True). + Defaults to False. + """ + # Track calls (we only update periodically). + self._call_counter += 1 + + # Return if it's not time to fetch another update. + if isinstance(self._update_period, datetime.timedelta): + if self._update_period.total_seconds() + self._last_call > time.time(): + return + else: + if self._call_counter < self._update_period: + return + + if wait: + if self._future is not None: + if self._future.running(): + self._future.cancel() + self._future = None + self._call_counter = 0 + self._last_call = time.time() + self.update_and_wait() + return + + # Return early if we are still waiting for a previous request to come back. + if self._future and not self._future.done(): + return + + # Get a future and add the copy function as a callback. + self._call_counter = 0 + self._last_call = time.time() + self._future = self._async_request() + self._future.add_done_callback(lambda f: self._callback(f.result())) + + def update_and_wait(self): + """Immediately update and block until we get the result.""" + self._callback(self._request()) + + def _callback(self, params_list: List[network_types.Params]): + if self._device and not isinstance(self._client, ReferenceVariableSource): + # Move variables to a proper device. + self._params = jax.device_put(params_list, self._device) + else: + self._params = params_list + + @property + def device(self) -> Optional[jax.xla.Device]: + return self._device + + @property + def params(self) -> Union[network_types.Params, List[network_types.Params]]: + """Returns the first params for one key, otherwise the whole params list.""" + if self._params is None: + self.update_and_wait() + + if len(self._params) == 1: + return self._params[0] + else: + return self._params diff --git a/acme/acme/jax/variable_utils_test.py b/acme/acme/jax/variable_utils_test.py new file mode 100644 index 00000000..826807ec --- /dev/null +++ b/acme/acme/jax/variable_utils_test.py @@ -0,0 +1,63 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for variable utilities.""" + +from acme.jax import variable_utils +from acme.testing import fakes +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +import tree + +from absl.testing import absltest + + +def dummy_network(x): + return hk.nets.MLP([50, 10])(x) + + +class VariableClientTest(absltest.TestCase): + + def test_update(self): + init_fn, _ = hk.without_apply_rng( + hk.transform(dummy_network)) + params = init_fn(jax.random.PRNGKey(1), jnp.zeros(shape=(1, 32))) + variable_source = fakes.VariableSource(params) + variable_client = variable_utils.VariableClient( + variable_source, key='policy') + variable_client.update_and_wait() + tree.map_structure(np.testing.assert_array_equal, variable_client.params, + params) + + def test_multiple_keys(self): + init_fn, _ = hk.without_apply_rng( + hk.transform(dummy_network)) + params = init_fn(jax.random.PRNGKey(1), jnp.zeros(shape=(1, 32))) + steps = jnp.zeros(shape=1) + variables = {'network': params, 'steps': steps} + variable_source = fakes.VariableSource(variables, use_default_key=False) + variable_client = variable_utils.VariableClient( + variable_source, key=['network', 'steps']) + variable_client.update_and_wait() + + tree.map_structure(np.testing.assert_array_equal, variable_client.params[0], + params) + tree.map_structure(np.testing.assert_array_equal, variable_client.params[1], + steps) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/multiagent/__init__.py b/acme/acme/multiagent/__init__.py new file mode 100644 index 00000000..9c16d296 --- /dev/null +++ b/acme/acme/multiagent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multiagent helpers.""" diff --git a/acme/acme/multiagent/types.py b/acme/acme/multiagent/types.py new file mode 100644 index 00000000..c33251c4 --- /dev/null +++ b/acme/acme/multiagent/types.py @@ -0,0 +1,51 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Types for multiagent setups.""" + +from typing import Any, Callable, Dict, Tuple + +from acme import specs +from acme.agents.jax import builders as jax_builders +from acme.utils.loggers import base +import reverb + + +# Sub-agent types +AgentID = str +EvalMode = bool +GenericAgent = Any +AgentConfig = Any +Networks = Any +PolicyNetwork = Any +LoggerFn = Callable[[], base.Logger] +InitNetworkFn = Callable[[GenericAgent, specs.EnvironmentSpec], Networks] +InitPolicyNetworkFn = Callable[ + [GenericAgent, Networks, specs.EnvironmentSpec, AgentConfig, bool], + Networks] +InitBuilderFn = Callable[[GenericAgent, AgentConfig], + jax_builders.GenericActorLearnerBuilder] + +# Multiagent types +MultiAgentLoggerFn = Dict[AgentID, LoggerFn] +MultiAgentNetworks = Dict[AgentID, Networks] +MultiAgentPolicyNetworks = Dict[AgentID, PolicyNetwork] +MultiAgentSample = Tuple[reverb.ReplaySample, ...] +NetworkFactory = Callable[[specs.EnvironmentSpec], MultiAgentNetworks] +PolicyFactory = Callable[[MultiAgentNetworks, EvalMode], + MultiAgentPolicyNetworks] +BuilderFactory = Callable[[ + Dict[AgentID, GenericAgent], + Dict[AgentID, AgentConfig], +], Dict[AgentID, jax_builders.GenericActorLearnerBuilder]] diff --git a/acme/acme/multiagent/utils.py b/acme/acme/multiagent/utils.py new file mode 100644 index 00000000..4c91d90b --- /dev/null +++ b/acme/acme/multiagent/utils.py @@ -0,0 +1,48 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multiagent utilities.""" + +from acme import specs +from acme.multiagent import types +import dm_env + + +def get_agent_spec(env_spec: specs.EnvironmentSpec, + agent_id: types.AgentID) -> specs.EnvironmentSpec: + """Returns a single agent spec from environment spec. + + Args: + env_spec: environment spec, wherein observation, action, and reward specs + are simply lists (with each entry specifying the respective spec for the + given agent index). Discounts are scalars shared amongst agents. + agent_id: agent index. + """ + return specs.EnvironmentSpec( + actions=env_spec.actions[agent_id], + discounts=env_spec.discounts, + observations=env_spec.observations[agent_id], + rewards=env_spec.rewards[agent_id]) + + +def get_agent_timestep(timestep: dm_env.TimeStep, + agent_id: types.AgentID) -> dm_env.TimeStep: + """Returns the extracted timestep for a particular agent.""" + # Discounts are assumed to be shared amongst agents + reward = None if timestep.reward is None else timestep.reward[agent_id] + return dm_env.TimeStep( + observation=timestep.observation[agent_id], + reward=reward, + discount=timestep.discount, + step_type=timestep.step_type) diff --git a/acme/acme/multiagent/utils_test.py b/acme/acme/multiagent/utils_test.py new file mode 100644 index 00000000..7c325fb2 --- /dev/null +++ b/acme/acme/multiagent/utils_test.py @@ -0,0 +1,59 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for multiagent_utils.""" + +from acme import specs +from acme.multiagent import utils as multiagent_utils +from acme.testing import fakes +from acme.testing import multiagent_fakes +import dm_env +from absl.testing import absltest + + +class UtilsTest(absltest.TestCase): + + def test_get_agent_spec(self): + agent_indices = ['a', '99', 'Z'] + spec = multiagent_fakes.make_multiagent_environment_spec(agent_indices) + for agent_id in spec.actions.keys(): + single_agent_spec = multiagent_utils.get_agent_spec( + spec, agent_id=agent_id) + expected_spec = specs.EnvironmentSpec( + actions=spec.actions[agent_id], + discounts=spec.discounts, + observations=spec.observations[agent_id], + rewards=spec.rewards[agent_id] + ) + self.assertEqual(single_agent_spec, expected_spec) + + def test_get_agent_timestep(self): + agent_indices = ['a', '99', 'Z'] + spec = multiagent_fakes.make_multiagent_environment_spec(agent_indices) + env = fakes.Environment(spec) + timestep = env.reset() + for agent_id in spec.actions.keys(): + single_agent_timestep = multiagent_utils.get_agent_timestep( + timestep, agent_id) + expected_timestep = dm_env.TimeStep( + observation=timestep.observation[agent_id], + reward=None, + discount=None, + step_type=timestep.step_type + ) + self.assertEqual(single_agent_timestep, expected_timestep) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/specs.py b/acme/acme/specs.py new file mode 100644 index 00000000..1d568436 --- /dev/null +++ b/acme/acme/specs.py @@ -0,0 +1,48 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Objects which specify the input/output spaces of an environment. + +This module exposes the same spec classes as `dm_env` as well as providing an +additional `EnvironmentSpec` class which collects all of the specs for a given +environment. An `EnvironmentSpec` instance can be created directly or by using +the `make_environment_spec` helper given a `dm_env.Environment` instance. +""" + +from typing import Any, NamedTuple + +import dm_env +from dm_env import specs + +Array = specs.Array +BoundedArray = specs.BoundedArray +DiscreteArray = specs.DiscreteArray + + +class EnvironmentSpec(NamedTuple): + """Full specification of the domains used by a given environment.""" + # TODO(b/144758674): Use NestedSpec type here. + observations: Any + actions: Any + rewards: Any + discounts: Any + + +def make_environment_spec(environment: dm_env.Environment) -> EnvironmentSpec: + """Returns an `EnvironmentSpec` describing values used by an environment.""" + return EnvironmentSpec( + observations=environment.observation_spec(), + actions=environment.action_spec(), + rewards=environment.reward_spec(), + discounts=environment.discount_spec()) diff --git a/acme/acme/testing/__init__.py b/acme/acme/testing/__init__.py new file mode 100644 index 00000000..d79d3bb0 --- /dev/null +++ b/acme/acme/testing/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Testing helpers.""" diff --git a/acme/acme/testing/fakes.py b/acme/acme/testing/fakes.py new file mode 100644 index 00000000..41fcb3fc --- /dev/null +++ b/acme/acme/testing/fakes.py @@ -0,0 +1,474 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fake (mock) components. + +Minimal implementations of fake Acme components which can be instantiated in +order to test or interact with other components. +""" + +import threading +from typing import List, Mapping, Optional, Sequence, Callable, Iterator + +from acme import core +from acme import specs +from acme import types +from acme import wrappers +import dm_env +import numpy as np +import reverb +from rlds import rlds_types +import tensorflow as tf +import tree + + +class Actor(core.Actor): + """Fake actor which generates random actions and validates specs.""" + + def __init__(self, spec: specs.EnvironmentSpec): + self._spec = spec + self.num_updates = 0 + + def select_action(self, observation: types.NestedArray) -> types.NestedArray: + _validate_spec(self._spec.observations, observation) + return _generate_from_spec(self._spec.actions) + + def observe_first(self, timestep: dm_env.TimeStep): + _validate_spec(self._spec.observations, timestep.observation) + + def observe( + self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + ): + _validate_spec(self._spec.actions, action) + _validate_spec(self._spec.rewards, next_timestep.reward) + _validate_spec(self._spec.discounts, next_timestep.discount) + _validate_spec(self._spec.observations, next_timestep.observation) + + def update(self, wait: bool = False): + self.num_updates += 1 + + +class VariableSource(core.VariableSource): + """Fake variable source.""" + + def __init__(self, + variables: Optional[types.NestedArray] = None, + barrier: Optional[threading.Barrier] = None, + use_default_key: bool = True): + # Add dummy variables so we can expose them in get_variables. + if use_default_key: + self._variables = {'policy': [] if variables is None else variables} + else: + self._variables = variables + self._barrier = barrier + + def get_variables(self, names: Sequence[str]) -> List[types.NestedArray]: + if self._barrier is not None: + self._barrier.wait() + return [self._variables[name] for name in names] + + +class Learner(core.Learner, VariableSource): + """Fake Learner.""" + + def __init__(self, + variables: Optional[types.NestedArray] = None, + barrier: Optional[threading.Barrier] = None): + super().__init__(variables=variables, barrier=barrier) + self.step_counter = 0 + + def step(self): + self.step_counter += 1 + + +class Environment(dm_env.Environment): + """A fake environment with a given spec.""" + + def __init__( + self, + spec: specs.EnvironmentSpec, + *, + episode_length: int = 25, + ): + # Assert that the discount spec is a BoundedArray with range [0, 1]. + def check_discount_spec(path, discount_spec): + if (not isinstance(discount_spec, specs.BoundedArray) or + not np.isclose(discount_spec.minimum, 0) or + not np.isclose(discount_spec.maximum, 1)): + if path: + path_str = ' ' + '/'.join(str(p) for p in path) + else: + path_str = '' + raise ValueError( + 'discount_spec {}isn\'t a BoundedArray in [0, 1].'.format(path_str)) + + tree.map_structure_with_path(check_discount_spec, spec.discounts) + + self._spec = spec + self._episode_length = episode_length + self._step = 0 + + def _generate_fake_observation(self): + return _generate_from_spec(self._spec.observations) + + def _generate_fake_reward(self): + return _generate_from_spec(self._spec.rewards) + + def _generate_fake_discount(self): + return _generate_from_spec(self._spec.discounts) + + def reset(self) -> dm_env.TimeStep: + observation = self._generate_fake_observation() + self._step = 1 + return dm_env.restart(observation) + + def step(self, action) -> dm_env.TimeStep: + # Return a reset timestep if we haven't touched the environment yet. + if not self._step: + return self.reset() + + _validate_spec(self._spec.actions, action) + + observation = self._generate_fake_observation() + reward = self._generate_fake_reward() + discount = self._generate_fake_discount() + + if self._episode_length and (self._step == self._episode_length): + self._step = 0 + # We can't use dm_env.termination directly because then the discount + # wouldn't necessarily conform to the spec (if eg. we want float32). + return dm_env.TimeStep(dm_env.StepType.LAST, reward, discount, + observation) + else: + self._step += 1 + return dm_env.transition( + reward=reward, observation=observation, discount=discount) + + def action_spec(self): + return self._spec.actions + + def observation_spec(self): + return self._spec.observations + + def reward_spec(self): + return self._spec.rewards + + def discount_spec(self): + return self._spec.discounts + + +class _BaseDiscreteEnvironment(Environment): + """Discrete action fake environment.""" + + def __init__(self, + *, + num_actions: int = 1, + action_dtype=np.int32, + observation_spec: types.NestedSpec, + discount_spec: Optional[types.NestedSpec] = None, + reward_spec: Optional[types.NestedSpec] = None, + **kwargs): + """Initialize the environment.""" + if reward_spec is None: + reward_spec = specs.Array((), np.float32) + + if discount_spec is None: + discount_spec = specs.BoundedArray((), np.float32, 0.0, 1.0) + + actions = specs.DiscreteArray(num_actions, dtype=action_dtype) + + super().__init__( + spec=specs.EnvironmentSpec( + observations=observation_spec, + actions=actions, + rewards=reward_spec, + discounts=discount_spec), + **kwargs) + + +class DiscreteEnvironment(_BaseDiscreteEnvironment): + """Discrete state and action fake environment.""" + + def __init__(self, + *, + num_actions: int = 1, + num_observations: int = 1, + action_dtype=np.int32, + obs_dtype=np.int32, + obs_shape: Sequence[int] = (), + discount_spec: Optional[types.NestedSpec] = None, + reward_spec: Optional[types.NestedSpec] = None, + **kwargs): + """Initialize the environment.""" + observations_spec = specs.BoundedArray( + shape=obs_shape, + dtype=obs_dtype, + minimum=obs_dtype(0), + maximum=obs_dtype(num_observations - 1)) + + super().__init__( + num_actions=num_actions, + action_dtype=action_dtype, + observation_spec=observations_spec, + discount_spec=discount_spec, + reward_spec=reward_spec, + **kwargs) + + +class NestedDiscreteEnvironment(_BaseDiscreteEnvironment): + """Discrete action fake environment with nested discrete state.""" + + def __init__(self, + *, + num_observations: Mapping[str, int], + num_actions: int = 1, + action_dtype=np.int32, + obs_dtype=np.int32, + obs_shape: Sequence[int] = (), + discount_spec: Optional[types.NestedSpec] = None, + reward_spec: Optional[types.NestedSpec] = None, + **kwargs): + """Initialize the environment.""" + + observations_spec = {} + for key in num_observations: + observations_spec[key] = specs.BoundedArray( + shape=obs_shape, + dtype=obs_dtype, + minimum=obs_dtype(0), + maximum=obs_dtype(num_observations[key] - 1)) + + super().__init__( + num_actions=num_actions, + action_dtype=action_dtype, + observation_spec=observations_spec, + discount_spec=discount_spec, + reward_spec=reward_spec, + **kwargs) + + +class ContinuousEnvironment(Environment): + """Continuous state and action fake environment.""" + + def __init__(self, + *, + action_dim: int = 1, + observation_dim: int = 1, + bounded: bool = False, + dtype=np.float32, + reward_dtype=np.float32, + **kwargs): + """Initialize the environment. + + Args: + action_dim: number of action dimensions. + observation_dim: number of observation dimensions. + bounded: whether or not the actions are bounded in [-1, 1]. + dtype: dtype of the action and observation spaces. + reward_dtype: dtype of the reward and discounts. + **kwargs: additional kwargs passed to the Environment base class. + """ + + action_shape = () if action_dim == 0 else (action_dim,) + observation_shape = () if observation_dim == 0 else (observation_dim,) + + observations = specs.Array(observation_shape, dtype) + rewards = specs.Array((), reward_dtype) + discounts = specs.BoundedArray((), reward_dtype, 0.0, 1.0) + + if bounded: + actions = specs.BoundedArray(action_shape, dtype, -1.0, 1.0) + else: + actions = specs.Array(action_shape, dtype) + + super().__init__( + spec=specs.EnvironmentSpec( + observations=observations, + actions=actions, + rewards=rewards, + discounts=discounts), + **kwargs) + + +def _validate_spec(spec: types.NestedSpec, value: types.NestedArray): + """Validate a value from a potentially nested spec.""" + tree.assert_same_structure(value, spec) + tree.map_structure(lambda s, v: s.validate(v), spec, value) + + +def _normalize_array(array: specs.Array) -> specs.Array: + """Converts bounded arrays with (-inf,+inf) bounds to unbounded arrays. + + The returned array should be mostly equivalent to the input, except that + `generate_value()` returns -infs on arrays bounded to (-inf,+inf) and zeros + on unbounded arrays. + + Args: + array: the array to be normalized. + + Returns: + normalized array. + """ + if isinstance(array, specs.DiscreteArray): + return array + if not isinstance(array, specs.BoundedArray): + return array + if not (array.minimum == float('-inf')).all(): + return array + if not (array.maximum == float('+inf')).all(): + return array + return specs.Array(array.shape, array.dtype, array.name) + + +def _generate_from_spec(spec: types.NestedSpec) -> types.NestedArray: + """Generate a value from a potentially nested spec.""" + return tree.map_structure(lambda s: _normalize_array(s).generate_value(), + spec) + + +def transition_dataset_from_spec( + spec: specs.EnvironmentSpec) -> tf.data.Dataset: + """Constructs fake dataset of Reverb N-step transition samples. + + Args: + spec: Constructed fake transitions match the provided specification. + + Returns: + tf.data.Dataset that produces the same fake N-step transition ReverbSample + object indefinitely. + """ + + observation = _generate_from_spec(spec.observations) + action = _generate_from_spec(spec.actions) + reward = _generate_from_spec(spec.rewards) + discount = _generate_from_spec(spec.discounts) + data = types.Transition(observation, action, reward, discount, observation) + + info = tree.map_structure( + lambda tf_dtype: tf.ones([], tf_dtype.as_numpy_dtype), + reverb.SampleInfo.tf_dtypes()) + sample = reverb.ReplaySample(info=info, data=data) + + return tf.data.Dataset.from_tensors(sample).repeat() + + +def transition_dataset(environment: dm_env.Environment) -> tf.data.Dataset: + """Constructs fake dataset of Reverb N-step transition samples. + + Args: + environment: Constructed fake transitions will match the specification of + this environment. + + Returns: + tf.data.Dataset that produces the same fake N-step transition ReverbSample + object indefinitely. + """ + return transition_dataset_from_spec(specs.make_environment_spec(environment)) + + +def transition_iterator_from_spec( + spec: specs.EnvironmentSpec) -> Callable[[int], Iterator[types.Transition]]: + """Constructs fake iterator of transitions. + + Args: + spec: Constructed fake transitions match the provided specification.. + + Returns: + A callable that given a batch_size returns an iterator of transitions. + """ + + observation = _generate_from_spec(spec.observations) + action = _generate_from_spec(spec.actions) + reward = _generate_from_spec(spec.rewards) + discount = _generate_from_spec(spec.discounts) + data = types.Transition(observation, action, reward, discount, observation) + + dataset = tf.data.Dataset.from_tensors(data).repeat() + + return lambda batch_size: dataset.batch(batch_size).as_numpy_iterator() + + +def transition_iterator( + environment: dm_env.Environment +) -> Callable[[int], Iterator[types.Transition]]: + """Constructs fake iterator of transitions. + + Args: + environment: Constructed fake transitions will match the specification of + this environment. + + Returns: + A callable that given a batch_size returns an iterator of transitions. + """ + return transition_iterator_from_spec(specs.make_environment_spec(environment)) + + +def fake_atari_wrapped(episode_length: int = 10, + oar_wrapper: bool = False) -> dm_env.Environment: + """Builds fake version of the environment to be used by tests. + + Args: + episode_length: The length of episodes produced by this environment. + oar_wrapper: Should ObservationActionRewardWrapper be applied. + + Returns: + Fake version of the environment equivalent to the one returned by + env_loader.load_atari_wrapped + """ + env = DiscreteEnvironment( + num_actions=18, + num_observations=2, + obs_shape=(84, 84, 4), + obs_dtype=np.float32, + episode_length=episode_length) + + if oar_wrapper: + env = wrappers.ObservationActionRewardWrapper(env) + return env + + +def rlds_dataset_from_env_spec( + spec: specs.EnvironmentSpec, + *, + episode_count: int = 10, + episode_length: int = 25, +) -> tf.data.Dataset: + """Constructs a fake RLDS dataset with the given spec. + + Args: + spec: specification to use for generation of fake steps. + episode_count: number of episodes in the dataset. + episode_length: length of the episode in the dataset. + + Returns: + a fake RLDS dataset. + """ + + fake_steps = { + rlds_types.OBSERVATION: + ([_generate_from_spec(spec.observations)] * episode_length), + rlds_types.ACTION: ([_generate_from_spec(spec.actions)] * episode_length), + rlds_types.REWARD: ([_generate_from_spec(spec.rewards)] * episode_length), + rlds_types.DISCOUNT: + ([_generate_from_spec(spec.discounts)] * episode_length), + rlds_types.IS_TERMINAL: [False] * (episode_length - 1) + [True], + rlds_types.IS_FIRST: [True] + [False] * (episode_length - 1), + rlds_types.IS_LAST: [False] * (episode_length - 1) + [True], + } + steps_dataset = tf.data.Dataset.from_tensor_slices(fake_steps) + + return tf.data.Dataset.from_tensor_slices( + {rlds_types.STEPS: [steps_dataset] * episode_count}) diff --git a/acme/acme/testing/multiagent_fakes.py b/acme/acme/testing/multiagent_fakes.py new file mode 100644 index 00000000..0bfe11b6 --- /dev/null +++ b/acme/acme/testing/multiagent_fakes.py @@ -0,0 +1,50 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fake (mock) components for multiagent testing.""" + +from typing import Dict, List + +from acme import specs +import numpy as np + + +def _make_multiagent_spec(agent_indices: List[str]) -> Dict[str, specs.Array]: + """Returns dummy multiagent sub-spec (e.g., observation or action spec). + + Args: + agent_indices: a list of agent indices. + """ + return { + agent_id: specs.BoundedArray((1,), np.float32, 0, 1) + for agent_id in agent_indices + } + + +def make_multiagent_environment_spec( + agent_indices: List[str]) -> specs.EnvironmentSpec: + """Returns dummy multiagent environment spec. + + Args: + agent_indices: a list of agent indices. + """ + action_spec = _make_multiagent_spec(agent_indices) + discount_spec = specs.BoundedArray((), np.float32, 0.0, 1.0) + observation_spec = _make_multiagent_spec(agent_indices) + reward_spec = _make_multiagent_spec(agent_indices) + return specs.EnvironmentSpec( + actions=action_spec, + discounts=discount_spec, + observations=observation_spec, + rewards=reward_spec) diff --git a/acme/acme/testing/test_utils.py b/acme/acme/testing/test_utils.py new file mode 100644 index 00000000..576c9b0c --- /dev/null +++ b/acme/acme/testing/test_utils.py @@ -0,0 +1,33 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Testing utilities.""" + +import sys +from typing import Optional + +from absl import flags +from absl.testing import parameterized + + +class TestCase(parameterized.TestCase): + """A custom TestCase which handles FLAG parsing for pytest compatibility.""" + + def get_tempdir(self, name: Optional[str] = None) -> str: + try: + flags.FLAGS.test_tmpdir + except flags.UnparsedFlagAccessError: + # Need to initialize flags when running `pytest`. + flags.FLAGS(sys.argv, known_only=True) + return self.create_tempdir(name).full_path diff --git a/acme/acme/tf/__init__.py b/acme/acme/tf/__init__.py new file mode 100644 index 00000000..240cb715 --- /dev/null +++ b/acme/acme/tf/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/acme/acme/tf/losses/__init__.py b/acme/acme/tf/losses/__init__.py new file mode 100644 index 00000000..70d51bf6 --- /dev/null +++ b/acme/acme/tf/losses/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Various losses for training agent components (policies, critics, etc).""" + +from acme.tf.losses.distributional import categorical +from acme.tf.losses.distributional import multiaxis_categorical +from acme.tf.losses.dpg import dpg +from acme.tf.losses.huber import huber +from acme.tf.losses.mompo import KLConstraint +from acme.tf.losses.mompo import MultiObjectiveMPO +from acme.tf.losses.mpo import MPO +from acme.tf.losses.r2d2 import transformed_n_step_loss + +# Internal imports. +# pylint: disable=g-bad-import-order,g-import-not-at-top +from acme.tf.losses.quantile import NonUniformQuantileRegression +from acme.tf.losses.quantile import QuantileDistribution diff --git a/acme/acme/tf/losses/distributional.py b/acme/acme/tf/losses/distributional.py new file mode 100644 index 00000000..54c0560c --- /dev/null +++ b/acme/acme/tf/losses/distributional.py @@ -0,0 +1,236 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Losses and projection operators relevant to distributional RL.""" + +from acme.tf import networks +import tensorflow as tf + + +def categorical(q_tm1: networks.DiscreteValuedDistribution, r_t: tf.Tensor, + d_t: tf.Tensor, + q_t: networks.DiscreteValuedDistribution) -> tf.Tensor: + """Implements the Categorical Distributional TD(0)-learning loss.""" + + z_t = tf.reshape(r_t, (-1, 1)) + tf.reshape(d_t, (-1, 1)) * q_t.values + p_t = tf.nn.softmax(q_t.logits) + + # Performs L2 projection. + target = tf.stop_gradient(l2_project(z_t, p_t, q_t.values)) + + # Calculates loss. + loss = tf.nn.softmax_cross_entropy_with_logits( + logits=q_tm1.logits, labels=target) + + return loss + + +# Use an old version of the l2 projection which is probably slower on CPU +# but will run on GPUs. +def l2_project( # pylint: disable=invalid-name + Zp: tf.Tensor, + P: tf.Tensor, + Zq: tf.Tensor, +) -> tf.Tensor: + """Project distribution (Zp, P) onto support Zq under the L2-metric over CDFs. + + This projection works for any support Zq. + Let Kq be len(Zq) and Kp be len(Zp). + + Args: + Zp: (batch_size, Kp) Support of distribution P + P: (batch_size, Kp) Probability values for P(Zp[i]) + Zq: (Kp,) Support to project onto + + Returns: + L2 projection of (Zp, P) onto Zq. + """ + + # Asserts that Zq has no leading dimension of size 1. + if Zq.get_shape().ndims > 1: + Zq = tf.squeeze(Zq, axis=0) + + # Extracts vmin and vmax and construct helper tensors from Zq. + vmin, vmax = Zq[0], Zq[-1] + d_pos = tf.concat([Zq, vmin[None]], 0)[1:] + d_neg = tf.concat([vmax[None], Zq], 0)[:-1] + + # Clips Zp to be in new support range (vmin, vmax). + clipped_zp = tf.clip_by_value(Zp, vmin, vmax)[:, None, :] + clipped_zq = Zq[None, :, None] + + # Gets the distance between atom values in support. + d_pos = (d_pos - Zq)[None, :, None] # Zq[i+1] - Zq[i] + d_neg = (Zq - d_neg)[None, :, None] # Zq[i] - Zq[i-1] + + delta_qp = clipped_zp - clipped_zq # Zp[j] - Zq[i] + + d_sign = tf.cast(delta_qp >= 0., dtype=P.dtype) + delta_hat = (d_sign * delta_qp / d_pos) - ((1. - d_sign) * delta_qp / d_neg) + P = P[:, None, :] + return tf.reduce_sum(tf.clip_by_value(1. - delta_hat, 0., 1.) * P, 2) + + +def multiaxis_categorical( # pylint: disable=invalid-name + q_tm1: networks.DiscreteValuedDistribution, + r_t: tf.Tensor, + d_t: tf.Tensor, + q_t: networks.DiscreteValuedDistribution) -> tf.Tensor: + """Implements a multi-axis categorical distributional TD(0)-learning loss. + + All arguments may have a leading batch axis, but q_tm1.logits, and one of + r_t or d_t *must* have a leading batch axis. + + Args: + q_tm1: Previous timestep's value distribution. + r_t: Reward. + d_t: Discount. + q_t: Current timestep's value distribution. + + Returns: + Cross-entropy Bellman loss between q_tm1 and q_t + r_t * d_t. + Shape: (B, *E), where + B is the batch size. + E is the broadcasted shape of r_t, d_t, and q_t.values[:-1]. + """ + tf.assert_equal(tf.rank(r_t), tf.rank(d_t)) + + # Append a singleton axis corresponding to the axis that indexes the atoms in + # q_t.values. + r_t = r_t[..., None] # shape: (B, *R, 1) + d_t = d_t[..., None] # shape: (B, *D, 1) + + z_t = r_t + d_t * q_t.values # shape: (B, *E, N) + + p_t = tf.nn.softmax(q_t.logits) + + # Performs L2 projection. + target = tf.stop_gradient(multiaxis_l2_project(z_t, p_t, q_t.values)) + + # Calculates loss. + loss = tf.nn.softmax_cross_entropy_with_logits( + logits=q_tm1.logits, labels=target) + + return loss + + +# A modification of l2_project that allows multi-axis support arguments. +def multiaxis_l2_project( # pylint: disable=invalid-name + Zp: tf.Tensor, + P: tf.Tensor, + Zq: tf.Tensor, +) -> tf.Tensor: + """Project distribution (Zp, P) onto support Zq under the L2-metric over CDFs. + + Let source support Zp's shape be described as (B, *C, M), where: + B is the batch size. + C contains the sizes of any axes in between the first and last axes. + M is the number of atoms in the support. + + Let destination support Zq's shape be described as (*D, N), where: + D contains the sizes of any axes before the last axis. + N is the number of atoms in the support. + + Shapes C and D must have the same number of dimensions, and must be + broadcastable with each other. + + Args: + Zp: Support of source distribution. Shape: (B, *C, M). + P: Probability values of source distribution p(Zp[i]). Shape: (B, *C, M). + Zq: Support to project P onto. Shape: (*D, N). + + Returns: + The L2 projection of P from support Zp to support Zq. + Shape: (B, *E, N), where E is the broadcast-merged shape of C and D. + """ + + tf.assert_equal(tf.shape(Zp), tf.shape(P)) + + # Shapes C, D, and E as defined in the docstring above. + shape_c = tf.shape(Zp)[1:-1] # drop the batch and atom axes + shape_d = tf.shape(Zq)[:-1] # drop the atom axis + shape_e = tf.broadcast_dynamic_shape(shape_c, shape_d) + + # If Zq has fewer inner axes than the broadcasted output shape, insert some + # size-1 axes to broadcast. + ndim_c = tf.size(shape_c) + ndim_e = tf.size(shape_e) + Zp = tf.reshape( + Zp, + tf.concat([tf.shape(Zp)[:1], # B + tf.ones(tf.math.maximum(ndim_e - ndim_c, 0), dtype=tf.int32), + shape_c, # C + tf.shape(Zp)[-1:]], # M + axis=0)) + P = tf.reshape(P, tf.shape(Zp)) + + # Broadcast Zp, P, and Zq's common axes to the same shape: E. + # + # Normally it'd be sufficient to ensure that these args have the same number + # of axes, then let the arithmetic operators broadcast as necessary. Instead, + # we need to explicitly broadcast them here, because there's a call to + # tf.clip_by_value(t, vmin, vmax) below, which doesn't allow t's dimensions + # to be expanded to match vmin and vmax. + + # Shape: (B, *E, M) + Zp = tf.broadcast_to( + Zp, + tf.concat([tf.shape(Zp)[:1], # B + shape_e, # E + tf.shape(Zp)[-1:]], # M + axis=0)) + + # Shape: (B, *E, M) + P = tf.broadcast_to(P, tf.shape(Zp)) + + # Shape: (*E, N) + Zq = tf.broadcast_to(Zq, tf.concat([shape_e, tf.shape(Zq)[-1:]], axis=0)) + + # Extracts vmin and vmax and construct helper tensors from Zq. + # These have shape shape_q, except the last axis has size 1. + # Shape: (*E, 1) + vmin, vmax = Zq[..., :1], Zq[..., -1:] + + # The distances between neighboring atom values in the target support. + # Shape: (*E, N) + d_pos = tf.roll(Zq, shift=-1, axis=-1) - Zq # d_pos[i] := Zq[i+1] - Zq[i] + d_neg = Zq - tf.roll(Zq, shift=1, axis=-1) # d_neg[i] := Zq[i] - Zq[i-1] + + # Clips Zp to be in new support range (vmin, vmax). + # Shape: (B, *E, 1, M) + clipped_zp = tf.clip_by_value(Zp, vmin, vmax)[..., None, :] + + # Shape: (1, *E, N, 1) + clipped_zq = Zq[None, ..., :, None] + + # Shape: (B, *E, N, M) + delta_qp = clipped_zp - clipped_zq # Zp[j] - Zq[i] + + # Shape: (B, *E, N, M) + d_sign = tf.cast(delta_qp >= 0., dtype=P.dtype) + + # Insert singleton axes to d_pos and d_neg to maintain the same shape as + # clipped_zq. + # Shape: (1, *E, N, 1) + d_pos = d_pos[None, ..., :, None] + d_neg = d_neg[None, ..., :, None] + + # Shape: (B, *E, N, M) + delta_hat = (d_sign * delta_qp / d_pos) - ((1. - d_sign) * delta_qp / d_neg) + + # Shape: (B, *E, 1, M) + P = P[..., None, :] + + # Shape: (B, *E, N) + return tf.reduce_sum(tf.clip_by_value(1. - delta_hat, 0., 1.) * P, axis=-1) diff --git a/acme/acme/tf/losses/distributional_test.py b/acme/acme/tf/losses/distributional_test.py new file mode 100644 index 00000000..3a368c07 --- /dev/null +++ b/acme/acme/tf/losses/distributional_test.py @@ -0,0 +1,177 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for acme.tf.losses.distributional.""" + +from acme.tf.losses import distributional +import numpy as np +from numpy import testing as npt +import tensorflow as tf + +from absl.testing import absltest +from absl.testing import parameterized + + +def _reference_l2_project(src_support, src_probs, dst_support): + """Multi-axis l2_project, implemented using single-axis l2_project. + + This is for testing multiaxis_l2_project's consistency with l2_project, + when used with multi-axis support vs single-axis support. + + Args: + src_support: Zp in l2_project. + src_probs: P in l2_project. + dst_support: Zq in l2_project. + + Returns: + src_probs, projected onto dst_support. + """ + assert src_support.shape == src_probs.shape + + # Remove the batch and value axes, and broadcast the rest to a common shape. + common_shape = np.broadcast(src_support[0, ..., 0], + dst_support[..., 0]).shape + + # If src_* have fewer internal axes than len(common_shape), insert size-1 + # axes. + while src_support.ndim-2 < len(common_shape): + src_support = src_support[:, None, ...] + + src_probs = np.reshape(src_probs, src_support.shape) + + # Broadcast args' non-batch, non-value axes to common_shape. + src_support = np.broadcast_to( + src_support, + src_support.shape[:1] + common_shape + src_support.shape[-1:]) + src_probs = np.broadcast_to(src_probs, src_support.shape) + dst_support = np.broadcast_to( + dst_support, + common_shape + dst_support.shape[-1:]) + + output_shape = (src_support.shape[0],) + dst_support.shape + + # Collapse all but the first (batch) and last (atom) axes. + src_support = src_support.reshape( + [src_support.shape[0], -1, src_support.shape[-1]]) + src_probs = src_probs.reshape( + [src_probs.shape[0], -1, src_probs.shape[-1]]) + + # Collapse all but the last (atom) axes. + dst_support = dst_support.reshape([-1, dst_support.shape[-1]]) + + dst_probs = np.zeros(src_support.shape[:1] + dst_support.shape, + dtype=src_probs.dtype) + + # iterate over all supports + for i in range(src_support.shape[1]): + s_support = tf.convert_to_tensor(src_support[:, i, :]) + s_probs = tf.convert_to_tensor(src_probs[:, i, :]) + d_support = tf.convert_to_tensor(dst_support[i, :]) + d_probs = distributional.l2_project(s_support, s_probs, d_support) + dst_probs[:, i, :] = d_probs.numpy() + + return dst_probs.reshape(output_shape) + + +class L2ProjectTest(parameterized.TestCase): + + @parameterized.parameters( + [(2, 11), (11,)], # C = (), D = (), matching num_atoms (11 and 11) + [(2, 11), (5,)], # C = (), D = (), differing num_atoms (11 and 5). + [(2, 3, 11), (3, 5)], # C = (3,), D = (3,) + [(2, 1, 11), (3, 5)], # C = (1,), D = (3,) + [(2, 3, 11), (1, 5)], # (C = (3,), D = (1,) + [(2, 3, 4, 11), (3, 4, 5)], # C = (3, 4), D = (3, 4) + [(2, 3, 4, 11), (4, 5)], # C = (3, 4), D = (4,) + [(2, 4, 11), (3, 4, 5)], # C = (4,), D = (3, 4) + ) + def test_multiaxis(self, src_shape, dst_shape): + """Tests consistency between multi-axis and single-axis l2_project. + + This calls l2_project on multi-axis supports, and checks that it gets + the same outcomes as many calls to single-axis supports. + + Args: + src_shape: Shape of source support. Includes a leading batch axis. + dst_shape: Shape of destination support. + Does not include a leading batch axis. + """ + # src_shape includes a leading batch axis, whereas dst_shape does not. + # assert len(src_shape) >= (1 + len(dst_shape)) + + def make_support(shape, minimum): + """Creates a ndarray of supports.""" + values = np.linspace(start=minimum, stop=minimum+100, num=shape[-1]) + offsets = np.arange(np.prod(shape[:-1])) + result = values[None, :] + offsets[:, None] + return result.reshape(shape) + + src_support = make_support(src_shape, -1) + dst_support = make_support(dst_shape, -.75) + + rng = np.random.RandomState(1) + src_probs = rng.uniform(low=1.0, high=2.0, size=src_shape) + src_probs /= src_probs.sum() + + # Repeated calls to l2_project using single-axis supports. + expected_dst_probs = _reference_l2_project(src_support, + src_probs, + dst_support) + + # A single call to l2_project, with multi-axis supports. + dst_probs = distributional.multiaxis_l2_project( + tf.convert_to_tensor(src_support), + tf.convert_to_tensor(src_probs), + tf.convert_to_tensor(dst_support)).numpy() + + npt.assert_allclose(dst_probs, expected_dst_probs) + + @parameterized.parameters( + # Same src and dst support shape, dst support is shifted by +.25 + ([[0., 1, 2, 3]], + [[0., 1, 0, 0]], + [.25, 1.25, 2.25, 3.25], + [[.25, .75, 0, 0]]), + # Similar to above, but with batched src. + ([[0., 1, 2, 3], + [0., 1, 2, 3]], + [[0., 1, 0, 0], + [0., 0, 1, 0]], + [.25, 1.25, 2.25, 3.25], + [[.25, .75, 0, 0], + [0., .25, .75, 0]]), + # Similar to above, but src_probs has two 0.5's, instead of being one-hot. + ([[0., 1, 2, 3]], + [[0., .5, .5, 0]], + [.25, 1.25, 2.25, 3.25], + 0.5 * (np.array([[.25, .75, 0, 0]]) + np.array([[0., .25, .75, 0]]))), + # src and dst support have differing sizes + ([[0., 1, 2, 3]], + [[0., 1, 0, 0]], + [0.00, 0.25, 0.50, 0.75, 1.00, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50], + [[0.00, 0.00, 0.00, 0.00, 1.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00]]), + ) + def test_l2_projection( + self, src_support, src_probs, dst_support, expected_dst_probs): + + dst_probs = distributional.multiaxis_l2_project( + tf.convert_to_tensor(src_support), + tf.convert_to_tensor(src_probs), + tf.convert_to_tensor(dst_support)).numpy() + npt.assert_allclose(dst_probs, expected_dst_probs) + + +if __name__ == '__main__': + absltest.main() + diff --git a/acme/acme/tf/losses/dpg.py b/acme/acme/tf/losses/dpg.py new file mode 100644 index 00000000..e268e45a --- /dev/null +++ b/acme/acme/tf/losses/dpg.py @@ -0,0 +1,59 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Losses for Deterministic Policy Gradients.""" + +from typing import Optional +import tensorflow as tf + + +def dpg( + q_max: tf.Tensor, + a_max: tf.Tensor, + tape: tf.GradientTape, + dqda_clipping: Optional[float] = None, + clip_norm: bool = False, +) -> tf.Tensor: + """Deterministic policy gradient loss, similar to trfl.dpg.""" + + # Calculate the gradient dq/da. + dqda = tape.gradient([q_max], [a_max])[0] + + if dqda is None: + raise ValueError('q_max needs to be a function of a_max.') + + # Clipping the gradient dq/da. + if dqda_clipping is not None: + if dqda_clipping <= 0: + raise ValueError('dqda_clipping should be bigger than 0, {} found'.format( + dqda_clipping)) + if clip_norm: + dqda = tf.clip_by_norm(dqda, dqda_clipping, axes=-1) + else: + dqda = tf.clip_by_value(dqda, -1. * dqda_clipping, dqda_clipping) + + # Target_a ensures correct gradient calculated during backprop. + target_a = dqda + a_max + # Stop the gradient going through Q network when backprop. + target_a = tf.stop_gradient(target_a) + # Gradient only go through actor network. + loss = 0.5 * tf.reduce_sum(tf.square(target_a - a_max), axis=-1) + # This recovers the DPG because (letting w be the actor network weights): + # d(loss)/dw = 0.5 * (2 * (target_a - a_max) * d(target_a - a_max)/dw) + # = (target_a - a_max) * [d(target_a)/dw - d(a_max)/dw] + # = dq/da * [d(target_a)/dw - d(a_max)/dw] # by defn of target_a + # = dq/da * [0 - d(a_max)/dw] # by stop_gradient + # = - dq/da * da/dw + + return loss diff --git a/acme/acme/tf/losses/huber.py b/acme/acme/tf/losses/huber.py new file mode 100644 index 00000000..5d97c243 --- /dev/null +++ b/acme/acme/tf/losses/huber.py @@ -0,0 +1,56 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Huber loss.""" + +import tensorflow as tf + + +def huber(inputs: tf.Tensor, quadratic_linear_boundary: float) -> tf.Tensor: + """Calculates huber loss of `inputs`. + + For each value x in `inputs`, the following is calculated: + + ``` + 0.5 * x^2 if |x| <= d + 0.5 * d^2 + d * (|x| - d) if |x| > d + ``` + + where d is `quadratic_linear_boundary`. + + Args: + inputs: Input Tensor to calculate the huber loss on. + quadratic_linear_boundary: The point where the huber loss function changes + from a quadratic to linear. + + Returns: + `Tensor` of the same shape as `inputs`, containing values calculated + in the manner described above. + + Raises: + ValueError: if quadratic_linear_boundary < 0. + """ + if quadratic_linear_boundary < 0: + raise ValueError("quadratic_linear_boundary must be >= 0.") + + abs_x = tf.abs(inputs) + delta = tf.constant(quadratic_linear_boundary) + quad = tf.minimum(abs_x, delta) + # The following expression is the same in value as + # tf.maximum(abs_x - delta, 0), but importantly the gradient for the + # expression when abs_x == delta is 0 (for tf.maximum it would be 1). This + # is necessary to avoid doubling the gradient, since there is already a + # nonzero contribution to the gradient from the quadratic term. + lin = (abs_x - quad) + return 0.5 * quad**2 + delta * lin diff --git a/acme/acme/tf/losses/mompo.py b/acme/acme/tf/losses/mompo.py new file mode 100644 index 00000000..591a2786 --- /dev/null +++ b/acme/acme/tf/losses/mompo.py @@ -0,0 +1,353 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implements the multi-objective MPO (MO-MPO) loss. + +This loss was proposed in (Abdolmaleki, Huang et al., 2020). + +The loss is implemented as a Sonnet module rather than a function so that it +can hold its own dual variables, as instances of `tf.Variable`, which it creates +the first time the module is called. + +Tensor shapes are annotated, where helpful, as follow: + B: batch size, + N: number of sampled actions, see MO-MPO paper for more details, + D: dimensionality of the action space, + K: number of objectives. + +(Abdolmaleki, Huang et al., 2020): https://arxiv.org/pdf/2005.07513.pdf +""" + +import dataclasses +from typing import Dict, Sequence, Tuple, Union + +from acme.tf.losses import mpo +import sonnet as snt +import tensorflow as tf +import tensorflow_probability as tfp + +tfd = tfp.distributions + +_MPO_FLOAT_EPSILON = 1e-8 + + +@dataclasses.dataclass +class KLConstraint: + """Defines a per-objective policy improvement step constraint for MO-MPO.""" + + name: str + value: float + + def __post_init__(self): + if self.value < 0: + raise ValueError("KL constraint epsilon must be non-negative.") + + +class MultiObjectiveMPO(snt.Module): + """Multi-objective MPO loss with decoupled KL constraints. + + This implementation of the MO-MPO loss is based on the approach proposed in + (Abdolmaleki, Huang et al., 2020). The following features are included as + options: + - Satisfying the KL-constraint on a per-dimension basis (on by default) + + (Abdolmaleki, Huang et al., 2020): https://arxiv.org/pdf/2005.07513.pdf + """ + + def __init__(self, + epsilons: Sequence[KLConstraint], + epsilon_mean: float, + epsilon_stddev: float, + init_log_temperature: float, + init_log_alpha_mean: float, + init_log_alpha_stddev: float, + per_dim_constraining: bool = True, + name: str = "MOMPO"): + """Initialize and configure the MPO loss. + + Args: + epsilons: per-objective KL constraints on the non-parametric auxiliary + policy, the one associated with the dual variables called temperature; + expected length K. + epsilon_mean: KL constraint on the mean of the Gaussian policy, the one + associated with the dual variable called alpha_mean. + epsilon_stddev: KL constraint on the stddev of the Gaussian policy, the + one associated with the dual variable called alpha_mean. + init_log_temperature: initial value for the temperature in log-space, note + a softplus (rather than an exp) will be used to transform this. + init_log_alpha_mean: initial value for the alpha_mean in log-space, note + a softplus (rather than an exp) will be used to transform this. + init_log_alpha_stddev: initial value for the alpha_stddev in log-space, + note a softplus (rather than an exp) will be used to transform this. + per_dim_constraining: whether to enforce the KL constraint on each + dimension independently; this is the default. Otherwise the overall KL + is constrained, which allows some dimensions to change more at the + expense of others staying put. + name: a name for the module, passed directly to snt.Module. + + """ + super().__init__(name=name) + + # MO-MPO constraint thresholds. + self._epsilons = tf.constant([x.value for x in epsilons]) + self._epsilon_mean = tf.constant(epsilon_mean) + self._epsilon_stddev = tf.constant(epsilon_stddev) + + # Initial values for the constraints' dual variables. + self._init_log_temperature = init_log_temperature + self._init_log_alpha_mean = init_log_alpha_mean + self._init_log_alpha_stddev = init_log_alpha_stddev + + # Whether to ensure per-dimension KL constraint satisfication. + self._per_dim_constraining = per_dim_constraining + + # Remember the number of objectives + self._num_objectives = len(epsilons) # K = number of objectives + self._objective_names = [x.name for x in epsilons] + + # Make sure there are no duplicate objective names + if len(self._objective_names) != len(set(self._objective_names)): + raise ValueError("Duplicate objective names are not allowed.") + + @property + def objective_names(self): + return self._objective_names + + @snt.once + def create_dual_variables_once(self, shape: tf.TensorShape, dtype: tf.DType): + """Creates the dual variables the first time the loss module is called.""" + + # Create the dual variables. + self._log_temperature = tf.Variable( + initial_value=[self._init_log_temperature] * self._num_objectives, + dtype=dtype, + name="log_temperature", + shape=(self._num_objectives,)) + self._log_alpha_mean = tf.Variable( + initial_value=tf.fill(shape, self._init_log_alpha_mean), + dtype=dtype, + name="log_alpha_mean", + shape=shape) + self._log_alpha_stddev = tf.Variable( + initial_value=tf.fill(shape, self._init_log_alpha_stddev), + dtype=dtype, + name="log_alpha_stddev", + shape=shape) + + # Cast constraint thresholds to the expected dtype. + self._epsilons = tf.cast(self._epsilons, dtype) + self._epsilon_mean = tf.cast(self._epsilon_mean, dtype) + self._epsilon_stddev = tf.cast(self._epsilon_stddev, dtype) + + def __call__( + self, + online_action_distribution: Union[tfd.MultivariateNormalDiag, + tfd.Independent], + target_action_distribution: Union[tfd.MultivariateNormalDiag, + tfd.Independent], + actions: tf.Tensor, # Shape [N, B, D]. + q_values: tf.Tensor, # Shape [N, B, K]. + ) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]: + """Computes the decoupled MO-MPO loss. + + Args: + online_action_distribution: online distribution returned by the online + policy network; expects batch_dims of [B] and event_dims of [D]. + target_action_distribution: target distribution returned by the target + policy network; expects same shapes as online distribution. + actions: actions sampled from the target policy; expects shape [N, B, D]. + q_values: Q-values associated with each action; expects shape [N, B, K]. + + Returns: + Loss, combining the policy loss, KL penalty, and dual losses required to + adapt the dual variables. + Stats, for diagnostics and tracking performance. + """ + + # Make sure the Q-values are per-objective + q_values.get_shape().assert_has_rank(3) + if q_values.get_shape()[-1] != self._num_objectives: + raise ValueError("Q-values do not match expected number of objectives.") + + # Cast `MultivariateNormalDiag`s to Independent Normals. + # The latter allows us to satisfy KL constraints per-dimension. + if isinstance(target_action_distribution, tfd.MultivariateNormalDiag): + target_action_distribution = tfd.Independent( + tfd.Normal(target_action_distribution.mean(), + target_action_distribution.stddev())) + online_action_distribution = tfd.Independent( + tfd.Normal(online_action_distribution.mean(), + online_action_distribution.stddev())) + + # Infer the shape and dtype of dual variables. + scalar_dtype = q_values.dtype + if self._per_dim_constraining: + dual_variable_shape = target_action_distribution.distribution.kl_divergence( + online_action_distribution.distribution).shape[1:] # Should be [D]. + else: + dual_variable_shape = target_action_distribution.kl_divergence( + online_action_distribution).shape[1:] # Should be [1]. + + # Create dual variables for the KL constraints; only happens the first call. + self.create_dual_variables_once(dual_variable_shape, scalar_dtype) + + # Project dual variables to ensure they stay positive. + min_log_temperature = tf.constant(-18.0, scalar_dtype) + min_log_alpha = tf.constant(-18.0, scalar_dtype) + self._log_temperature.assign( + tf.maximum(min_log_temperature, self._log_temperature)) + self._log_alpha_mean.assign(tf.maximum(min_log_alpha, self._log_alpha_mean)) + self._log_alpha_stddev.assign( + tf.maximum(min_log_alpha, self._log_alpha_stddev)) + + # Transform dual variables from log-space. + # Note: using softplus instead of exponential for numerical stability. + temperature = tf.math.softplus(self._log_temperature) + _MPO_FLOAT_EPSILON + alpha_mean = tf.math.softplus(self._log_alpha_mean) + _MPO_FLOAT_EPSILON + alpha_stddev = tf.math.softplus(self._log_alpha_stddev) + _MPO_FLOAT_EPSILON + + # Get online and target means and stddevs in preparation for decomposition. + online_mean = online_action_distribution.distribution.mean() + online_scale = online_action_distribution.distribution.stddev() + target_mean = target_action_distribution.distribution.mean() + target_scale = target_action_distribution.distribution.stddev() + + # Compute normalized importance weights, used to compute expectations with + # respect to the non-parametric policy; and the temperature loss, used to + # adapt the tempering of Q-values. + normalized_weights, loss_temperature = compute_weights_and_temperature_loss( + q_values, self._epsilons, temperature) # Shapes [N, B, K] and [1, K]. + normalized_weights_sum = tf.reduce_sum(normalized_weights, axis=-1) + loss_temperature_mean = tf.reduce_mean(loss_temperature) + + # Only needed for diagnostics: Compute estimated actualized KL between the + # non-parametric and current target policies. + kl_nonparametric = mpo.compute_nonparametric_kl_from_normalized_weights( + normalized_weights) + + # Decompose the online policy into fixed-mean & fixed-stddev distributions. + # This has been documented as having better performance in bandit settings, + # see e.g. https://arxiv.org/pdf/1812.02256.pdf. + fixed_stddev_distribution = tfd.Independent( + tfd.Normal(loc=online_mean, scale=target_scale)) + fixed_mean_distribution = tfd.Independent( + tfd.Normal(loc=target_mean, scale=online_scale)) + + # Compute the decomposed policy losses. + loss_policy_mean = mpo.compute_cross_entropy_loss( + actions, normalized_weights_sum, fixed_stddev_distribution) + loss_policy_stddev = mpo.compute_cross_entropy_loss( + actions, normalized_weights_sum, fixed_mean_distribution) + + # Compute the decomposed KL between the target and online policies. + if self._per_dim_constraining: + kl_mean = target_action_distribution.distribution.kl_divergence( + fixed_stddev_distribution.distribution) # Shape [B, D]. + kl_stddev = target_action_distribution.distribution.kl_divergence( + fixed_mean_distribution.distribution) # Shape [B, D]. + else: + kl_mean = target_action_distribution.kl_divergence( + fixed_stddev_distribution) # Shape [B]. + kl_stddev = target_action_distribution.kl_divergence( + fixed_mean_distribution) # Shape [B]. + + # Compute the alpha-weighted KL-penalty and dual losses to adapt the alphas. + loss_kl_mean, loss_alpha_mean = mpo.compute_parametric_kl_penalty_and_dual_loss( + kl_mean, alpha_mean, self._epsilon_mean) + loss_kl_stddev, loss_alpha_stddev = mpo.compute_parametric_kl_penalty_and_dual_loss( + kl_stddev, alpha_stddev, self._epsilon_stddev) + + # Combine losses. + loss_policy = loss_policy_mean + loss_policy_stddev + loss_kl_penalty = loss_kl_mean + loss_kl_stddev + loss_dual = loss_alpha_mean + loss_alpha_stddev + loss_temperature_mean + loss = loss_policy + loss_kl_penalty + loss_dual + + stats = {} + # Dual Variables. + stats["dual_alpha_mean"] = tf.reduce_mean(alpha_mean) + stats["dual_alpha_stddev"] = tf.reduce_mean(alpha_stddev) + # Losses. + stats["loss_policy"] = tf.reduce_mean(loss) + stats["loss_alpha"] = tf.reduce_mean(loss_alpha_mean + loss_alpha_stddev) + # KL measurements. + stats["kl_mean_rel"] = tf.reduce_mean(kl_mean, axis=0) / self._epsilon_mean + stats["kl_stddev_rel"] = tf.reduce_mean( + kl_stddev, axis=0) / self._epsilon_stddev + # If the policy has standard deviation, log summary stats for this as well. + pi_stddev = online_action_distribution.distribution.stddev() + stats["pi_stddev_min"] = tf.reduce_mean(tf.reduce_min(pi_stddev, axis=-1)) + stats["pi_stddev_max"] = tf.reduce_mean(tf.reduce_max(pi_stddev, axis=-1)) + + # Condition number of the diagonal covariance (actually, stddev) matrix. + stats["pi_stddev_cond"] = tf.reduce_mean( + tf.reduce_max(pi_stddev, axis=-1) / tf.reduce_min(pi_stddev, axis=-1)) + + # Log per-objective values. + for i, name in enumerate(self._objective_names): + stats["{}_dual_temperature".format(name)] = temperature[i] + stats["{}_loss_temperature".format(name)] = loss_temperature[i] + stats["{}_kl_q_rel".format(name)] = tf.reduce_mean( + kl_nonparametric[:, i]) / self._epsilons[i] + + # Q measurements. + stats["{}_q_min".format(name)] = tf.reduce_mean(tf.reduce_min( + q_values, axis=0)[:, i]) + stats["{}_q_mean".format(name)] = tf.reduce_mean(tf.reduce_mean( + q_values, axis=0)[:, i]) + stats["{}_q_max".format(name)] = tf.reduce_mean(tf.reduce_max( + q_values, axis=0)[:, i]) + + return loss, stats + + +def compute_weights_and_temperature_loss( + q_values: tf.Tensor, + epsilons: tf.Tensor, + temperature: tf.Variable, +) -> Tuple[tf.Tensor, tf.Tensor]: + """Computes normalized importance weights for the policy optimization. + + Args: + q_values: Q-values associated with the actions sampled from the target + policy; expected shape [N, B, K]. + epsilons: Desired per-objective constraints on the KL between the target + and non-parametric policies; expected shape [K]. + temperature: Per-objective scalar used to temper the Q-values before + computing normalized importance weights from them; expected shape [K]. + This is really the Lagrange dual variable in the constrained optimization + problem, the solution of which is the non-parametric policy targeted by + the policy loss. + + Returns: + Normalized importance weights, used for policy optimization; shape [N,B,K]. + Temperature loss, used to adapt the temperature; shape [1, K]. + """ + + # Temper the given Q-values using the current temperature. + tempered_q_values = tf.stop_gradient(q_values) / temperature[None, None, :] + + # Compute the normalized importance weights used to compute expectations with + # respect to the non-parametric policy. + normalized_weights = tf.nn.softmax(tempered_q_values, axis=0) + normalized_weights = tf.stop_gradient(normalized_weights) + + # Compute the temperature loss (dual of the E-step optimization problem). + q_logsumexp = tf.reduce_logsumexp(tempered_q_values, axis=0) + log_num_actions = tf.math.log(tf.cast(q_values.shape[0], tf.float32)) + loss_temperature = ( + epsilons + tf.reduce_mean(q_logsumexp, axis=0) - log_num_actions) + loss_temperature = temperature * loss_temperature + + return normalized_weights, loss_temperature diff --git a/acme/acme/tf/losses/mpo.py b/acme/acme/tf/losses/mpo.py new file mode 100644 index 00000000..0b956cd1 --- /dev/null +++ b/acme/acme/tf/losses/mpo.py @@ -0,0 +1,439 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implements the MPO losses. + +The MPO loss is implemented as a Sonnet module rather than a function so that it +can hold its own dual variables, as instances of `tf.Variable`, which it creates +the first time the module is called. + +Tensor shapes are annotated, where helpful, as follow: + B: batch size, + N: number of sampled actions, see MPO paper for more details, + D: dimensionality of the action space. +""" + +from typing import Dict, Tuple, Union + +import sonnet as snt +import tensorflow as tf +import tensorflow_probability as tfp + +tfd = tfp.distributions + +_MPO_FLOAT_EPSILON = 1e-8 + + +class MPO(snt.Module): + """MPO loss with decoupled KL constraints as in (Abdolmaleki et al., 2018). + + This implementation of the MPO loss includes the following features, as + options: + - Satisfying the KL-constraint on a per-dimension basis (on by default); + - Penalizing actions that fall outside of [-1, 1] (on by default) as a + special case of multi-objective MPO (MO-MPO; Abdolmaleki et al., 2020). + For best results on the control suite, keep both of these on. + + (Abdolmaleki et al., 2018): https://arxiv.org/pdf/1812.02256.pdf + (Abdolmaleki et al., 2020): https://arxiv.org/pdf/2005.07513.pdf + """ + + def __init__(self, + epsilon: float, + epsilon_mean: float, + epsilon_stddev: float, + init_log_temperature: float, + init_log_alpha_mean: float, + init_log_alpha_stddev: float, + per_dim_constraining: bool = True, + action_penalization: bool = True, + epsilon_penalty: float = 0.001, + name: str = "MPO"): + """Initialize and configure the MPO loss. + + Args: + epsilon: KL constraint on the non-parametric auxiliary policy, the one + associated with the dual variable called temperature. + epsilon_mean: KL constraint on the mean of the Gaussian policy, the one + associated with the dual variable called alpha_mean. + epsilon_stddev: KL constraint on the stddev of the Gaussian policy, the + one associated with the dual variable called alpha_mean. + init_log_temperature: initial value for the temperature in log-space, note + a softplus (rather than an exp) will be used to transform this. + init_log_alpha_mean: initial value for the alpha_mean in log-space, note + a softplus (rather than an exp) will be used to transform this. + init_log_alpha_stddev: initial value for the alpha_stddev in log-space, + note a softplus (rather than an exp) will be used to transform this. + per_dim_constraining: whether to enforce the KL constraint on each + dimension independently; this is the default. Otherwise the overall KL + is constrained, which allows some dimensions to change more at the + expense of others staying put. + action_penalization: whether to use a KL constraint to penalize actions + via the MO-MPO algorithm. + epsilon_penalty: KL constraint on the probability of violating the action + constraint. + name: a name for the module, passed directly to snt.Module. + + """ + super().__init__(name=name) + + # MPO constrain thresholds. + self._epsilon = tf.constant(epsilon) + self._epsilon_mean = tf.constant(epsilon_mean) + self._epsilon_stddev = tf.constant(epsilon_stddev) + + # Initial values for the constraints' dual variables. + self._init_log_temperature = init_log_temperature + self._init_log_alpha_mean = init_log_alpha_mean + self._init_log_alpha_stddev = init_log_alpha_stddev + + # Whether to penalize out-of-bound actions via MO-MPO and its corresponding + # constraint threshold. + self._action_penalization = action_penalization + self._epsilon_penalty = tf.constant(epsilon_penalty) + + # Whether to ensure per-dimension KL constraint satisfication. + self._per_dim_constraining = per_dim_constraining + + @snt.once + def create_dual_variables_once(self, shape: tf.TensorShape, dtype: tf.DType): + """Creates the dual variables the first time the loss module is called.""" + + # Create the dual variables. + self._log_temperature = tf.Variable( + initial_value=[self._init_log_temperature], + dtype=dtype, + name="log_temperature", + shape=(1,)) + self._log_alpha_mean = tf.Variable( + initial_value=tf.fill(shape, self._init_log_alpha_mean), + dtype=dtype, + name="log_alpha_mean", + shape=shape) + self._log_alpha_stddev = tf.Variable( + initial_value=tf.fill(shape, self._init_log_alpha_stddev), + dtype=dtype, + name="log_alpha_stddev", + shape=shape) + + # Cast constraint thresholds to the expected dtype. + self._epsilon = tf.cast(self._epsilon, dtype) + self._epsilon_mean = tf.cast(self._epsilon_mean, dtype) + self._epsilon_stddev = tf.cast(self._epsilon_stddev, dtype) + + # Maybe create the action penalization dual variable. + if self._action_penalization: + self._epsilon_penalty = tf.cast(self._epsilon_penalty, dtype) + self._log_penalty_temperature = tf.Variable( + initial_value=[self._init_log_temperature], + dtype=dtype, + name="log_penalty_temperature", + shape=(1,)) + + def __call__( + self, + online_action_distribution: Union[tfd.MultivariateNormalDiag, + tfd.Independent], + target_action_distribution: Union[tfd.MultivariateNormalDiag, + tfd.Independent], + actions: tf.Tensor, # Shape [N, B, D]. + q_values: tf.Tensor, # Shape [N, B]. + ) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]: + """Computes the decoupled MPO loss. + + Args: + online_action_distribution: online distribution returned by the online + policy network; expects batch_dims of [B] and event_dims of [D]. + target_action_distribution: target distribution returned by the target + policy network; expects same shapes as online distribution. + actions: actions sampled from the target policy; expects shape [N, B, D]. + q_values: Q-values associated with each action; expects shape [N, B]. + + Returns: + Loss, combining the policy loss, KL penalty, and dual losses required to + adapt the dual variables. + Stats, for diagnostics and tracking performance. + """ + + # Cast `MultivariateNormalDiag`s to Independent Normals. + # The latter allows us to satisfy KL constraints per-dimension. + if isinstance(target_action_distribution, tfd.MultivariateNormalDiag): + target_action_distribution = tfd.Independent( + tfd.Normal(target_action_distribution.mean(), + target_action_distribution.stddev())) + online_action_distribution = tfd.Independent( + tfd.Normal(online_action_distribution.mean(), + online_action_distribution.stddev())) + + # Infer the shape and dtype of dual variables. + scalar_dtype = q_values.dtype + if self._per_dim_constraining: + dual_variable_shape = target_action_distribution.distribution.kl_divergence( + online_action_distribution.distribution).shape[1:] # Should be [D]. + else: + dual_variable_shape = target_action_distribution.kl_divergence( + online_action_distribution).shape[1:] # Should be [1]. + + # Create dual variables for the KL constraints; only happens the first call. + self.create_dual_variables_once(dual_variable_shape, scalar_dtype) + + # Project dual variables to ensure they stay positive. + min_log_temperature = tf.constant(-18.0, scalar_dtype) + min_log_alpha = tf.constant(-18.0, scalar_dtype) + self._log_temperature.assign( + tf.maximum(min_log_temperature, self._log_temperature)) + self._log_alpha_mean.assign(tf.maximum(min_log_alpha, self._log_alpha_mean)) + self._log_alpha_stddev.assign( + tf.maximum(min_log_alpha, self._log_alpha_stddev)) + + # Transform dual variables from log-space. + # Note: using softplus instead of exponential for numerical stability. + temperature = tf.math.softplus(self._log_temperature) + _MPO_FLOAT_EPSILON + alpha_mean = tf.math.softplus(self._log_alpha_mean) + _MPO_FLOAT_EPSILON + alpha_stddev = tf.math.softplus(self._log_alpha_stddev) + _MPO_FLOAT_EPSILON + + # Get online and target means and stddevs in preparation for decomposition. + online_mean = online_action_distribution.distribution.mean() + online_scale = online_action_distribution.distribution.stddev() + target_mean = target_action_distribution.distribution.mean() + target_scale = target_action_distribution.distribution.stddev() + + # Compute normalized importance weights, used to compute expectations with + # respect to the non-parametric policy; and the temperature loss, used to + # adapt the tempering of Q-values. + normalized_weights, loss_temperature = compute_weights_and_temperature_loss( + q_values, self._epsilon, temperature) + + # Only needed for diagnostics: Compute estimated actualized KL between the + # non-parametric and current target policies. + kl_nonparametric = compute_nonparametric_kl_from_normalized_weights( + normalized_weights) + + if self._action_penalization: + # Project and transform action penalization temperature. + self._log_penalty_temperature.assign( + tf.maximum(min_log_temperature, self._log_penalty_temperature)) + penalty_temperature = tf.math.softplus( + self._log_penalty_temperature) + _MPO_FLOAT_EPSILON + + # Compute action penalization cost. + # Note: the cost is zero in [-1, 1] and quadratic beyond. + diff_out_of_bound = actions - tf.clip_by_value(actions, -1.0, 1.0) + cost_out_of_bound = -tf.norm(diff_out_of_bound, axis=-1) + + penalty_normalized_weights, loss_penalty_temperature = compute_weights_and_temperature_loss( + cost_out_of_bound, self._epsilon_penalty, penalty_temperature) + + # Only needed for diagnostics: Compute estimated actualized KL between the + # non-parametric and current target policies. + penalty_kl_nonparametric = compute_nonparametric_kl_from_normalized_weights( + penalty_normalized_weights) + + # Combine normalized weights. + normalized_weights += penalty_normalized_weights + loss_temperature += loss_penalty_temperature + # Decompose the online policy into fixed-mean & fixed-stddev distributions. + # This has been documented as having better performance in bandit settings, + # see e.g. https://arxiv.org/pdf/1812.02256.pdf. + fixed_stddev_distribution = tfd.Independent( + tfd.Normal(loc=online_mean, scale=target_scale)) + fixed_mean_distribution = tfd.Independent( + tfd.Normal(loc=target_mean, scale=online_scale)) + + # Compute the decomposed policy losses. + loss_policy_mean = compute_cross_entropy_loss( + actions, normalized_weights, fixed_stddev_distribution) + loss_policy_stddev = compute_cross_entropy_loss( + actions, normalized_weights, fixed_mean_distribution) + + # Compute the decomposed KL between the target and online policies. + if self._per_dim_constraining: + kl_mean = target_action_distribution.distribution.kl_divergence( + fixed_stddev_distribution.distribution) # Shape [B, D]. + kl_stddev = target_action_distribution.distribution.kl_divergence( + fixed_mean_distribution.distribution) # Shape [B, D]. + else: + kl_mean = target_action_distribution.kl_divergence( + fixed_stddev_distribution) # Shape [B]. + kl_stddev = target_action_distribution.kl_divergence( + fixed_mean_distribution) # Shape [B]. + + # Compute the alpha-weighted KL-penalty and dual losses to adapt the alphas. + loss_kl_mean, loss_alpha_mean = compute_parametric_kl_penalty_and_dual_loss( + kl_mean, alpha_mean, self._epsilon_mean) + loss_kl_stddev, loss_alpha_stddev = compute_parametric_kl_penalty_and_dual_loss( + kl_stddev, alpha_stddev, self._epsilon_stddev) + + # Combine losses. + loss_policy = loss_policy_mean + loss_policy_stddev + loss_kl_penalty = loss_kl_mean + loss_kl_stddev + loss_dual = loss_alpha_mean + loss_alpha_stddev + loss_temperature + loss = loss_policy + loss_kl_penalty + loss_dual + + stats = {} + # Dual Variables. + stats["dual_alpha_mean"] = tf.reduce_mean(alpha_mean) + stats["dual_alpha_stddev"] = tf.reduce_mean(alpha_stddev) + stats["dual_temperature"] = tf.reduce_mean(temperature) + # Losses. + stats["loss_policy"] = tf.reduce_mean(loss) + stats["loss_alpha"] = tf.reduce_mean(loss_alpha_mean + loss_alpha_stddev) + stats["loss_temperature"] = tf.reduce_mean(loss_temperature) + # KL measurements. + stats["kl_q_rel"] = tf.reduce_mean(kl_nonparametric) / self._epsilon + + if self._action_penalization: + stats["penalty_kl_q_rel"] = tf.reduce_mean( + penalty_kl_nonparametric) / self._epsilon_penalty + + stats["kl_mean_rel"] = tf.reduce_mean(kl_mean) / self._epsilon_mean + stats["kl_stddev_rel"] = tf.reduce_mean(kl_stddev) / self._epsilon_stddev + if self._per_dim_constraining: + # When KL is constrained per-dimension, we also log per-dimension min and + # max of mean/std of the realized KL costs. + stats["kl_mean_rel_min"] = tf.reduce_min(tf.reduce_mean( + kl_mean, axis=0)) / self._epsilon_mean + stats["kl_mean_rel_max"] = tf.reduce_max(tf.reduce_mean( + kl_mean, axis=0)) / self._epsilon_mean + stats["kl_stddev_rel_min"] = tf.reduce_min( + tf.reduce_mean(kl_stddev, axis=0)) / self._epsilon_stddev + stats["kl_stddev_rel_max"] = tf.reduce_max( + tf.reduce_mean(kl_stddev, axis=0)) / self._epsilon_stddev + # Q measurements. + stats["q_min"] = tf.reduce_mean(tf.reduce_min(q_values, axis=0)) + stats["q_max"] = tf.reduce_mean(tf.reduce_max(q_values, axis=0)) + # If the policy has standard deviation, log summary stats for this as well. + pi_stddev = online_action_distribution.distribution.stddev() + stats["pi_stddev_min"] = tf.reduce_mean(tf.reduce_min(pi_stddev, axis=-1)) + stats["pi_stddev_max"] = tf.reduce_mean(tf.reduce_max(pi_stddev, axis=-1)) + # Condition number of the diagonal covariance (actually, stddev) matrix. + stats["pi_stddev_cond"] = tf.reduce_mean( + tf.reduce_max(pi_stddev, axis=-1) / tf.reduce_min(pi_stddev, axis=-1)) + + return loss, stats + + +def compute_weights_and_temperature_loss( + q_values: tf.Tensor, + epsilon: float, + temperature: tf.Variable, +) -> Tuple[tf.Tensor, tf.Tensor]: + """Computes normalized importance weights for the policy optimization. + + Args: + q_values: Q-values associated with the actions sampled from the target + policy; expected shape [N, B]. + epsilon: Desired constraint on the KL between the target and non-parametric + policies. + temperature: Scalar used to temper the Q-values before computing normalized + importance weights from them. This is really the Lagrange dual variable + in the constrained optimization problem, the solution of which is the + non-parametric policy targeted by the policy loss. + + Returns: + Normalized importance weights, used for policy optimization. + Temperature loss, used to adapt the temperature. + """ + + # Temper the given Q-values using the current temperature. + tempered_q_values = tf.stop_gradient(q_values) / temperature + + # Compute the normalized importance weights used to compute expectations with + # respect to the non-parametric policy. + normalized_weights = tf.nn.softmax(tempered_q_values, axis=0) + normalized_weights = tf.stop_gradient(normalized_weights) + + # Compute the temperature loss (dual of the E-step optimization problem). + q_logsumexp = tf.reduce_logsumexp(tempered_q_values, axis=0) + log_num_actions = tf.math.log(tf.cast(q_values.shape[0], tf.float32)) + loss_temperature = epsilon + tf.reduce_mean(q_logsumexp) - log_num_actions + loss_temperature = temperature * loss_temperature + + return normalized_weights, loss_temperature + + +def compute_nonparametric_kl_from_normalized_weights( + normalized_weights: tf.Tensor) -> tf.Tensor: + """Estimate the actualized KL between the non-parametric and target policies.""" + + # Compute integrand. + num_action_samples = tf.cast(normalized_weights.shape[0], tf.float32) + integrand = tf.math.log(num_action_samples * normalized_weights + 1e-8) + + # Return the expectation with respect to the non-parametric policy. + return tf.reduce_sum(normalized_weights * integrand, axis=0) + + +def compute_cross_entropy_loss( + sampled_actions: tf.Tensor, + normalized_weights: tf.Tensor, + online_action_distribution: tfp.distributions.Distribution, +) -> tf.Tensor: + """Compute cross-entropy online and the reweighted target policy. + + Args: + sampled_actions: samples used in the Monte Carlo integration in the policy + loss. Expected shape is [N, B, ...], where N is the number of sampled + actions and B is the number of sampled states. + normalized_weights: target policy multiplied by the exponentiated Q values + and normalized; expected shape is [N, B]. + online_action_distribution: policy to be optimized. + + Returns: + loss_policy_gradient: the cross-entropy loss that, when differentiated, + produces the policy gradient. + """ + + # Compute the M-step loss. + log_prob = online_action_distribution.log_prob(sampled_actions) + + # Compute the weighted average log-prob using the normalized weights. + loss_policy_gradient = -tf.reduce_sum(log_prob * normalized_weights, axis=0) + + # Return the mean loss over the batch of states. + return tf.reduce_mean(loss_policy_gradient, axis=0) + + +def compute_parametric_kl_penalty_and_dual_loss( + kl: tf.Tensor, + alpha: tf.Variable, + epsilon: float, +) -> Tuple[tf.Tensor, tf.Tensor]: + """Computes the KL cost to be added to the Lagragian and its dual loss. + + The KL cost is simply the alpha-weighted KL divergence and it is added as a + regularizer to the policy loss. The dual variable alpha itself has a loss that + can be minimized to adapt the strength of the regularizer to keep the KL + between consecutive updates at the desired target value of epsilon. + + Args: + kl: KL divergence between the target and online policies. + alpha: Lagrange multipliers (dual variables) for the KL constraints. + epsilon: Desired value for the KL. + + Returns: + loss_kl: alpha-weighted KL regularization to be added to the policy loss. + loss_alpha: The Lagrange dual loss minimized to adapt alpha. + """ + + # Compute the mean KL over the batch. + mean_kl = tf.reduce_mean(kl, axis=0) + + # Compute the regularization. + loss_kl = tf.reduce_sum(tf.stop_gradient(alpha) * mean_kl) + + # Compute the dual loss. + loss_alpha = tf.reduce_sum(alpha * (epsilon - tf.stop_gradient(mean_kl))) + + return loss_kl, loss_alpha diff --git a/acme/acme/tf/losses/quantile.py b/acme/acme/tf/losses/quantile.py new file mode 100644 index 00000000..bfbb18b2 --- /dev/null +++ b/acme/acme/tf/losses/quantile.py @@ -0,0 +1,94 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Losses for quantile regression.""" + +from typing import NamedTuple + +from .huber import huber +import sonnet as snt +import tensorflow as tf + + +class QuantileDistribution(NamedTuple): + values: tf.Tensor + logits: tf.Tensor + + +class NonUniformQuantileRegression(snt.Module): + """Compute the quantile regression loss for the distributional TD error.""" + + def __init__( + self, + huber_param: float = 0., + name: str = 'NUQuantileRegression'): + """Initializes the module. + + Args: + huber_param: The point where the huber loss function changes from a + quadratic to linear. + name: name to use for grouping operations. + """ + super().__init__(name=name) + self._huber_param = huber_param + + def __call__( + self, + q_tm1: QuantileDistribution, + r_t: tf.Tensor, + pcont_t: tf.Tensor, + q_t: QuantileDistribution, + tau: tf.Tensor, + ) -> tf.Tensor: + """Calculates the loss. + + Note that this is only defined for discrete quantile-valued distributions. + In particular we assume that the distributions define q.logits and + q.values. + + Args: + q_tm1: the distribution at time t-1. + r_t: the reward at time t. + pcont_t: the discount factor at time t. + q_t: the target distribution. + tau: the quantile regression targets. + + Returns: + Value of the loss. + """ + # Distributional Bellman update + values_t = (tf.reshape(r_t, (-1, 1)) + + tf.reshape(pcont_t, (-1, 1)) * q_t.values) + values_t = tf.stop_gradient(values_t) + probs_t = tf.nn.softmax(q_t.logits) + + # Quantile regression loss + # Tau gives the quantile regression targets, where in the sample + # space [0, 1] each output should train towards + # Tau applies along the second dimension in delta (below) + tau = tf.expand_dims(tau, -1) + + # quantile td-error and assymmetric weighting + delta = values_t[:, None, :] - q_tm1.values[:, :, None] + delta_neg = tf.cast(delta < 0., dtype=tf.float32) + # This stop_gradient is very important, do not remove + weight = tf.stop_gradient(tf.abs(tau - delta_neg)) + + # loss + loss = huber(delta, self._huber_param) * weight + loss = tf.reduce_sum(loss * probs_t[:, None, :], 2) + + # Have not been able to get quite as good performance with mean vs. sum + loss = tf.reduce_sum(loss, -1) + return loss diff --git a/acme/acme/tf/losses/r2d2.py b/acme/acme/tf/losses/r2d2.py new file mode 100644 index 00000000..7c82868d --- /dev/null +++ b/acme/acme/tf/losses/r2d2.py @@ -0,0 +1,179 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Loss functions for R2D2.""" + +from typing import Iterable, NamedTuple, Sequence + +import tensorflow as tf +import trfl + + +class LossCoreExtra(NamedTuple): + targets: tf.Tensor + errors: tf.Tensor + + +def transformed_n_step_loss( + qs: tf.Tensor, + targnet_qs: tf.Tensor, + actions: tf.Tensor, + rewards: tf.Tensor, + pcontinues: tf.Tensor, + target_policy_probs: tf.Tensor, + bootstrap_n: int, + stop_targnet_gradients: bool = True, + name: str = 'transformed_n_step_loss', +) -> trfl.base_ops.LossOutput: + """Helper function for computing transformed loss on sequences. + + Args: + qs: 3-D tensor corresponding to the Q-values to be learned. Shape is [T+1, + B, A]. + targnet_qs: Like `qs`, but in the target network setting, these values + should be computed by the target network. Shape is [T+1, B, A]. + actions: 2-D tensor holding the indices of actions executed during the + transition that corresponds to each major index. Shape is [T+1, B]. + rewards: 2-D tensor holding rewards received during the transition that + corresponds to each major index. Shape is [T, B]. + pcontinues: 2-D tensor holding pcontinue values received during the + transition that corresponds to each major index. Shape is [T, B]. + target_policy_probs: 3-D tensor holding per-action policy probabilities for + the states encountered just before taking the transitions that correspond + to each major index, according to the target policy (i.e. the policy we + wish to learn). For standard Q-learning the probabilities should form a + one-hot vector over actions where the nonzero index corresponds to the max + Q. Shape is [T+1, B, A]. + bootstrap_n: Transition length for N-step bootstrapping. + stop_targnet_gradients: `bool` indicating whether to apply tf.stop_gradients + to the target values. This should usually be True. + name: name to prefix ops created by this function. + + Returns: + a tuple of: + * `loss`: the transformed Q-learning loss summed over `T`. + * `LossCoreExtra`: namedtuple containing the fields `targets` and `errors`. + """ + + with tf.name_scope(name): + # Require correct tensor ranks---as long as we have shape information + # available to check. If there isn't any, we print a warning. + def check_rank(tensors: Iterable[tf.Tensor], ranks: Sequence[int]): + for i, (tensor, rank) in enumerate(zip(tensors, ranks)): + if tensor.get_shape(): + trfl.assert_rank_and_shape_compatibility([tensor], rank) + else: + raise ValueError( + f'Tensor "{tensor.name}", which was offered as transformed_n_step_loss' + f'parameter {i+1}, has no rank at construction time, so cannot verify' + f'that it has the necessary rank of {rank}') + + check_rank( + [qs, targnet_qs, actions, rewards, pcontinues, target_policy_probs], + [3, 3, 2, 2, 2, 3]) + + # Construct arguments to compute bootstrap target. + a_tm1 = actions[:-1] # (0:T) x B + r_t, pcont_t = rewards, pcontinues # (1:T+1) x B + q_tm1 = qs[:-1] # (0:T) x B x A + target_policy_t = target_policy_probs[1:] # (1:T+1) x B x A + targnet_q_t = targnet_qs[1:] # (1:T+1) x B x A + + bootstrap_value = tf.reduce_sum( + target_policy_t * _signed_parabolic_tx(targnet_q_t), -1) + target = _compute_n_step_sequence_targets( + r_t=r_t, + pcont_t=pcont_t, + bootstrap_value=bootstrap_value, + n=bootstrap_n) + + if stop_targnet_gradients: + target = tf.stop_gradient(target) + + # tx/inv_tx may result in numerical instabilities so mask any NaNs. + finite_mask = tf.math.is_finite(target) + target = tf.where(finite_mask, target, tf.zeros_like(target)) + + qa_tm1 = trfl.batched_index(q_tm1, a_tm1) + errors = qa_tm1 - _signed_hyperbolic_tx(target) + + # Only compute n-step errors w.r.t. finite targets. + errors = tf.where(finite_mask, errors, tf.zeros_like(errors)) + + # Sum over time dimension. + loss = 0.5 * tf.reduce_sum(tf.square(errors), axis=0) + + return trfl.base_ops.LossOutput( + loss, LossCoreExtra(targets=target, errors=errors)) + + +def _compute_n_step_sequence_targets( + r_t: tf.Tensor, + pcont_t: tf.Tensor, + bootstrap_value: tf.Tensor, + n: int, +) -> tf.Tensor: + """Computes n-step bootstrapped returns over a sequence. + + Args: + r_t: 2-D tensor of shape [T, B] corresponding to rewards. + pcont_t: 2-D tensor of shape [T, B] corresponding to pcontinues. + bootstrap_value: 2-D tensor of shape [T, B] corresponding to bootstrap + values. + n: number of steps over which to accumulate reward before bootstrapping. + + Returns: + 2-D tensor of shape [T, B] corresponding to bootstrapped returns. + """ + time_size, batch_size = r_t.shape.as_list() + + # Pad r_t and pcont_t so we can use static slice shapes in scan. + r_t = tf.concat([r_t, tf.zeros((n - 1, batch_size))], 0) + pcont_t = tf.concat([pcont_t, tf.ones((n - 1, batch_size))], 0) + + # We need to use tf.slice with static shapes for TPU compatibility. + def _slice(tensor, index, size): + return tf.slice(tensor, [index, 0], [size, batch_size]) + + # Construct correct bootstrap targets for each time slice t, which are exactly + # the target values at timestep min(t+n-1, time_size-1). + last_bootstrap_value = _slice(bootstrap_value, time_size - 1, 1) + if time_size > n - 1: + full_bootstrap_steps = [_slice(bootstrap_value, n - 1, time_size - (n - 1))] + truncated_bootstrap_steps = [last_bootstrap_value] * (n - 1) + else: + # Only truncated steps, since n > time_size. + full_bootstrap_steps = [] + truncated_bootstrap_steps = [last_bootstrap_value] * time_size + bootstrap_value = tf.concat(full_bootstrap_steps + truncated_bootstrap_steps, + 0) + + # Iterate backwards for n steps to construct n-step return targets. + targets = bootstrap_value + for i in range(n - 1, -1, -1): + this_pcont_t = _slice(pcont_t, i, time_size) + this_r_t = _slice(r_t, i, time_size) + targets = this_r_t + this_pcont_t * targets + return targets + + +def _signed_hyperbolic_tx(x: tf.Tensor, eps: float = 1e-3) -> tf.Tensor: + """Signed hyperbolic transform, inverse of signed_parabolic.""" + return tf.sign(x) * (tf.sqrt(abs(x) + 1) - 1) + eps * x + + +def _signed_parabolic_tx(x: tf.Tensor, eps: float = 1e-3) -> tf.Tensor: + """Signed parabolic transform, inverse of signed_hyperbolic.""" + z = tf.sqrt(1 + 4 * eps * (eps + 1 + abs(x))) / 2 / eps - 1 / 2 / eps + return tf.sign(x) * (tf.square(z) - 1) diff --git a/acme/acme/tf/networks/__init__.py b/acme/acme/tf/networks/__init__.py new file mode 100644 index 00000000..67242196 --- /dev/null +++ b/acme/acme/tf/networks/__init__.py @@ -0,0 +1,66 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Useful network definitions.""" + +from acme.tf.networks.atari import AtariTorso +from acme.tf.networks.atari import DeepIMPALAAtariNetwork +from acme.tf.networks.atari import DQNAtariNetwork +from acme.tf.networks.atari import IMPALAAtariNetwork +from acme.tf.networks.atari import R2D2AtariNetwork +from acme.tf.networks.base import DistributionalModule +from acme.tf.networks.base import Module +from acme.tf.networks.base import RNNCore +from acme.tf.networks.continuous import LayerNormAndResidualMLP +from acme.tf.networks.continuous import LayerNormMLP +from acme.tf.networks.continuous import NearZeroInitializedLinear +from acme.tf.networks.discrete import DiscreteFilteredQNetwork +from acme.tf.networks.distributional import ApproximateMode +from acme.tf.networks.distributional import DiscreteValuedHead +from acme.tf.networks.distributional import MultivariateGaussianMixture +from acme.tf.networks.distributional import MultivariateNormalDiagHead +from acme.tf.networks.distributional import UnivariateGaussianMixture +from acme.tf.networks.distributions import DiscreteValuedDistribution +from acme.tf.networks.duelling import DuellingMLP +from acme.tf.networks.masked_epsilon_greedy import NetworkWithMaskedEpsilonGreedy +from acme.tf.networks.multihead import Multihead +from acme.tf.networks.multiplexers import CriticMultiplexer +from acme.tf.networks.noise import ClippedGaussian +from acme.tf.networks.policy_value import PolicyValueHead +from acme.tf.networks.recurrence import CriticDeepRNN +from acme.tf.networks.recurrence import DeepRNN +from acme.tf.networks.recurrence import LSTM +from acme.tf.networks.recurrence import RecurrentExpQWeightedPolicy +from acme.tf.networks.rescaling import ClipToSpec +from acme.tf.networks.rescaling import RescaleToSpec +from acme.tf.networks.rescaling import TanhToSpec +from acme.tf.networks.stochastic import ExpQWeightedPolicy +from acme.tf.networks.stochastic import StochasticMeanHead +from acme.tf.networks.stochastic import StochasticModeHead +from acme.tf.networks.stochastic import StochasticSamplingHead +from acme.tf.networks.vision import DrQTorso +from acme.tf.networks.vision import ResNetTorso + +# For backwards compatibility. +GaussianMixtureHead = UnivariateGaussianMixture + +try: + # pylint: disable=g-bad-import-order,g-import-not-at-top + from acme.tf.networks.legal_actions import MaskedSequential + from acme.tf.networks.legal_actions import EpsilonGreedy +except ImportError: + pass + +# Internal imports. +from acme.tf.networks.quantile import IQNNetwork diff --git a/acme/acme/tf/networks/atari.py b/acme/acme/tf/networks/atari.py new file mode 100644 index 00000000..2b722e47 --- /dev/null +++ b/acme/acme/tf/networks/atari.py @@ -0,0 +1,191 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Commonly-used networks for running on Atari.""" + +from typing import Optional, Tuple + +from acme.tf.networks import base +from acme.tf.networks import duelling +from acme.tf.networks import embedding +from acme.tf.networks import policy_value +from acme.tf.networks import recurrence +from acme.tf.networks import vision +from acme.wrappers import observation_action_reward + +import sonnet as snt +import tensorflow as tf + +Images = tf.Tensor +QValues = tf.Tensor +Logits = tf.Tensor +Value = tf.Tensor + + +class AtariTorso(base.Module): + """Simple convolutional stack commonly used for Atari.""" + + def __init__(self): + super().__init__(name='atari_torso') + self._network = snt.Sequential([ + snt.Conv2D(32, [8, 8], [4, 4]), + tf.nn.relu, + snt.Conv2D(64, [4, 4], [2, 2]), + tf.nn.relu, + snt.Conv2D(64, [3, 3], [1, 1]), + tf.nn.relu, + snt.Flatten(), + ]) + + def __call__(self, inputs: Images) -> tf.Tensor: + return self._network(inputs) + + +class DQNAtariNetwork(base.Module): + """A feed-forward network for use with Ape-X DQN. + + See https://arxiv.org/pdf/1803.00933.pdf for more information. + """ + + def __init__(self, num_actions: int): + super().__init__(name='dqn_atari_network') + self._network = snt.Sequential([ + AtariTorso(), + duelling.DuellingMLP(num_actions, hidden_sizes=[512]), + ]) + + def __call__(self, inputs: Images) -> QValues: + return self._network(inputs) + + +class R2D2AtariNetwork(base.RNNCore): + """A recurrent network for use with R2D2. + + See https://openreview.net/forum?id=r1lyTjAqYX for more information. + """ + + def __init__(self, num_actions: int, core: Optional[base.RNNCore] = None): + super().__init__(name='r2d2_atari_network') + self._embed = embedding.OAREmbedding( + torso=AtariTorso(), num_actions=num_actions) + self._core = core if core is not None else recurrence.LSTM(512) + self._head = duelling.DuellingMLP(num_actions, hidden_sizes=[512]) + + def __call__( + self, + inputs: observation_action_reward.OAR, + state: base.State, + ) -> Tuple[QValues, base.State]: + + embeddings = self._embed(inputs) + embeddings, new_state = self._core(embeddings, state) + action_values = self._head(embeddings) # [B, A] + + return action_values, new_state + + # TODO(b/171287329): Figure out why return type annotation causes error. + def initial_state(self, batch_size: int, **unused_kwargs) -> base.State: # pytype: disable=invalid-annotation + return self._core.initial_state(batch_size) + + def unroll( + self, + inputs: observation_action_reward.OAR, + state: base.State, + sequence_length: int, + ) -> Tuple[QValues, base.State]: + """Efficient unroll that applies embeddings, MLP, & convnet in one pass.""" + embeddings = snt.BatchApply(self._embed)(inputs) # [T, B, D+A+1] + embeddings, new_state = self._core.unroll(embeddings, state, + sequence_length) + action_values = snt.BatchApply(self._head)(embeddings) + + return action_values, new_state + + +class IMPALAAtariNetwork(snt.RNNCore): + """A recurrent network for use with IMPALA. + + See https://arxiv.org/pdf/1802.01561.pdf for more information. + """ + + def __init__(self, num_actions: int): + super().__init__(name='impala_atari_network') + self._embed = embedding.OAREmbedding( + torso=AtariTorso(), num_actions=num_actions) + self._core = snt.LSTM(256) + self._head = snt.Sequential([ + snt.Linear(256), + tf.nn.relu, + policy_value.PolicyValueHead(num_actions), + ]) + self._num_actions = num_actions + + def __call__( + self, inputs: observation_action_reward.OAR, + state: snt.LSTMState) -> Tuple[Tuple[Logits, Value], snt.LSTMState]: + + embeddings = self._embed(inputs) + embeddings, new_state = self._core(embeddings, state) + logits, value = self._head(embeddings) # [B, A] + + return (logits, value), new_state + + def initial_state(self, batch_size: int, **unused_kwargs) -> snt.LSTMState: + return self._core.initial_state(batch_size) + + +class DeepIMPALAAtariNetwork(base.RNNCore): + """A recurrent network for use with IMPALA. + + See https://arxiv.org/pdf/1802.01561.pdf for more information. + """ + + def __init__(self, num_actions: int): + super().__init__(name='deep_impala_atari_network') + self._embed = embedding.OAREmbedding( + torso=vision.ResNetTorso(), num_actions=num_actions) + self._core = snt.LSTM(256) + self._head = snt.Sequential([ + snt.Linear(256), + tf.nn.relu, + policy_value.PolicyValueHead(num_actions), + ]) + self._num_actions = num_actions + + def __call__( + self, inputs: observation_action_reward.OAR, + state: snt.LSTMState) -> Tuple[Tuple[Logits, Value], snt.LSTMState]: + + embeddings = self._embed(inputs) + embeddings, new_state = self._core(embeddings, state) + logits, value = self._head(embeddings) # [B, A] + + return (logits, value), new_state + + def initial_state(self, batch_size: int, **unused_kwargs) -> snt.LSTMState: + return self._core.initial_state(batch_size) + + def unroll( + self, + inputs: observation_action_reward.OAR, + states: snt.LSTMState, + sequence_length: int, + ) -> Tuple[Tuple[Logits, Value], snt.LSTMState]: + """Efficient unroll that applies embeddings, MLP, & convnet in one pass.""" + embeddings = snt.BatchApply(self._embed)(inputs) # [T, B, D+A+1] + embeddings, new_states = snt.static_unroll(self._core, embeddings, states, + sequence_length) + logits, values = snt.BatchApply(self._head)(embeddings) + + return (logits, values), new_states diff --git a/acme/acme/tf/networks/base.py b/acme/acme/tf/networks/base.py new file mode 100644 index 00000000..ca6c0190 --- /dev/null +++ b/acme/acme/tf/networks/base.py @@ -0,0 +1,66 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convenient base classes for custom networks.""" + +import abc +from typing import Tuple, TypeVar + +from acme import types +import sonnet as snt +import tensorflow_probability as tfp + +State = TypeVar('State') + + +class Module(snt.Module, abc.ABC): + """A base class for module with abstract __call__ method.""" + + @abc.abstractmethod + def __call__(self, *args, **kwargs) -> types.NestedTensor: + """Forward pass of the module.""" + + +class DistributionalModule(snt.Module, abc.ABC): + """A base class for modules that output distributions.""" + + @abc.abstractmethod + def __call__(self, *args, **kwargs) -> tfp.distributions.Distribution: + """Forward pass of the module.""" + + +class RNNCore(snt.RNNCore, abc.ABC): + """An RNN core with a custom `unroll` function.""" + + @abc.abstractmethod + def unroll(self, + inputs: types.NestedTensor, + state: State, + sequence_length: int, + ) -> Tuple[types.NestedTensor, State]: + """A custom function for doing static unrolls over sequences. + + This has the same API as `snt.static_unroll`, but allows the user to specify + their own implementation to take advantage of the structure of the network + for better performance, e.g. by batching the feed-forward pass over the + whole sequence. + + Args: + inputs: A nest of `tf.Tensor` in time-major format. + state: The RNN core state. + sequence_length: How long the static_unroll should go for. + + Returns: + Nested sequence output of RNN, and final state. + """ diff --git a/acme/acme/tf/networks/continuous.py b/acme/acme/tf/networks/continuous.py new file mode 100644 index 00000000..e003ac8d --- /dev/null +++ b/acme/acme/tf/networks/continuous.py @@ -0,0 +1,138 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Networks used in continuous control.""" + +from typing import Callable, Optional, Sequence + +from acme import types +from acme.tf import utils as tf2_utils +from acme.tf.networks import base +import sonnet as snt +import tensorflow as tf + + +def _uniform_initializer(): + return tf.initializers.VarianceScaling( + distribution='uniform', mode='fan_out', scale=0.333) + + +class NearZeroInitializedLinear(snt.Linear): + """Simple linear layer, initialized at near zero weights and zero biases.""" + + def __init__(self, output_size: int, scale: float = 1e-4): + super().__init__(output_size, w_init=tf.initializers.VarianceScaling(scale)) + + +class LayerNormMLP(snt.Module): + """Simple feedforward MLP torso with initial layer-norm. + + This module is an MLP which uses LayerNorm (with a tanh normalizer) on the + first layer and non-linearities (elu) on all but the last remaining layers. + """ + + def __init__(self, + layer_sizes: Sequence[int], + w_init: Optional[snt.initializers.Initializer] = None, + activation: Callable[[tf.Tensor], tf.Tensor] = tf.nn.elu, + activate_final: bool = False): + """Construct the MLP. + + Args: + layer_sizes: a sequence of ints specifying the size of each layer. + w_init: initializer for Linear weights. + activation: activation function to apply between linear layers. Defaults + to ELU. Note! This is different from snt.nets.MLP's default. + activate_final: whether or not to use the activation function on the final + layer of the neural network. + """ + super().__init__(name='feedforward_mlp_torso') + + self._network = snt.Sequential([ + snt.Linear(layer_sizes[0], w_init=w_init or _uniform_initializer()), + snt.LayerNorm( + axis=slice(1, None), create_scale=True, create_offset=True), + tf.nn.tanh, + snt.nets.MLP( + layer_sizes[1:], + w_init=w_init or _uniform_initializer(), + activation=activation, + activate_final=activate_final), + ]) + + def __call__(self, observations: types.Nest) -> tf.Tensor: + """Forwards the policy network.""" + return self._network(tf2_utils.batch_concat(observations)) + + +class ResidualLayernormWrapper(snt.Module): + """Wrapper that applies residual connections and layer norm.""" + + def __init__(self, layer: base.Module): + """Creates the Wrapper Class. + + Args: + layer: module to wrap. + """ + + super().__init__(name='ResidualLayernormWrapper') + self._layer = layer + + self._layer_norm = snt.LayerNorm( + axis=-1, create_scale=True, create_offset=True) + + def __call__(self, inputs: tf.Tensor): + """Returns the result of the residual and layernorm computation. + + Args: + inputs: inputs to the main module. + """ + + # Apply main module. + outputs = self._layer(inputs) + outputs = self._layer_norm(outputs + inputs) + + return outputs + + +class LayerNormAndResidualMLP(snt.Module): + """MLP with residual connections and layer norm. + + An MLP which applies residual connection and layer normalisation every two + linear layers. Similar to Resnet, but with FC layers instead of convolutions. + """ + + def __init__(self, hidden_size: int, num_blocks: int): + """Create the model. + + Args: + hidden_size: width of each hidden layer. + num_blocks: number of blocks, each block being MLP([hidden_size, + hidden_size]) + layer norm + residual connection. + """ + super().__init__(name='LayerNormAndResidualMLP') + + # Create initial MLP layer. + layers = [snt.nets.MLP([hidden_size], w_init=_uniform_initializer())] + + # Follow it up with num_blocks MLPs with layernorm and residual connections. + for _ in range(num_blocks): + mlp = snt.nets.MLP([hidden_size, hidden_size], + w_init=_uniform_initializer()) + layers.append(ResidualLayernormWrapper(mlp)) + + self._module = snt.Sequential(layers) + + def __call__(self, inputs: tf.Tensor): + return self._module(inputs) diff --git a/acme/acme/tf/networks/discrete.py b/acme/acme/tf/networks/discrete.py new file mode 100644 index 00000000..b4134c35 --- /dev/null +++ b/acme/acme/tf/networks/discrete.py @@ -0,0 +1,45 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Networks used in discrete-action agents.""" + +import sonnet as snt +import tensorflow as tf + + +class DiscreteFilteredQNetwork(snt.Module): + """Discrete filtered Q-network. + + This produces filtered Q values according to the method used in the discrete + BCQ algorithm (https://arxiv.org/pdf/1910.01708.pdf - section 4). + """ + + def __init__(self, + g_network: snt.Module, + q_network: snt.Module, + threshold: float): + super().__init__(name='discrete_filtered_qnet') + assert threshold >= 0 and threshold <= 1 + self.g_network = g_network + self.q_network = q_network + self._threshold = threshold + + def __call__(self, o_t: tf.Tensor) -> tf.Tensor: + q_t = self.q_network(o_t) + g_t = tf.nn.softmax(self.g_network(o_t)) + normalized_g_t = g_t / tf.reduce_max(g_t, axis=-1, keepdims=True) + + # Filter actions based on g_network outputs. + min_q = tf.reduce_min(q_t, axis=-1, keepdims=True) + return tf.where(normalized_g_t >= self._threshold, q_t, min_q) diff --git a/acme/acme/tf/networks/distributional.py b/acme/acme/tf/networks/distributional.py new file mode 100644 index 00000000..e96dc29f --- /dev/null +++ b/acme/acme/tf/networks/distributional.py @@ -0,0 +1,314 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Distributional modules: these are modules that return a tfd.Distribution. + +There are useful modules in `acme.networks.stochastic` to either sample or +take the mean of these distributions. +""" + +import types +from typing import Optional, Union +from absl import logging +from acme.tf.networks import distributions as ad +import numpy as np +import sonnet as snt +import tensorflow as tf +import tensorflow_probability as tfp + +tfd = tfp.distributions +snt_init = snt.initializers + +_MIN_SCALE = 1e-4 + + +class DiscreteValuedHead(snt.Module): + """Represents a parameterized discrete valued distribution. + + The returned distribution is essentially a `tfd.Categorical`, but one which + knows its support and so can compute the mean value. + """ + + def __init__(self, + vmin: Union[float, np.ndarray, tf.Tensor], + vmax: Union[float, np.ndarray, tf.Tensor], + num_atoms: int, + w_init: Optional[snt.initializers.Initializer] = None, + b_init: Optional[snt.initializers.Initializer] = None): + """Initialization. + + If vmin and vmax have shape S, this will store the category values as a + Tensor of shape (S*, num_atoms). + + Args: + vmin: Minimum of the value range + vmax: Maximum of the value range + num_atoms: The atom values associated with each bin. + w_init: Initialization for linear layer weights. + b_init: Initialization for linear layer biases. + """ + super().__init__(name='DiscreteValuedHead') + vmin = tf.convert_to_tensor(vmin) + vmax = tf.convert_to_tensor(vmax) + self._values = tf.linspace(vmin, vmax, num_atoms, axis=-1) + self._distributional_layer = snt.Linear(tf.size(self._values), + w_init=w_init, + b_init=b_init) + + def __call__(self, inputs: tf.Tensor) -> tfd.Distribution: + logits = self._distributional_layer(inputs) + logits = tf.reshape(logits, + tf.concat([tf.shape(logits)[:1], # batch size + tf.shape(self._values)], + axis=0)) + values = tf.cast(self._values, logits.dtype) + + return ad.DiscreteValuedDistribution(values=values, logits=logits) + + +class MultivariateNormalDiagHead(snt.Module): + """Module that produces a multivariate normal distribution using tfd.Independent or tfd.MultivariateNormalDiag.""" + + def __init__( + self, + num_dimensions: int, + init_scale: float = 0.3, + min_scale: float = 1e-6, + tanh_mean: bool = False, + fixed_scale: bool = False, + use_tfd_independent: bool = False, + w_init: snt_init.Initializer = tf.initializers.VarianceScaling(1e-4), + b_init: snt_init.Initializer = tf.initializers.Zeros()): + """Initialization. + + Args: + num_dimensions: Number of dimensions of MVN distribution. + init_scale: Initial standard deviation. + min_scale: Minimum standard deviation. + tanh_mean: Whether to transform the mean (via tanh) before passing it to + the distribution. + fixed_scale: Whether to use a fixed variance. + use_tfd_independent: Whether to use tfd.Independent or + tfd.MultivariateNormalDiag class + w_init: Initialization for linear layer weights. + b_init: Initialization for linear layer biases. + """ + super().__init__(name='MultivariateNormalDiagHead') + self._init_scale = init_scale + self._min_scale = min_scale + self._tanh_mean = tanh_mean + self._mean_layer = snt.Linear(num_dimensions, w_init=w_init, b_init=b_init) + self._fixed_scale = fixed_scale + + if not fixed_scale: + self._scale_layer = snt.Linear( + num_dimensions, w_init=w_init, b_init=b_init) + self._use_tfd_independent = use_tfd_independent + + def __call__(self, inputs: tf.Tensor) -> tfd.Distribution: + zero = tf.constant(0, dtype=inputs.dtype) + mean = self._mean_layer(inputs) + + if self._fixed_scale: + scale = tf.ones_like(mean) * self._init_scale + else: + scale = tf.nn.softplus(self._scale_layer(inputs)) + scale *= self._init_scale / tf.nn.softplus(zero) + scale += self._min_scale + + # Maybe transform the mean. + if self._tanh_mean: + mean = tf.tanh(mean) + + if self._use_tfd_independent: + dist = tfd.Independent(tfd.Normal(loc=mean, scale=scale)) + else: + dist = tfd.MultivariateNormalDiag(loc=mean, scale_diag=scale) + + return dist + + +class GaussianMixture(snt.Module): + """Module that outputs a Gaussian Mixture Distribution.""" + + def __init__(self, + num_dimensions: int, + num_components: int, + multivariate: bool, + init_scale: Optional[float] = None, + name: str = 'GaussianMixture'): + """Initialization. + + Args: + num_dimensions: dimensionality of the output distribution + num_components: number of mixture components. + multivariate: whether the resulting distribution is multivariate or not. + init_scale: the initial scale for the Gaussian mixture components. + name: name of the module passed to snt.Module parent class. + """ + super().__init__(name=name) + + self._num_dimensions = num_dimensions + self._num_components = num_components + self._multivariate = multivariate + + if init_scale is not None: + self._scale_factor = init_scale / tf.nn.softplus(0.) + else: + self._scale_factor = 1.0 # Corresponds to init_scale = softplus(0). + + # Define the weight initializer. + w_init = tf.initializers.VarianceScaling(1e-5) + + # Create a layer that outputs the unnormalized log-weights. + if self._multivariate: + logits_size = self._num_components + else: + logits_size = self._num_dimensions * self._num_components + self._logit_layer = snt.Linear(logits_size, w_init=w_init) + + # Create two layers that outputs a location and a scale, respectively, for + # each dimension and each component. + self._loc_layer = snt.Linear( + self._num_dimensions * self._num_components, w_init=w_init) + self._scale_layer = snt.Linear( + self._num_dimensions * self._num_components, w_init=w_init) + + def __call__(self, + inputs: tf.Tensor, + low_noise_policy: bool = False) -> tfd.Distribution: + """Run the networks through inputs. + + Args: + inputs: hidden activations of the policy network body. + low_noise_policy: whether to set vanishingly small scales for each + component. If this flag is set to True, the policy is effectively run + without Gaussian noise. + + Returns: + Mixture Gaussian distribution. + """ + + # Compute logits, locs, and scales if necessary. + logits = self._logit_layer(inputs) + locs = self._loc_layer(inputs) + + # When a low_noise_policy is requested, set the scales to its minimum value. + if low_noise_policy: + scales = tf.fill(locs.shape, _MIN_SCALE) + else: + scales = self._scale_layer(inputs) + scales = self._scale_factor * tf.nn.softplus(scales) + _MIN_SCALE + + if self._multivariate: + shape = [-1, self._num_components, self._num_dimensions] + # Reshape the mixture's location and scale parameters appropriately. + locs = tf.reshape(locs, shape) + scales = tf.reshape(scales, shape) + # In this case, no need to reshape logits as they are in the correct shape + # already, namely [batch_size, num_components]. + components_distribution = tfd.MultivariateNormalDiag( + loc=locs, scale_diag=scales) + else: + shape = [-1, self._num_dimensions, self._num_components] + # Reshape the mixture's location and scale parameters appropriately. + locs = tf.reshape(locs, shape) + scales = tf.reshape(scales, shape) + components_distribution = tfd.Normal(loc=locs, scale=scales) + logits = tf.reshape(logits, shape) + + # Create the mixture distribution. + distribution = tfd.MixtureSameFamily( + mixture_distribution=tfd.Categorical(logits=logits), + components_distribution=components_distribution) + + if not self._multivariate: + distribution = tfd.Independent(distribution) + + return distribution + + +class UnivariateGaussianMixture(GaussianMixture): + """Head which outputs a Mixture of Gaussians Distribution.""" + + def __init__(self, + num_dimensions: int, + num_components: int = 5, + init_scale: Optional[float] = None, + num_mixtures: Optional[int] = None): + """Create an mixture of Gaussian actor head. + + Args: + num_dimensions: dimensionality of the output distribution. Each dimension + is going to be an independent 1d GMM model. + num_components: number of mixture components. + init_scale: the initial scale for the Gaussian mixture components. + num_mixtures: deprecated argument which overwrites num_components. + """ + if num_mixtures is not None: + logging.warning("""the num_mixtures parameter has been deprecated; use + num_components instead; the value of num_components is being + ignored""") + num_components = num_mixtures + super().__init__(num_dimensions=num_dimensions, + num_components=num_components, + multivariate=False, + init_scale=init_scale, + name='UnivariateGaussianMixture') + + +class MultivariateGaussianMixture(GaussianMixture): + """Head which outputs a mixture of multivariate Gaussians distribution.""" + + def __init__(self, + num_dimensions: int, + num_components: int = 5, + init_scale: Optional[float] = None): + """Initialization. + + Args: + num_dimensions: dimensionality of the output distribution + (also the dimensionality of the multivariate Gaussian model). + num_components: number of mixture components. + init_scale: the initial scale for the Gaussian mixture components. + """ + super().__init__(num_dimensions=num_dimensions, + num_components=num_components, + multivariate=True, + init_scale=init_scale, + name='MultivariateGaussianMixture') + + +class ApproximateMode(snt.Module): + """Override the mode function of the distribution. + + For non-constant Jacobian transformed distributions, the mode is non-trivial + to compute, so for these distributions the mode function is not supported in + TFP. A frequently used approximation is to forward transform the mode of the + untransformed distribution. + + Otherwise (an untransformed distribution or a transformed distribution with a + constant Jacobian), this is a no-op. + """ + + def __call__(self, inputs: tfd.Distribution) -> tfd.Distribution: + if isinstance(inputs, tfd.TransformedDistribution): + if not inputs.bijector.is_constant_jacobian: + def _mode(self, **kwargs): + distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs) + x = self.distribution.mode(**distribution_kwargs) + y = self.bijector.forward(x, **bijector_kwargs) + return y + inputs._mode = types.MethodType(_mode, inputs) + return inputs diff --git a/acme/acme/tf/networks/distributional_test.py b/acme/acme/tf/networks/distributional_test.py new file mode 100644 index 00000000..448dc2de --- /dev/null +++ b/acme/acme/tf/networks/distributional_test.py @@ -0,0 +1,68 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for acme.tf.networks.distributional.""" + +from acme.tf.networks import distributional +import numpy as np +from numpy import testing as npt + +from absl.testing import absltest +from absl.testing import parameterized + + +class DistributionalTest(parameterized.TestCase): + + @parameterized.parameters( + ((2, 3), (), (), 5, (2, 5)), + ((2, 3), (4, 1), (1, 5), 6, (2, 4, 5, 6)), + ) + def test_discrete_valued_head( + self, + input_shape, + vmin_shape, + vmax_shape, + num_atoms, + expected_logits_shape): + + vmin = np.zeros(vmin_shape, float) + vmax = np.ones(vmax_shape, float) + head = distributional.DiscreteValuedHead( + vmin=vmin, + vmax=vmax, + num_atoms=num_atoms) + input_array = np.zeros(input_shape, dtype=float) + output_distribution = head(input_array) + self.assertEqual(output_distribution.logits_parameter().shape, + expected_logits_shape) + + values = output_distribution._values + + # Can't do assert_allclose(values[..., 0], vmin), because the args may + # have broadcast-compatible but unequal shapes. Do the following instead: + npt.assert_allclose(values[..., 0] - vmin, np.zeros_like(values[..., 0])) + npt.assert_allclose(values[..., -1] - vmax, np.zeros_like(values[..., -1])) + + # Check that values are monotonically increasing. + intervals = values[..., 1:] - values[..., :-1] + npt.assert_array_less(np.zeros_like(intervals), intervals) + + # Check that the values are equally spaced. + npt.assert_allclose(intervals[..., 1:] - intervals[..., :1], + np.zeros_like(intervals[..., 1:]), + atol=1e-7) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/tf/networks/distributions.py b/acme/acme/tf/networks/distributions.py new file mode 100644 index 00000000..bda8e973 --- /dev/null +++ b/acme/acme/tf/networks/distributions.py @@ -0,0 +1,107 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Distributions, for use in acme/networks/distributional.py.""" + +from typing import Optional +import tensorflow as tf +import tensorflow_probability as tfp + +tfd = tfp.distributions + + +@tfp.experimental.register_composite +class DiscreteValuedDistribution(tfd.Categorical): + """This is a generalization of a categorical distribution. + + The support for the DiscreteValued distribution can be any real valued range, + whereas the categorical distribution has support [0, n_categories - 1] or + [1, n_categories]. This generalization allows us to take the mean of the + distribution over its support. + """ + + def __init__(self, + values: tf.Tensor, + logits: Optional[tf.Tensor] = None, + probs: Optional[tf.Tensor] = None, + name: str = 'DiscreteValuedDistribution'): + """Initialization. + + Args: + values: Values making up support of the distribution. Should have a shape + compatible with logits. + logits: An N-D Tensor, N >= 1, representing the log probabilities of a set + of Categorical distributions. The first N - 1 dimensions index into a + batch of independent distributions and the last dimension indexes into + the classes. + probs: An N-D Tensor, N >= 1, representing the probabilities of a set of + Categorical distributions. The first N - 1 dimensions index into a batch + of independent distributions and the last dimension represents a vector + of probabilities for each class. Only one of logits or probs should be + passed in. + name: Name of the distribution object. + """ + self._values = tf.convert_to_tensor(values) + shape_strings = [f'D{i}' for i, _ in enumerate(values.shape)] + + if logits is not None: + logits = tf.convert_to_tensor(logits) + tf.debugging.assert_shapes([(values, shape_strings), + (logits, [..., *shape_strings])]) + if probs is not None: + probs = tf.convert_to_tensor(probs) + tf.debugging.assert_shapes([(values, shape_strings), + (probs, [..., *shape_strings])]) + + super().__init__(logits=logits, probs=probs, name=name) + + self._parameters = dict(values=values, + logits=logits, + probs=probs, + name=name) + + @property + def values(self) -> tf.Tensor: + return self._values + + @classmethod + def _parameter_properties(cls, dtype, num_classes=None): + return dict( + values=tfp.util.ParameterProperties(event_ndims=None), + logits=tfp.util.ParameterProperties( + event_ndims=lambda self: self.values.shape.rank), + probs=tfp.util.ParameterProperties( + event_ndims=lambda self: self.values.shape.rank, + is_preferred=False)) + + def _sample_n(self, n, seed=None) -> tf.Tensor: + indices = super()._sample_n(n, seed=seed) + return tf.gather(self.values, indices, axis=-1) + + def _mean(self) -> tf.Tensor: + """Overrides the Categorical mean by incorporating category values.""" + return tf.reduce_sum(self.probs_parameter() * self.values, axis=-1) + + def _variance(self) -> tf.Tensor: + """Overrides the Categorical variance by incorporating category values.""" + dist_squared = tf.square(tf.expand_dims(self.mean(), -1) - self.values) + return tf.reduce_sum(self.probs_parameter() * dist_squared, axis=-1) + + def _event_shape(self): + # Omit the atoms axis, to return just the shape of a single (i.e. unbatched) + # sample value. + return self._values.shape[:-1] + + def _event_shape_tensor(self): + return tf.shape(self._values)[:-1] diff --git a/acme/acme/tf/networks/distributions_test.py b/acme/acme/tf/networks/distributions_test.py new file mode 100644 index 00000000..211eed8e --- /dev/null +++ b/acme/acme/tf/networks/distributions_test.py @@ -0,0 +1,67 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for acme.tf.networks.distributions.""" + +from acme.tf.networks import distributions +import numpy as np +from numpy import testing as npt + +from absl.testing import absltest +from absl.testing import parameterized + + +class DiscreteValuedDistributionTest(parameterized.TestCase): + + @parameterized.parameters( + ((), (), 5), + ((2,), (), 5), + ((), (3, 4), 5), + ((2,), (3, 4), 5), + ((2, 6), (3, 4), 5), + ) + def test_constructor(self, batch_shape, event_shape, num_values): + logits_shape = batch_shape + event_shape + (num_values,) + logits_size = np.prod(logits_shape) + logits = np.arange(logits_size, dtype=float).reshape(logits_shape) + values = np.linspace(start=-np.ones(event_shape, dtype=float), + stop=np.ones(event_shape, dtype=float), + num=num_values, + axis=-1) + distribution = distributions.DiscreteValuedDistribution(values=values, + logits=logits) + + # Check batch and event shapes. + self.assertEqual(distribution.batch_shape, batch_shape) + self.assertEqual(distribution.event_shape, event_shape) + self.assertEqual(distribution.logits_parameter().shape.as_list(), + list(logits.shape)) + self.assertEqual(distribution.logits_parameter().shape.as_list()[-1], + logits.shape[-1]) + + # Test slicing + if len(batch_shape) == 1: + slice_0_logits = distribution[1:3].logits_parameter().numpy() + expected_slice_0_logits = distribution.logits_parameter().numpy()[1:3] + npt.assert_allclose(slice_0_logits, expected_slice_0_logits) + elif len(batch_shape) == 2: + slice_logits = distribution[0, 1:3].logits_parameter().numpy() + expected_slice_logits = distribution.logits_parameter().numpy()[0, 1:3] + npt.assert_allclose(slice_logits, expected_slice_logits) + else: + assert not batch_shape + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/tf/networks/duelling.py b/acme/acme/tf/networks/duelling.py new file mode 100644 index 00000000..ed891f5a --- /dev/null +++ b/acme/acme/tf/networks/duelling.py @@ -0,0 +1,58 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A duelling network architecture, as described in [0]. + +[0] https://arxiv.org/abs/1511.06581 +""" + +from typing import Sequence + +import sonnet as snt +import tensorflow as tf + + +class DuellingMLP(snt.Module): + """A Duelling MLP Q-network.""" + + def __init__( + self, + num_actions: int, + hidden_sizes: Sequence[int], + ): + super().__init__(name='duelling_q_network') + + self._value_mlp = snt.nets.MLP([*hidden_sizes, 1]) + self._advantage_mlp = snt.nets.MLP([*hidden_sizes, num_actions]) + + def __call__(self, inputs: tf.Tensor) -> tf.Tensor: + """Forward pass of the duelling network. + + Args: + inputs: 2-D tensor of shape [batch_size, embedding_size]. + + Returns: + q_values: 2-D tensor of action values of shape [batch_size, num_actions] + """ + + # Compute value & advantage for duelling. + value = self._value_mlp(inputs) # [B, 1] + advantages = self._advantage_mlp(inputs) # [B, A] + + # Advantages have zero mean. + advantages -= tf.reduce_mean(advantages, axis=-1, keepdims=True) # [B, A] + + q_values = value + advantages # [B, A] + + return q_values diff --git a/acme/acme/tf/networks/embedding.py b/acme/acme/tf/networks/embedding.py new file mode 100644 index 00000000..c11b765d --- /dev/null +++ b/acme/acme/tf/networks/embedding.py @@ -0,0 +1,45 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Modules for computing custom embeddings.""" + +from acme.tf.networks import base +from acme.wrappers import observation_action_reward + +import sonnet as snt +import tensorflow as tf + + +class OAREmbedding(snt.Module): + """Module for embedding (observation, action, reward) inputs together.""" + + def __init__(self, torso: base.Module, num_actions: int): + super().__init__(name='oar_embedding') + self._num_actions = num_actions + self._torso = torso + + def __call__(self, inputs: observation_action_reward.OAR) -> tf.Tensor: + """Embed each of the (observation, action, reward) inputs & concatenate.""" + + # Add dummy trailing dimension to rewards if necessary. + if len(inputs.reward.shape.dims) == 1: + inputs = inputs._replace(reward=tf.expand_dims(inputs.reward, axis=-1)) + + features = self._torso(inputs.observation) # [T?, B, D] + action = tf.one_hot(inputs.action, depth=self._num_actions) # [T?, B, A] + reward = tf.nn.tanh(inputs.reward) # [T?, B, 1] + + embedding = tf.concat([features, action, reward], axis=-1) # [T?, B, D+A+1] + + return embedding diff --git a/acme/acme/tf/networks/legal_actions.py b/acme/acme/tf/networks/legal_actions.py new file mode 100644 index 00000000..84273905 --- /dev/null +++ b/acme/acme/tf/networks/legal_actions.py @@ -0,0 +1,127 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Networks used for handling illegal actions.""" + +from typing import Any, Callable, Iterable, Optional, Union + +# pytype: disable=import-error +from acme.wrappers import open_spiel_wrapper +# pytype: enable=import-error + +import numpy as np +import sonnet as snt +import tensorflow as tf +import tensorflow_probability as tfp + +tfd = tfp.distributions + + +class MaskedSequential(snt.Module): + """Applies a legal actions mask to a linear chain of modules / callables. + + It is assumed the trailing dimension of the final layer (representing + action values) is the same as the trailing dimension of legal_actions. + """ + + def __init__(self, + layers: Optional[Iterable[Callable[..., Any]]] = None, + name: str = 'MaskedSequential'): + super().__init__(name=name) + self._layers = list(layers) if layers is not None else [] + self._illegal_action_penalty = -1e9 + # Note: illegal_action_penalty cannot be -np.inf because trfl's qlearning + # ops utilize a batched_index function that returns NaN whenever -np.inf + # is present among action values. + + def __call__(self, inputs: open_spiel_wrapper.OLT) -> tf.Tensor: + # Extract observation, legal actions, and terminal + outputs = inputs.observation + legal_actions = inputs.legal_actions + terminal = inputs.terminal + + for mod in self._layers: + outputs = mod(outputs) + + # Apply legal actions mask + outputs = tf.where(tf.equal(legal_actions, 1), outputs, + tf.fill(tf.shape(outputs), self._illegal_action_penalty)) + + # When computing the Q-learning target (r_t + d_t * max q_t) we need to + # ensure max q_t = 0 in terminal states. + outputs = tf.where(tf.equal(terminal, 1), tf.zeros_like(outputs), outputs) + + return outputs + + +# FIXME: Add functionality to support decaying epsilon parameter. +# FIXME: This is a modified version of trfl's epsilon_greedy() which +# incorporates code from the bug fix described here +# https://github.com/deepmind/trfl/pull/28 +class EpsilonGreedy(snt.Module): + """Computes an epsilon-greedy distribution over actions. + + This policy does the following: + - With probability 1 - epsilon, take the action corresponding to the highest + action value, breaking ties uniformly at random. + - With probability epsilon, take an action uniformly at random. + """ + + def __init__(self, + epsilon: Union[tf.Tensor, float], + threshold: float, + name: str = 'EpsilonGreedy'): + """Initialize the policy. + + Args: + epsilon: Exploratory param with value between 0 and 1. + threshold: Action values must exceed this value to qualify as a legal + action and possibly be selected by the policy. + name: Name of the network. + + Returns: + policy: tfp.distributions.Categorical distribution representing the + policy. + """ + super().__init__(name=name) + self._epsilon = tf.Variable(epsilon, trainable=False) + self._threshold = threshold + + def __call__(self, action_values: tf.Tensor) -> tfd.Categorical: + legal_actions_mask = tf.where( + tf.math.less_equal(action_values, self._threshold), + tf.fill(tf.shape(action_values), 0.), + tf.fill(tf.shape(action_values), 1.)) + + # Dithering action distribution. + dither_probs = 1 / tf.reduce_sum(legal_actions_mask, axis=-1, + keepdims=True) * legal_actions_mask + masked_action_values = tf.where(tf.equal(legal_actions_mask, 1), + action_values, + tf.fill(tf.shape(action_values), -np.inf)) + # Greedy action distribution, breaking ties uniformly at random. + max_value = tf.reduce_max(masked_action_values, axis=-1, keepdims=True) + greedy_probs = tf.cast( + tf.equal(action_values * legal_actions_mask, max_value), + action_values.dtype) + + greedy_probs /= tf.reduce_sum(greedy_probs, axis=-1, keepdims=True) + + # Epsilon-greedy action distribution. + probs = self._epsilon * dither_probs + (1 - self._epsilon) * greedy_probs + + # Make the policy object. + policy = tfd.Categorical(probs=probs) + + return policy diff --git a/acme/acme/tf/networks/masked_epsilon_greedy.py b/acme/acme/tf/networks/masked_epsilon_greedy.py new file mode 100644 index 00000000..bf707dac --- /dev/null +++ b/acme/acme/tf/networks/masked_epsilon_greedy.py @@ -0,0 +1,55 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrapping trfl epsilon_greedy with legal action masking.""" + +from typing import Optional, Mapping, Union + +import sonnet as snt +import tensorflow as tf +import trfl + + +class NetworkWithMaskedEpsilonGreedy(snt.Module): + """Epsilon greedy sampling with action masking on network outputs.""" + + def __init__(self, + network: snt.Module, + epsilon: Optional[tf.Tensor] = None): + """Initialize the network and epsilon. + + Usage: + Wrap an observation in a dictionary in your environment as follows: + + observation <-- {"your_key_for_observation": observation, + "legal_actions_mask": your_action_mask_tensor} + + and update your network to use 'observation["your_key_for_observation"]' + rather than 'observation'. + + Args: + network: the online Q network (the one being optimized) + epsilon: probability of taking a random action. + """ + super().__init__() + self._network = network + self._epsilon = epsilon + + def __call__( + self, observation: Union[Mapping[str, tf.Tensor], + tf.Tensor]) -> tf.Tensor: + q = self._network(observation) + return trfl.epsilon_greedy( + q, epsilon=self._epsilon, + legal_actions_mask=observation['legal_actions_mask']).sample() diff --git a/acme/acme/tf/networks/multihead.py b/acme/acme/tf/networks/multihead.py new file mode 100644 index 00000000..49a731d8 --- /dev/null +++ b/acme/acme/tf/networks/multihead.py @@ -0,0 +1,52 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multihead networks apply separate networks to the input.""" + +from typing import Callable, Union, Sequence + +from acme import types + +import sonnet as snt +import tensorflow as tf +import tensorflow_probability as tfp + +tfd = tfp.distributions +TensorTransformation = Union[snt.Module, Callable[[types.NestedTensor], + tf.Tensor]] + + +class Multihead(snt.Module): + """Multi-head network module. + + This takes as input a list of N `network_heads`, and returns another network + whose output is the stacked outputs of each of these network heads separately + applied to the module input. The dimension of the output is [..., N]. + """ + + def __init__(self, + network_heads: Sequence[TensorTransformation]): + if not network_heads: + raise ValueError('Must specify non-empty, non-None critic_network_heads.') + self._network_heads = network_heads + super().__init__(name='multihead') + + def __call__(self, + inputs: tf.Tensor) -> Union[tf.Tensor, Sequence[tf.Tensor]]: + outputs = [network_head(inputs) for network_head in self._network_heads] + if isinstance(outputs[0], tfd.Distribution): + # Cannot stack distributions + return outputs + outputs = tf.stack(outputs, axis=-1) + return outputs diff --git a/acme/acme/tf/networks/multiplexers.py b/acme/acme/tf/networks/multiplexers.py new file mode 100644 index 00000000..32815373 --- /dev/null +++ b/acme/acme/tf/networks/multiplexers.py @@ -0,0 +1,79 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multiplexers are networks that take multiple inputs.""" + +from typing import Callable, Optional, Union + +from acme import types +from acme.tf import utils as tf2_utils + +import sonnet as snt +import tensorflow as tf +import tensorflow_probability as tfp + +tfd = tfp.distributions +TensorTransformation = Union[snt.Module, Callable[[types.NestedTensor], + tf.Tensor]] + + +class CriticMultiplexer(snt.Module): + """Module connecting a critic torso to (transformed) observations/actions. + + This takes as input a `critic_network`, an `observation_network`, and an + `action_network` and returns another network whose outputs are given by + `critic_network(observation_network(o), action_network(a))`. + + The observations and actions passed to this module are assumed to have a batch + dimension that match. + + Notes: + - Either the `observation_` or `action_network` can be `None`, in which case + the observation or action, resp., are passed to the critic network as is. + - If all `critic_`, `observation_` and `action_network` are `None`, this + module reduces to a simple `tf2_utils.batch_concat()`. + """ + + def __init__(self, + critic_network: Optional[TensorTransformation] = None, + observation_network: Optional[TensorTransformation] = None, + action_network: Optional[TensorTransformation] = None): + self._critic_network = critic_network + self._observation_network = observation_network + self._action_network = action_network + super().__init__(name='critic_multiplexer') + + def __call__(self, + observation: types.NestedTensor, + action: types.NestedTensor) -> tf.Tensor: + + # Maybe transform observations and actions before feeding them on. + if self._observation_network: + observation = self._observation_network(observation) + if self._action_network: + action = self._action_network(action) + + if hasattr(observation, 'dtype') and hasattr(action, 'dtype'): + if observation.dtype != action.dtype: + # Observation and action must be the same type for concat to work + action = tf.cast(action, observation.dtype) + + # Concat observations and actions, with one batch dimension. + outputs = tf2_utils.batch_concat([observation, action]) + + # Maybe transform output before returning. + if self._critic_network: + outputs = self._critic_network(outputs) + + return outputs diff --git a/acme/acme/tf/networks/noise.py b/acme/acme/tf/networks/noise.py new file mode 100644 index 00000000..6a482233 --- /dev/null +++ b/acme/acme/tf/networks/noise.py @@ -0,0 +1,40 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Noise layers (for exploration).""" + +from acme import types +import sonnet as snt +import tensorflow as tf +import tensorflow_probability as tfp +import tree + +tfd = tfp.distributions + + +class ClippedGaussian(snt.Module): + """Sonnet module for adding clipped Gaussian noise to each output.""" + + def __init__(self, stddev: float, name: str = 'clipped_gaussian'): + super().__init__(name=name) + self._noise = tfd.Normal(loc=0., scale=stddev) + + def __call__(self, inputs: types.NestedTensor) -> types.NestedTensor: + def add_noise(tensor: tf.Tensor): + output = tensor + tf.cast(self._noise.sample(tensor.shape), + dtype=tensor.dtype) + output = tf.clip_by_value(output, -1.0, 1.0) + return output + + return tree.map_structure(add_noise, inputs) diff --git a/acme/acme/tf/networks/policy_value.py b/acme/acme/tf/networks/policy_value.py new file mode 100644 index 00000000..cecbe305 --- /dev/null +++ b/acme/acme/tf/networks/policy_value.py @@ -0,0 +1,36 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Policy-value network head for actor-critic algorithms.""" + +from typing import Tuple + +import sonnet as snt +import tensorflow as tf + + +class PolicyValueHead(snt.Module): + """A network with two linear layers, for policy and value respectively.""" + + def __init__(self, num_actions: int): + super().__init__(name='policy_value_network') + self._policy_layer = snt.Linear(num_actions) + self._value_layer = snt.Linear(1) + + def __call__(self, inputs: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: + """Returns a (Logits, Value) tuple.""" + logits = self._policy_layer(inputs) # [B, A] + value = tf.squeeze(self._value_layer(inputs), axis=-1) # [B] + + return logits, value diff --git a/acme/acme/tf/networks/quantile.py b/acme/acme/tf/networks/quantile.py new file mode 100644 index 00000000..89bd7bfa --- /dev/null +++ b/acme/acme/tf/networks/quantile.py @@ -0,0 +1,94 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An implicit quantile network, as described in [0]. + +[0] https://arxiv.org/abs/1806.06923 +""" + +import numpy as np +import sonnet as snt +import tensorflow as tf + + +class IQNNetwork(snt.Module): + """A feedforward network for use with IQN. + + IQN extends the Q-network of regular DQN which consists of torso and head + networks. IQN embeds sampled quantile thresholds into the output space of the + torso network and merges them with the torso output. + + Outputs a tuple consisting of (mean) Q-values, Q-value quantiles, and sampled + quantile thresholds. + """ + + def __init__(self, + torso: snt.Module, + head: snt.Module, + latent_dim: int, + num_quantile_samples: int, + name: str = 'iqn_network'): + """Initializes the network. + + Args: + torso: Network producing an intermediate representation, typically a + convolutional network. + head: Network producing Q-value quantiles, typically an MLP. + latent_dim: Dimension of latent variables. + num_quantile_samples: Number of quantile thresholds to sample. + name: Module name. + """ + super().__init__(name) + self._torso = torso + self._head = head + self._latent_dim = latent_dim + self._num_quantile_samples = num_quantile_samples + + @snt.once + def _create_embedding(self, size): + self._embedding = snt.Linear(size) + + def __call__(self, observations): + # Transform observations to intermediate representations (typically a + # convolutional network). + torso_output = self._torso(observations) + + # Now that dimension of intermediate representation is known initialize + # embedding of sample quantile thresholds (only done once). + self._create_embedding(torso_output.shape[-1]) + + # Sample quantile thresholds. + batch_size = tf.shape(observations)[0] + tau_shape = tf.stack([batch_size, self._num_quantile_samples]) + tau = tf.random.uniform(tau_shape) + indices = tf.range(1, self._latent_dim+1, dtype=tf.float32) + + # Embed sampled quantile thresholds in intermediate representation space. + tau_tiled = tf.tile(tau[:, :, None], (1, 1, self._latent_dim)) + indices_tiled = tf.tile(indices[None, None, :], + tf.concat([tau_shape, [1]], 0)) + tau_embedding = tf.cos(tau_tiled * indices_tiled * np.pi) + tau_embedding = snt.BatchApply(self._embedding)(tau_embedding) + tau_embedding = tf.nn.relu(tau_embedding) + + # Merge intermediate representations with embeddings, and apply head + # network (typically an MLP). + torso_output = tf.tile(torso_output[:, None, :], + (1, self._num_quantile_samples, 1)) + q_value_quantiles = snt.BatchApply(self._head)(tau_embedding * torso_output) + q_dist = tf.transpose(q_value_quantiles, (0, 2, 1)) + q_values = tf.reduce_mean(q_value_quantiles, axis=1) + q_values = tf.stop_gradient(q_values) + + return q_values, q_dist, tau diff --git a/acme/acme/tf/networks/recurrence.py b/acme/acme/tf/networks/recurrence.py new file mode 100644 index 00000000..b07bf9ea --- /dev/null +++ b/acme/acme/tf/networks/recurrence.py @@ -0,0 +1,375 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Networks useful for building recurrent agents. +""" + +import functools +from typing import NamedTuple, Optional, Sequence, Tuple +from absl import logging +from acme import types +from acme.tf import savers +from acme.tf import utils +from acme.tf.networks import base +import sonnet as snt +import tensorflow as tf +import tensorflow_probability as tfp +import tree + +RNNState = types.NestedTensor + + +class PolicyCriticRNNState(NamedTuple): + """Consists of two RNNStates called 'policy' and 'critic'.""" + policy: RNNState + critic: RNNState + + +class UnpackWrapper(snt.Module): + """Gets a list of arguments and pass them as separate arguments. + + Example + ``` + class Critic(snt.Module): + def __call__(self, o, a): + pass + + critic = Critic() + UnpackWrapper(critic)((o, a)) + ``` + calls critic(o, a) + """ + + def __init__(self, module: snt.Module, name: str = 'UnpackWrapper'): + super().__init__(name=name) + self._module = module + + def __call__(self, + inputs: Sequence[types.NestedTensor]) -> types.NestedTensor: + # Unpack the inputs before passing to the underlying module. + return self._module(*inputs) + + +class RNNUnpackWrapper(snt.RNNCore): + """Gets a list of arguments and pass them as separate arguments. + + Example + ``` + class Critic(snt.RNNCore): + def __call__(self, o, a, prev_state): + pass + + critic = Critic() + RNNUnpackWrapper(critic)((o, a), prev_state) + ``` + calls m(o, a, prev_state) + """ + + def __init__(self, module: snt.RNNCore, name: str = 'RNNUnpackWrapper'): + super().__init__(name=name) + self._module = module + + def __call__(self, inputs: Sequence[types.NestedTensor], + prev_state: RNNState) -> Tuple[types.NestedTensor, RNNState]: + # Unpack the inputs before passing to the underlying module. + return self._module(*inputs, prev_state) + + def initial_state(self, batch_size): + return self._module.initial_state(batch_size) + + +class CriticDeepRNN(snt.DeepRNN): + """Same as snt.DeepRNN, but takes three inputs (obs, act, prev_state). + """ + + def __init__(self, layers: Sequence[snt.Module]): + # Make the first layer take a single input instead of a list of arguments. + if isinstance(layers[0], snt.RNNCore): + first_layer = RNNUnpackWrapper(layers[0]) + else: + first_layer = UnpackWrapper(layers[0]) + super().__init__([first_layer] + list(layers[1:])) + + self._unwrapped_first_layer = layers[0] + self.__input_signature = None + + def __call__(self, inputs: types.NestedTensor, action: tf.Tensor, + prev_state: RNNState) -> Tuple[types.NestedTensor, RNNState]: + # Pack the inputs into a tuple and then using inherited DeepRNN logic to + # pass them through the layers. + # This in turn will pass the packed inputs into the first layer + # (UnpackWrapper) which will unpack them back. + return super().__call__((inputs, action), prev_state) + + @property + def _input_signature(self) -> Optional[tf.TensorSpec]: + """Return input signature for Acme snapshotting. + + The Acme way of snapshotting works as follows: you first create your network + variables via the utility function `acme.tf.utils.create_variables()`, which + adds an `_input_signature` attribute to your module. This attribute is + critical for proper snapshot saving and loading. + + If a module with such an attribute is wrapped into e.g. DeepRNN, Acme + descends into the `_layers[0]` of that DeepRNN to find the input + signature. + + This implementation allows CriticDeepRNN to work seamlessly like DeepRNN for + the following two use-cases: + + 1) Creating variables *before* wrapping: + ``` + unwrapped_critic = Critic() + acme.tf.utils.create_variables(unwrapped_critic, specs) + critic = CriticDeepRNN([unwrapped_critic]) + ``` + + 2) Create variables *after* wrapping: + ``` + unwrapped_critic = Critic() + critic = CriticDeepRNN([unwrapped_critic]) + acme.tf.utils.create_variables(critic, specs) + ``` + + Returns: + input_signature of the module or None of it is not known (i.e. the + variables were not created by acme.tf.utils.create_variables nor for this + module nor for any of its descendants). + """ + + if self.__input_signature is not None: + # To make case (2) (see above) work, we need to allow create_variables to + # assign an _input_signature attribute to this module, which is why we + # create additional __input_signature attribute with a setter (see below). + return self.__input_signature + + # To make case (1) work, we descend into self._unwrapped_first_layer + # and try to get its input signature (if it exists) by calling + # savers.get_input_signature. + + # Ideally, savers.get_input_signature should automatically descend into + # DeepRNN. But in this case it breaks on CriticDeepRNN because + # CriticDeepRNN._layers[0] is an UnpackWrapper around the underlying module + # and not the module itself. + input_signature = savers._get_input_signature(self._unwrapped_first_layer) # pylint: disable=protected-access + if input_signature is None: + return None + # Since adding recurrent modules via CriticDeepRNN changes the recurrent + # state, we need to update its spec here. + state = self.initial_state(1) + input_signature[-1] = tree.map_structure( + lambda t: tf.TensorSpec((None,) + t.shape[1:], t.dtype), state) + self.__input_signature = input_signature + return input_signature + + @_input_signature.setter + def _input_signature(self, new_spec: tf.TensorSpec): + self.__input_signature = new_spec + + +class RecurrentExpQWeightedPolicy(snt.RNNCore): + """Recurrent exponentially Q-weighted policy.""" + + def __init__(self, + policy_network: snt.Module, + critic_network: snt.Module, + temperature_beta: float = 1.0, + num_action_samples: int = 16): + super().__init__(name='RecurrentExpQWeightedPolicy') + self._policy_network = policy_network + self._critic_network = critic_network + self._num_action_samples = num_action_samples + self._temperature_beta = temperature_beta + + def __call__(self, + observation: types.NestedTensor, + prev_state: PolicyCriticRNNState + ) -> Tuple[types.NestedTensor, PolicyCriticRNNState]: + + return tf.vectorized_map(self._call, (observation, prev_state)) + + def _call( + self, observation_and_state: Tuple[types.NestedTensor, + PolicyCriticRNNState] + ) -> Tuple[types.NestedTensor, PolicyCriticRNNState]: + """Computes a forward step for a single element. + + The observation and state are packed together in order to use + `tf.vectorized_map` to handle batches of observations. + See this module's __call__() function. + + Args: + observation_and_state: the observation and state packed in a tuple. + + Returns: + The selected action and the corresponding state. + """ + observation, prev_state = observation_and_state + + # Tile input observations and states to allow multiple policy predictions. + tiled_observation, tiled_prev_state = utils.tile_nested( + (observation, prev_state), self._num_action_samples) + actions, policy_states = self._policy_network( + tiled_observation, tiled_prev_state.policy) + + # Evaluate multiple critic predictions with the sampled actions. + value_distribution, critic_states = self._critic_network( + tiled_observation, actions, tiled_prev_state.critic) + value_estimate = value_distribution.mean() + + # Resample a single action of the sampled actions according to logits given + # by the tempered Q-values. + selected_action_idx = tfp.distributions.Categorical( + probs=tf.nn.softmax(value_estimate / self._temperature_beta)).sample() + selected_action = actions[selected_action_idx] + + # Select and return the RNN state that corresponds to the selected action. + states = PolicyCriticRNNState( + policy=policy_states, critic=critic_states) + selected_state = tree.map_structure( + lambda x: x[selected_action_idx], states) + + return selected_action, selected_state + + def initial_state(self, batch_size: int) -> PolicyCriticRNNState: + return PolicyCriticRNNState( + policy=self._policy_network.initial_state(batch_size), + critic=self._critic_network.initial_state(batch_size) + ) + + +class DeepRNN(snt.DeepRNN, base.RNNCore): + """Unroll-aware deep RNN module. + + Sonnet's DeepRNN steps through RNNCores sequentially which can result in a + performance hit, in particular when using Transformers. This module adds an + unroll() method which unrolls each module in the DeepRNN individually, + allowing efficient implementation of the unroll operation. For example, a + Transformer can 'unroll' by evaluating the whole sequence at once (this being + one of the advantages of Transformers over e.g. LSTMs). + + Any RNNCore passed to this module should implement unroll(). Failure to so + may cause the RNNCore not to be called properly. For example, passing a + partial function application of a snt.RNNCore to this module will fail (this + is also true for snt.DeepRNN). However, the special case of passing in a + RNNCore object that does not implement unroll() is supported and will be + dynamically unrolled. Implement unroll() to override this behavior with + static unrolling. + + Stateless modules (i.e. anything other than an RNNCore) which do not + implement unroll() are batch applied over the time and batch axes + simultaneously. Effectively, this means that such modules may be applied to + fairly large batches, potentially leading to out-of-memory issues. + """ + + def __init__(self, layers, name: Optional[str] = None): + """Initializes the module.""" + super().__init__(layers, name=name) + + self.__input_signature = None + self._num_unrollable = 0 + + # As a convenience, check for snt.RNNCore modules and dynamically unroll + # them if they don't already support unrolling. This check can fail, e.g. + # if a partially applied RNNCore is passed in. Sonnet's implementation of + # DeepRNN suffers from the same problem. + for layer in self._layers: + if hasattr(layer, 'unroll'): + self._num_unrollable += 1 + elif isinstance(layer, snt.RNNCore): + self._num_unrollable += 1 + layer.unroll = functools.partial(snt.dynamic_unroll, layer) + logging.warning( + 'Acme DeepRNN detected a Sonnet RNNCore. ' + 'This will be dynamically unrolled. Please implement unroll() ' + 'to suppress this warning.') + + def unroll(self, + inputs: types.NestedTensor, + state: base.State, + sequence_length: int, + ) -> Tuple[types.NestedTensor, base.State]: + """Unroll each layer individually. + + Calls unroll() on layers which support it, all other layers are + batch-applied over the first two axes (assumed to be the time and batch + axes). + + Args: + inputs: A nest of `tf.Tensor` in time-major format. + state: The RNN core state. + sequence_length: How long the static_unroll should go for. + + Returns: + Nested sequence output of RNN, and final state. + + Raises: + ValueError if the length of `state` does not match the number of + unrollable layers. + """ + if len(state) != self._num_unrollable: + raise ValueError( + 'DeepRNN was called with the wrong number of states. The length of ' + '`state` does not match the number of unrollable layers.') + + states = iter(state) + outputs = inputs + next_states = [] + for layer in self._layers: + if hasattr(layer, 'unroll'): + # The length of the `states` list was checked above. + outputs, next_state = layer.unroll(outputs, next(states), + sequence_length) + next_states.append(next_state) + else: + # Couldn't unroll(); assume that this is a stateless module. + outputs = snt.BatchApply(layer, num_dims=2)(outputs) + + return outputs, tuple(next_states) + + @property + def _input_signature(self) -> Optional[tf.TensorSpec]: + """Return input signature for Acme snapshotting, see CriticDeepRNN.""" + + if self.__input_signature is not None: + return self.__input_signature + + input_signature = savers._get_input_signature(self._layers[0]) # pylint: disable=protected-access + if input_signature is None: + return None + + state = self.initial_state(1) + input_signature[-1] = tree.map_structure( + lambda t: tf.TensorSpec((None,) + t.shape[1:], t.dtype), state) + self.__input_signature = input_signature + return input_signature + + @_input_signature.setter + def _input_signature(self, new_spec: tf.TensorSpec): + self.__input_signature = new_spec + + +class LSTM(snt.LSTM, base.RNNCore): + """Unrollable interface to LSTM. + + This module is supposed to be used with the DeepRNN class above, and more + generally in networks which support unroll(). + """ + + def unroll(self, + inputs: types.NestedTensor, + state: base.State, + sequence_length: int, + ) -> Tuple[types.NestedTensor, base.State]: + return snt.static_unroll(self, inputs, state, sequence_length) diff --git a/acme/acme/tf/networks/recurrence_test.py b/acme/acme/tf/networks/recurrence_test.py new file mode 100644 index 00000000..2c97c8fe --- /dev/null +++ b/acme/acme/tf/networks/recurrence_test.py @@ -0,0 +1,88 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test networks for building recurrent agents.""" + +import os + +from acme import specs +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.tf.networks import recurrence +import numpy as np +import sonnet as snt +import tensorflow as tf +import tree + +from absl.testing import absltest + + +# Simple critic-like modules for testing. +class Critic(snt.Module): + + def __call__(self, o, a): + return o * a + + +class RNNCritic(snt.RNNCore): + + def __call__(self, o, a, prev_state): + return o * a, prev_state + + def initial_state(self, batch_size): + return () + + +class NetsTest(tf.test.TestCase): + + def test_criticdeeprnn_snapshot(self): + """Test that CriticDeepRNN works correctly with snapshotting.""" + # Create a test network. + critic = Critic() + rnn_critic = RNNCritic() + + for base_net in [critic, rnn_critic]: + net = recurrence.CriticDeepRNN([base_net, snt.LSTM(10)]) + obs = specs.Array([10], dtype=np.float32) + actions = specs.Array([10], dtype=np.float32) + spec = [obs, actions] + tf2_utils.create_variables(net, spec) + + # Test that if you add some postprocessing without rerunning + # create_variables, it still works. + wrapped_net = recurrence.CriticDeepRNN([net, lambda x: x]) + + for curr_net in [net, wrapped_net]: + # Save the test network. + directory = absltest.get_default_test_tmpdir() + objects_to_save = {'net': curr_net} + snapshotter = tf2_savers.Snapshotter( + objects_to_save, directory=directory) + snapshotter.save() + + # Reload the test network. + net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net')) + + obs = tf.ones((2, 10)) + actions = tf.ones((2, 10)) + state = curr_net.initial_state(2) + outputs1, next_state1 = curr_net(obs, actions, state) + outputs2, next_state2 = net2(obs, actions, state) + + assert np.allclose(outputs1, outputs2) + assert np.allclose(tree.flatten(next_state1), tree.flatten(next_state2)) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/tf/networks/rescaling.py b/acme/acme/tf/networks/rescaling.py new file mode 100644 index 00000000..d661d18b --- /dev/null +++ b/acme/acme/tf/networks/rescaling.py @@ -0,0 +1,73 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Rescaling layers (e.g. to match action specs).""" + +from typing import Union +from acme import specs +import sonnet as snt +import tensorflow as tf +import tensorflow_probability as tfp + +tfd = tfp.distributions +tfb = tfp.bijectors + + +class ClipToSpec(snt.Module): + """Sonnet module clipping inputs to within a BoundedArraySpec.""" + + def __init__(self, spec: specs.BoundedArray, name: str = 'clip_to_spec'): + super().__init__(name=name) + self._min = spec.minimum + self._max = spec.maximum + + def __call__(self, inputs: tf.Tensor) -> tf.Tensor: + return tf.clip_by_value(inputs, self._min, self._max) + + +class RescaleToSpec(snt.Module): + """Sonnet module rescaling inputs in [-1, 1] to match a BoundedArraySpec.""" + + def __init__(self, spec: specs.BoundedArray, name: str = 'rescale_to_spec'): + super().__init__(name=name) + self._scale = spec.maximum - spec.minimum + self._offset = spec.minimum + + def __call__(self, inputs: tf.Tensor) -> tf.Tensor: + inputs = 0.5 * (inputs + 1.0) # [0, 1] + output = inputs * self._scale + self._offset # [minimum, maximum] + + return output + + +class TanhToSpec(snt.Module): + """Sonnet module squashing real-valued inputs to match a BoundedArraySpec.""" + + def __init__(self, spec: specs.BoundedArray, name: str = 'tanh_to_spec'): + super().__init__(name=name) + self._scale = spec.maximum - spec.minimum + self._offset = spec.minimum + + def __call__( + self, inputs: Union[tf.Tensor, tfd.Distribution] + ) -> Union[tf.Tensor, tfd.Distribution]: + if isinstance(inputs, tfd.Distribution): + inputs = tfb.Tanh()(inputs) + inputs = tfb.ScaleMatvecDiag(0.5 * self._scale)(inputs) + output = tfb.Shift(self._offset + 0.5 * self._scale)(inputs) + else: + inputs = tf.tanh(inputs) # [-1, 1] + inputs = 0.5 * (inputs + 1.0) # [0, 1] + output = inputs * self._scale + self._offset # [minimum, maximum] + return output diff --git a/acme/acme/tf/networks/stochastic.py b/acme/acme/tf/networks/stochastic.py new file mode 100644 index 00000000..264d270a --- /dev/null +++ b/acme/acme/tf/networks/stochastic.py @@ -0,0 +1,104 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Useful sonnet modules to chain after distributional module outputs.""" + +from acme import types +from acme.tf import utils as tf2_utils +import sonnet as snt +import tensorflow as tf +import tensorflow_probability as tfp +import tree + +tfd = tfp.distributions + + +class StochasticModeHead(snt.Module): + """Simple sonnet module to produce the mode of a tfp.Distribution.""" + + def __call__(self, distribution: tfd.Distribution): + return distribution.mode() + + +class StochasticMeanHead(snt.Module): + """Simple sonnet module to produce the mean of a tfp.Distribution.""" + + def __call__(self, distribution: tfd.Distribution): + return distribution.mean() + + +class StochasticSamplingHead(snt.Module): + """Simple sonnet module to sample from a tfp.Distribution.""" + + def __call__(self, distribution: tfd.Distribution): + return distribution.sample() + + +class ExpQWeightedPolicy(snt.Module): + """Exponentially Q-weighted policy. + + Given a stochastic policy and a critic, returns a (stochastic) policy which + samples multiple actions from the underlying policy, computes the Q-values for + each action, and chooses the final action among the sampled ones with + probability proportional to the exponentiated Q values, tempered by + a parameter beta. + """ + + def __init__(self, + actor_network: snt.Module, + critic_network: snt.Module, + beta: float = 1.0, + num_action_samples: int = 16): + super().__init__(name='ExpQWeightedPolicy') + self._actor_network = actor_network + self._critic_network = critic_network + self._num_action_samples = num_action_samples + self._beta = beta + + def __call__(self, inputs: types.NestedTensor) -> tf.Tensor: + # Inputs are of size [B, ...]. Here we tile them to be of shape [N, B, ...]. + tiled_inputs = tf2_utils.tile_nested(inputs, self._num_action_samples) + shape = tf.shape(tree.flatten(tiled_inputs)[0]) + n, b = shape[0], shape[1] + tf.debugging.assert_equal(n, self._num_action_samples, + 'Internal Error. Unexpected tiled_inputs shape.') + dummy_zeros_n_b = tf.zeros((n, b)) + # Reshape to [N * B, ...]. + merge = lambda x: snt.merge_leading_dims(x, 2) + tiled_inputs = tree.map_structure(merge, tiled_inputs) + + tiled_actions = self._actor_network(tiled_inputs) + + # Compute Q-values and the resulting tempered probabilities. + q = self._critic_network(tiled_inputs, tiled_actions) + boltzmann_logits = q / self._beta + + boltzmann_logits = snt.split_leading_dim(boltzmann_logits, dummy_zeros_n_b, + 2) + # [B, N] + boltzmann_logits = tf.transpose(boltzmann_logits, perm=(1, 0)) + # Resample one action per batch according to the Boltzmann distribution. + action_idx = tfp.distributions.Categorical(logits=boltzmann_logits).sample() + # [B, 2], where the first column is 0, 1, 2,... corresponding to indices to + # the batch dimension. + action_idx = tf.stack((tf.range(b), action_idx), axis=1) + + tiled_actions = snt.split_leading_dim(tiled_actions, dummy_zeros_n_b, 2) + action_dim = len(tiled_actions.get_shape().as_list()) + tiled_actions = tf.transpose(tiled_actions, + perm=[1, 0] + list(range(2, action_dim))) + # [B, ...] + action_sample = tf.gather_nd(tiled_actions, action_idx) + + return action_sample diff --git a/acme/acme/tf/networks/vision.py b/acme/acme/tf/networks/vision.py new file mode 100644 index 00000000..7065dba7 --- /dev/null +++ b/acme/acme/tf/networks/vision.py @@ -0,0 +1,236 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Visual networks for processing pixel inputs.""" + +from typing import Callable, Optional, Sequence, Union +import sonnet as snt +import tensorflow as tf + + +class ResNetTorso(snt.Module): + """ResNet architecture used in IMPALA paper.""" + + def __init__( + self, + num_channels: Sequence[int] = (16, 32, 32), # default to IMPALA resnet. + num_blocks: Sequence[int] = (2, 2, 2), # default to IMPALA resnet. + num_output_hidden: Sequence[int] = (256,), # default to IMPALA resnet. + conv_shape: Union[int, Sequence[int]] = 3, + conv_stride: Union[int, Sequence[int]] = 1, + pool_size: Union[int, Sequence[int]] = 3, + pool_stride: Union[int, Sequence[int], Sequence[Sequence[int]]] = 2, + data_format: str = 'NHWC', + activation: Callable[[tf.Tensor], tf.Tensor] = tf.nn.relu, + output_dtype: tf.DType = tf.float32, + name: str = 'resnet_torso'): + """Builds an IMPALA-style ResNet. + + The arguments' default values construct the IMPALA resnet. + + Args: + num_channels: The number of convolutional channels for each layer. + num_blocks: The number of resnet blocks in each "layer". + num_output_hidden: The output size(s) of the MLP layer(s) on top. + conv_shape: The convolution filter size (int), or size dimensions (H, W). + conv_stride: the convolution stride (int), or strides (row, column). + pool_size: The pooling footprint size (int), or size dimensions (H, W). + pool_stride: The pooling stride (int) or strides (row, column), or + strides for each of the N layers ((r1, c1), (r2, c2), ..., (rN, cN)). + data_format: The axis order of the input. + activation: The activation function. + output_dtype: the output dtype. + name: The Sonnet module name. + """ + super().__init__(name=name) + + self._output_dtype = output_dtype + self._num_layers = len(num_blocks) + + if isinstance(pool_stride, int): + pool_stride = (pool_stride, pool_stride) + + if isinstance(pool_stride[0], int): + pool_stride = self._num_layers * (pool_stride,) + + # Create sequence of residual blocks. + blocks = [] + for i in range(self._num_layers): + blocks.append( + ResidualBlockGroup( + num_blocks[i], + num_channels[i], + conv_shape, + conv_stride, + pool_size, + pool_stride[i], + data_format=data_format, + activation=activation)) + + # Create output layer. + out_layer = snt.nets.MLP(num_output_hidden, activation=activation) + + # Compose blocks and final layer. + self._resnet = snt.Sequential( + blocks + [activation, snt.Flatten(), out_layer]) + + def __call__(self, inputs: tf.Tensor) -> tf.Tensor: + """Evaluates the ResidualPixelCore.""" + + # Convert to floats. + preprocessed_inputs = _preprocess_inputs(inputs, self._output_dtype) + torso_output = self._resnet(preprocessed_inputs) + + return torso_output + + +class ResidualBlockGroup(snt.Module): + """Higher level block for ResNet implementation.""" + + def __init__(self, + num_blocks: int, + num_output_channels: int, + conv_shape: Union[int, Sequence[int]], + conv_stride: Union[int, Sequence[int]], + pool_shape: Union[int, Sequence[int]], + pool_stride: Union[int, Sequence[int]], + data_format: str = 'NHWC', + activation: Callable[[tf.Tensor], tf.Tensor] = tf.nn.relu, + name: Optional[str] = None): + super().__init__(name=name) + + self._num_blocks = num_blocks + self._data_format = data_format + self._activation = activation + + # The pooling operation expects a 2-rank shape/stride (height and width). + if isinstance(pool_shape, int): + pool_shape = 2 * [pool_shape] + if isinstance(pool_stride, int): + pool_stride = 2 * [pool_stride] + + # Create a Conv2D factory since we'll be making quite a few. + def build_conv_layer(name: str): + return snt.Conv2D( + num_output_channels, + conv_shape, + stride=conv_stride, + padding='SAME', + data_format=data_format, + name=name) + + # Create a pooling layer. + def pooling_layer(inputs: tf.Tensor) -> tf.Tensor: + return tf.nn.pool( + inputs, + pool_shape, + pooling_type='MAX', + strides=pool_stride, + padding='SAME', + data_format=data_format) + + # Create an initial conv layer and pooling to scale the image down. + self._downscale = snt.Sequential( + [build_conv_layer('downscale'), pooling_layer]) + + # Residual block(s). + self._convs = [] + for i in range(self._num_blocks): + name = 'residual_block_%d' % i + self._convs.append( + [build_conv_layer(name + '_0'), + build_conv_layer(name + '_1')]) + + def __call__(self, inputs: tf.Tensor) -> tf.Tensor: + # Downscale the inputs. + conv_out = self._downscale(inputs) + + # Apply (sequence of) residual block(s). + for i in range(self._num_blocks): + block_input = conv_out + conv_out = self._activation(conv_out) + conv_out = self._convs[i][0](conv_out) + conv_out = self._activation(conv_out) + conv_out = self._convs[i][1](conv_out) + conv_out += block_input + return conv_out + + +def _preprocess_inputs(inputs: tf.Tensor, output_dtype: tf.DType) -> tf.Tensor: + """Returns the `Tensor` corresponding to the preprocessed inputs.""" + rank = inputs.shape.rank + if rank < 4: + raise ValueError( + 'Input Tensor must have at least 4 dimensions (for ' + 'batch size, height, width, and channels), but it only has ' + '{}'.format(rank)) + + flattened_inputs = snt.Flatten(preserve_dims=3)(inputs) + processed_inputs = tf.image.convert_image_dtype( + flattened_inputs, dtype=output_dtype) + return processed_inputs + + +class DrQTorso(snt.Module): + """DrQ Torso inspired by the second DrQ paper [Yarats et al., 2021]. + + [Yarats et al., 2021] https://arxiv.org/abs/2107.09645 + """ + + def __init__( + self, + data_format: str = 'NHWC', + activation: Callable[[tf.Tensor], tf.Tensor] = tf.nn.relu, + output_dtype: tf.DType = tf.float32, + name: str = 'resnet_torso'): + super().__init__(name=name) + + self._output_dtype = output_dtype + + # Create a Conv2D factory since we'll be making quite a few. + gain = 2**0.5 if activation == tf.nn.relu else 1. + def build_conv_layer(name: str, + output_channels: int = 32, + kernel_shape: Sequence[int] = (3, 3), + stride: int = 1): + return snt.Conv2D( + output_channels=output_channels, + kernel_shape=kernel_shape, + stride=stride, + padding='SAME', + data_format=data_format, + w_init=snt.initializers.Orthogonal(gain=gain, seed=None), + b_init=snt.initializers.Zeros(), + name=name) + + self._network = snt.Sequential( + [build_conv_layer('conv_0', stride=2), + activation, + build_conv_layer('conv_1', stride=1), + activation, + build_conv_layer('conv_2', stride=1), + activation, + build_conv_layer('conv_3', stride=1), + activation, + snt.Flatten()]) + + def __call__(self, inputs: tf.Tensor) -> tf.Tensor: + """Evaluates the ResidualPixelCore.""" + + # Normalize to -0.5 to 0.5 + preprocessed_inputs = _preprocess_inputs(inputs, self._output_dtype) - 0.5 + + torso_output = self._network(preprocessed_inputs) + + return torso_output diff --git a/acme/acme/tf/savers.py b/acme/acme/tf/savers.py new file mode 100644 index 00000000..760618cd --- /dev/null +++ b/acme/acme/tf/savers.py @@ -0,0 +1,460 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility classes for saving model checkpoints and snapshots.""" + +import abc +import datetime +import os +import pickle +import time +from typing import Mapping, Optional, Union + +from absl import logging +from acme import core +from acme.utils import signals +from acme.utils import paths +import sonnet as snt +import tensorflow as tf +import tensorflow_probability as tfp +import tree + +from tensorflow.python.saved_model import revived_types + +PythonState = tf.train.experimental.PythonState +Checkpointable = Union[tf.Module, tf.Variable, PythonState] + +_DEFAULT_CHECKPOINT_TTL = int(datetime.timedelta(days=5).total_seconds()) +_DEFAULT_SNAPSHOT_TTL = int(datetime.timedelta(days=90).total_seconds()) + + +class TFSaveable(abc.ABC): + """An interface for objects that expose their checkpointable TF state.""" + + @property + @abc.abstractmethod + def state(self) -> Mapping[str, Checkpointable]: + """Returns TensorFlow checkpointable state.""" + + +class Checkpointer: + """Convenience class for periodically checkpointing. + + This can be used to checkpoint any object with trackable state (e.g. + tensorflow variables or modules); see tf.train.Checkpoint for + details. Objects inheriting from tf.train.experimental.PythonState can also + be checkpointed. + + Typically people use Checkpointer to make sure that they can correctly recover + from a machine going down during learning. For more permanent storage of self- + contained "networks" see the Snapshotter object. + + Usage example: + + ```python + model = snt.Linear(10) + checkpointer = Checkpointer(objects_to_save={'model': model}) + + for _ in range(100): + # ... + checkpointer.save() + ``` + """ + + def __init__( + self, + objects_to_save: Mapping[str, Union[Checkpointable, core.Saveable]], + *, + directory: str = '~/acme/', + subdirectory: str = 'default', + time_delta_minutes: float = 10.0, + enable_checkpointing: bool = True, + add_uid: bool = True, + max_to_keep: int = 1, + checkpoint_ttl_seconds: int = _DEFAULT_CHECKPOINT_TTL, + keep_checkpoint_every_n_hours: Optional[int] = None, + ): + """Builds the saver object. + + Args: + objects_to_save: Mapping specifying what to checkpoint. + directory: Which directory to put the checkpoint in. + subdirectory: Sub-directory to use (e.g. if multiple checkpoints are being + saved). + time_delta_minutes: How often to save the checkpoint, in minutes. + enable_checkpointing: whether to checkpoint or not. + add_uid: If True adds a UID to the checkpoint path, see + `paths.get_unique_id()` for how this UID is generated. + max_to_keep: The maximum number of checkpoints to keep. + checkpoint_ttl_seconds: TTL (time to leave) in seconds for checkpoints. + keep_checkpoint_every_n_hours: keep_checkpoint_every_n_hours passed to + tf.train.CheckpointManager. + """ + + # Convert `Saveable` objects to TF `Checkpointable` first, if necessary. + def to_ckptable(x: Union[Checkpointable, core.Saveable]) -> Checkpointable: + if isinstance(x, core.Saveable): + return SaveableAdapter(x) + return x + + objects_to_save = {k: to_ckptable(v) for k, v in objects_to_save.items()} + + self._time_delta_minutes = time_delta_minutes + self._last_saved = 0. + self._enable_checkpointing = enable_checkpointing + self._checkpoint_manager = None + + if enable_checkpointing: + # Checkpoint object that handles saving/restoring. + self._checkpoint = tf.train.Checkpoint(**objects_to_save) + self._checkpoint_dir = paths.process_path( + directory, + 'checkpoints', + subdirectory, + ttl_seconds=checkpoint_ttl_seconds, + backups=False, + add_uid=add_uid) + + # Create a manager to maintain different checkpoints. + self._checkpoint_manager = tf.train.CheckpointManager( + self._checkpoint, + directory=self._checkpoint_dir, + max_to_keep=max_to_keep, + keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours) + + self.restore() + + def save(self, force: bool = False) -> bool: + """Save the checkpoint if it's the appropriate time, otherwise no-ops. + + Args: + force: Whether to force a save regardless of time elapsed since last save. + + Returns: + A boolean indicating if a save event happened. + """ + if not self._enable_checkpointing: + return False + + if (not force and + time.time() - self._last_saved < 60 * self._time_delta_minutes): + return False + + # Save any checkpoints. + logging.info('Saving checkpoint: %s', self._checkpoint_manager.directory) + self._checkpoint_manager.save() + self._last_saved = time.time() + + return True + + def restore(self): + # Restore from the most recent checkpoint (if it exists). + checkpoint_to_restore = self._checkpoint_manager.latest_checkpoint + logging.info('Attempting to restore checkpoint: %s', + checkpoint_to_restore) + self._checkpoint.restore(checkpoint_to_restore) + + @property + def directory(self): + return self._checkpoint_manager.directory + + +class CheckpointingRunner(core.Worker): + """Wrap an object and expose a run method which checkpoints periodically. + + This internally creates a Checkpointer around `wrapped` object and exposes + all of the methods of `wrapped`. Additionally, any `**kwargs` passed to the + runner are forwarded to the internal Checkpointer. + """ + + def __init__( + self, + wrapped: Union[Checkpointable, core.Saveable, TFSaveable], + key: str = 'wrapped', + *, + time_delta_minutes: int = 30, + **kwargs, + ): + + if isinstance(wrapped, TFSaveable): + # If the object to be wrapped exposes its TF State, checkpoint that. + objects_to_save = wrapped.state + else: + # Otherwise checkpoint the wrapped object itself. + objects_to_save = wrapped + + self._wrapped = wrapped + self._time_delta_minutes = time_delta_minutes + self._checkpointer = Checkpointer( + objects_to_save={key: objects_to_save}, + time_delta_minutes=time_delta_minutes, + **kwargs) + + # Handle preemption signal. Note that this must happen in the main thread. + def _signal_handler(self): + logging.info('Caught SIGTERM: forcing a checkpoint save.') + self._checkpointer.save(force=True) + + def step(self): + if isinstance(self._wrapped, core.Learner): + # Learners have a step() method, so alternate between that and ckpt call. + self._wrapped.step() + self._checkpointer.save() + else: + # Wrapped object doesn't have a run method; set our run method to ckpt. + self.checkpoint() + + def run(self): + """Runs the checkpointer.""" + with signals.runtime_terminator(self._signal_handler): + while True: + self.step() + + def __dir__(self): + return dir(self._wrapped) + ['get_directory'] + + # TODO(b/195915583) : Throw when wrapped object has get_directory() method. + def __getattr__(self, name): + if name == 'get_directory': + return self.get_directory + return getattr(self._wrapped, name) + + def checkpoint(self): + self._checkpointer.save() + # Do not sleep for a long period of time to avoid LaunchPad program + # termination hangs (time.sleep is not interruptible). + for _ in range(self._time_delta_minutes * 60): + time.sleep(1) + + def get_directory(self): + return self._checkpointer.directory + + +class Snapshotter: + """Convenience class for periodically snapshotting. + + Objects which can be snapshotted are limited to Sonnet or tensorflow Modules + which implement a __call__ method. This will save the module's graph and + variables such that they can be loaded later using `tf.saved_model.load`. See + https://www.tensorflow.org/guide/saved_model for more details. + + The Snapshotter is typically used to save infrequent permanent self-contained + snapshots which can be loaded later for inspection. For frequent saving of + model parameters in order to guard against pre-emption of the learning process + see the Checkpointer class. + + Usage example: + + ```python + model = snt.Linear(10) + snapshotter = Snapshotter(objects_to_save={'model': model}) + + for _ in range(100): + # ... + snapshotter.save() + ``` + """ + + def __init__( + self, + objects_to_save: Mapping[str, snt.Module], + *, + directory: str = '~/acme/', + time_delta_minutes: float = 30.0, + snapshot_ttl_seconds: int = _DEFAULT_SNAPSHOT_TTL, + ): + """Builds the saver object. + + Args: + objects_to_save: Mapping specifying what to snapshot. + directory: Which directory to put the snapshot in. + time_delta_minutes: How often to save the snapshot, in minutes. + snapshot_ttl_seconds: TTL (time to leave) in seconds for snapshots. + """ + objects_to_save = objects_to_save or {} + + self._time_delta_minutes = time_delta_minutes + self._last_saved = 0. + self._snapshots = {} + + # Save the base directory path so we can refer to it if needed. + self.directory = paths.process_path( + directory, 'snapshots', ttl_seconds=snapshot_ttl_seconds) + + # Save a dictionary mapping paths to snapshot capable models. + for name, module in objects_to_save.items(): + path = os.path.join(self.directory, name) + self._snapshots[path] = make_snapshot(module) + + def save(self, force: bool = False) -> bool: + """Snapshots if it's the appropriate time, otherwise no-ops. + + Args: + force: If True, save new snapshot no matter how long it's been since the + last one. + + Returns: + A boolean indicating if a save event happened. + """ + seconds_since_last = time.time() - self._last_saved + if (self._snapshots and + (force or seconds_since_last >= 60 * self._time_delta_minutes)): + # Save any snapshots. + for path, snapshot in self._snapshots.items(): + tf.saved_model.save(snapshot, path) + + # Record the time we finished saving. + self._last_saved = time.time() + + return True + + return False + + +class Snapshot(tf.Module): + """Thin wrapper which allows the module to be saved.""" + + def __init__(self): + super().__init__() + self._module = None + self._variables = None + self._trainable_variables = None + + @tf.function + def __call__(self, *args, **kwargs): + return self._module(*args, **kwargs) + + @property + def submodules(self): + return [self._module] + + @property + def variables(self): + return self._variables + + @property + def trainable_variables(self): + return self._trainable_variables + + +# Registers the Snapshot object above such that when it is restored by +# tf.saved_model.load it will be restored as a Snapshot. This is important +# because it allows us to expose the __call__, and *_variables properties. +revived_types.register_revived_type( + 'acme_snapshot', + lambda obj: isinstance(obj, Snapshot), + versions=[ + revived_types.VersionedTypeRegistration( + object_factory=lambda proto: Snapshot(), + version=1, + min_producer_version=1, + min_consumer_version=1, + setter=setattr, + ) + ]) + + +def make_snapshot(module: snt.Module): + """Create a thin wrapper around a module to make it snapshottable.""" + # Get the input signature as long as it has been created. + input_signature = _get_input_signature(module) + if input_signature is None: + raise ValueError( + ('module instance "{}" has no input_signature attribute, ' + 'which is required for snapshotting; run ' + 'create_variables to add this annotation.').format(module.name)) + + # This function will return the object as a composite tensor if it is a + # distribution and will otherwise return it with no changes. + def as_composite(obj): + if isinstance(obj, tfp.distributions.Distribution): + return tfp.experimental.as_composite(obj) + else: + return obj + + # Replace any distributions returned by the module with composite tensors and + # wrap it up in tf.function so we can process it properly. + @tf.function + def wrapped_module(*args, **kwargs): + return tree.map_structure(as_composite, module(*args, **kwargs)) + + # pylint: disable=protected-access + snapshot = Snapshot() + snapshot._module = wrapped_module + snapshot._variables = module.variables + snapshot._trainable_variables = module.trainable_variables + # pylint: disable=protected-access + + # Make sure the snapshot has the proper input signature. + snapshot.__call__.get_concrete_function(*input_signature) + + # If we are an RNN also save the initial-state generating function. + if isinstance(module, snt.RNNCore): + snapshot.initial_state = tf.function(module.initial_state) + snapshot.initial_state.get_concrete_function( + tf.TensorSpec(shape=(), dtype=tf.int32)) + + return snapshot + + +def _get_input_signature(module: snt.Module) -> Optional[tf.TensorSpec]: + """Get module input signature. + + Works even if the module with signature is wrapper into snt.Sequentual or + snt.DeepRNN. + + Args: + module: the module which input signature we need to get. The module has to + either have input_signature itself (i.e. you have to run create_variables + on the module), or it has to be a module (with input_signature) wrapped in + (one or multiple) snt.Sequential or snt.DeepRNNs. + + Returns: + Input signature of the module or None if it's not available. + """ + if hasattr(module, '_input_signature'): + return module._input_signature # pylint: disable=protected-access + + if isinstance(module, snt.Sequential): + first_layer = module._layers[0] # pylint: disable=protected-access + return _get_input_signature(first_layer) + + if isinstance(module, snt.DeepRNN): + first_layer = module._layers[0] # pylint: disable=protected-access + input_signature = _get_input_signature(first_layer) + + # Wrapping a module in DeepRNN changes its state shape, so we need to bring + # it up to date. + state = module.initial_state(1) + input_signature[-1] = tree.map_structure( + lambda t: tf.TensorSpec((None,) + t.shape[1:], t.dtype), state) + + return input_signature + + return None + + +class SaveableAdapter(tf.train.experimental.PythonState): + """Adapter which allows `Saveable` object to be checkpointed by TensorFlow.""" + + def __init__(self, object_to_save: core.Saveable): + self._object_to_save = object_to_save + + def serialize(self): + state = self._object_to_save.save() + return pickle.dumps(state) + + def deserialize(self, pickled: bytes): + state = pickle.loads(pickled) + self._object_to_save.restore(state) diff --git a/acme/acme/tf/savers_test.py b/acme/acme/tf/savers_test.py new file mode 100644 index 00000000..cd075a18 --- /dev/null +++ b/acme/acme/tf/savers_test.py @@ -0,0 +1,294 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for TF2 savers.""" + +import os +import re +import time +from unittest import mock + +from acme import specs +from acme.testing import test_utils +from acme.tf import networks +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import paths +import launchpad +import numpy as np +import sonnet as snt +import tensorflow as tf +import tree + +from absl.testing import absltest + + +class DummySaveable(tf2_savers.TFSaveable): + + _state: tf.Variable + + def __init__(self): + self._state = tf.Variable(0, dtype=tf.int32) + + @property + def state(self): + return {'state': self._state} + + +class CheckpointerTest(test_utils.TestCase): + + def test_save_and_restore(self): + """Test that checkpointer correctly calls save and restore.""" + + x = tf.Variable(0, dtype=tf.int32) + directory = self.get_tempdir() + checkpointer = tf2_savers.Checkpointer( + objects_to_save={'x': x}, time_delta_minutes=0., directory=directory) + + for _ in range(10): + saved = checkpointer.save() + self.assertTrue(saved) + x.assign_add(1) + checkpointer.restore() + np.testing.assert_array_equal(x.numpy(), np.int32(0)) + + def test_save_and_new_restore(self): + """Tests that a fresh checkpointer correctly restores an existing ckpt.""" + with mock.patch.object(paths, 'get_unique_id') as mock_unique_id: + mock_unique_id.return_value = ('test',) + x = tf.Variable(0, dtype=tf.int32) + directory = self.get_tempdir() + checkpointer1 = tf2_savers.Checkpointer( + objects_to_save={'x': x}, time_delta_minutes=0., directory=directory) + checkpointer1.save() + x.assign_add(1) + # Simulate a preemption: x is changed, and we make a new Checkpointer. + checkpointer2 = tf2_savers.Checkpointer( + objects_to_save={'x': x}, time_delta_minutes=0., directory=directory) + checkpointer2.restore() + np.testing.assert_array_equal(x.numpy(), np.int32(0)) + + def test_save_and_restore_time_based(self): + """Test that checkpointer correctly calls save and restore based on time.""" + + x = tf.Variable(0, dtype=tf.int32) + directory = self.get_tempdir() + checkpointer = tf2_savers.Checkpointer( + objects_to_save={'x': x}, time_delta_minutes=1., directory=directory) + + with mock.patch.object(time, 'time') as mock_time: + mock_time.return_value = 0. + self.assertFalse(checkpointer.save()) + + mock_time.return_value = 40. + self.assertFalse(checkpointer.save()) + + mock_time.return_value = 70. + self.assertTrue(checkpointer.save()) + x.assign_add(1) + checkpointer.restore() + np.testing.assert_array_equal(x.numpy(), np.int32(0)) + + def test_no_checkpoint(self): + """Test that checkpointer does nothing when checkpoint=False.""" + num_steps = tf.Variable(0) + checkpointer = tf2_savers.Checkpointer( + objects_to_save={'num_steps': num_steps}, enable_checkpointing=False) + + for _ in range(10): + self.assertFalse(checkpointer.save()) + self.assertIsNone(checkpointer._checkpoint_manager) + + def test_tf_saveable(self): + x = DummySaveable() + + directory = self.get_tempdir() + checkpoint_runner = tf2_savers.CheckpointingRunner( + x, time_delta_minutes=0, directory=directory) + checkpoint_runner._checkpointer.save() + + x._state.assign_add(1) + checkpoint_runner._checkpointer.restore() + + np.testing.assert_array_equal(x._state.numpy(), np.int32(0)) + + +class CheckpointingRunnerTest(test_utils.TestCase): + + def test_signal_handling(self): + x = DummySaveable() + + # Increment the value of DummySavable. + x.state['state'].assign_add(1) + + directory = self.get_tempdir() + + # Patch signals.add_handler so the registered signal handler sets the event. + with mock.patch.object( + launchpad, 'register_stop_handler') as mock_register_stop_handler: + def add_handler(fn): + fn() + mock_register_stop_handler.side_effect = add_handler + + runner = tf2_savers.CheckpointingRunner( + wrapped=x, + time_delta_minutes=0, + directory=directory) + with self.assertRaises(SystemExit): + runner.run() + + # Recreate DummySavable(), its tf.Variable is initialized to 0. + x = DummySaveable() + # Recreate the CheckpointingRunner, which will restore DummySavable() to 1. + tf2_savers.CheckpointingRunner( + wrapped=x, + time_delta_minutes=0, + directory=directory) + # Check DummyVariable() was restored properly. + np.testing.assert_array_equal(x.state['state'].numpy(), np.int32(1)) + + def test_checkpoint_dir(self): + directory = self.get_tempdir() + ckpt_runner = tf2_savers.CheckpointingRunner( + wrapped=DummySaveable(), + time_delta_minutes=0, + directory=directory) + expected_dir_re = f'{directory}/[a-z0-9-]*/checkpoints/default' + regexp = re.compile(expected_dir_re) + self.assertIsNotNone(regexp.fullmatch(ckpt_runner.get_directory())) + + +class SnapshotterTest(test_utils.TestCase): + + def test_snapshot(self): + """Test that snapshotter correctly calls saves/restores snapshots.""" + # Create a test network. + net1 = networks.LayerNormMLP([10, 10]) + spec = specs.Array([10], dtype=np.float32) + tf2_utils.create_variables(net1, [spec]) + + # Save the test network. + directory = self.get_tempdir() + objects_to_save = {'net': net1} + snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory) + snapshotter.save() + + # Reload the test network. + net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net')) + inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) + + with tf.GradientTape() as tape: + outputs1 = net1(inputs) + loss1 = tf.math.reduce_sum(outputs1) + grads1 = tape.gradient(loss1, net1.trainable_variables) + + with tf.GradientTape() as tape: + outputs2 = net2(inputs) + loss2 = tf.math.reduce_sum(outputs2) + grads2 = tape.gradient(loss2, net2.trainable_variables) + + assert np.allclose(outputs1, outputs2) + assert all(tree.map_structure(np.allclose, list(grads1), list(grads2))) + + def test_snapshot_distribution(self): + """Test that snapshotter correctly calls saves/restores snapshots.""" + # Create a test network. + net1 = snt.Sequential([ + networks.LayerNormMLP([10, 10]), + networks.MultivariateNormalDiagHead(1) + ]) + spec = specs.Array([10], dtype=np.float32) + tf2_utils.create_variables(net1, [spec]) + + # Save the test network. + directory = self.get_tempdir() + objects_to_save = {'net': net1} + snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory) + snapshotter.save() + + # Reload the test network. + net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net')) + inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) + + with tf.GradientTape() as tape: + dist1 = net1(inputs) + loss1 = tf.math.reduce_sum(dist1.mean() + dist1.variance()) + grads1 = tape.gradient(loss1, net1.trainable_variables) + + with tf.GradientTape() as tape: + dist2 = net2(inputs) + loss2 = tf.math.reduce_sum(dist2.mean() + dist2.variance()) + grads2 = tape.gradient(loss2, net2.trainable_variables) + + assert all(tree.map_structure(np.allclose, list(grads1), list(grads2))) + + def test_force_snapshot(self): + """Test that the force feature in Snapshotter.save() works correctly.""" + # Create a test network. + net = snt.Linear(10) + spec = specs.Array([10], dtype=np.float32) + tf2_utils.create_variables(net, [spec]) + + # Save the test network. + directory = self.get_tempdir() + objects_to_save = {'net': net} + # Very long time_delta_minutes. + snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory, + time_delta_minutes=1000) + self.assertTrue(snapshotter.save(force=False)) + + # Due to the long time_delta_minutes, only force=True will create a new + # snapshot. This also checks the default is force=False. + self.assertFalse(snapshotter.save()) + self.assertTrue(snapshotter.save(force=True)) + + def test_rnn_snapshot(self): + """Test that snapshotter correctly calls saves/restores snapshots on RNNs.""" + # Create a test network. + net = snt.LSTM(10) + spec = specs.Array([10], dtype=np.float32) + tf2_utils.create_variables(net, [spec]) + + # Test that if you add some postprocessing without rerunning + # create_variables, it still works. + wrapped_net = snt.DeepRNN([net, lambda x: x]) + + for net1 in [net, wrapped_net]: + # Save the test network. + directory = self.get_tempdir() + objects_to_save = {'net': net1} + snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory) + snapshotter.save() + + # Reload the test network. + net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net')) + inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) + + with tf.GradientTape() as tape: + outputs1, next_state1 = net1(inputs, net1.initial_state(1)) + loss1 = tf.math.reduce_sum(outputs1) + grads1 = tape.gradient(loss1, net1.trainable_variables) + + with tf.GradientTape() as tape: + outputs2, next_state2 = net2(inputs, net2.initial_state(1)) + loss2 = tf.math.reduce_sum(outputs2) + grads2 = tape.gradient(loss2, net2.trainable_variables) + + assert np.allclose(outputs1, outputs2) + assert np.allclose(tree.flatten(next_state1), tree.flatten(next_state2)) + assert all(tree.map_structure(np.allclose, list(grads1), list(grads2))) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/tf/utils.py b/acme/acme/tf/utils.py new file mode 100644 index 00000000..8a52e93c --- /dev/null +++ b/acme/acme/tf/utils.py @@ -0,0 +1,184 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for nested data structures involving NumPy and TensorFlow 2.x.""" + +import functools +from typing import List, Optional, Union + +from acme import types +from acme.utils import tree_utils + +import sonnet as snt +import tensorflow as tf +import tree + + +def add_batch_dim(nest: types.NestedTensor) -> types.NestedTensor: + """Adds a batch dimension to each leaf of a nested structure of Tensors.""" + return tree.map_structure(lambda x: tf.expand_dims(x, axis=0), nest) + + +def squeeze_batch_dim(nest: types.NestedTensor) -> types.NestedTensor: + """Squeezes out a batch dimension from each leaf of a nested structure.""" + return tree.map_structure(lambda x: tf.squeeze(x, axis=0), nest) + + +def batch_concat(inputs: types.NestedTensor) -> tf.Tensor: + """Concatenate a collection of Tensors while preserving the batch dimension. + + This takes a potentially nested collection of tensors, flattens everything + but the batch (first) dimension, and concatenates along the resulting data + (second) dimension. + + Args: + inputs: a tensor or nested collection of tensors. + + Returns: + A concatenated tensor which maintains the batch dimension but concatenates + all other data along the flattened second dimension. + """ + flat_leaves = tree.map_structure(snt.Flatten(), inputs) + return tf.concat(tree.flatten(flat_leaves), axis=-1) + + +def batch_to_sequence(data: types.NestedTensor) -> types.NestedTensor: + """Converts data between sequence-major and batch-major format.""" + return tree.map_structure( + lambda t: tf.transpose(t, [1, 0] + list(range(2, t.shape.rank))), data) + + +def tile_tensor(tensor: tf.Tensor, multiple: int) -> tf.Tensor: + """Tiles `multiple` copies of `tensor` along a new leading axis.""" + rank = len(tensor.shape) + multiples = tf.constant([multiple] + [1] * rank, dtype=tf.int32) + expanded_tensor = tf.expand_dims(tensor, axis=0) + return tf.tile(expanded_tensor, multiples) + + +def tile_nested(inputs: types.NestedTensor, + multiple: int) -> types.NestedTensor: + """Tiles tensors in a nested structure along a new leading axis.""" + tile = functools.partial(tile_tensor, multiple=multiple) + return tree.map_structure(tile, inputs) + + +def create_variables( + network: snt.Module, + input_spec: List[Union[types.NestedSpec, tf.TensorSpec]], +) -> Optional[tf.TensorSpec]: + """Builds the network with dummy inputs to create the necessary variables. + + Args: + network: Sonnet Module whose variables are to be created. + input_spec: list of input specs to the network. The length of this list + should match the number of arguments expected by `network`. + + Returns: + output_spec: only returns an output spec if the output is a tf.Tensor, else + it doesn't return anything (None); e.g. if the output is a + tfp.distributions.Distribution. + """ + # Create a dummy observation with no batch dimension. + dummy_input = zeros_like(input_spec) + + # If we have an RNNCore the hidden state will be an additional input. + if isinstance(network, snt.RNNCore): + initial_state = squeeze_batch_dim(network.initial_state(1)) + dummy_input += [initial_state] + + # Forward pass of the network which will create variables as a side effect. + dummy_output = network(*add_batch_dim(dummy_input)) + + # Evaluate the input signature by converting the dummy input into a + # TensorSpec. We then save the signature as a property of the network. This is + # done so that we can later use it when creating snapshots. We do this here + # because the snapshot code may not have access to the precise form of the + # inputs. + input_signature = tree.map_structure( + lambda t: tf.TensorSpec((None,) + t.shape, t.dtype), dummy_input) + network._input_signature = input_signature # pylint: disable=protected-access + + def spec(output): + # If the output is not a Tensor, return None as spec is ill-defined. + if not isinstance(output, tf.Tensor): + return None + # If this is not a scalar Tensor, make sure to squeeze out the batch dim. + if tf.rank(output) > 0: + output = squeeze_batch_dim(output) + return tf.TensorSpec(output.shape, output.dtype) + + return tree.map_structure(spec, dummy_output) + + +class TransformationWrapper(snt.Module): + """Helper class for to_sonnet_module. + + This wraps arbitrary Tensor-valued callables as a Sonnet module. + A use case for this is in agent code that could take either a trainable + sonnet module or a hard-coded function as its policy. By wrapping a hard-coded + policy with this class, the agent can then treat it as if it were a Sonnet + module. This removes the need for "if is_hard_coded:..." branches, which you'd + otherwise need if e.g. calling get_variables() on the policy. + """ + + def __init__(self, + transformation: types.TensorValuedCallable, + name: Optional[str] = None): + super().__init__(name=name) + self._transformation = transformation + + def __call__(self, *args, **kwargs): + return self._transformation(*args, **kwargs) + + +def to_sonnet_module( + transformation: types.TensorValuedCallable + ) -> snt.Module: + """Convert a tensor transformation to a Sonnet Module. + + Args: + transformation: A Callable that takes one or more (nested) Tensors, and + returns one or more (nested) Tensors. + + Returns: + A Sonnet Module that wraps the transformation. + """ + + if isinstance(transformation, snt.Module): + return transformation + + module = TransformationWrapper(transformation) + + # Wrap the module to allow it to return an empty variable tuple. + return snt.allow_empty_variables(module) + + +def to_numpy(nest: types.NestedTensor) -> types.NestedArray: + """Converts a nest of Tensors to a nest of numpy arrays.""" + return tree.map_structure(lambda x: x.numpy(), nest) + + +def to_numpy_squeeze(nest: types.NestedTensor, axis=0) -> types.NestedArray: + """Converts a nest of Tensors to a nest of numpy arrays and squeeze axis.""" + return tree.map_structure(lambda x: tf.squeeze(x, axis=axis).numpy(), nest) + + +def zeros_like(nest: types.Nest) -> types.NestedTensor: + """Given a nest of array-like objects, returns similarly nested tf.zeros.""" + return tree.map_structure(lambda x: tf.zeros(x.shape, x.dtype), nest) + + +# TODO(b/160311329): Migrate call-sites and remove. +stack_sequence_fields = tree_utils.stack_sequence_fields diff --git a/acme/acme/tf/utils_test.py b/acme/acme/tf/utils_test.py new file mode 100644 index 00000000..f5421262 --- /dev/null +++ b/acme/acme/tf/utils_test.py @@ -0,0 +1,134 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for acme.tf.utils.""" + +from typing import Sequence, Tuple + +from acme import specs +from acme.tf import utils as tf2_utils +import numpy as np +import sonnet as snt +import tensorflow as tf + +from absl.testing import absltest +from absl.testing import parameterized + + +class PolicyValueHead(snt.Module): + """A network with two linear layers, for policy and value respectively.""" + + def __init__(self, num_actions: int): + super().__init__(name='policy_value_network') + self._policy_layer = snt.Linear(num_actions) + self._value_layer = snt.Linear(1) + + def __call__(self, inputs: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: + """Returns a (Logits, Value) tuple.""" + logits = self._policy_layer(inputs) # [B, A] + value = tf.squeeze(self._value_layer(inputs), axis=-1) # [B] + + return logits, value + + +class CreateVariableTest(parameterized.TestCase): + """Tests for tf2_utils.create_variables method.""" + + @parameterized.parameters([True, False]) + def test_feedforward(self, recurrent: bool): + model = snt.Linear(42) + if recurrent: + model = snt.DeepRNN([model]) + input_spec = specs.Array(shape=(10,), dtype=np.float32) + tf2_utils.create_variables(model, [input_spec]) + variables: Sequence[tf.Variable] = model.variables + shapes = [v.shape.as_list() for v in variables] + self.assertSequenceEqual(shapes, [[42], [10, 42]]) + + @parameterized.parameters([True, False]) + def test_output_spec_feedforward(self, recurrent: bool): + input_spec = specs.Array(shape=(10,), dtype=np.float32) + model = snt.Linear(42) + expected_spec = tf.TensorSpec(shape=(42,), dtype=tf.float32) + if recurrent: + model = snt.DeepRNN([model]) + expected_spec = (expected_spec, ()) + + output_spec = tf2_utils.create_variables(model, [input_spec]) + self.assertEqual(output_spec, expected_spec) + + def test_multiple_outputs(self): + model = PolicyValueHead(42) + input_spec = specs.Array(shape=(10,), dtype=np.float32) + expected_spec = (tf.TensorSpec(shape=(42,), dtype=tf.float32), + tf.TensorSpec(shape=(), dtype=tf.float32)) + output_spec = tf2_utils.create_variables(model, [input_spec]) + variables: Sequence[tf.Variable] = model.variables + shapes = [v.shape.as_list() for v in variables] + self.assertSequenceEqual(shapes, [[42], [10, 42], [1], [10, 1]]) + self.assertSequenceEqual(output_spec, expected_spec) + + def test_scalar_output(self): + model = tf2_utils.to_sonnet_module(tf.reduce_sum) + input_spec = specs.Array(shape=(10,), dtype=np.float32) + expected_spec = tf.TensorSpec(shape=(), dtype=tf.float32) + output_spec = tf2_utils.create_variables(model, [input_spec]) + self.assertEqual(model.variables, ()) + self.assertEqual(output_spec, expected_spec) + + def test_none_output(self): + model = tf2_utils.to_sonnet_module(lambda x: None) + input_spec = specs.Array(shape=(10,), dtype=np.float32) + expected_spec = None + output_spec = tf2_utils.create_variables(model, [input_spec]) + self.assertEqual(model.variables, ()) + self.assertEqual(output_spec, expected_spec) + + def test_multiple_inputs_and_outputs(self): + def transformation(aa, bb, cc): + return (tf.concat([aa, bb, cc], axis=-1), + tf.concat([bb, cc], axis=-1)) + + model = tf2_utils.to_sonnet_module(transformation) + dtype = np.float32 + input_spec = [specs.Array(shape=(2,), dtype=dtype), + specs.Array(shape=(3,), dtype=dtype), + specs.Array(shape=(4,), dtype=dtype)] + expected_output_spec = (tf.TensorSpec(shape=(9,), dtype=dtype), + tf.TensorSpec(shape=(7,), dtype=dtype)) + output_spec = tf2_utils.create_variables(model, input_spec) + self.assertEqual(model.variables, ()) + self.assertEqual(output_spec, expected_output_spec) + + +class Tf2UtilsTest(parameterized.TestCase): + """Tests for tf2_utils methods.""" + + def test_batch_concat(self): + batch_size = 32 + inputs = [ + tf.zeros(shape=(batch_size, 2)), + { + 'foo': tf.zeros(shape=(batch_size, 5, 3)) + }, + [tf.zeros(shape=(batch_size, 1))], + ] + + output_shape = tf2_utils.batch_concat(inputs).shape.as_list() + expected_shape = [batch_size, 2 + 5 * 3 + 1] + self.assertSequenceEqual(output_shape, expected_shape) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/tf/variable_utils.py b/acme/acme/tf/variable_utils.py new file mode 100644 index 00000000..462d96b2 --- /dev/null +++ b/acme/acme/tf/variable_utils.py @@ -0,0 +1,108 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Variable handling utilities for TensorFlow 2.""" + +from concurrent import futures +from typing import Mapping, Optional, Sequence + +from acme import core + +import tensorflow as tf +import tree + + +class VariableClient: + """A variable client for updating variables from a remote source.""" + + def __init__(self, + client: core.VariableSource, + variables: Mapping[str, Sequence[tf.Variable]], + update_period: int = 1): + self._keys = list(variables.keys()) + self._variables = tree.flatten(list(variables.values())) + self._call_counter = 0 + self._update_period = update_period + self._client = client + self._request = lambda: client.get_variables(self._keys) + + # Create a single background thread to fetch variables without necessarily + # blocking the actor. + self._executor = futures.ThreadPoolExecutor(max_workers=1) + self._async_request = lambda: self._executor.submit(self._request) + + # Initialize this client's future to None to indicate to the `update()` + # method that there is no pending/running request. + self._future: Optional[futures.Future] = None + + def update(self, wait: bool = False): + """Periodically updates the variables with the latest copy from the source. + + This stateful update method keeps track of the number of calls to it and, + every `update_period` call, sends a request to its server to retrieve the + latest variables. + + If wait is True, a blocking request is executed. Any active request will be + cancelled. + If wait is False, this method makes an asynchronous request for variables + and returns. Unless the request is immediately fulfilled, the variables are + only copied _within a subsequent call to_ `update()`, whenever the request + is fulfilled by the `VariableSource`. If there is an existing fulfilled + request when this method is called, the resulting variables are immediately + copied. + + Args: + wait: if True, executes blocking update. + """ + # Track the number of calls (we only update periodically). + if self._call_counter < self._update_period: + self._call_counter += 1 + + period_reached: bool = self._call_counter >= self._update_period + + if period_reached and wait: + # Cancel any active request. + self._future: Optional[futures.Future] = None + self.update_and_wait() + self._call_counter = 0 + return + + if period_reached and self._future is None: + # The update period has been reached and no request has been sent yet, so + # making an asynchronous request now. + self._future = self._async_request() + self._call_counter = 0 + + if self._future is not None and self._future.done(): + # The active request is done so copy the result and remove the future. + self._copy(self._future.result()) + self._future: Optional[futures.Future] = None + else: + # There is either a pending/running request or we're between update + # periods, so just carry on. + return + + def update_and_wait(self): + """Immediately update and block until we get the result.""" + self._copy(self._request()) + + def _copy(self, new_variables: Sequence[Sequence[tf.Variable]]): + """Copies the new variables to the old ones.""" + + new_variables = tree.flatten(new_variables) + if len(self._variables) != len(new_variables): + raise ValueError('Length mismatch between old variables and new.') + + for new, old in zip(new_variables, self._variables): + old.assign(new) diff --git a/acme/acme/tf/variable_utils_test.py b/acme/acme/tf/variable_utils_test.py new file mode 100644 index 00000000..2c626c43 --- /dev/null +++ b/acme/acme/tf/variable_utils_test.py @@ -0,0 +1,133 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for acme.tf.variable_utils.""" + +import threading + +from acme.testing import fakes +from acme.tf import utils as tf2_utils +from acme.tf import variable_utils as tf2_variable_utils +import sonnet as snt +import tensorflow as tf + + +_MLP_LAYERS = [50, 30] +_INPUT_SIZE = 28 +_BATCH_SIZE = 8 +_UPDATE_PERIOD = 2 + + +class VariableClientTest(tf.test.TestCase): + + def setUp(self): + super().setUp() + + # Create two instances of the same model. + self._actor_model = snt.nets.MLP(_MLP_LAYERS) + self._learner_model = snt.nets.MLP(_MLP_LAYERS) + + # Create variables first. + input_spec = tf.TensorSpec(shape=(_INPUT_SIZE,), dtype=tf.float32) + tf2_utils.create_variables(self._actor_model, [input_spec]) + tf2_utils.create_variables(self._learner_model, [input_spec]) + + def test_update_and_wait(self): + # Create a variable source (emulating the learner). + np_learner_variables = tf2_utils.to_numpy(self._learner_model.variables) + variable_source = fakes.VariableSource(np_learner_variables) + + # Create a variable client (emulating the actor). + variable_client = tf2_variable_utils.VariableClient( + variable_source, {'policy': self._actor_model.variables}) + + # Create some random batch of test input: + x = tf.random.normal(shape=(_BATCH_SIZE, _INPUT_SIZE)) + + # Before copying variables, the models have different outputs. + self.assertNotAllClose(self._actor_model(x), self._learner_model(x)) + + # Update the variable client. + variable_client.update_and_wait() + + # After copying variables (by updating the client), the models are the same. + self.assertAllClose(self._actor_model(x), self._learner_model(x)) + + def test_update(self): + # Create a barrier to be shared between the test body and the variable + # source. The barrier will block until, in this case, two threads call + # wait(). Note that the (fake) variable source will call it within its + # get_variables() call. + barrier = threading.Barrier(2) + + # Create a variable source (emulating the learner). + np_learner_variables = tf2_utils.to_numpy(self._learner_model.variables) + variable_source = fakes.VariableSource(np_learner_variables, barrier) + + # Create a variable client (emulating the actor). + variable_client = tf2_variable_utils.VariableClient( + variable_source, {'policy': self._actor_model.variables}, + update_period=_UPDATE_PERIOD) + + # Create some random batch of test input: + x = tf.random.normal(shape=(_BATCH_SIZE, _INPUT_SIZE)) + + # Create variables by doing the computation once. + learner_output = self._learner_model(x) + actor_output = self._actor_model(x) + del learner_output, actor_output + + for _ in range(_UPDATE_PERIOD): + # Before the update period is reached, the models have different outputs. + self.assertNotAllClose(self._actor_model.variables, + self._learner_model.variables) + + # Before the update period is reached, the variable client should not make + # any requests for variables. + self.assertIsNone(variable_client._future) + + variable_client.update() + + # Make sure the last call created a request for variables and reset the + # internal call counter. + self.assertIsNotNone(variable_client._future) + self.assertEqual(variable_client._call_counter, 0) + future = variable_client._future + + for _ in range(_UPDATE_PERIOD): + # Before the barrier allows the variables to be released, the models have + # different outputs. + self.assertNotAllClose(self._actor_model.variables, + self._learner_model.variables) + + variable_client.update() + + # Make sure no new requests are made. + self.assertEqual(variable_client._future, future) + + # Calling wait() on the barrier will now allow the variables to be copied + # over from source to client. + barrier.wait() + + # Update once more to ensure the variables are copied over. + while variable_client._future is not None: + variable_client.update() + + # After a number of update calls, the variables should be the same. + self.assertAllClose(self._actor_model.variables, + self._learner_model.variables) + + +if __name__ == '__main__': + tf.test.main() diff --git a/acme/acme/types.py b/acme/acme/types.py new file mode 100644 index 00000000..8305880f --- /dev/null +++ b/acme/acme/types.py @@ -0,0 +1,65 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common types used throughout Acme.""" + +from typing import Any, Callable, Iterable, Mapping, NamedTuple, Union +from acme import specs + +# Define types for nested arrays and tensors. +# TODO(b/144758674): Replace these with recursive type definitions. +NestedArray = Any +NestedTensor = Any + +# pytype: disable=not-supported-yet +NestedSpec = Union[ + specs.Array, + Iterable['NestedSpec'], + Mapping[Any, 'NestedSpec'], +] +# pytype: enable=not-supported-yet + +# TODO(b/144763593): Replace all instances of nest with the tensor/array types. +Nest = Union[NestedArray, NestedTensor, NestedSpec] + +TensorTransformation = Callable[[NestedTensor], NestedTensor] +TensorValuedCallable = Callable[..., NestedTensor] + + +class Batches(int): + """Helper class for specification of quantities in units of batches. + + Example usage: + + # Configure the batch size and replay size in units of batches. + config.batch_size = 32 + config.replay_size = Batches(4) + + # ... + + # Convert the replay size at runtime. + if isinstance(config.replay_size, Batches): + config.replay_size = config.replay_size * config.batch_size # int: 128 + + """ + + +class Transition(NamedTuple): + """Container for a transition.""" + observation: NestedArray + action: NestedArray + reward: NestedArray + discount: NestedArray + next_observation: NestedArray + extras: NestedArray = () diff --git a/acme/acme/utils/__init__.py b/acme/acme/utils/__init__.py new file mode 100644 index 00000000..64ff1fdf --- /dev/null +++ b/acme/acme/utils/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Acme utility functions.""" diff --git a/acme/acme/utils/async_utils.py b/acme/acme/utils/async_utils.py new file mode 100644 index 00000000..aaccc561 --- /dev/null +++ b/acme/acme/utils/async_utils.py @@ -0,0 +1,113 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities to use within loggers.""" + +import queue +import threading +from typing import Callable, TypeVar, Generic + +from absl import logging + + +E = TypeVar("E") + + +class AsyncExecutor(Generic[E]): + """Executes a blocking function asynchronously on a queue of items.""" + + def __init__( + self, + fn: Callable[[E], None], + queue_size: int = 1, + interruptible_interval_secs: float = 1.0, + ): + """Buffers elements in a queue and runs `fn` asynchronously.. + + NOTE: Once closed, `AsyncExecutor` will block until current `fn` finishes + but is not guaranteed to dequeue all elements currently stored in + the data queue. This is intentional so as to prevent a blocking `fn` call + from preventing `AsyncExecutor` from closing. + + Args: + fn: A callable to be executed upon dequeuing an element from data + queue. + queue_size: The maximum size of the synchronized buffer queue. + interruptible_interval_secs: Timeout interval in seconds for blocking + queue operations after which the background threads check for errors and + if background threads should stop. + """ + self._data = queue.Queue(maxsize=queue_size) + self._should_stop = threading.Event() + self._errors = queue.Queue() + self._interruptible_interval_secs = interruptible_interval_secs + + def _dequeue() -> None: + """Dequeue data from a queue and invoke blocking call.""" + while not self._should_stop.is_set(): + try: + element = self._data.get(timeout=self._interruptible_interval_secs) + # Execute fn upon dequeuing an element from the data queue. + fn(element) + except queue.Empty: + # If queue is Empty for longer than the specified time interval, + # check again if should_stop has been requested and retry. + continue + except Exception as e: + logging.error("AsyncExecuter thread terminated with error.") + logging.exception(e) + self._errors.put(e) + self._should_stop.set() + raise # Never caught by anything, just terminates the thread. + + self._thread = threading.Thread(target=_dequeue, daemon=True) + self._thread.start() + + def _raise_on_error(self) -> None: + try: + # Raise the error on the caller thread if an error has been raised in the + # looper thread. + raise self._errors.get_nowait() + except queue.Empty: + pass + + def close(self): + self._should_stop.set() + # Join all background threads. + self._thread.join() + # Raise errors produced by background threads. + self._raise_on_error() + + def put(self, element: E) -> None: + """Puts `element` asynchronuously onto the underlying data queue. + + The write call blocks if the underlying data_queue contains `queue_size` + elements for over `self._interruptible_interval_secs` second, in which + case we check if stop has been requested or if there has been an error + raised on the looper thread. If neither happened, retry enqueue. + + Args: + element: an element to be put into the underlying data queue and dequeued + asynchronuously for `fn(element)` call. + """ + while not self._should_stop.is_set(): + try: + self._data.put(element, timeout=self._interruptible_interval_secs) + break + except queue.Full: + continue + else: + # If `should_stop` has been set, then raises if any has been raised on + # the background thread. + self._raise_on_error() diff --git a/acme/acme/utils/counting.py b/acme/acme/utils/counting.py new file mode 100644 index 00000000..8492b3d4 --- /dev/null +++ b/acme/acme/utils/counting.py @@ -0,0 +1,139 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A simple, hierarchical distributed counter.""" + +import threading +import time +from typing import Dict, Mapping, Optional, Union + +from acme import core + +Number = Union[int, float] + + +class Counter(core.Saveable): + """A simple counter object that can periodically sync with a parent.""" + + def __init__(self, + parent: Optional['Counter'] = None, + prefix: str = '', + time_delta: float = 1.0, + return_only_prefixed: bool = False): + """Initialize the counter. + + Args: + parent: a Counter object to cache locally (or None for no caching). + prefix: string prefix to use for all local counts. + time_delta: time difference in seconds between syncing with the parent + counter. + return_only_prefixed: if True, and if `prefix` isn't empty, return counts + restricted to the given `prefix` on each call to `increment` and + `get_counts`. The `prefix` is stripped from returned count names. + """ + + self._parent = parent + self._prefix = prefix + self._time_delta = time_delta + + # Hold local counts and we'll lock around that. + # These are counts to be synced to the parent and the cache. + self._counts = {} + self._lock = threading.Lock() + + # We'll sync periodically (when the last sync was more than self._time_delta + # seconds ago.) + self._cache = {} + self._last_sync_time = 0.0 + + self._return_only_prefixed = return_only_prefixed + + def increment(self, **counts: Number) -> Dict[str, Number]: + """Increment a set of counters. + + Args: + **counts: keyword arguments specifying count increments. + + Returns: + The [name, value] mapping of all counters stored, i.e. this will also + include counts that were not updated by this call to increment. + """ + with self._lock: + for key, value in counts.items(): + self._counts.setdefault(key, 0) + self._counts[key] += value + return self.get_counts() + + def get_counts(self) -> Dict[str, Number]: + """Return all counts tracked by this counter.""" + now = time.time() + # TODO(b/144421838): use futures instead of blocking. + if self._parent and (now - self._last_sync_time) > self._time_delta: + with self._lock: + counts = _prefix_keys(self._counts, self._prefix) + # Reset the local counts, as they will be merged into the parent and the + # cache. + self._counts = {} + self._cache = self._parent.increment(**counts) + self._last_sync_time = now + + # Potentially prefix the keys in the counts dictionary. + counts = _prefix_keys(self._counts, self._prefix) + + # If there's no prefix make a copy of the dictionary so we don't modify the + # internal self._counts. + if not self._prefix: + counts = dict(counts) + + # Combine local counts with any parent counts. + for key, value in self._cache.items(): + counts[key] = counts.get(key, 0) + value + + if self._prefix and self._return_only_prefixed: + counts = dict([(key[len(self._prefix) + 1:], value) + for key, value in counts.items() + if key.startswith(f'{self._prefix}_')]) + return counts + + def save(self) -> Mapping[str, Mapping[str, Number]]: + return {'counts': self._counts, 'cache': self._cache} + + def restore(self, state: Mapping[str, Mapping[str, Number]]): + # Force a sync, if necessary, on the next get_counts call. + self._last_sync_time = 0. + self._counts = state['counts'] + self._cache = state['cache'] + + def get_steps_key(self) -> str: + """Returns the key to use for steps by this counter.""" + if not self._prefix or self._return_only_prefixed: + return 'steps' + return f'{self._prefix}_steps' + + +def _prefix_keys(dictionary: Dict[str, Number], prefix: str): + """Return a dictionary with prefixed keys. + + Args: + dictionary: dictionary to return a copy of. + prefix: string to use as the prefix. + + Returns: + Return a copy of the given dictionary whose keys are replaced by + "{prefix}_{key}". If the prefix is the empty string it returns the given + dictionary unchanged. + """ + if prefix: + dictionary = {f'{prefix}_{k}': v for k, v in dictionary.items()} + return dictionary diff --git a/acme/acme/utils/counting_test.py b/acme/acme/utils/counting_test.py new file mode 100644 index 00000000..1736b632 --- /dev/null +++ b/acme/acme/utils/counting_test.py @@ -0,0 +1,117 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for acme.utils.counting.""" + +import threading + +from acme.utils import counting + +from absl.testing import absltest + + +class Barrier: + """Defines a simple barrier class to synchronize on a particular event.""" + + def __init__(self, num_threads): + """Constructor. + + Args: + num_threads: int - how many threads will be syncronizing on this barrier + """ + self._num_threads = num_threads + self._count = 0 + self._cond = threading.Condition() + + def wait(self): + """Waits on the barrier until all threads have called this method.""" + with self._cond: + self._count += 1 + self._cond.notifyAll() + while self._count < self._num_threads: + self._cond.wait() + + +class CountingTest(absltest.TestCase): + + def test_counter_threading(self): + counter = counting.Counter() + num_threads = 10 + barrier = Barrier(num_threads) + + # Increment in every thread at the same time. + def add_to_counter(): + barrier.wait() + counter.increment(foo=1) + + # Run the threads. + threads = [] + for _ in range(num_threads): + t = threading.Thread(target=add_to_counter) + t.start() + threads.append(t) + for t in threads: + t.join() + + # Make sure the counter has been incremented once per thread. + counts = counter.get_counts() + self.assertEqual(counts['foo'], num_threads) + + def test_counter_caching(self): + parent = counting.Counter() + counter = counting.Counter(parent, time_delta=0.) + counter.increment(foo=12) + self.assertEqual(parent.get_counts(), counter.get_counts()) + + def test_shared_counts(self): + # Two counters with shared parent should share counts (modulo namespacing). + parent = counting.Counter() + child1 = counting.Counter(parent, 'child1') + child2 = counting.Counter(parent, 'child2') + child1.increment(foo=1) + result = child2.increment(foo=2) + expected = {'child1_foo': 1, 'child2_foo': 2} + self.assertEqual(result, expected) + + def test_return_only_prefixed(self): + parent = counting.Counter() + child1 = counting.Counter( + parent, 'child1', time_delta=0., return_only_prefixed=False) + child2 = counting.Counter( + parent, 'child2', time_delta=0., return_only_prefixed=True) + child1.increment(foo=1) + child2.increment(bar=1) + self.assertEqual(child1.get_counts(), {'child1_foo': 1, 'child2_bar': 1}) + self.assertEqual(child2.get_counts(), {'bar': 1}) + + def test_get_steps_key(self): + parent = counting.Counter() + child1 = counting.Counter( + parent, 'child1', time_delta=0., return_only_prefixed=False) + child2 = counting.Counter( + parent, 'child2', time_delta=0., return_only_prefixed=True) + self.assertEqual(child1.get_steps_key(), 'child1_steps') + self.assertEqual(child2.get_steps_key(), 'steps') + child1.increment(steps=1) + child2.increment(steps=2) + self.assertEqual(child1.get_counts().get(child1.get_steps_key()), 1) + self.assertEqual(child2.get_counts().get(child2.get_steps_key()), 2) + + def test_parent_prefix(self): + parent = counting.Counter(prefix='parent') + child = counting.Counter(parent, prefix='child', time_delta=0.) + self.assertEqual(child.get_steps_key(), 'child_steps') + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/utils/experiment_utils.py b/acme/acme/utils/experiment_utils.py new file mode 100644 index 00000000..89bec431 --- /dev/null +++ b/acme/acme/utils/experiment_utils.py @@ -0,0 +1,28 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility definitions for Acme experiments.""" + +from typing import Optional + +from acme.utils import loggers + + +def make_experiment_logger(label: str, + steps_key: Optional[str] = None, + task_instance: int = 0) -> loggers.Logger: + del task_instance + if steps_key is None: + steps_key = f'{label}_steps' + return loggers.make_default_logger(label=label, steps_key=steps_key) diff --git a/acme/acme/utils/frozen_learner.py b/acme/acme/utils/frozen_learner.py new file mode 100644 index 00000000..d33d221b --- /dev/null +++ b/acme/acme/utils/frozen_learner.py @@ -0,0 +1,58 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Frozen learner.""" + +from typing import Callable, List, Optional, Sequence + +import acme + + +class FrozenLearner(acme.Learner): + """Wraps a learner ignoring the step calls, i.e. freezing it.""" + + def __init__(self, + learner: acme.Learner, + step_fn: Optional[Callable[[], None]] = None): + """Initializes the frozen learner. + + Args: + learner: Learner to be wrapped. + step_fn: Function to call instead of the step() method of the learner. + This can be used, e.g. to drop samples from an iterator that would + normally be consumed by the learner. + """ + self._learner = learner + self._step_fn = step_fn + + def step(self): + """See base class.""" + if self._step_fn: + self._step_fn() + + def run(self, num_steps: Optional[int] = None): + """See base class.""" + self._learner.run(num_steps) + + def save(self): + """See base class.""" + return self._learner.save() + + def restore(self, state): + """See base class.""" + self._learner.restore(state) + + def get_variables(self, names: Sequence[str]) -> List[acme.types.NestedArray]: + """See base class.""" + return self._learner.get_variables(names) diff --git a/acme/acme/utils/frozen_learner_test.py b/acme/acme/utils/frozen_learner_test.py new file mode 100644 index 00000000..aabd3eb6 --- /dev/null +++ b/acme/acme/utils/frozen_learner_test.py @@ -0,0 +1,77 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for frozen_learner.""" + +from unittest import mock + +import acme +from acme.utils import frozen_learner +from absl.testing import absltest + + +class FrozenLearnerTest(absltest.TestCase): + + @mock.patch.object(acme, 'Learner', autospec=True) + def test_step_fn(self, mock_learner): + num_calls = 0 + + def step_fn(): + nonlocal num_calls + num_calls += 1 + + learner = frozen_learner.FrozenLearner(mock_learner, step_fn=step_fn) + + # Step two times. + learner.step() + learner.step() + + self.assertEqual(num_calls, 2) + # step() method of the wrapped learner should not be called. + mock_learner.step.assert_not_called() + + @mock.patch.object(acme, 'Learner', autospec=True) + def test_no_step_fn(self, mock_learner): + learner = frozen_learner.FrozenLearner(mock_learner) + learner.step() + # step() method of the wrapped learner should not be called. + mock_learner.step.assert_not_called() + + @mock.patch.object(acme, 'Learner', autospec=True) + def test_save_and_restore(self, mock_learner): + learner = frozen_learner.FrozenLearner(mock_learner) + + mock_learner.save.return_value = 'state1' + + state = learner.save() + self.assertEqual(state, 'state1') + + learner.restore('state2') + # State of the wrapped learner should be restored. + mock_learner.restore.assert_called_once_with('state2') + + @mock.patch.object(acme, 'Learner', autospec=True) + def test_get_variables(self, mock_learner): + learner = frozen_learner.FrozenLearner(mock_learner) + + mock_learner.get_variables.return_value = [1, 2] + + variables = learner.get_variables(['a', 'b']) + # Values should match with those returned by the wrapped learner. + self.assertEqual(variables, [1, 2]) + mock_learner.get_variables.assert_called_once_with(['a', 'b']) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/utils/iterator_utils.py b/acme/acme/utils/iterator_utils.py new file mode 100644 index 00000000..67a2f17f --- /dev/null +++ b/acme/acme/utils/iterator_utils.py @@ -0,0 +1,38 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Iterator utilities.""" +import itertools +import operator +from typing import Any, Iterator, List, Sequence + + +def unzip_iterators(zipped_iterators: Iterator[Sequence[Any]], + num_sub_iterators: int) -> List[Iterator[Any]]: + """Returns unzipped iterators. + + Note that simply returning: + [(x[i] for x in iter_tuple[i]) for i in range(num_sub_iterators)] + seems to cause all iterators to point to the final value of i, thus causing + all sub_learners to consume data from this final iterator. + + Args: + zipped_iterators: zipped iterators (e.g., from zip_iterators()). + num_sub_iterators: the number of sub-iterators in the zipped iterator. + """ + iter_tuple = itertools.tee(zipped_iterators, num_sub_iterators) + return [ + map(operator.itemgetter(i), iter_tuple[i]) + for i in range(num_sub_iterators) + ] diff --git a/acme/acme/utils/iterator_utils_test.py b/acme/acme/utils/iterator_utils_test.py new file mode 100644 index 00000000..ebe21f3a --- /dev/null +++ b/acme/acme/utils/iterator_utils_test.py @@ -0,0 +1,40 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for iterator_utils.""" + +from acme.utils import iterator_utils +import numpy as np + +from absl.testing import absltest + + +class IteratorUtilsTest(absltest.TestCase): + + def test_iterator_zipping(self): + + def get_iters(): + x = iter(range(0, 10)) + y = iter(range(20, 30)) + return [x, y] + + zipped = zip(*get_iters()) + unzipped = iterator_utils.unzip_iterators(zipped, num_sub_iterators=2) + expected_x, expected_y = get_iters() + np.testing.assert_equal(list(unzipped[0]), list(expected_x)) + np.testing.assert_equal(list(unzipped[1]), list(expected_y)) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/utils/loggers/__init__.py b/acme/acme/utils/loggers/__init__.py new file mode 100644 index 00000000..a79942e7 --- /dev/null +++ b/acme/acme/utils/loggers/__init__.py @@ -0,0 +1,38 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Acme loggers.""" + +from acme.utils.loggers.aggregators import Dispatcher +from acme.utils.loggers.asynchronous import AsyncLogger +from acme.utils.loggers.auto_close import AutoCloseLogger +from acme.utils.loggers.base import Logger +from acme.utils.loggers.base import LoggerFactory +from acme.utils.loggers.base import LoggerLabel +from acme.utils.loggers.base import LoggerStepsKey +from acme.utils.loggers.base import LoggingData +from acme.utils.loggers.base import NoOpLogger +from acme.utils.loggers.base import TaskInstance +from acme.utils.loggers.base import to_numpy +from acme.utils.loggers.constant import ConstantLogger +from acme.utils.loggers.csv import CSVLogger +from acme.utils.loggers.dataframe import InMemoryLogger +from acme.utils.loggers.filters import GatedFilter +from acme.utils.loggers.filters import KeyFilter +from acme.utils.loggers.filters import NoneFilter +from acme.utils.loggers.filters import TimeFilter +from acme.utils.loggers.default import make_default_logger # pylint: disable=g-bad-import-order +from acme.utils.loggers.terminal import TerminalLogger + +# Internal imports. diff --git a/acme/acme/utils/loggers/aggregators.py b/acme/acme/utils/loggers/aggregators.py new file mode 100644 index 00000000..354f72cb --- /dev/null +++ b/acme/acme/utils/loggers/aggregators.py @@ -0,0 +1,42 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for aggregating to other loggers.""" + +from typing import Callable, Optional, Sequence +from acme.utils.loggers import base + + +class Dispatcher(base.Logger): + """Writes data to multiple `Logger` objects.""" + + def __init__( + self, + to: Sequence[base.Logger], + serialize_fn: Optional[Callable[[base.LoggingData], str]] = None, + ): + """Initialize `Dispatcher` connected to several `Logger` objects.""" + self._to = to + self._serialize_fn = serialize_fn + + def write(self, values: base.LoggingData): + """Writes `values` to the underlying `Logger` objects.""" + if self._serialize_fn: + values = self._serialize_fn(values) + for logger in self._to: + logger.write(values) + + def close(self): + for logger in self._to: + logger.close() diff --git a/acme/acme/utils/loggers/asynchronous.py b/acme/acme/utils/loggers/asynchronous.py new file mode 100644 index 00000000..06aeb005 --- /dev/null +++ b/acme/acme/utils/loggers/asynchronous.py @@ -0,0 +1,42 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Logger which makes another logger asynchronous.""" + +from typing import Any, Mapping + +from acme.utils import async_utils +from acme.utils.loggers import base + + +class AsyncLogger(base.Logger): + """Logger which makes the logging to another logger asyncronous.""" + + def __init__(self, to: base.Logger): + """Initializes the logger. + + Args: + to: A `Logger` object to which the current object will forward its results + when `write` is called. + """ + self._to = to + self._async_worker = async_utils.AsyncExecutor(self._to.write, queue_size=5) + + def write(self, values: Mapping[str, Any]): + self._async_worker.put(values) + + def close(self): + """Closes the logger, closing is synchronous.""" + self._async_worker.close() + self._to.close() diff --git a/acme/acme/utils/loggers/auto_close.py b/acme/acme/utils/loggers/auto_close.py new file mode 100644 index 00000000..c3a92eef --- /dev/null +++ b/acme/acme/utils/loggers/auto_close.py @@ -0,0 +1,43 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Logger which self closes on exit if not closed yet.""" + +import weakref + +from acme.utils.loggers import base + + +class AutoCloseLogger(base.Logger): + """Logger which auto closes itself on exit if not already closed.""" + + def __init__(self, logger: base.Logger): + self._logger = logger + # The finalizer "logger.close" is invoked in one of the following scenario: + # 1) the current logger is GC + # 2) from the python doc, when the program exits, each remaining live + # finalizer is called. + # Note that in the normal flow, where "close" is explicitly called, + # the finalizer is marked as dead using the detach function so that + # the underlying logger is not closed twice (once explicitly and once + # implicitly when the object is GC or when the program exits). + self._finalizer = weakref.finalize(self, logger.close) + + def write(self, values: base.LoggingData): + self._logger.write(values) + + def close(self): + if self._finalizer.detach(): + self._logger.close() + self._logger = None diff --git a/acme/acme/utils/loggers/base.py b/acme/acme/utils/loggers/base.py new file mode 100644 index 00000000..8dba65d0 --- /dev/null +++ b/acme/acme/utils/loggers/base.py @@ -0,0 +1,87 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base logger.""" + +import abc +from typing import Any, Mapping, Protocol, Optional + +import numpy as np +import tree + +LoggingData = Mapping[str, Any] + + +class Logger(abc.ABC): + """A logger has a `write` method.""" + + @abc.abstractmethod + def write(self, data: LoggingData) -> None: + """Writes `data` to destination (file, terminal, database, etc).""" + + @abc.abstractmethod + def close(self) -> None: + """Closes the logger, not expecting any further write.""" + + +TaskInstance = int +# TODO(stanczyk): Turn LoggerLabel into an enum of [Learner, Actor, Evaluator]. +LoggerLabel = str +LoggerStepsKey = str + + +class LoggerFactory(Protocol): + + def __call__(self, + label: LoggerLabel, + steps_key: Optional[LoggerStepsKey] = None, + instance: Optional[TaskInstance] = None) -> Logger: + ... + + +class NoOpLogger(Logger): + """Simple Logger which does nothing and outputs no logs. + + This should be used sparingly, but it can prove useful if we want to quiet an + individual component and have it produce no logging whatsoever. + """ + + def write(self, data: LoggingData): + pass + + def close(self): + pass + + +def tensor_to_numpy(value: Any): + if hasattr(value, 'numpy'): + return value.numpy() # tf.Tensor (TF2). + if hasattr(value, 'device_buffer'): + return np.asarray(value) # jnp.DeviceArray. + return value + + +def to_numpy(values: Any): + """Converts tensors in a nested structure to numpy. + + Converts tensors from TensorFlow to Numpy if needed without importing TF + dependency. + + Args: + values: nested structure with numpy and / or TF tensors. + + Returns: + Same nested structure as values, but with numpy tensors. + """ + return tree.map_structure(tensor_to_numpy, values) diff --git a/acme/acme/utils/loggers/base_test.py b/acme/acme/utils/loggers/base_test.py new file mode 100644 index 00000000..b392a2d4 --- /dev/null +++ b/acme/acme/utils/loggers/base_test.py @@ -0,0 +1,41 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for acme.utils.loggers.base.""" + +from acme.utils.loggers import base +import jax.numpy as jnp +import numpy as np +import tensorflow as tf + +from absl.testing import absltest + + +class BaseTest(absltest.TestCase): + + def test_tensor_serialisation(self): + data = {'x': tf.zeros(shape=(32,))} + output = base.to_numpy(data) + expected = {'x': np.zeros(shape=(32,))} + np.testing.assert_array_equal(output['x'], expected['x']) + + def test_device_array_serialisation(self): + data = {'x': jnp.zeros(shape=(32,))} + output = base.to_numpy(data) + expected = {'x': np.zeros(shape=(32,))} + np.testing.assert_array_equal(output['x'], expected['x']) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/utils/loggers/constant.py b/acme/acme/utils/loggers/constant.py new file mode 100644 index 00000000..2dba2686 --- /dev/null +++ b/acme/acme/utils/loggers/constant.py @@ -0,0 +1,46 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Logger for values that remain constant.""" + +from acme.utils.loggers import base + + +class ConstantLogger(base.Logger): + """Logger for values that remain constant throughout the experiment. + + This logger is used to log additional values e.g. level_name or + hyperparameters that do not change in an experiment. Having these values + allows to group or facet plots when analysing data post-experiment. + """ + + def __init__( + self, + constant_data: base.LoggingData, + to: base.Logger, + ): + """Initialise the extra info logger. + + Args: + constant_data: Key-value pairs containing the constant info to be logged. + to: The logger to add these extra info to. + """ + self._constant_data = constant_data + self._to = to + + def write(self, data: base.LoggingData): + self._to.write({**self._constant_data, **data}) + + def close(self): + self._to.close() diff --git a/acme/acme/utils/loggers/csv.py b/acme/acme/utils/loggers/csv.py new file mode 100644 index 00000000..eb19a514 --- /dev/null +++ b/acme/acme/utils/loggers/csv.py @@ -0,0 +1,141 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A simple CSV logger. + +Warning: Does not support preemption. +""" + +import csv +import os +import time +from typing import TextIO, Union + +from absl import logging + +from acme.utils import paths +from acme.utils.loggers import base + + +class CSVLogger(base.Logger): + """Standard CSV logger. + + The fields are inferred from the first call to write() and any additional + fields afterwards are ignored. + + TODO(jaslanides): Consider making this stateless/robust to preemption. + """ + + _open = open + + def __init__( + self, + directory_or_file: Union[str, TextIO] = '~/acme', + label: str = '', + time_delta: float = 0., + add_uid: bool = True, + flush_every: int = 30, + ): + """Instantiates the logger. + + Args: + directory_or_file: Either a directory path as a string, or a file TextIO + object. + label: Extra label to add to logger. This is added as a suffix to the + directory. + time_delta: Interval in seconds between which writes are dropped to + throttle throughput. + add_uid: Whether to add a UID to the file path. See `paths.process_path` + for details. + flush_every: Interval (in writes) between flushes. + """ + + if flush_every <= 0: + raise ValueError( + f'`flush_every` must be a positive integer (got {flush_every}).') + + self._last_log_time = time.time() - time_delta + self._time_delta = time_delta + self._flush_every = flush_every + self._add_uid = add_uid + self._writer = None + self._file_owner = False + self._file = self._create_file(directory_or_file, label) + self._writes = 0 + logging.info('Logging to %s', self.file_path) + + def _create_file( + self, + directory_or_file: Union[str, TextIO], + label: str, + ) -> TextIO: + """Opens a file if input is a directory or use existing file.""" + if isinstance(directory_or_file, str): + directory = paths.process_path( + directory_or_file, 'logs', label, add_uid=self._add_uid) + file_path = os.path.join(directory, 'logs.csv') + self._file_owner = True + return self._open(file_path, mode='a') + + # TextIO instance. + file = directory_or_file + if label: + logging.info('File, not directory, passed to CSVLogger; label not used.') + if not file.mode.startswith('a'): + raise ValueError('File must be open in append mode; instead got ' + f'mode="{file.mode}".') + return file + + def write(self, data: base.LoggingData): + """Writes a `data` into a row of comma-separated values.""" + # Only log if `time_delta` seconds have passed since last logging event. + now = time.time() + + # TODO(b/192227744): Remove this in favour of filters.TimeFilter. + elapsed = now - self._last_log_time + if elapsed < self._time_delta: + logging.debug('Not due to log for another %.2f seconds, dropping data.', + self._time_delta - elapsed) + return + self._last_log_time = now + + # Append row to CSV. + data = base.to_numpy(data) + # Use fields from initial `data` to create the header. If extra fields are + # present in subsequent `data`, we ignore them. + if not self._writer: + fields = sorted(data.keys()) + self._writer = csv.DictWriter(self._file, fieldnames=fields, + extrasaction='ignore') + # Write header only if the file is empty. + if not self._file.tell(): + self._writer.writeheader() + self._writer.writerow(data) + + # Flush every `flush_every` writes. + if self._writes % self._flush_every == 0: + self.flush() + self._writes += 1 + + def close(self): + self.flush() + if self._file_owner: + self._file.close() + + def flush(self): + self._file.flush() + + @property + def file_path(self) -> str: + return self._file.name diff --git a/acme/acme/utils/loggers/csv_test.py b/acme/acme/utils/loggers/csv_test.py new file mode 100644 index 00000000..3bf8d070 --- /dev/null +++ b/acme/acme/utils/loggers/csv_test.py @@ -0,0 +1,102 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for csv logging.""" + +import csv +import os + +from acme.testing import test_utils +from acme.utils import paths +from acme.utils.loggers import csv as csv_logger + +from absl.testing import absltest +from absl.testing import parameterized + +_TEST_INPUTS = [{ + 'c': 'foo', + 'a': '1337', + 'b': '42.0001', +}, { + 'c': 'foo2', + 'a': '1338', + 'b': '43.0001', +}] + + +class CSVLoggingTest(test_utils.TestCase): + + def test_logging_input_is_directory(self): + + # Set up logger. + directory = self.get_tempdir() + label = 'test' + logger = csv_logger.CSVLogger(directory_or_file=directory, label=label) + + # Write data and close. + for inp in _TEST_INPUTS: + logger.write(inp) + logger.close() + + # Read back data. + outputs = [] + with open(logger.file_path) as f: + csv_reader = csv.DictReader(f) + for row in csv_reader: + outputs.append(dict(row)) + self.assertEqual(outputs, _TEST_INPUTS) + + @parameterized.parameters(True, False) + def test_logging_input_is_file(self, add_uid: bool): + + # Set up logger. + directory = paths.process_path( + self.get_tempdir(), 'logs', 'my_label', add_uid=add_uid) + file = open(os.path.join(directory, 'logs.csv'), 'a') + logger = csv_logger.CSVLogger(directory_or_file=file, add_uid=add_uid) + + # Write data and close. + for inp in _TEST_INPUTS: + logger.write(inp) + logger.close() + + # Logger doesn't close the file; caller must do this manually. + self.assertFalse(file.closed) + file.close() + + # Read back data. + outputs = [] + with open(logger.file_path) as f: + csv_reader = csv.DictReader(f) + for row in csv_reader: + outputs.append(dict(row)) + self.assertEqual(outputs, _TEST_INPUTS) + + def test_flush(self): + + logger = csv_logger.CSVLogger(self.get_tempdir(), flush_every=1) + for inp in _TEST_INPUTS: + logger.write(inp) + + # Read back data. + outputs = [] + with open(logger.file_path) as f: + csv_reader = csv.DictReader(f) + for row in csv_reader: + outputs.append(dict(row)) + self.assertEqual(outputs, _TEST_INPUTS) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/utils/loggers/dataframe.py b/acme/acme/utils/loggers/dataframe.py new file mode 100644 index 00000000..16c59bed --- /dev/null +++ b/acme/acme/utils/loggers/dataframe.py @@ -0,0 +1,52 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Logger for writing to an in-memory list. + +This is convenient for e.g. interactive usage via Google Colab. + +For example, for usage with pandas: + +```python +from acme.utils import loggers +import pandas as pd + +logger = InMemoryLogger() +# ... +logger.write({'foo': 1.337, 'bar': 420}) + +results = pd.DataFrame(logger.data) +``` +""" + +from typing import Sequence + +from acme.utils.loggers import base + + +class InMemoryLogger(base.Logger): + """A simple logger that keeps all data in memory.""" + + def __init__(self): + self._data = [] + + def write(self, data: base.LoggingData): + self._data.append(data) + + def close(self): + pass + + @property + def data(self) -> Sequence[base.LoggingData]: + return self._data diff --git a/acme/acme/utils/loggers/default.py b/acme/acme/utils/loggers/default.py new file mode 100644 index 00000000..1c9e9de3 --- /dev/null +++ b/acme/acme/utils/loggers/default.py @@ -0,0 +1,69 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default logger.""" + +import logging +from typing import Any, Callable, Mapping, Optional + +from acme.utils.loggers import aggregators +from acme.utils.loggers import asynchronous as async_logger +from acme.utils.loggers import base +from acme.utils.loggers import csv +from acme.utils.loggers import filters +from acme.utils.loggers import terminal + + +def make_default_logger( + label: str, + save_data: bool = True, + time_delta: float = 1.0, + asynchronous: bool = False, + print_fn: Optional[Callable[[str], None]] = None, + serialize_fn: Optional[Callable[[Mapping[str, Any]], str]] = base.to_numpy, + steps_key: str = 'steps', +) -> base.Logger: + """Makes a default Acme logger. + + Args: + label: Name to give to the logger. + save_data: Whether to persist data. + time_delta: Time (in seconds) between logging events. + asynchronous: Whether the write function should block or not. + print_fn: How to print to terminal (defaults to print). + serialize_fn: An optional function to apply to the write inputs before + passing them to the various loggers. + steps_key: Ignored. + + Returns: + A logger object that responds to logger.write(some_dict). + """ + del steps_key + if not print_fn: + print_fn = logging.info + terminal_logger = terminal.TerminalLogger(label=label, print_fn=print_fn) + + loggers = [terminal_logger] + + if save_data: + loggers.append(csv.CSVLogger(label=label)) + + # Dispatch to all writers and filter Nones and by time. + logger = aggregators.Dispatcher(loggers, serialize_fn) + logger = filters.NoneFilter(logger) + if asynchronous: + logger = async_logger.AsyncLogger(logger) + logger = filters.TimeFilter(logger, time_delta) + + return logger diff --git a/acme/acme/utils/loggers/filters.py b/acme/acme/utils/loggers/filters.py new file mode 100644 index 00000000..10a2c241 --- /dev/null +++ b/acme/acme/utils/loggers/filters.py @@ -0,0 +1,165 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Loggers which filter other loggers.""" + +import math +import time +from typing import Callable, Optional, Sequence + +from acme.utils.loggers import base + + +class NoneFilter(base.Logger): + """Logger which writes to another logger, filtering any `None` values.""" + + def __init__(self, to: base.Logger): + """Initializes the logger. + + Args: + to: A `Logger` object to which the current object will forward its results + when `write` is called. + """ + self._to = to + + def write(self, values: base.LoggingData): + values = {k: v for k, v in values.items() if v is not None} + self._to.write(values) + + def close(self): + self._to.close() + + +class TimeFilter(base.Logger): + """Logger which writes to another logger at a given time interval.""" + + def __init__(self, to: base.Logger, time_delta: float): + """Initializes the logger. + + Args: + to: A `Logger` object to which the current object will forward its results + when `write` is called. + time_delta: How often to write values out in seconds. + Note that writes within `time_delta` are dropped. + """ + self._to = to + self._time = 0 + self._time_delta = time_delta + if time_delta < 0: + raise ValueError(f'time_delta must be greater than 0 (got {time_delta}).') + + def write(self, values: base.LoggingData): + now = time.time() + if (now - self._time) > self._time_delta: + self._to.write(values) + self._time = now + + def close(self): + self._to.close() + + +class KeyFilter(base.Logger): + """Logger which filters keys in logged data.""" + + def __init__( + self, + to: base.Logger, + *, + keep: Optional[Sequence[str]] = None, + drop: Optional[Sequence[str]] = None, + ): + """Creates the filter. + + Args: + to: A `Logger` object to which the current object will forward its writes. + keep: Keys that are kept by the filter. Note that `keep` and `drop` cannot + be both set at once. + drop: Keys that are dropped by the filter. Note that `keep` and `drop` + cannot be both set at once. + """ + if bool(keep) == bool(drop): + raise ValueError('Exactly one of `keep` & `drop` arguments must be set.') + self._to = to + self._keep = keep + self._drop = drop + + def write(self, data: base.LoggingData): + if self._keep: + data = {k: data[k] for k in self._keep} + if self._drop: + data = {k: v for k, v in data.items() if k not in self._drop} + self._to.write(data) + + def close(self): + self._to.close() + + +class GatedFilter(base.Logger): + """Logger which writes to another logger based on a gating function. + + This logger tracks the number of times its `write` method is called, and uses + a gating function on this number to decide when to write. + """ + + def __init__(self, to: base.Logger, gating_fn: Callable[[int], bool]): + """Initialises the logger. + + Args: + to: A `Logger` object to which the current object will forward its results + when `write` is called. + gating_fn: A function that takes an integer (number of calls) as input. + For example, to log every tenth call: gating_fn=lambda t: t % 10 == 0. + """ + self._to = to + self._gating_fn = gating_fn + self._calls = 0 + + def write(self, values: base.LoggingData): + if self._gating_fn(self._calls): + self._to.write(values) + self._calls += 1 + + def close(self): + self._to.close() + + @classmethod + def logarithmic(cls, to: base.Logger, n: int = 10) -> 'GatedFilter': + """Builds a logger for writing at logarithmically-spaced intervals. + + This will log on a linear scale at each order of magnitude of `n`. + For example, with n=10, this will log at times: + [0, 1, 2, ..., 9, 10, 20, 30, ... 90, 100, 200, 300, ... 900, 1000] + + Args: + to: The underlying logger to write to. + n: Base (default 10) on which to operate. + Returns: + A GatedFilter logger, which gates logarithmically as described above. + """ + def logarithmic_filter(t: int) -> bool: + magnitude = math.floor(math.log10(max(t, 1))/math.log10(n)) + return t % (n**magnitude) == 0 + return cls(to, gating_fn=logarithmic_filter) + + @classmethod + def periodic(cls, to: base.Logger, interval: int = 10) -> 'GatedFilter': + """Builds a logger for writing at linearly-spaced intervals. + + Args: + to: The underlying logger to write to. + interval: The interval between writes. + Returns: + A GatedFilter logger, which gates periodically as described above. + """ + return cls(to, gating_fn=lambda t: t % interval == 0) diff --git a/acme/acme/utils/loggers/filters_test.py b/acme/acme/utils/loggers/filters_test.py new file mode 100644 index 00000000..c3278741 --- /dev/null +++ b/acme/acme/utils/loggers/filters_test.py @@ -0,0 +1,112 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for logging filters.""" + +import time + +from acme.utils.loggers import base +from acme.utils.loggers import filters + +from absl.testing import absltest + + +# TODO(jaslanides): extract this to test_utils, or similar, for re-use. +class FakeLogger(base.Logger): + """A fake logger for testing.""" + + def __init__(self): + self.data = [] + + def write(self, data): + self.data.append(data) + + @property + def last_write(self): + return self.data[-1] + + def close(self): + pass + + +class GatedFilterTest(absltest.TestCase): + + def test_logarithmic_filter(self): + logger = FakeLogger() + filtered = filters.GatedFilter.logarithmic(logger, n=10) + for t in range(100): + filtered.write({'t': t}) + rows = [row['t'] for row in logger.data] + self.assertEqual(rows, [*range(10), *range(10, 100, 10)]) + + def test_periodic_filter(self): + logger = FakeLogger() + filtered = filters.GatedFilter.periodic(logger, interval=10) + for t in range(100): + filtered.write({'t': t}) + rows = [row['t'] for row in logger.data] + self.assertEqual(rows, list(range(0, 100, 10))) + + +class TimeFilterTest(absltest.TestCase): + + def test_delta(self): + logger = FakeLogger() + filtered = filters.TimeFilter(logger, time_delta=0.1) + + # Logged. + filtered.write({'foo': 1}) + self.assertIn('foo', logger.last_write) + + # *Not* logged. + filtered.write({'bar': 2}) + self.assertNotIn('bar', logger.last_write) + + # Wait out delta. + time.sleep(0.11) + + # Logged. + filtered.write({'baz': 3}) + self.assertIn('baz', logger.last_write) + + self.assertLen(logger.data, 2) + + +class KeyFilterTest(absltest.TestCase): + + def test_keep_filter(self): + logger = FakeLogger() + filtered = filters.KeyFilter(logger, keep=('foo',)) + filtered.write({'foo': 'bar', 'baz': 12}) + row, *_ = logger.data + self.assertIn('foo', row) + self.assertNotIn('baz', row) + + def test_drop_filter(self): + logger = FakeLogger() + filtered = filters.KeyFilter(logger, drop=('foo',)) + filtered.write({'foo': 'bar', 'baz': 12}) + row, *_ = logger.data + self.assertIn('baz', row) + self.assertNotIn('foo', row) + + def test_bad_arguments(self): + with self.assertRaises(ValueError): + filters.KeyFilter(FakeLogger()) + with self.assertRaises(ValueError): + filters.KeyFilter(FakeLogger(), keep=('a',), drop=('b',)) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/utils/loggers/image.py b/acme/acme/utils/loggers/image.py new file mode 100644 index 00000000..d6def890 --- /dev/null +++ b/acme/acme/utils/loggers/image.py @@ -0,0 +1,76 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An image logger, for writing out arrays to disk as PNG.""" + +import collections +import pathlib +from typing import Optional + +from absl import logging +from acme.utils.loggers import base +from PIL import Image + + +class ImageLogger(base.Logger): + """Logger for writing NumPy arrays as PNG images to disk. + + Assumes that all data passed are NumPy arrays that can be converted to images. + + TODO(jaslanides): Make this stateless/robust to preemptions. + """ + + def __init__( + self, + directory: str, + *, + label: str = '', + mode: Optional[str] = None, + ): + """Initialises the writer. + + Args: + directory: Base directory to which images are logged. + label: Optional subdirectory in which to save images. + mode: Image mode for use with Pillow. If `None` (default), mode is + determined by data type. See [0] for details. + + [0] https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes + """ + + self._path = self._get_path(directory, label) + if not self._path.exists(): + self._path.mkdir(parents=True) + + self._mode = mode + self._indices = collections.defaultdict(int) + + def write(self, data: base.LoggingData): + for k, v in data.items(): + image = Image.fromarray(v, mode=self._mode) + path = self._path / f'{k}_{self._indices[k]:06}.png' + self._indices[k] += 1 + with path.open(mode='wb') as f: + logging.info('Writing image to %s.', str(path)) + image.save(f) + + def close(self): + pass + + @property + def directory(self) -> str: + return str(self._path) + + def _get_path(self, *args, **kwargs) -> pathlib.Path: + return pathlib.Path(*args, **kwargs) diff --git a/acme/acme/utils/loggers/image_test.py b/acme/acme/utils/loggers/image_test.py new file mode 100644 index 00000000..a2412558 --- /dev/null +++ b/acme/acme/utils/loggers/image_test.py @@ -0,0 +1,60 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for image logger.""" + +import os + +from acme.testing import test_utils +from acme.utils.loggers import image +import numpy as np +from PIL import Image + +from absl.testing import absltest + + +class ImageTest(test_utils.TestCase): + + def test_save_load_identity(self): + directory = self.get_tempdir() + logger = image.ImageLogger(directory, label='foo') + array = (np.random.rand(10, 10) * 255).astype(np.uint8) + logger.write({'img': array}) + + with open(f'{directory}/foo/img_000000.png', mode='rb') as f: + out = np.asarray(Image.open(f)) + np.testing.assert_array_equal(array, out) + + def test_indexing(self): + directory = self.get_tempdir() + logger = image.ImageLogger(directory, label='foo') + zeros = np.zeros(shape=(3, 3), dtype=np.uint8) + logger.write({'img': zeros, 'other_img': zeros + 1}) + logger.write({'img': zeros - 1}) + logger.write({'other_img': zeros + 1}) + logger.write({'other_img': zeros + 2}) + + fnames = sorted(os.listdir(f'{directory}/foo')) + expected = [ + 'img_000000.png', + 'img_000001.png', + 'other_img_000000.png', + 'other_img_000001.png', + 'other_img_000002.png', + ] + self.assertEqual(fnames, expected) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/utils/loggers/terminal.py b/acme/acme/utils/loggers/terminal.py new file mode 100644 index 00000000..821a00c1 --- /dev/null +++ b/acme/acme/utils/loggers/terminal.py @@ -0,0 +1,95 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for logging to the terminal.""" + +import logging +import time +from typing import Any, Callable + +from acme.utils.loggers import base +import numpy as np + + +def _format_key(key: str) -> str: + """Internal function for formatting keys.""" + return key.replace('_', ' ').title() + + +def _format_value(value: Any) -> str: + """Internal function for formatting values.""" + value = base.to_numpy(value) + if isinstance(value, (float, np.number)): + return f'{value:0.3f}' + return f'{value}' + + +def serialize(values: base.LoggingData) -> str: + """Converts `values` to a pretty-printed string. + + This takes a dictionary `values` whose keys are strings and returns + a formatted string such that each [key, value] pair is separated by ' = ' and + each entry is separated by ' | '. The keys are sorted alphabetically to ensure + a consistent order, and snake case is split into words. + + For example: + + values = {'a': 1, 'b' = 2.33333333, 'c': 'hello', 'big_value': 10} + # Returns 'A = 1 | B = 2.333 | Big Value = 10 | C = hello' + values_string = serialize(values) + + Args: + values: A dictionary with string keys. + + Returns: + A formatted string. + """ + return ' | '.join(f'{_format_key(k)} = {_format_value(v)}' + for k, v in sorted(values.items())) + + +class TerminalLogger(base.Logger): + """Logs to terminal.""" + + def __init__( + self, + label: str = '', + print_fn: Callable[[str], None] = logging.info, + serialize_fn: Callable[[base.LoggingData], str] = serialize, + time_delta: float = 0.0, + ): + """Initializes the logger. + + Args: + label: label string to use when logging. + print_fn: function to call which acts like print. + serialize_fn: function to call which transforms values into a str. + time_delta: How often (in seconds) to write values. This can be used to + minimize terminal spam, but is 0 by default---ie everything is written. + """ + + self._print_fn = print_fn + self._serialize_fn = serialize_fn + self._label = label and f'[{_format_key(label)}] ' + self._time = time.time() + self._time_delta = time_delta + + def write(self, values: base.LoggingData): + now = time.time() + if (now - self._time) > self._time_delta: + self._print_fn(f'{self._label}{self._serialize_fn(values)}') + self._time = now + + def close(self): + pass diff --git a/acme/acme/utils/loggers/terminal_test.py b/acme/acme/utils/loggers/terminal_test.py new file mode 100644 index 00000000..facdcacb --- /dev/null +++ b/acme/acme/utils/loggers/terminal_test.py @@ -0,0 +1,46 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for terminal logger.""" + +from acme.utils.loggers import terminal + +from absl.testing import absltest + + +class LoggingTest(absltest.TestCase): + + def test_logging_output_format(self): + inputs = { + 'c': 'foo', + 'a': 1337, + 'b': 42.0001, + } + expected_outputs = 'A = 1337 | B = 42.000 | C = foo' + test_fn = lambda outputs: self.assertEqual(outputs, expected_outputs) + + logger = terminal.TerminalLogger(print_fn=test_fn) + logger.write(inputs) + + def test_label(self): + inputs = {'foo': 'bar', 'baz': 123} + expected_outputs = '[Test] Baz = 123 | Foo = bar' + test_fn = lambda outputs: self.assertEqual(outputs, expected_outputs) + + logger = terminal.TerminalLogger(print_fn=test_fn, label='test') + logger.write(inputs) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/utils/loggers/tf_summary.py b/acme/acme/utils/loggers/tf_summary.py new file mode 100644 index 00000000..868c900b --- /dev/null +++ b/acme/acme/utils/loggers/tf_summary.py @@ -0,0 +1,74 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for logging to a tf.summary.""" + +import time +from typing import Optional + +from absl import logging +from acme.utils.loggers import base +import tensorflow as tf + + +def _format_key(key: str) -> str: + """Internal function for formatting keys in Tensorboard format.""" + return key.title().replace('_', '') + + +class TFSummaryLogger(base.Logger): + """Logs to a tf.summary created in a given logdir. + + If multiple TFSummaryLogger are created with the same logdir, results will be + categorized by labels. + """ + + def __init__( + self, + logdir: str, + label: str = 'Logs', + steps_key: Optional[str] = None + ): + """Initializes the logger. + + Args: + logdir: directory to which we should log files. + label: label string to use when logging. Default to 'Logs'. + steps_key: key to use for steps. Must be in the values passed to write. + """ + self._time = time.time() + self.label = label + self._iter = 0 + self.summary = tf.summary.create_file_writer(logdir) + self._steps_key = steps_key + + def write(self, values: base.LoggingData): + if self._steps_key is not None and self._steps_key not in values: + logging.warning('steps key %s not found. Skip logging.', self._steps_key) + return + + step = values[ + self._steps_key] if self._steps_key is not None else self._iter + + with self.summary.as_default(): + # TODO(b/159065169): Remove this suppression once the bug is resolved. + # pytype: disable=unsupported-operands + for key in values.keys() - [self._steps_key]: + # pytype: enable=unsupported-operands + tf.summary.scalar( + f'{self.label}/{_format_key(key)}', data=values[key], step=step) + self._iter += 1 + + def close(self): + self.summary.close() diff --git a/acme/acme/utils/lp_utils.py b/acme/acme/utils/lp_utils.py new file mode 100644 index 00000000..354c0b0d --- /dev/null +++ b/acme/acme/utils/lp_utils.py @@ -0,0 +1,229 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility function for building and launching launchpad programs.""" + +import atexit +import functools +import inspect +import os +import sys +import time +from typing import Any, Callable, Optional + +from absl import flags +from absl import logging +from acme.utils import counting +from acme.utils import signals + +FLAGS = flags.FLAGS + + +def partial_kwargs(function: Callable[..., Any], + **kwargs: Any) -> Callable[..., Any]: + """Return a partial function application by overriding default keywords. + + This function is equivalent to `functools.partial(function, **kwargs)` but + will raise a `ValueError` when called if either the given keyword arguments + are not defined by `function` or if they do not have defaults. + + This is useful as a way to define a factory function with default parameters + and then to override them in a safe way. + + Args: + function: the base function before partial application. + **kwargs: keyword argument overrides. + + Returns: + A function. + """ + # Try to get the argspec of our function which we'll use to get which keywords + # have defaults. + argspec = inspect.getfullargspec(function) + + # Figure out which keywords have defaults. + if argspec.defaults is None: + defaults = [] + else: + defaults = argspec.args[-len(argspec.defaults):] + + # Find any keys not given as defaults by the function. + unknown_kwargs = set(kwargs.keys()).difference(defaults) + + # Raise an error + if unknown_kwargs: + error_string = 'Cannot override unknown or non-default kwargs: {}' + raise ValueError(error_string.format(', '.join(unknown_kwargs))) + + return functools.partial(function, **kwargs) + + +class StepsLimiter: + """Process that terminates an experiment when `max_steps` is reached.""" + + def __init__(self, + counter: counting.Counter, + max_steps: int, + steps_key: str = 'actor_steps'): + self._counter = counter + self._max_steps = max_steps + self._steps_key = steps_key + + def run(self): + """Run steps limiter to terminate an experiment when max_steps is reached. + """ + + logging.info('StepsLimiter: Starting with max_steps = %d (%s)', + self._max_steps, self._steps_key) + with signals.runtime_terminator(): + while True: + # Update the counts. + counts = self._counter.get_counts() + num_steps = counts.get(self._steps_key, 0) + + logging.info('StepsLimiter: Reached %d recorded steps', num_steps) + + if num_steps > self._max_steps: + logging.info('StepsLimiter: Max steps of %d was reached, terminating', + self._max_steps) + # Avoid importing Launchpad until it is actually used. + import launchpad as lp # pylint: disable=g-import-not-at-top + lp.stop() + + # Don't spam the counter. + for _ in range(10): + # Do not sleep for a long period of time to avoid LaunchPad program + # termination hangs (time.sleep is not interruptible). + time.sleep(1) + + +def is_local_run() -> bool: + return FLAGS.lp_launch_type.startswith('local') + + +# Resources for each individual instance of the program. +def make_xm_docker_resources(program, + requirements: Optional[str] = None): + """Returns Docker XManager resources for each program's node. + + For each node of the Launchpad's program appropriate hardware requirements are + specified (CPU, memory...), while the list of PyPi packages specified in + the requirements file will be installed inside the Docker images. + + Args: + program: program for which to construct Docker XManager resources. + requirements: file containing additional requirements to use. + If not specified, default Acme dependencies are used instead. + """ + if (FLAGS.lp_launch_type != 'vertex_ai' and + FLAGS.lp_launch_type != 'local_docker'): + # Avoid importing 'xmanager' for local runs. + return None + + # Avoid importing Launchpad until it is actually used. + import launchpad as lp # pylint: disable=g-import-not-at-top + # Reference lp.DockerConfig to force lazy import of xmanager by Launchpad and + # then import it. It is done this way to avoid heavy imports by default. + lp.DockerConfig # pylint: disable=pointless-statement + from xmanager import xm # pylint: disable=g-import-not-at-top + + # Get number of each type of node. + num_nodes = {k: len(v) for k, v in program.groups.items()} + + xm_resources = {} + + acme_location = os.path.dirname(os.path.dirname(__file__)) + if not requirements: + # Acme requirements are located in the Acme directory (when installed + # with pip), or need to be extracted from setup.py when using Acme codebase + # from GitHub without PyPi installation. + requirements = os.path.join(acme_location, 'requirements.txt') + if not os.path.isfile(requirements): + # Try to generate requirements.txt from setup.py + setup = os.path.join(os.path.dirname(acme_location), 'setup.py') + if os.path.isfile(setup): + # Generate requirements.txt file using setup.py. + import importlib.util # pylint: disable=g-import-not-at-top + spec = importlib.util.spec_from_file_location('setup', setup) + setup = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(setup) # pytype: disable=attribute-error + except SystemExit: + pass + atexit.register(os.remove, requirements) + setup.generate_requirements_file(requirements) + + # Extend PYTHONPATH with paths used by the launcher. + python_path = [] + for path in sys.path: + if path.startswith(acme_location) and acme_location != path: + python_path.append(path[len(acme_location):]) + + if 'replay' in num_nodes: + replay_cpu = 6 + num_nodes.get('actor', 0) * 0.01 + replay_cpu = min(40, replay_cpu) + + xm_resources['replay'] = lp.DockerConfig( + acme_location, + requirements, + hw_requirements=xm.JobRequirements(cpu=replay_cpu, ram=10 * xm.GiB), + python_path=python_path) + + if 'evaluator' in num_nodes: + xm_resources['evaluator'] = lp.DockerConfig( + acme_location, + requirements, + hw_requirements=xm.JobRequirements(cpu=2, ram=4 * xm.GiB), + python_path=python_path) + + if 'actor' in num_nodes: + xm_resources['actor'] = lp.DockerConfig( + acme_location, + requirements, + hw_requirements=xm.JobRequirements(cpu=2, ram=4 * xm.GiB), + python_path=python_path) + + if 'learner' in num_nodes: + learner_cpu = 6 + num_nodes.get('actor', 0) * 0.01 + learner_cpu = min(40, learner_cpu) + xm_resources['learner'] = lp.DockerConfig( + acme_location, + requirements, + hw_requirements=xm.JobRequirements( + cpu=learner_cpu, ram=6 * xm.GiB, P100=1), + python_path=python_path) + + if 'environment_loop' in num_nodes: + xm_resources['environment_loop'] = lp.DockerConfig( + acme_location, + requirements, + hw_requirements=xm.JobRequirements( + cpu=6, ram=6 * xm.GiB, P100=1), + python_path=python_path) + + if 'counter' in num_nodes: + xm_resources['counter'] = lp.DockerConfig( + acme_location, + requirements, + hw_requirements=xm.JobRequirements(cpu=3, ram=4 * xm.GiB), + python_path=python_path) + + if 'cacher' in num_nodes: + xm_resources['cacher'] = lp.DockerConfig( + acme_location, + requirements, + hw_requirements=xm.JobRequirements(cpu=3, ram=6 * xm.GiB), + python_path=python_path) + + return xm_resources diff --git a/acme/acme/utils/lp_utils_test.py b/acme/acme/utils/lp_utils_test.py new file mode 100644 index 00000000..d1254698 --- /dev/null +++ b/acme/acme/utils/lp_utils_test.py @@ -0,0 +1,52 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for acme launchpad utilities.""" + +from acme.utils import lp_utils + +from absl.testing import absltest + + +class LpUtilsTest(absltest.TestCase): + + def test_partial_kwargs(self): + + def foo(a, b, c=2): + return a, b, c + + def bar(a, b): + return a, b + + # Override the default values. The last two should be no-ops. + foo1 = lp_utils.partial_kwargs(foo, c=1) + foo2 = lp_utils.partial_kwargs(foo) + bar1 = lp_utils.partial_kwargs(bar) + + # Check that we raise errors on overriding kwargs with no default values + with self.assertRaises(ValueError): + lp_utils.partial_kwargs(foo, a=2) + + # CHeck the we raise if we try to override a kwarg that doesn't exist. + with self.assertRaises(ValueError): + lp_utils.partial_kwargs(foo, d=2) + + # Make sure we get back the correct values. + self.assertEqual(foo1(1, 2), (1, 2, 1)) + self.assertEqual(foo2(1, 2), (1, 2, 2)) + self.assertEqual(bar1(1, 2), (1, 2)) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/utils/metrics.py b/acme/acme/utils/metrics.py new file mode 100644 index 00000000..5a1991f8 --- /dev/null +++ b/acme/acme/utils/metrics.py @@ -0,0 +1,23 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module does nothing and exists solely for the sake of OS compatibility.""" + +from typing import Type, TypeVar + +T = TypeVar('T') + + +def record_class_usage(cls: Type[T]) -> Type[T]: + return cls diff --git a/acme/acme/utils/observers/__init__.py b/acme/acme/utils/observers/__init__.py new file mode 100644 index 00000000..093853d0 --- /dev/null +++ b/acme/acme/utils/observers/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Acme observers.""" + +from acme.utils.observers.action_metrics import ContinuousActionObserver +from acme.utils.observers.action_norm import ActionNormObserver +from acme.utils.observers.base import EnvLoopObserver +from acme.utils.observers.base import Number +from acme.utils.observers.env_info import EnvInfoObserver +from acme.utils.observers.measurement_metrics import MeasurementObserver diff --git a/acme/acme/utils/observers/action_metrics.py b/acme/acme/utils/observers/action_metrics.py new file mode 100644 index 00000000..cb5665f1 --- /dev/null +++ b/acme/acme/utils/observers/action_metrics.py @@ -0,0 +1,66 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An observer that tracks statistics about the actions.""" + +from typing import Dict + +from acme.utils.observers import base +import dm_env +import numpy as np + + +class ContinuousActionObserver(base.EnvLoopObserver): + """Observer that tracks statstics of continuous actions taken by the agent. + + Assumes the action is a np.ndarray, and for each dimension in the action, + calculates some useful statistics for a particular episode. + """ + + def __init__(self): + self._actions = None + + def observe_first(self, env: dm_env.Environment, + timestep: dm_env.TimeStep) -> None: + """Observes the initial state.""" + self._actions = [] + + def observe(self, env: dm_env.Environment, timestep: dm_env.TimeStep, + action: np.ndarray) -> None: + """Records one environment step.""" + self._actions.append(action) + + def get_metrics(self) -> Dict[str, base.Number]: + """Returns metrics collected for the current episode.""" + aggregate_metrics = {} + if not self._actions: + return aggregate_metrics + + metrics = { + 'action_max': np.max(self._actions, axis=0), + 'action_min': np.min(self._actions, axis=0), + 'action_mean': np.mean(self._actions, axis=0), + 'action_p50': np.percentile(self._actions, q=50., axis=0) + } + + for index, sub_action_metric in np.ndenumerate(metrics['action_max']): + aggregate_metrics[f'action{list(index)}_max'] = sub_action_metric + aggregate_metrics[f'action{list(index)}_min'] = metrics['action_min'][ + index] + aggregate_metrics[f'action{list(index)}_mean'] = metrics['action_mean'][ + index] + aggregate_metrics[f'action{list(index)}_p50'] = metrics['action_p50'][ + index] + + return aggregate_metrics diff --git a/acme/acme/utils/observers/action_metrics_test.py b/acme/acme/utils/observers/action_metrics_test.py new file mode 100644 index 00000000..406e78c5 --- /dev/null +++ b/acme/acme/utils/observers/action_metrics_test.py @@ -0,0 +1,127 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for action_metrics_observers.""" + + +from acme import specs +from acme.testing import fakes +from acme.utils.observers import action_metrics +import dm_env +import numpy as np + +from absl.testing import absltest + + +def _make_fake_env() -> dm_env.Environment: + env_spec = specs.EnvironmentSpec( + observations=specs.Array(shape=(10, 5), dtype=np.float32), + actions=specs.BoundedArray( + shape=(1,), dtype=np.float32, minimum=-100., maximum=100.), + rewards=specs.Array(shape=(), dtype=np.float32), + discounts=specs.BoundedArray( + shape=(), dtype=np.float32, minimum=0., maximum=1.), + ) + return fakes.Environment(env_spec, episode_length=10) + +_FAKE_ENV = _make_fake_env() +_TIMESTEP = _FAKE_ENV.reset() + + +class ActionMetricsTest(absltest.TestCase): + + def test_observe_nothing(self): + observer = action_metrics.ContinuousActionObserver() + self.assertEqual({}, observer.get_metrics()) + + def test_observe_first(self): + observer = action_metrics.ContinuousActionObserver() + observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) + self.assertEqual({}, observer.get_metrics()) + + def test_observe_single_step(self): + observer = action_metrics.ContinuousActionObserver() + observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) + observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([1])) + self.assertEqual( + { + 'action[0]_max': 1, + 'action[0]_min': 1, + 'action[0]_mean': 1, + 'action[0]_p50': 1, + }, + observer.get_metrics(), + ) + + def test_observe_multiple_step(self): + observer = action_metrics.ContinuousActionObserver() + observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) + observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([1])) + observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([4])) + observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([5])) + self.assertEqual( + { + 'action[0]_max': 5, + 'action[0]_min': 1, + 'action[0]_mean': 10 / 3, + 'action[0]_p50': 4, + }, + observer.get_metrics(), + ) + + def test_observe_zero_dimensions(self): + observer = action_metrics.ContinuousActionObserver() + observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) + observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array(1)) + self.assertEqual( + { + 'action[]_max': 1, + 'action[]_min': 1, + 'action[]_mean': 1, + 'action[]_p50': 1, + }, + observer.get_metrics(), + ) + + def test_observe_multiple_dimensions(self): + observer = action_metrics.ContinuousActionObserver() + observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) + observer.observe( + env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([[1, 2], [3, 4]])) + np.testing.assert_equal( + { + 'action[0, 0]_max': 1, + 'action[0, 0]_min': 1, + 'action[0, 0]_mean': 1, + 'action[0, 0]_p50': 1, + 'action[0, 1]_max': 2, + 'action[0, 1]_min': 2, + 'action[0, 1]_mean': 2, + 'action[0, 1]_p50': 2, + 'action[1, 0]_max': 3, + 'action[1, 0]_min': 3, + 'action[1, 0]_mean': 3, + 'action[1, 0]_p50': 3, + 'action[1, 1]_max': 4, + 'action[1, 1]_min': 4, + 'action[1, 1]_mean': 4, + 'action[1, 1]_p50': 4, + }, + observer.get_metrics(), + ) + + +if __name__ == '__main__': + absltest.main() + diff --git a/acme/acme/utils/observers/action_norm.py b/acme/acme/utils/observers/action_norm.py new file mode 100644 index 00000000..ed20aaaf --- /dev/null +++ b/acme/acme/utils/observers/action_norm.py @@ -0,0 +1,44 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An observer that collects action norm stats. +""" +from typing import Dict + +from acme.utils.observers import base +import dm_env +import numpy as np + + +class ActionNormObserver(base.EnvLoopObserver): + """An observer that collects action norm stats.""" + + def __init__(self): + self._action_norms = None + + def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep + ) -> None: + """Observes the initial state.""" + self._action_norms = [] + + def observe(self, env: dm_env.Environment, timestep: dm_env.TimeStep, + action: np.ndarray) -> None: + """Records one environment step.""" + self._action_norms.append(np.linalg.norm(action)) + + def get_metrics(self) -> Dict[str, base.Number]: + """Returns metrics collected for the current episode.""" + return {'action_norm_avg': np.mean(self._action_norms), + 'action_norm_min': np.min(self._action_norms), + 'action_norm_max': np.max(self._action_norms)} diff --git a/acme/acme/utils/observers/action_norm_test.py b/acme/acme/utils/observers/action_norm_test.py new file mode 100644 index 00000000..d6732f24 --- /dev/null +++ b/acme/acme/utils/observers/action_norm_test.py @@ -0,0 +1,57 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for acme.utils.observers.action_norm.""" + +from acme import specs +from acme.testing import fakes +from acme.utils.observers import action_norm +import dm_env +import numpy as np + +from absl.testing import absltest + + +def _make_fake_env() -> dm_env.Environment: + env_spec = specs.EnvironmentSpec( + observations=specs.Array(shape=(10, 5), dtype=np.float32), + actions=specs.BoundedArray( + shape=(1,), dtype=np.float32, minimum=-10., maximum=10.), + rewards=specs.Array(shape=(), dtype=np.float32), + discounts=specs.BoundedArray( + shape=(), dtype=np.float32, minimum=0., maximum=1.), + ) + return fakes.Environment(env_spec, episode_length=10) + + +class ActionNormTest(absltest.TestCase): + + def test_basic(self): + env = _make_fake_env() + observer = action_norm.ActionNormObserver() + timestep = env.reset() + observer.observe_first(env, timestep) + for it in range(5): + action = np.ones((1,), dtype=np.float32) * it + timestep = env.step(action) + observer.observe(env, timestep, action) + metrics = observer.get_metrics() + self.assertLen(metrics, 3) + np.testing.assert_equal(metrics['action_norm_min'], 0) + np.testing.assert_equal(metrics['action_norm_max'], 4) + np.testing.assert_equal(metrics['action_norm_avg'], 2) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/utils/observers/base.py b/acme/acme/utils/observers/base.py new file mode 100644 index 00000000..3e85a71b --- /dev/null +++ b/acme/acme/utils/observers/base.py @@ -0,0 +1,42 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Metrics observers.""" + +import abc +from typing import Dict, Union + +import dm_env +import numpy as np + + +Number = Union[int, float] + + +class EnvLoopObserver(abc.ABC): + """An interface for collecting metrics/counters in EnvironmentLoop.""" + + @abc.abstractmethod + def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep + ) -> None: + """Observes the initial state.""" + + @abc.abstractmethod + def observe(self, env: dm_env.Environment, timestep: dm_env.TimeStep, + action: np.ndarray) -> None: + """Records one environment step.""" + + @abc.abstractmethod + def get_metrics(self) -> Dict[str, Number]: + """Returns metrics collected for the current episode.""" diff --git a/acme/acme/utils/observers/env_info.py b/acme/acme/utils/observers/env_info.py new file mode 100644 index 00000000..5fc77dca --- /dev/null +++ b/acme/acme/utils/observers/env_info.py @@ -0,0 +1,53 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An observer that returns env's info. +""" +from typing import Dict + +from acme.utils.observers import base +import dm_env +import numpy as np + + +class EnvInfoObserver(base.EnvLoopObserver): + """An observer that collects and accumulates scalars from env's info.""" + + def __init__(self): + self._metrics = None + + def _accumulate_metrics(self, env: dm_env.Environment) -> None: + if not hasattr(env, 'get_info'): + return + info = getattr(env, 'get_info')() + if not info: + return + for k, v in info.items(): + if np.isscalar(v): + self._metrics[k] = self._metrics.get(k, 0) + v + + def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep + ) -> None: + """Observes the initial state.""" + self._metrics = {} + self._accumulate_metrics(env) + + def observe(self, env: dm_env.Environment, timestep: dm_env.TimeStep, + action: np.ndarray) -> None: + """Records one environment step.""" + self._accumulate_metrics(env) + + def get_metrics(self) -> Dict[str, base.Number]: + """Returns metrics collected for the current episode.""" + return self._metrics diff --git a/acme/acme/utils/observers/env_info_test.py b/acme/acme/utils/observers/env_info_test.py new file mode 100644 index 00000000..f8baabd3 --- /dev/null +++ b/acme/acme/utils/observers/env_info_test.py @@ -0,0 +1,69 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for acme.utils.observers.env_info.""" + +from acme.utils.observers import env_info +from acme.wrappers import gym_wrapper +import gym +from gym import spaces +import numpy as np + +from absl.testing import absltest + + +class GymEnvWithInfo(gym.Env): + + def __init__(self): + obs_space = np.ones((10,)) + self.observation_space = spaces.Box(-obs_space, obs_space, dtype=np.float32) + act_space = np.ones((3,)) + self.action_space = spaces.Box(-act_space, act_space, dtype=np.float32) + self._step = 0 + + def reset(self): + self._step = 0 + return self.observation_space.sample() + + def step(self, action: np.ndarray): + self._step += 1 + info = {'survival_bonus': 1} + if self._step == 1 or self._step == 7: + info['found_checkpoint'] = 1 + if self._step == 5: + info['picked_up_an_apple'] = 1 + return self.observation_space.sample(), 0, False, info + + +class ActionNormTest(absltest.TestCase): + + def test_basic(self): + env = GymEnvWithInfo() + env = gym_wrapper.GymWrapper(env) + observer = env_info.EnvInfoObserver() + timestep = env.reset() + observer.observe_first(env, timestep) + for _ in range(20): + action = np.zeros((3,)) + timestep = env.step(action) + observer.observe(env, timestep, action) + metrics = observer.get_metrics() + self.assertLen(metrics, 3) + np.testing.assert_equal(metrics['found_checkpoint'], 2) + np.testing.assert_equal(metrics['picked_up_an_apple'], 1) + np.testing.assert_equal(metrics['survival_bonus'], 20) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/utils/observers/measurement_metrics.py b/acme/acme/utils/observers/measurement_metrics.py new file mode 100644 index 00000000..f2ad1369 --- /dev/null +++ b/acme/acme/utils/observers/measurement_metrics.py @@ -0,0 +1,74 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An observer that tracks statistics about the observations.""" + +from typing import Mapping, List + +from acme.utils.observers import base +import dm_env +import numpy as np + + +class MeasurementObserver(base.EnvLoopObserver): + """Observer the provides statistics for measurements at every timestep. + + This assumes the measurements is a multidimensional array with a static spec. + Warning! It is not intended to be used for high dimensional observations. + + self._measurements: List[np.ndarray] + """ + + def __init__(self): + self._measurements = [] + + def observe_first(self, env: dm_env.Environment, + timestep: dm_env.TimeStep) -> None: + """Observes the initial state.""" + self._measurements = [] + + def observe(self, env: dm_env.Environment, timestep: dm_env.TimeStep, + action: np.ndarray) -> None: + """Records one environment step.""" + self._measurements.append(timestep.observation) + + def get_metrics(self) -> Mapping[str, List[base.Number]]: + """Returns metrics collected for the current episode.""" + aggregate_metrics = {} + if not self._measurements: + return aggregate_metrics + + metrics = { + 'measurement_max': np.max(self._measurements, axis=0), + 'measurement_min': np.min(self._measurements, axis=0), + 'measurement_mean': np.mean(self._measurements, axis=0), + 'measurement_p25': np.percentile(self._measurements, q=25., axis=0), + 'measurement_p50': np.percentile(self._measurements, q=50., axis=0), + 'measurement_p75': np.percentile(self._measurements, q=75., axis=0), + } + for index, sub_observation_metric in np.ndenumerate( + metrics['measurement_max']): + aggregate_metrics[ + f'measurement{list(index)}_max'] = sub_observation_metric + aggregate_metrics[f'measurement{list(index)}_min'] = metrics[ + 'measurement_min'][index] + aggregate_metrics[f'measurement{list(index)}_mean'] = metrics[ + 'measurement_mean'][index] + aggregate_metrics[f'measurement{list(index)}_p50'] = metrics[ + 'measurement_p50'][index] + aggregate_metrics[f'measurement{list(index)}_p25'] = metrics[ + 'measurement_p25'][index] + aggregate_metrics[f'measurement{list(index)}_p75'] = metrics[ + 'measurement_p75'][index] + return aggregate_metrics diff --git a/acme/acme/utils/observers/measurement_metrics_test.py b/acme/acme/utils/observers/measurement_metrics_test.py new file mode 100644 index 00000000..31c97d37 --- /dev/null +++ b/acme/acme/utils/observers/measurement_metrics_test.py @@ -0,0 +1,172 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for measurement_metrics.""" + +import copy +from unittest import mock + +from acme import specs +from acme.testing import fakes +from acme.utils.observers import measurement_metrics +import dm_env +import numpy as np + +from absl.testing import absltest + + +def _make_fake_env() -> dm_env.Environment: + env_spec = specs.EnvironmentSpec( + observations=specs.Array(shape=(10, 5), dtype=np.float32), + actions=specs.BoundedArray( + shape=(1,), dtype=np.float32, minimum=-100., maximum=100.), + rewards=specs.Array(shape=(), dtype=np.float32), + discounts=specs.BoundedArray( + shape=(), dtype=np.float32, minimum=0., maximum=1.), + ) + return fakes.Environment(env_spec, episode_length=10) + + +_FAKE_ENV = _make_fake_env() +_TIMESTEP = mock.MagicMock(spec=dm_env.TimeStep) + +_TIMESTEP.observation = [1.0, -2.0] + + +class MeasurementMetricsTest(absltest.TestCase): + + def test_observe_nothing(self): + observer = measurement_metrics.MeasurementObserver() + self.assertEqual({}, observer.get_metrics()) + + def test_observe_first(self): + observer = measurement_metrics.MeasurementObserver() + observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) + self.assertEqual({}, observer.get_metrics()) + + def test_observe_single_step(self): + observer = measurement_metrics.MeasurementObserver() + observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) + observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([1])) + self.assertEqual( + { + 'measurement[0]_max': 1.0, + 'measurement[0]_mean': 1.0, + 'measurement[0]_p25': 1.0, + 'measurement[0]_p50': 1.0, + 'measurement[0]_p75': 1.0, + 'measurement[1]_max': -2.0, + 'measurement[1]_mean': -2.0, + 'measurement[1]_p25': -2.0, + 'measurement[1]_p50': -2.0, + 'measurement[1]_p75': -2.0, + 'measurement[0]_min': 1.0, + 'measurement[1]_min': -2.0, + }, + observer.get_metrics(), + ) + + def test_observe_multiple_step_same_observation(self): + observer = measurement_metrics.MeasurementObserver() + observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) + observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([1])) + observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([4])) + observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([5])) + self.assertEqual( + { + 'measurement[0]_max': 1.0, + 'measurement[0]_mean': 1.0, + 'measurement[0]_p25': 1.0, + 'measurement[0]_p50': 1.0, + 'measurement[0]_p75': 1.0, + 'measurement[1]_max': -2.0, + 'measurement[1]_mean': -2.0, + 'measurement[1]_p25': -2.0, + 'measurement[1]_p50': -2.0, + 'measurement[1]_p75': -2.0, + 'measurement[0]_min': 1.0, + 'measurement[1]_min': -2.0, + }, + observer.get_metrics(), + ) + + def test_observe_multiple_step(self): + observer = measurement_metrics.MeasurementObserver() + observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) + observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([1])) + first_obs_timestep = copy.deepcopy(_TIMESTEP) + first_obs_timestep.observation = [1000.0, -50.0] + observer.observe( + env=_FAKE_ENV, timestep=first_obs_timestep, action=np.array([4])) + second_obs_timestep = copy.deepcopy(_TIMESTEP) + second_obs_timestep.observation = [-1000.0, 500.0] + observer.observe( + env=_FAKE_ENV, timestep=second_obs_timestep, action=np.array([4])) + self.assertEqual( + { + 'measurement[0]_max': 1000.0, + 'measurement[0]_mean': 1.0/3, + 'measurement[0]_p25': -499.5, + 'measurement[0]_p50': 1.0, + 'measurement[0]_p75': 500.5, + 'measurement[1]_max': 500.0, + 'measurement[1]_mean': 448.0/3.0, + 'measurement[1]_p25': -26.0, + 'measurement[1]_p50': -2.0, + 'measurement[1]_p75': 249.0, + 'measurement[0]_min': -1000.0, + 'measurement[1]_min': -50.0, + }, + observer.get_metrics(), + ) + + def test_observe_empty_observation(self): + observer = measurement_metrics.MeasurementObserver() + empty_timestep = copy.deepcopy(_TIMESTEP) + empty_timestep.observation = {} + observer.observe_first(env=_FAKE_ENV, timestep=empty_timestep) + self.assertEqual({}, observer.get_metrics()) + + def test_observe_single_dimensions(self): + observer = measurement_metrics.MeasurementObserver() + observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) + single_obs_timestep = copy.deepcopy(_TIMESTEP) + single_obs_timestep.observation = [1000.0, -50.0] + + observer.observe( + env=_FAKE_ENV, + timestep=single_obs_timestep, + action=np.array([[1, 2], [3, 4]])) + + np.testing.assert_equal( + { + 'measurement[0]_max': 1000.0, + 'measurement[0]_min': 1000.0, + 'measurement[0]_mean': 1000.0, + 'measurement[0]_p25': 1000.0, + 'measurement[0]_p50': 1000.0, + 'measurement[0]_p75': 1000.0, + 'measurement[1]_max': -50.0, + 'measurement[1]_mean': -50.0, + 'measurement[1]_p25': -50.0, + 'measurement[1]_p50': -50.0, + 'measurement[1]_p75': -50.0, + 'measurement[1]_min': -50.0, + }, + observer.get_metrics(), + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/utils/paths.py b/acme/acme/utils/paths.py new file mode 100644 index 00000000..15e7cd52 --- /dev/null +++ b/acme/acme/utils/paths.py @@ -0,0 +1,80 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Filesystem path helpers.""" + +import os +import os.path +import shutil +import time +from typing import Optional, Tuple + +from absl import flags + +ACME_ID = flags.DEFINE_string('acme_id', None, + 'Experiment identifier to use for Acme.') + + +def process_path(path: str, + *subpaths: str, + ttl_seconds: Optional[int] = None, + backups: Optional[bool] = None, + add_uid: bool = True) -> str: + """Process the path string. + + This will process the path string by running `os.path.expanduser` to replace + any initial "~". It will also append a unique string on the end of the path + and create the directories leading to this path if necessary. + + Args: + path: string defining the path to process and create. + *subpaths: potential subpaths to include after uniqification. + ttl_seconds: ignored. + backups: ignored. + add_uid: Whether to add a unique directory identifier between `path` and + `subpaths`. If the `--acme_id` flag is set, will use that as the + identifier. + + Returns: + the processed, expanded path string. + """ + del backups, ttl_seconds + + path = os.path.expanduser(path) + if add_uid: + path = os.path.join(path, *get_unique_id()) + path = os.path.join(path, *subpaths) + os.makedirs(path, exist_ok=True) + return path + + +def get_unique_id() -> Tuple[str, ...]: + """Makes a unique identifier for this process; override with --acme_id.""" + # By default we'll use the global id. + identifier = time.strftime('%Y%m%d-%H%M%S') + + # If the --acme_id flag is given prefer that; ignore if flag processing has + # been skipped (this happens in colab or in tests). + try: + identifier = ACME_ID.value or identifier + except flags.UnparsedFlagAccessError: + pass + + # Return as a tuple (for future proofing). + return (identifier,) + + +def rmdir(path: str): + """Remove directory recursively.""" + shutil.rmtree(path) diff --git a/acme/acme/utils/paths_test.py b/acme/acme/utils/paths_test.py new file mode 100644 index 00000000..6af5d6a1 --- /dev/null +++ b/acme/acme/utils/paths_test.py @@ -0,0 +1,41 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for paths.""" + +from unittest import mock + +from acme.testing import test_utils +import acme.utils.paths as paths + +from absl.testing import flagsaver +from absl.testing import absltest + + +class PathTest(test_utils.TestCase): + + def test_process_path(self): + root_directory = self.get_tempdir() + with mock.patch.object(paths, 'get_unique_id') as mock_unique_id: + mock_unique_id.return_value = ('test',) + path = paths.process_path(root_directory, 'foo', 'bar') + self.assertEqual(path, f'{root_directory}/test/foo/bar') + + def test_unique_id_with_flag(self): + with flagsaver.flagsaver((paths.ACME_ID, 'test_flag')): + self.assertEqual(paths.get_unique_id(), ('test_flag',)) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/utils/reverb_utils.py b/acme/acme/utils/reverb_utils.py new file mode 100644 index 00000000..5df39153 --- /dev/null +++ b/acme/acme/utils/reverb_utils.py @@ -0,0 +1,140 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reverb utils. + +Contains functions manipulating reverb tables and samples. +""" + +from acme import types +import jax +import numpy as np +import reverb +from reverb import item_selectors +from reverb import rate_limiters +from reverb import reverb_types +import tensorflow as tf +import tree + + +def make_replay_table_from_info( + table_info: reverb_types.TableInfo) -> reverb.Table: + """Build a replay table out of its specs in a TableInfo. + + Args: + table_info: A TableInfo containing the Table specs. + + Returns: + A reverb replay table matching the info specs. + """ + sampler = _make_selector_from_key_distribution_options( + table_info.sampler_options) + remover = _make_selector_from_key_distribution_options( + table_info.remover_options) + rate_limiter = _make_rate_limiter_from_rate_limiter_info( + table_info.rate_limiter_info) + return reverb.Table( + name=table_info.name, + sampler=sampler, + remover=remover, + max_size=table_info.max_size, + rate_limiter=rate_limiter, + max_times_sampled=table_info.max_times_sampled, + signature=table_info.signature) + + +def _make_selector_from_key_distribution_options( + options) -> reverb_types.SelectorType: + """Returns a Selector from its KeyDistributionOptions description.""" + one_of = options.WhichOneof('distribution') + if one_of == 'fifo': + return item_selectors.Fifo() + if one_of == 'uniform': + return item_selectors.Uniform() + if one_of == 'prioritized': + return item_selectors.Prioritized(options.prioritized.priority_exponent) + if one_of == 'heap': + if options.heap.min_heap: + return item_selectors.MinHeap() + return item_selectors.MaxHeap() + if one_of == 'lifo': + return item_selectors.Lifo() + raise ValueError(f'Unknown distribution field: {one_of}') + + +def _make_rate_limiter_from_rate_limiter_info( + info) -> rate_limiters.RateLimiter: + return rate_limiters.SampleToInsertRatio( + samples_per_insert=info.samples_per_insert, + min_size_to_sample=info.min_size_to_sample, + error_buffer=(info.min_diff, info.max_diff)) + + +def replay_sample_to_sars_transition( + sample: reverb.ReplaySample, + is_sequence: bool, + strip_last_transition: bool = False, + flatten_batch: bool = False) -> types.Transition: + """Converts the replay sample to a types.Transition. + + NB: If is_sequence is True then the last next_observation of each sequence is + rubbish. Don't train on it. + + Args: + sample: The replay sample + is_sequence: If False we expect the sample data to match the + types.Transition already. Otherwise we expect a batch of sequences of + steps. + strip_last_transition: If True and is_sequence, the last transition will be + stripped as its next_observation field is incorrect. + flatten_batch: If True and is_sequence, the two batch dimensions will be + flatten to one. + + Returns: + A types.Transition built from the sample data. + If is_sequence and strip_last_transition are both True, the output will be + smaller than the output as the last transition of every sequence will have + been removed. + """ + if not is_sequence: + return types.Transition(*sample.data) + # Note that the last next_observation is invalid. + steps = sample.data + def roll(observation): + return np.roll(observation, shift=-1, axis=1) + transitions = types.Transition( + observation=steps.observation, + action=steps.action, + reward=steps.reward, + discount=steps.discount, + next_observation=tree.map_structure(roll, steps.observation), + extras=steps.extras) + if strip_last_transition: + # We remove the last transition as its next_observation field is incorrect. + # It has been obtained by rolling the observation field, such that + # transitions.next_observations[:, -1] is transitions.observations[:, 0] + transitions = jax.tree_map(lambda x: x[:, :-1, ...], transitions) + if flatten_batch: + # Merge the 2 leading batch dimensions into 1. + transitions = jax.tree_map(lambda x: np.reshape(x, (-1,) + x.shape[2:]), + transitions) + return transitions + + +def transition_to_replaysample( + transitions: types.Transition) -> reverb.ReplaySample: + """Converts a types.Transition to a reverb.ReplaySample.""" + info = tree.map_structure(lambda dtype: tf.ones([], dtype), + reverb.SampleInfo.tf_dtypes()) + return reverb.ReplaySample(info=info, data=transitions) diff --git a/acme/acme/utils/reverb_utils_test.py b/acme/acme/utils/reverb_utils_test.py new file mode 100644 index 00000000..2c71c524 --- /dev/null +++ b/acme/acme/utils/reverb_utils_test.py @@ -0,0 +1,87 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for acme.utils.reverb_utils.""" + +from acme import types +from acme.adders import reverb as reverb_adders +from acme.utils import reverb_utils +import numpy as np +import reverb +import tree + +from absl.testing import absltest + + +class ReverbUtilsTest(absltest.TestCase): + + def test_make_replay_table_preserves_table_info(self): + limiter = reverb.rate_limiters.SampleToInsertRatio( + samples_per_insert=1, min_size_to_sample=2, error_buffer=(0, 10)) + table = reverb.Table( + name='test', + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=10, + rate_limiter=limiter) + new_table = reverb_utils.make_replay_table_from_info(table.info) + new_info = new_table.info + + # table_worker_time is not set by the above utility since this is meant to + # be monitoring information about any given table. So instead we copy this + # so that the assertion below checks that everything else matches. + + new_info.table_worker_time.sleeping_ms = ( + table.info.table_worker_time.sleeping_ms) + + self.assertEqual(new_info, table.info) + + _EMPTY_INFO = reverb.SampleInfo(*[() for _ in reverb.SampleInfo.tf_dtypes()]) + _DUMMY_OBS = np.array([[[0], [1], [2]]]) + _DUMMY_ACTION = np.array([[[3], [4], [5]]]) + _DUMMY_REWARD = np.array([[6, 7, 8]]) + _DUMMY_DISCOUNT = np.array([[.99, .99, .99]]) + _DUMMY_NEXT_OBS = np.array([[[1], [2], [0]]]) + _DUMMY_RETURN = np.array([[20.77, 14.92, 8.]]) + + def _create_dummy_steps(self): + return reverb_adders.Step( + observation=self._DUMMY_OBS, + action=self._DUMMY_ACTION, + reward=self._DUMMY_REWARD, + discount=self._DUMMY_DISCOUNT, + start_of_episode=True, + extras={'return': self._DUMMY_RETURN}) + + def _create_dummy_transitions(self): + return types.Transition( + observation=self._DUMMY_OBS, + action=self._DUMMY_ACTION, + reward=self._DUMMY_REWARD, + discount=self._DUMMY_DISCOUNT, + next_observation=self._DUMMY_NEXT_OBS, + extras={'return': self._DUMMY_RETURN}) + + def test_replay_sample_to_sars_transition_is_sequence(self): + fake_sample = reverb.ReplaySample( + info=self._EMPTY_INFO, data=self._create_dummy_steps()) + fake_transition = self._create_dummy_transitions() + transition_from_sample = reverb_utils.replay_sample_to_sars_transition( + fake_sample, is_sequence=True) + tree.map_structure(np.testing.assert_array_equal, transition_from_sample, + fake_transition) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/utils/signals.py b/acme/acme/utils/signals.py new file mode 100644 index 00000000..907faac2 --- /dev/null +++ b/acme/acme/utils/signals.py @@ -0,0 +1,49 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper methods for handling signals.""" + +import contextlib +import ctypes +import threading +from typing import Any, Callable, Optional + +import launchpad + +_Handler = Callable[[], Any] + + +@contextlib.contextmanager +def runtime_terminator(callback: Optional[_Handler] = None): + """Runtime terminator used for stopping computation upon agent termination. + + Runtime terminator optionally executed a provided `callback` and then raises + `SystemExit` exception in the thread performing the computation. + + Args: + callback: callback to execute before raising exception. + + Yields: + None. + """ + worker_id = threading.get_ident() + def signal_handler(): + if callback: + callback() + res = ctypes.pythonapi.PyThreadState_SetAsyncExc( + ctypes.c_long(worker_id), ctypes.py_object(SystemExit)) + assert res < 2, 'Stopping worker failed' + launchpad.register_stop_handler(signal_handler) + yield + launchpad.unregister_stop_handler(signal_handler) diff --git a/acme/acme/utils/tree_utils.py b/acme/acme/utils/tree_utils.py new file mode 100644 index 00000000..8447ee09 --- /dev/null +++ b/acme/acme/utils/tree_utils.py @@ -0,0 +1,191 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tensor framework-agnostic utilities for manipulating nested structures.""" + +from typing import Sequence, List, TypeVar, Any + +import numpy as np +import tree + +ElementType = TypeVar('ElementType') + + +def fast_map_structure(func, *structure): + """Faster map_structure implementation which skips some error checking.""" + flat_structure = (tree.flatten(s) for s in structure) + entries = zip(*flat_structure) + # Arbitrarily choose one of the structures of the original sequence (the last) + # to match the structure for the flattened sequence. + return tree.unflatten_as(structure[-1], [func(*x) for x in entries]) + + +def fast_map_structure_with_path(func, *structure): + """Faster map_structure_with_path implementation.""" + head_entries_with_path = tree.flatten_with_path(structure[0]) + if len(structure) > 1: + tail_entries = (tree.flatten(s) for s in structure[1:]) + entries_with_path = [ + e[0] + e[1:] for e in zip(head_entries_with_path, *tail_entries) + ] + else: + entries_with_path = head_entries_with_path + # Arbitrarily choose one of the structures of the original sequence (the last) + # to match the structure for the flattened sequence. + return tree.unflatten_as(structure[-1], [func(*x) for x in entries_with_path]) + + +def stack_sequence_fields(sequence: Sequence[ElementType]) -> ElementType: + """Stacks a list of identically nested objects. + + This takes a sequence of identically nested objects and returns a single + nested object whose ith leaf is a stacked numpy array of the corresponding + ith leaf from each element of the sequence. + + For example, if `sequence` is: + + ```python + [{ + 'action': np.array([1.0]), + 'observation': (np.array([0.0, 1.0, 2.0]),), + 'reward': 1.0 + }, { + 'action': np.array([0.5]), + 'observation': (np.array([1.0, 2.0, 3.0]),), + 'reward': 0.0 + }, { + 'action': np.array([0.3]),1 + 'observation': (np.array([2.0, 3.0, 4.0]),), + 'reward': 0.5 + }] + ``` + + Then this function will return: + + ```python + { + 'action': np.array([....]) # array shape = [3 x 1] + 'observation': (np.array([...]),) # array shape = [3 x 3] + 'reward': np.array([...]) # array shape = [3] + } + ``` + + Note that the 'observation' entry in the above example has two levels of + nesting, i.e it is a tuple of arrays. + + Args: + sequence: a list of identically nested objects. + + Returns: + A nested object with numpy. + + Raises: + ValueError: If `sequence` is an empty sequence. + """ + # Handle empty input sequences. + if not sequence: + raise ValueError('Input sequence must not be empty') + + # Default to asarray when arrays don't have the same shape to be compatible + # with old behaviour. + try: + return fast_map_structure(lambda *values: np.stack(values), *sequence) + except ValueError: + return fast_map_structure(lambda *values: np.asarray(values), *sequence) + + +def unstack_sequence_fields(struct: ElementType, + batch_size: int) -> List[ElementType]: + """Converts a struct of batched arrays to a list of structs. + + This is effectively the inverse of `stack_sequence_fields`. + + Args: + struct: An (arbitrarily nested) structure of arrays. + batch_size: The length of the leading dimension of each array in the struct. + This is assumed to be static and known. + + Returns: + A list of structs with the same structure as `struct`, where each leaf node + is an unbatched element of the original leaf node. + """ + + return [ + tree.map_structure(lambda s, i=i: s[i], struct) for i in range(batch_size) + ] + + +def broadcast_structures(*args: Any) -> Any: + """Returns versions of the arguments that give them the same nested structure. + + Any nested items in *args must have the same structure. + + Any non-nested item will be replaced with a nested version that shares that + structure. The leaves will all be references to the same original non-nested + item. + + If all *args are nested, or all *args are non-nested, this function will + return *args unchanged. + + Example: + ``` + a = ('a', 'b') + b = 'c' + tree_a, tree_b = broadcast_structure(a, b) + tree_a + > ('a', 'b') + tree_b + > ('c', 'c') + ``` + + Args: + *args: A Sequence of nested or non-nested items. + + Returns: + `*args`, except with all items sharing the same nest structure. + """ + if not args: + return + + reference_tree = None + for arg in args: + if tree.is_nested(arg): + reference_tree = arg + break + + # If reference_tree is None then none of args are nested and we can skip over + # the rest of this function, which would be a no-op. + if reference_tree is None: + return args + + def mirror_structure(value, reference_tree): + if tree.is_nested(value): + # Use check_types=True so that the types of the trees we construct aren't + # dependent on our arbitrary choice of which nested arg to use as the + # reference_tree. + tree.assert_same_structure(value, reference_tree, check_types=True) + return value + else: + return tree.map_structure(lambda _: value, reference_tree) + + return tuple(mirror_structure(arg, reference_tree) for arg in args) + + +def tree_map(f): + """Transforms `f` into a tree-mapped version.""" + + def mapped_f(*structures): + return tree.map_structure(f, *structures) + + return mapped_f diff --git a/acme/acme/utils/tree_utils_test.py b/acme/acme/utils/tree_utils_test.py new file mode 100644 index 00000000..b24b87b7 --- /dev/null +++ b/acme/acme/utils/tree_utils_test.py @@ -0,0 +1,106 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for tree_utils.""" + +import functools +from typing import Sequence + +from acme.utils import tree_utils +import numpy as np +import tree + +from absl.testing import absltest + +TEST_SEQUENCE = [ + { + 'action': np.array([1.0]), + 'observation': (np.array([0.0, 1.0, 2.0]),), + 'reward': np.array(1.0), + }, + { + 'action': np.array([0.5]), + 'observation': (np.array([1.0, 2.0, 3.0]),), + 'reward': np.array(0.0), + }, + { + 'action': np.array([0.3]), + 'observation': (np.array([2.0, 3.0, 4.0]),), + 'reward': np.array(0.5), + }, +] + + +class SequenceStackTest(absltest.TestCase): + """Tests for various tree utilities.""" + + def test_stack_sequence_fields(self): + """Tests that `stack_sequence_fields` behaves correctly on nested data.""" + + stacked = tree_utils.stack_sequence_fields(TEST_SEQUENCE) + + # Check that the stacked output has the correct structure. + tree.assert_same_structure(stacked, TEST_SEQUENCE[0]) + + # Check that the leaves have the correct array shapes. + self.assertEqual(stacked['action'].shape, (3, 1)) + self.assertEqual(stacked['observation'][0].shape, (3, 3)) + self.assertEqual(stacked['reward'].shape, (3,)) + + # Check values. + self.assertEqual(stacked['observation'][0].tolist(), [ + [0., 1., 2.], + [1., 2., 3.], + [2., 3., 4.], + ]) + self.assertEqual(stacked['action'].tolist(), [[1.], [0.5], [0.3]]) + self.assertEqual(stacked['reward'].tolist(), [1., 0., 0.5]) + + def test_unstack_sequence_fields(self): + """Tests that `unstack_sequence_fields(stack_sequence_fields(x)) == x`.""" + stacked = tree_utils.stack_sequence_fields(TEST_SEQUENCE) + batch_size = len(TEST_SEQUENCE) + unstacked = tree_utils.unstack_sequence_fields(stacked, batch_size) + tree.map_structure(np.testing.assert_array_equal, unstacked, TEST_SEQUENCE) + + def test_fast_map_structure_with_path(self): + structure = { + 'a': { + 'b': np.array([0.0]) + }, + 'c': (np.array([1.0]), np.array([2.0])), + 'd': [np.array(3.0), np.array(4.0)], + } + + def map_fn(path: Sequence[str], x: np.ndarray, y: np.ndarray): + return x + y + len(path) + + single_arg_map_fn = functools.partial(map_fn, y=np.array([0.0])) + + expected_mapped_structure = ( + tree.map_structure_with_path(single_arg_map_fn, structure)) + mapped_structure = ( + tree_utils.fast_map_structure_with_path(single_arg_map_fn, structure)) + self.assertEqual(mapped_structure, expected_mapped_structure) + + expected_double_mapped_structure = ( + tree.map_structure_with_path(map_fn, structure, mapped_structure)) + double_mapped_structure = ( + tree_utils.fast_map_structure_with_path(map_fn, structure, + mapped_structure)) + self.assertEqual(double_mapped_structure, expected_double_mapped_structure) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/wrappers/__init__.py b/acme/acme/wrappers/__init__.py new file mode 100644 index 00000000..fbf383de --- /dev/null +++ b/acme/acme/wrappers/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common environment wrapper classes.""" + +from acme.wrappers.action_repeat import ActionRepeatWrapper +from acme.wrappers.atari_wrapper import AtariWrapper +from acme.wrappers.base import EnvironmentWrapper +from acme.wrappers.base import wrap_all +from acme.wrappers.canonical_spec import CanonicalSpecWrapper +from acme.wrappers.concatenate_observations import ConcatObservationWrapper +from acme.wrappers.delayed_reward import DelayedRewardWrapper +from acme.wrappers.expand_scalar_observation_shapes import ExpandScalarObservationShapesWrapper +from acme.wrappers.frame_stacking import FrameStackingWrapper +from acme.wrappers.gym_wrapper import GymAtariAdapter +from acme.wrappers.gym_wrapper import GymWrapper +from acme.wrappers.noop_starts import NoopStartsWrapper +from acme.wrappers.observation_action_reward import ObservationActionRewardWrapper +from acme.wrappers.single_precision import SinglePrecisionWrapper +from acme.wrappers.step_limit import StepLimitWrapper + +try: + # pylint: disable=g-import-not-at-top + from acme.wrappers.open_spiel_wrapper import OpenSpielWrapper +except ImportError: + pass diff --git a/acme/acme/wrappers/action_repeat.py b/acme/acme/wrappers/action_repeat.py new file mode 100644 index 00000000..8c6978ef --- /dev/null +++ b/acme/acme/wrappers/action_repeat.py @@ -0,0 +1,47 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrapper that implements action repeats.""" + +from acme import types +from acme.wrappers import base +import dm_env + + +class ActionRepeatWrapper(base.EnvironmentWrapper): + """Action repeat wrapper.""" + + def __init__(self, environment: dm_env.Environment, num_repeats: int = 1): + super().__init__(environment) + self._num_repeats = num_repeats + + def step(self, action: types.NestedArray) -> dm_env.TimeStep: + # Initialize accumulated reward and discount. + reward = 0. + discount = 1. + + # Step the environment by repeating action. + for _ in range(self._num_repeats): + timestep = self._environment.step(action) + + # Accumulate reward and discount. + reward += timestep.reward * discount + discount *= timestep.discount + + # Don't go over episode boundaries. + if timestep.last(): + break + + # Replace the final timestep's reward and discount. + return timestep._replace(reward=reward, discount=discount) diff --git a/acme/acme/wrappers/atari_wrapper.py b/acme/acme/wrappers/atari_wrapper.py new file mode 100644 index 00000000..64c15b6a --- /dev/null +++ b/acme/acme/wrappers/atari_wrapper.py @@ -0,0 +1,390 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Atari wrappers functionality for Python environments.""" + +import abc +from typing import Tuple, List, Optional, Sequence, Union + +from acme.wrappers import base +from acme.wrappers import frame_stacking + +import dm_env +from dm_env import specs +import numpy as np +from PIL import Image + +RGB_INDEX = 0 # Observation index holding the RGB data. +LIVES_INDEX = 1 # Observation index holding the lives count. +NUM_COLOR_CHANNELS = 3 # Number of color channels in RGB data. + + +class BaseAtariWrapper(abc.ABC, base.EnvironmentWrapper): + """Abstract base class for Atari wrappers. + + This assumes that the input environment is a dm_env.Environment instance in + which observations are tuples whose first element is an RGB observation and + the second element is the lives count. + + The wrapper itself performs the following modifications: + + 1. Soft-termination (setting discount to zero) on loss of life. + 2. Action repeats. + 3. Frame pooling for action repeats. + 4. Conversion to grayscale and downscaling. + 5. Reward clipping. + 6. Observation stacking. + + The details of grayscale conversion, downscaling, and frame pooling are + delegated to the concrete subclasses. + + This wrapper will raise an error if the underlying Atari environment does not: + + - Exposes RGB observations in interleaved format (shape `(H, W, C)`). + - Expose zero-indexed actions. + + Note that this class does not expose a configurable rescale method (defaults + to bilinear internally). + + This class also exposes an additional option `to_float` that doesn't feature + in other wrappers, which rescales pixel values to floats in the range [0, 1]. + """ + + def __init__(self, + environment: dm_env.Environment, + *, + max_abs_reward: Optional[float] = None, + scale_dims: Optional[Tuple[int, int]] = (84, 84), + action_repeats: int = 4, + pooled_frames: int = 2, + zero_discount_on_life_loss: bool = False, + expose_lives_observation: bool = False, + num_stacked_frames: int = 4, + max_episode_len: Optional[int] = None, + to_float: bool = False, + grayscaling: bool = True): + """Initializes a new AtariWrapper. + + Args: + environment: An Atari environment. + max_abs_reward: Maximum absolute reward value before clipping is applied. + If set to `None` (default), no clipping is applied. + scale_dims: Image size for the rescaling step after grayscaling, given as + `(height, width)`. Set to `None` to disable resizing. + action_repeats: Number of times to step wrapped environment for each given + action. + pooled_frames: Number of observations to pool over. Set to 1 to disable + frame pooling. + zero_discount_on_life_loss: If `True`, sets the discount to zero when the + number of lives decreases in in Atari environment. + expose_lives_observation: If `False`, the `lives` part of the observation + is discarded, otherwise it is kept as part of an observation tuple. This + does not affect the `zero_discount_on_life_loss` feature. When enabled, + the observation consists of a single pixel array, otherwise it is a + tuple (pixel_array, lives). + num_stacked_frames: Number of recent (pooled) observations to stack into + the returned observation. + max_episode_len: Number of frames before truncating episode. By default, + there is no maximum length. + to_float: If `True`, rescales RGB observations to floats in [0, 1]. + grayscaling: If `True` returns a grayscale version of the observations. In + this case, the observation is 3D (H, W, num_stacked_frames). If `False` + the observations are RGB and have shape (H, W, C, num_stacked_frames). + + Raises: + ValueError: For various invalid inputs. + """ + if not 1 <= pooled_frames <= action_repeats: + raise ValueError("pooled_frames ({}) must be between 1 and " + "action_repeats ({}) inclusive".format( + pooled_frames, action_repeats)) + + if zero_discount_on_life_loss: + super().__init__(_ZeroDiscountOnLifeLoss(environment)) + else: + super().__init__(environment) + + if not max_episode_len: + max_episode_len = np.inf + + self._frame_stacker = frame_stacking.FrameStacker( + num_frames=num_stacked_frames) + self._action_repeats = action_repeats + self._pooled_frames = pooled_frames + self._scale_dims = scale_dims + self._max_abs_reward = max_abs_reward or np.inf + self._to_float = to_float + self._expose_lives_observation = expose_lives_observation + + if scale_dims: + self._height, self._width = scale_dims + else: + spec = environment.observation_spec() + self._height, self._width = spec[RGB_INDEX].shape[:2] + + self._episode_len = 0 + self._max_episode_len = max_episode_len + self._reset_next_step = True + + self._grayscaling = grayscaling + + # Based on underlying observation spec, decide whether lives are to be + # included in output observations. + observation_spec = self._environment.observation_spec() + spec_names = [spec.name for spec in observation_spec] + if "lives" in spec_names and spec_names.index("lives") != 1: + raise ValueError("`lives` observation needs to have index 1 in Atari.") + + self._observation_spec = self._init_observation_spec() + + self._raw_observation = None + + def _init_observation_spec(self): + """Computes the observation spec for the pixel observations. + + Returns: + An `Array` specification for the pixel observations. + """ + if self._to_float: + pixels_dtype = float + else: + pixels_dtype = np.uint8 + + if self._grayscaling: + pixels_spec_shape = (self._height, self._width) + pixels_spec_name = "grayscale" + else: + pixels_spec_shape = (self._height, self._width, NUM_COLOR_CHANNELS) + pixels_spec_name = "RGB" + + pixel_spec = specs.Array( + shape=pixels_spec_shape, dtype=pixels_dtype, name=pixels_spec_name) + pixel_spec = self._frame_stacker.update_spec(pixel_spec) + + if self._expose_lives_observation: + return (pixel_spec,) + self._environment.observation_spec()[1:] + return pixel_spec + + def reset(self) -> dm_env.TimeStep: + """Resets environment and provides the first timestep.""" + self._reset_next_step = False + self._episode_len = 0 + self._frame_stacker.reset() + timestep = self._environment.reset() + + observation = self._observation_from_timestep_stack([timestep]) + + return self._postprocess_observation( + timestep._replace(observation=observation)) + + def step(self, action: int) -> dm_env.TimeStep: + """Steps up to action_repeat times and returns a post-processed step.""" + if self._reset_next_step: + return self.reset() + + timestep_stack = [] + + # Step on environment multiple times for each selected action. + for _ in range(self._action_repeats): + timestep = self._environment.step([np.array([action])]) + + self._episode_len += 1 + if self._episode_len == self._max_episode_len: + timestep = timestep._replace(step_type=dm_env.StepType.LAST) + + timestep_stack.append(timestep) + + if timestep.last(): + # Action repeat frames should not span episode boundaries. Also, no need + # to pad with zero-valued observations as all the reductions in + # _postprocess_observation work gracefully for any non-zero size of + # timestep_stack. + self._reset_next_step = True + break + + # Determine a single step type. We let FIRST take priority over LAST, since + # we think it's more likely algorithm code will be set up to deal with that, + # due to environments supporting reset() (which emits a FIRST). + # Note we'll never have LAST then FIRST in timestep_stack here. + step_type = dm_env.StepType.MID + for timestep in timestep_stack: + if timestep.first(): + step_type = dm_env.StepType.FIRST + break + elif timestep.last(): + step_type = dm_env.StepType.LAST + break + + if timestep_stack[0].first(): + # Update first timestep to have identity effect on reward and discount. + timestep_stack[0] = timestep_stack[0]._replace(reward=0., discount=1.) + + # Sum reward over stack. + reward = sum(timestep_t.reward for timestep_t in timestep_stack) + + # Multiply discount over stack (will either be 0. or 1.). + discount = np.product( + [timestep_t.discount for timestep_t in timestep_stack]) + + observation = self._observation_from_timestep_stack(timestep_stack) + + timestep = dm_env.TimeStep( + step_type=step_type, + reward=reward, + observation=observation, + discount=discount) + + return self._postprocess_observation(timestep) + + @abc.abstractmethod + def _preprocess_pixels(self, timestep_stack: List[dm_env.TimeStep]): + """Process Atari pixels.""" + + def _observation_from_timestep_stack(self, + timestep_stack: List[dm_env.TimeStep]): + """Compute the observation for a stack of timesteps.""" + self._raw_observation = timestep_stack[-1].observation[RGB_INDEX].copy() + processed_pixels = self._preprocess_pixels(timestep_stack) + + if self._to_float: + stacked_observation = self._frame_stacker.step(processed_pixels / 255.0) + else: + stacked_observation = self._frame_stacker.step(processed_pixels) + + # We use last timestep for lives only. + observation = timestep_stack[-1].observation + if self._expose_lives_observation: + return (stacked_observation,) + observation[1:] + + return stacked_observation + + def _postprocess_observation(self, + timestep: dm_env.TimeStep) -> dm_env.TimeStep: + """Observation processing applied after action repeat consolidation.""" + + if timestep.first(): + return dm_env.restart(timestep.observation) + + reward = np.clip(timestep.reward, -self._max_abs_reward, + self._max_abs_reward) + + return timestep._replace(reward=reward) + + def action_spec(self) -> specs.DiscreteArray: + raw_spec = self._environment.action_spec()[0] + return specs.DiscreteArray(num_values=raw_spec.maximum.item() - + raw_spec.minimum.item() + 1) + + def observation_spec(self) -> Union[specs.Array, Sequence[specs.Array]]: + return self._observation_spec + + def reward_spec(self) -> specs.Array: + return specs.Array(shape=(), dtype=float) + + @property + def raw_observation(self) -> np.ndarray: + """Returns the raw observation, after any pooling has been applied.""" + return self._raw_observation + + +class AtariWrapper(BaseAtariWrapper): + """Standard "Nature Atari" wrapper for Python environments. + + Before being fed to a neural network, Atari frames go through a prepocessing, + implemented in this wrapper. For historical reasons, there were different + choices in the method to apply there between what was done in the Dopamine + library and what is done in Acme. During the processing of + Atari frames, three operations need to happen. Images are + transformed from RGB to grayscale, we perform a max-pooling on the time scale, + and images are resized to 84x84. + + 1. The `standard` style (this one, matches previous acme versions): + - does max pooling, then rgb -> grayscale + - uses Pillow inter area interpolation for resizing + 2. The `dopamine` style: + - does rgb -> grayscale, then max pooling + - uses opencv bilinear interpolation for resizing. + + This can change the behavior of RL agents on some games. The recommended + setting is to use the standard style with this class. The Dopamine setting is + available in `atari_wrapper_dopamine.py` for the + user that wishes to compare agents between librairies. + """ + + def _preprocess_pixels(self, timestep_stack: List[dm_env.TimeStep]): + """Preprocess Atari frames.""" + # 1. Max pooling + processed_pixels = np.max( + np.stack([ + s.observation[RGB_INDEX] + for s in timestep_stack[-self._pooled_frames:] + ]), + axis=0) + + # 2. RGB to grayscale + if self._grayscaling: + processed_pixels = np.tensordot(processed_pixels, + [0.299, 0.587, 1 - (0.299 + 0.587)], + (-1, 0)) + + # 3. Resize + processed_pixels = processed_pixels.astype(np.uint8, copy=False) + if self._scale_dims != processed_pixels.shape[:2]: + processed_pixels = Image.fromarray(processed_pixels).resize( + (self._width, self._height), Image.BILINEAR) + processed_pixels = np.array(processed_pixels, dtype=np.uint8) + + return processed_pixels + + +class _ZeroDiscountOnLifeLoss(base.EnvironmentWrapper): + """Implements soft-termination (zero discount) on life loss.""" + + def __init__(self, environment: dm_env.Environment): + """Initializes a new `_ZeroDiscountOnLifeLoss` wrapper. + + Args: + environment: An Atari environment. + + Raises: + ValueError: If the environment does not expose a lives observation. + """ + super().__init__(environment) + self._reset_next_step = True + self._last_num_lives = None + + def reset(self) -> dm_env.TimeStep: + timestep = self._environment.reset() + self._reset_next_step = False + self._last_num_lives = timestep.observation[LIVES_INDEX] + return timestep + + def step(self, action: int) -> dm_env.TimeStep: + if self._reset_next_step: + return self.reset() + + timestep = self._environment.step(action) + lives = timestep.observation[LIVES_INDEX] + + is_life_loss = True + # We have a life loss when: + # The wrapped environment is in a regular (MID) transition. + is_life_loss &= timestep.mid() + # Lives have decreased since last time `step` was called. + is_life_loss &= lives < self._last_num_lives + + self._last_num_lives = lives + if is_life_loss: + return timestep._replace(discount=0.0) + return timestep diff --git a/acme/acme/wrappers/atari_wrapper_dopamine.py b/acme/acme/wrappers/atari_wrapper_dopamine.py new file mode 100644 index 00000000..b8df88e6 --- /dev/null +++ b/acme/acme/wrappers/atari_wrapper_dopamine.py @@ -0,0 +1,67 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Atari wrapper using Opencv for pixel prepocessing. + +Note that the AtariWrapper comes with several options, including the legacy +style that allows reproducing behavior of previous Acme versions. + +To reproduce accurate standard result, we recommend using the default +configuration. +""" + +from typing import List + +from acme.wrappers import atari_wrapper +# pytype: disable=import-error +import cv2 +# pytype: enable=import-error +import dm_env +import numpy as np + + +class AtariWrapperDopamine(atari_wrapper.BaseAtariWrapper): + """Atari wrapper that matches exactly Dopamine's prepocessing. + + Wraning: using this wrapper requires that you have opencv and its dependencies + installed. In general, opencv is not required for Acme. + """ + + def _preprocess_pixels(self, timestep_stack: List[dm_env.TimeStep]): + """Preprocess Atari frames.""" + + # 1. RBG to grayscale + def rgb_to_grayscale(obs): + if self._grayscaling: + return np.tensordot(obs, [0.2989, 0.5870, 0.1140], (-1, 0)) + return obs + + # 2. Max pooling + processed_pixels = np.max( + np.stack([ + rgb_to_grayscale(s.observation[atari_wrapper.RGB_INDEX]) + for s in timestep_stack[-self._pooled_frames:] + ]), + axis=0) + + # 3. Resize + processed_pixels = np.round(processed_pixels).astype(np.uint8) + if self._scale_dims != processed_pixels.shape[:2]: + processed_pixels = cv2.resize( + processed_pixels, (self._width, self._height), + interpolation=cv2.INTER_AREA) + + processed_pixels = np.round(processed_pixels).astype(np.uint8) + + return processed_pixels diff --git a/acme/acme/wrappers/atari_wrapper_test.py b/acme/acme/wrappers/atari_wrapper_test.py new file mode 100644 index 00000000..6b07e702 --- /dev/null +++ b/acme/acme/wrappers/atari_wrapper_test.py @@ -0,0 +1,91 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for atari_wrapper.""" + +import unittest + +from acme.wrappers import atari_wrapper +from dm_env import specs +import numpy as np + +from absl.testing import absltest +from absl.testing import parameterized + +SKIP_GYM_TESTS = False +SKIP_GYM_MESSAGE = 'gym not installed.' +SKIP_ATARI_TESTS = False +SKIP_ATARI_MESSAGE = '' + +try: + # pylint: disable=g-import-not-at-top + from acme.wrappers import gym_wrapper + import gym + # pylint: enable=g-import-not-at-top +except ModuleNotFoundError: + SKIP_GYM_TESTS = True + + +try: + import atari_py # pylint: disable=g-import-not-at-top + atari_py.get_game_path('pong') +except ModuleNotFoundError as e: + SKIP_ATARI_TESTS = True + SKIP_ATARI_MESSAGE = str(e) +except Exception as e: # pylint: disable=broad-except + # This exception is raised by atari_py.get_game_path('pong') if the Atari ROM + # file has not been installed. + SKIP_ATARI_TESTS = True + SKIP_ATARI_MESSAGE = str(e) + del atari_py +else: + del atari_py + + +@unittest.skipIf(SKIP_ATARI_TESTS, SKIP_ATARI_MESSAGE) +@unittest.skipIf(SKIP_GYM_TESTS, SKIP_GYM_MESSAGE) +class AtariWrapperTest(parameterized.TestCase): + + @parameterized.parameters(True, False) + def test_pong(self, zero_discount_on_life_loss: bool): + env = gym.make('PongNoFrameskip-v4', full_action_space=True) + env = gym_wrapper.GymAtariAdapter(env) + env = atari_wrapper.AtariWrapper( + env, zero_discount_on_life_loss=zero_discount_on_life_loss) + + # Test converted observation spec. + observation_spec = env.observation_spec() + self.assertEqual(type(observation_spec), specs.Array) + + # Test converted action spec. + action_spec: specs.DiscreteArray = env.action_spec() + self.assertEqual(type(action_spec), specs.DiscreteArray) + self.assertEqual(action_spec.shape, ()) + self.assertEqual(action_spec.minimum, 0) + self.assertEqual(action_spec.maximum, 17) + self.assertEqual(action_spec.num_values, 18) + self.assertEqual(action_spec.dtype, np.dtype('int32')) + + # Check that the `render` call gets delegated to the underlying Gym env. + env.render('rgb_array') + + # Test step. + timestep = env.reset() + self.assertTrue(timestep.first()) + _ = env.step(0) + env.close() + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/wrappers/base.py b/acme/acme/wrappers/base.py new file mode 100644 index 00000000..987dccc7 --- /dev/null +++ b/acme/acme/wrappers/base.py @@ -0,0 +1,80 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Environment wrapper base class.""" + +from typing import Callable, Sequence + +import dm_env + + +class EnvironmentWrapper(dm_env.Environment): + """Environment that wraps another environment. + + This exposes the wrapped environment with the `.environment` property and also + defines `__getattr__` so that attributes are invisibly forwarded to the + wrapped environment (and hence enabling duck-typing). + """ + + _environment: dm_env.Environment + + def __init__(self, environment: dm_env.Environment): + self._environment = environment + + def __getattr__(self, name): + if name.startswith("__"): + raise AttributeError( + "attempted to get missing private attribute '{}'".format(name)) + return getattr(self._environment, name) + + @property + def environment(self) -> dm_env.Environment: + return self._environment + + # The following lines are necessary because methods defined in + # `dm_env.Environment` are not delegated through `__getattr__`, which would + # only be used to expose methods or properties that are not defined in the + # base `dm_env.Environment` class. + + def step(self, action) -> dm_env.TimeStep: + return self._environment.step(action) + + def reset(self) -> dm_env.TimeStep: + return self._environment.reset() + + def action_spec(self): + return self._environment.action_spec() + + def discount_spec(self): + return self._environment.discount_spec() + + def observation_spec(self): + return self._environment.observation_spec() + + def reward_spec(self): + return self._environment.reward_spec() + + def close(self): + return self._environment.close() + + +def wrap_all( + environment: dm_env.Environment, + wrappers: Sequence[Callable[[dm_env.Environment], dm_env.Environment]], +) -> dm_env.Environment: + """Given an environment, wrap it in a list of wrappers.""" + for w in wrappers: + environment = w(environment) + + return environment diff --git a/acme/acme/wrappers/base_test.py b/acme/acme/wrappers/base_test.py new file mode 100644 index 00000000..96160f63 --- /dev/null +++ b/acme/acme/wrappers/base_test.py @@ -0,0 +1,44 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for base.""" + +import copy +import pickle + +from acme.testing import fakes +from acme.wrappers import base + +from absl.testing import absltest + + +class BaseTest(absltest.TestCase): + + def test_pickle_unpickle(self): + test_env = base.EnvironmentWrapper(environment=fakes.DiscreteEnvironment()) + + test_env_pickled = pickle.dumps(test_env) + test_env_restored = pickle.loads(test_env_pickled) + self.assertEqual( + test_env.observation_spec(), + test_env_restored.observation_spec(), + ) + + def test_deepcopy(self): + test_env = base.EnvironmentWrapper(environment=fakes.DiscreteEnvironment()) + copied_env = copy.deepcopy(test_env) + del copied_env + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/wrappers/canonical_spec.py b/acme/acme/wrappers/canonical_spec.py new file mode 100644 index 00000000..c5363e48 --- /dev/null +++ b/acme/acme/wrappers/canonical_spec.py @@ -0,0 +1,98 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Environment wrapper that converts environments to use canonical action specs. + +This only affects action specs of type `specs.BoundedArray`. + +For bounded action specs, we refer to a canonical action spec as the bounding +box [-1, 1]^d where d is the dimensionality of the spec. So the shape and dtype +of the spec is unchanged, while the maximum/minimum values are set to +/- 1. +""" + +from acme import specs +from acme import types +from acme.wrappers import base + +import dm_env +import numpy as np +import tree + + +class CanonicalSpecWrapper(base.EnvironmentWrapper): + """Wrapper which converts environments to use canonical action specs. + + This only affects action specs of type `specs.BoundedArray`. + + For bounded action specs, we refer to a canonical action spec as the bounding + box [-1, 1]^d where d is the dimensionality of the spec. So the shape and + dtype of the spec is unchanged, while the maximum/minimum values are set + to +/- 1. + """ + + def __init__(self, environment: dm_env.Environment, clip: bool = False): + super().__init__(environment) + self._action_spec = environment.action_spec() + self._clip = clip + + def step(self, action: types.NestedArray) -> dm_env.TimeStep: + scaled_action = _scale_nested_action(action, self._action_spec, self._clip) + return self._environment.step(scaled_action) + + def action_spec(self): + return _convert_spec(self._environment.action_spec()) + + +def _convert_spec(nested_spec: types.NestedSpec) -> types.NestedSpec: + """Converts all bounded specs in nested spec to the canonical scale.""" + + def _convert_single_spec(spec: specs.Array) -> specs.Array: + """Converts a single spec to canonical if bounded.""" + if isinstance(spec, specs.BoundedArray): + return spec.replace( + minimum=-np.ones(spec.shape), maximum=np.ones(spec.shape)) + else: + return spec + + return tree.map_structure(_convert_single_spec, nested_spec) + + +def _scale_nested_action( + nested_action: types.NestedArray, + nested_spec: types.NestedSpec, + clip: bool, +) -> types.NestedArray: + """Converts a canonical nested action back to the given nested action spec.""" + + def _scale_action(action: np.ndarray, spec: specs.Array): + """Converts a single canonical action back to the given action spec.""" + if isinstance(spec, specs.BoundedArray): + # Get scale and offset of output action spec. + scale = spec.maximum - spec.minimum + offset = spec.minimum + + # Maybe clip the action. + if clip: + action = np.clip(action, -1.0, 1.0) + + # Map action to [0, 1]. + action = 0.5 * (action + 1.0) + + # Map action to [spec.minimum, spec.maximum]. + action *= scale + action += offset + + return action + + return tree.map_structure(_scale_action, nested_action, nested_spec) diff --git a/acme/acme/wrappers/concatenate_observations.py b/acme/acme/wrappers/concatenate_observations.py new file mode 100644 index 00000000..895de39e --- /dev/null +++ b/acme/acme/wrappers/concatenate_observations.py @@ -0,0 +1,97 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrapper that implements concatenation of observation fields.""" + +from typing import Sequence, Optional + +from acme import types +from acme.wrappers import base +import dm_env +import numpy as np +import tree + + +def _concat(values: types.NestedArray) -> np.ndarray: + """Concatenates the leaves of `values` along the leading dimension. + + Treats scalars as 1d arrays and expects that the shapes of all leaves are + the same except for the leading dimension. + + Args: + values: the nested arrays to concatenate. + + Returns: + The concatenated array. + """ + leaves = list(map(np.atleast_1d, tree.flatten(values))) + return np.concatenate(leaves) + + +def _zeros_like(nest, dtype=None): + """Generate a nested NumPy array according to spec.""" + return tree.map_structure(lambda x: np.zeros(x.shape, dtype or x.dtype), nest) + + +class ConcatObservationWrapper(base.EnvironmentWrapper): + """Wrapper that concatenates observation fields. + + It takes an environment with nested observations and concatenates the fields + in a single tensor. The original fields should be 1-dimensional. + Observation fields that are not in name_filter are dropped. + + **NOTE**: The fields in the flattened observations will be in sorted order by + their names, see tree.flatten for more information. + """ + + def __init__(self, + environment: dm_env.Environment, + name_filter: Optional[Sequence[str]] = None): + """Initializes a new ConcatObservationWrapper. + + Args: + environment: Environment to wrap. + name_filter: Sequence of observation names to keep. None keeps them all. + """ + super().__init__(environment) + observation_spec = environment.observation_spec() + if name_filter is None: + name_filter = list(observation_spec.keys()) + self._obs_names = [x for x in name_filter if x in observation_spec.keys()] + + dummy_obs = _zeros_like(observation_spec) + dummy_obs = self._convert_observation(dummy_obs) + self._observation_spec = dm_env.specs.BoundedArray( + shape=dummy_obs.shape, + dtype=dummy_obs.dtype, + minimum=-np.inf, + maximum=np.inf, + name='state') + + def _convert_observation(self, observation): + obs = {k: observation[k] for k in self._obs_names} + return _concat(obs) + + def step(self, action) -> dm_env.TimeStep: + timestep = self._environment.step(action) + return timestep._replace( + observation=self._convert_observation(timestep.observation)) + + def reset(self) -> dm_env.TimeStep: + timestep = self._environment.reset() + return timestep._replace( + observation=self._convert_observation(timestep.observation)) + + def observation_spec(self) -> types.NestedSpec: + return self._observation_spec diff --git a/acme/acme/wrappers/delayed_reward.py b/acme/acme/wrappers/delayed_reward.py new file mode 100644 index 00000000..65fb45c4 --- /dev/null +++ b/acme/acme/wrappers/delayed_reward.py @@ -0,0 +1,89 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Delayed reward wrapper.""" + +import operator +from typing import Optional + +from acme import types +from acme.wrappers import base +import dm_env +import numpy as np +import tree + + +class DelayedRewardWrapper(base.EnvironmentWrapper): + """Implements delayed reward on environments. + + This wrapper sparsifies any environment by adding a reward delay. Instead of + returning a reward at each step, the wrapped environment returns the + accumulated reward every N steps or at the end of an episode, whichever comes + first. This does not change the optimal expected return, but typically makes + the environment harder by adding exploration and longer term dependencies. + """ + + def __init__(self, + environment: dm_env.Environment, + accumulation_period: Optional[int] = 1): + """Initializes a `DelayedRewardWrapper`. + + Args: + environment: An environment conforming to the dm_env.Environment + interface. + accumulation_period: number of steps to accumulate the reward over. If + `accumulation_period` is an integer, reward is accumulated and returned + every `accumulation_period` steps, and at the end of an episode. If + `accumulation_period` is None, reward is only returned at the end of an + episode. If `accumulation_period`=1, this wrapper is a no-op. + """ + + super().__init__(environment) + if accumulation_period is not None and accumulation_period < 1: + raise ValueError( + f'Accumuluation period is {accumulation_period} but should be greater than 1.' + ) + self._accumuation_period = accumulation_period + self._delayed_reward = self._zero_reward + self._accumulation_counter = 0 + + @property + def _zero_reward(self): + return tree.map_structure(lambda s: np.zeros(s.shape, s.dtype), + self._environment.reward_spec()) + + def reset(self) -> dm_env.TimeStep: + """Resets environment and provides the first timestep.""" + timestep = self.environment.reset() + self._delayed_reward = self._zero_reward + self._accumulation_counter = 0 + return timestep + + def step(self, action: types.NestedArray) -> dm_env.TimeStep: + """Performs one step and maybe returns a reward.""" + timestep = self.environment.step(action) + self._delayed_reward = tree.map_structure(operator.iadd, + self._delayed_reward, + timestep.reward) + self._accumulation_counter += 1 + + if (self._accumuation_period is not None and self._accumulation_counter + == self._accumuation_period) or timestep.last(): + timestep = timestep._replace(reward=self._delayed_reward) + self._accumulation_counter = 0 + self._delayed_reward = self._zero_reward + else: + timestep = timestep._replace(reward=self._zero_reward) + + return timestep diff --git a/acme/acme/wrappers/delayed_reward_test.py b/acme/acme/wrappers/delayed_reward_test.py new file mode 100644 index 00000000..c85032df --- /dev/null +++ b/acme/acme/wrappers/delayed_reward_test.py @@ -0,0 +1,90 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the delayed reward wrapper.""" + +from typing import Any +from acme import wrappers +from acme.testing import fakes +from dm_env import specs +import numpy as np +import tree + +from absl.testing import absltest +from absl.testing import parameterized + + +def _episode_reward(env): + timestep = env.reset() + action_spec = env.action_spec() + rng = np.random.RandomState(seed=1) + reward = [] + while not timestep.last(): + timestep = env.step(rng.randint(action_spec.num_values)) + reward.append(timestep.reward) + return reward + + +def _compare_nested_sequences(seq1, seq2): + """Compare two sequences of arrays.""" + return all([(l == m).all() for l, m in zip(seq1, seq2)]) + + +class _DiscreteEnvironmentOneReward(fakes.DiscreteEnvironment): + """A fake discrete environement with constant reward of 1.""" + + def _generate_fake_reward(self) -> Any: + return tree.map_structure(lambda s: s.generate_value() + 1., + self._spec.rewards) + + +class DelayedRewardTest(parameterized.TestCase): + + def test_noop(self): + """Ensure when accumulation_period=1 it does not change anything.""" + base_env = _DiscreteEnvironmentOneReward( + action_dtype=np.int64, + reward_spec=specs.Array(dtype=np.float32, shape=())) + wrapped_env = wrappers.DelayedRewardWrapper(base_env, accumulation_period=1) + base_episode_reward = _episode_reward(base_env) + wrapped_episode_reward = _episode_reward(wrapped_env) + self.assertEqual(base_episode_reward, wrapped_episode_reward) + + def test_noop_composite_reward(self): + """No-op test with composite rewards.""" + base_env = _DiscreteEnvironmentOneReward( + action_dtype=np.int64, + reward_spec=specs.Array(dtype=np.float32, shape=(2, 1))) + wrapped_env = wrappers.DelayedRewardWrapper(base_env, accumulation_period=1) + base_episode_reward = _episode_reward(base_env) + wrapped_episode_reward = _episode_reward(wrapped_env) + self.assertTrue( + _compare_nested_sequences(base_episode_reward, wrapped_episode_reward)) + + @parameterized.parameters(10, None) + def test_same_episode_composite_reward(self, accumulation_period): + """Ensure that wrapper does not change total reward.""" + base_env = _DiscreteEnvironmentOneReward( + action_dtype=np.int64, + reward_spec=specs.Array(dtype=np.float32, shape=())) + wrapped_env = wrappers.DelayedRewardWrapper( + base_env, accumulation_period=accumulation_period) + base_episode_reward = _episode_reward(base_env) + wrapped_episode_reward = _episode_reward(wrapped_env) + self.assertTrue( + (sum(base_episode_reward) == sum(wrapped_episode_reward)).all()) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/wrappers/expand_scalar_observation_shapes.py b/acme/acme/wrappers/expand_scalar_observation_shapes.py new file mode 100644 index 00000000..7554d202 --- /dev/null +++ b/acme/acme/wrappers/expand_scalar_observation_shapes.py @@ -0,0 +1,68 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This wrapper expands scalar observations to have non-trivial shape. + +This is useful for example if the observation holds the previous (scalar) +action, which can cause issues when manipulating arrays with axis=-1. This +wrapper makes sure the environment returns a previous action with shape [1]. + +This can be necessary when stacking observations with previous actions. +""" + +from typing import Any + +from acme.wrappers import base +import dm_env +from dm_env import specs +import numpy as np +import tree + + +class ExpandScalarObservationShapesWrapper(base.EnvironmentWrapper): + """Expands scalar shapes in the observation. + + For example, if the observation holds the previous (scalar) action, this + wrapper makes sure the environment returns a previous action with shape [1]. + + This can be necessary when stacking observations with previous actions. + """ + + def step(self, action: Any) -> dm_env.TimeStep: + timestep = self._environment.step(action) + expanded_observation = tree.map_structure(_expand_scalar_array_shape, + timestep.observation) + return timestep._replace(observation=expanded_observation) + + def reset(self) -> dm_env.TimeStep: + timestep = self._environment.reset() + expanded_observation = tree.map_structure(_expand_scalar_array_shape, + timestep.observation) + return timestep._replace(observation=expanded_observation) + + def observation_spec(self) -> specs.Array: + return tree.map_structure(_expand_scalar_spec_shape, + self._environment.observation_spec()) + + +def _expand_scalar_spec_shape(spec: specs.Array) -> specs.Array: + if not spec.shape: + # NOTE: This line upcasts the spec to an Array to avoid edge cases (as in + # DiscreteSpec) where we cannot set the spec's shape. + spec = specs.Array(shape=(1,), dtype=spec.dtype, name=spec.name) + return spec + + +def _expand_scalar_array_shape(array: np.ndarray) -> np.ndarray: + return array if array.shape else np.expand_dims(array, axis=-1) diff --git a/acme/acme/wrappers/frame_stacking.py b/acme/acme/wrappers/frame_stacking.py new file mode 100644 index 00000000..ce06dd71 --- /dev/null +++ b/acme/acme/wrappers/frame_stacking.py @@ -0,0 +1,99 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Frame stacking utilities.""" + +import collections + +from acme import types +from acme.wrappers import base +import dm_env +from dm_env import specs as dm_env_specs +import numpy as np +import tree + + +class FrameStackingWrapper(base.EnvironmentWrapper): + """Wrapper that stacks observations along a new final axis.""" + + def __init__(self, environment: dm_env.Environment, num_frames: int = 4, + flatten: bool = False): + """Initializes a new FrameStackingWrapper. + + Args: + environment: Environment. + num_frames: Number frames to stack. + flatten: Whether to flatten the channel and stack dimensions together. + """ + self._environment = environment + original_spec = self._environment.observation_spec() + self._stackers = tree.map_structure( + lambda _: FrameStacker(num_frames=num_frames, flatten=flatten), + self._environment.observation_spec()) + self._observation_spec = tree.map_structure( + lambda stacker, spec: stacker.update_spec(spec), + self._stackers, original_spec) + + def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: + observation = tree.map_structure(lambda stacker, x: stacker.step(x), + self._stackers, timestep.observation) + return timestep._replace(observation=observation) + + def reset(self) -> dm_env.TimeStep: + for stacker in tree.flatten(self._stackers): + stacker.reset() + return self._process_timestep(self._environment.reset()) + + def step(self, action: int) -> dm_env.TimeStep: + return self._process_timestep(self._environment.step(action)) + + def observation_spec(self) -> types.NestedSpec: + return self._observation_spec + + +class FrameStacker: + """Simple class for frame-stacking observations.""" + + def __init__(self, num_frames: int, flatten: bool = False): + self._num_frames = num_frames + self._flatten = flatten + self.reset() + + @property + def num_frames(self) -> int: + return self._num_frames + + def reset(self): + self._stack = collections.deque(maxlen=self._num_frames) + + def step(self, frame: np.ndarray) -> np.ndarray: + """Append frame to stack and return the stack.""" + if not self._stack: + # Fill stack with blank frames if empty. + self._stack.extend([np.zeros_like(frame)] * (self._num_frames - 1)) + self._stack.append(frame) + stacked_frames = np.stack(self._stack, axis=-1) + + if not self._flatten: + return stacked_frames + else: + new_shape = stacked_frames.shape[:-2] + (-1,) + return stacked_frames.reshape(*new_shape) + + def update_spec(self, spec: dm_env_specs.Array) -> dm_env_specs.Array: + if not self._flatten: + new_shape = spec.shape + (self._num_frames,) + else: + new_shape = spec.shape[:-1] + (self._num_frames * spec.shape[-1],) + return dm_env_specs.Array(shape=new_shape, dtype=spec.dtype, name=spec.name) diff --git a/acme/acme/wrappers/frame_stacking_test.py b/acme/acme/wrappers/frame_stacking_test.py new file mode 100644 index 00000000..ff21f47e --- /dev/null +++ b/acme/acme/wrappers/frame_stacking_test.py @@ -0,0 +1,81 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the single precision wrapper.""" + +from acme import wrappers +from acme.testing import fakes +import numpy as np +import tree + +from absl.testing import absltest + + +class FakeNonZeroObservationEnvironment(fakes.ContinuousEnvironment): + """Fake environment with non-zero observations.""" + + def _generate_fake_observation(self): + original_observation = super()._generate_fake_observation() + return tree.map_structure(np.ones_like, original_observation) + + +class FrameStackingTest(absltest.TestCase): + + def test_specs(self): + original_env = FakeNonZeroObservationEnvironment() + env = wrappers.FrameStackingWrapper(original_env, 2) + + original_observation_spec = original_env.observation_spec() + expected_shape = original_observation_spec.shape + (2,) + observation_spec = env.observation_spec() + self.assertEqual(expected_shape, observation_spec.shape) + + expected_action_spec = original_env.action_spec() + action_spec = env.action_spec() + self.assertEqual(expected_action_spec, action_spec) + + expected_reward_spec = original_env.reward_spec() + reward_spec = env.reward_spec() + self.assertEqual(expected_reward_spec, reward_spec) + + expected_discount_spec = original_env.discount_spec() + discount_spec = env.discount_spec() + self.assertEqual(expected_discount_spec, discount_spec) + + def test_step(self): + original_env = FakeNonZeroObservationEnvironment() + env = wrappers.FrameStackingWrapper(original_env, 2) + observation_spec = env.observation_spec() + action_spec = env.action_spec() + + timestep = env.reset() + self.assertEqual(observation_spec.shape, timestep.observation.shape) + self.assertTrue(np.all(timestep.observation[..., 0] == 0)) + + timestep = env.step(action_spec.generate_value()) + self.assertEqual(observation_spec.shape, timestep.observation.shape) + + def test_second_reset(self): + original_env = FakeNonZeroObservationEnvironment() + env = wrappers.FrameStackingWrapper(original_env, 2) + action_spec = env.action_spec() + + env.reset() + env.step(action_spec.generate_value()) + timestep = env.reset() + self.assertTrue(np.all(timestep.observation[..., 0] == 0)) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/wrappers/gym_wrapper.py b/acme/acme/wrappers/gym_wrapper.py new file mode 100644 index 00000000..8170f8a1 --- /dev/null +++ b/acme/acme/wrappers/gym_wrapper.py @@ -0,0 +1,206 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wraps an OpenAI Gym environment to be used as a dm_env environment.""" + +from typing import Any, Dict, List, Optional + +from acme import specs +from acme import types + +import dm_env +import gym +from gym import spaces +import numpy as np +import tree + + +class GymWrapper(dm_env.Environment): + """Environment wrapper for OpenAI Gym environments.""" + + # Note: we don't inherit from base.EnvironmentWrapper because that class + # assumes that the wrapped environment is a dm_env.Environment. + + def __init__(self, environment: gym.Env): + + self._environment = environment + self._reset_next_step = True + self._last_info = None + + # Convert action and observation specs. + obs_space = self._environment.observation_space + act_space = self._environment.action_space + self._observation_spec = _convert_to_spec(obs_space, name='observation') + self._action_spec = _convert_to_spec(act_space, name='action') + + def reset(self) -> dm_env.TimeStep: + """Resets the episode.""" + self._reset_next_step = False + observation = self._environment.reset() + # Reset the diagnostic information. + self._last_info = None + return dm_env.restart(observation) + + def step(self, action: types.NestedArray) -> dm_env.TimeStep: + """Steps the environment.""" + if self._reset_next_step: + return self.reset() + + observation, reward, done, info = self._environment.step(action) + self._reset_next_step = done + self._last_info = info + + # Convert the type of the reward based on the spec, respecting the scalar or + # array property. + reward = tree.map_structure( + lambda x, t: ( # pylint: disable=g-long-lambda + t.dtype.type(x) + if np.isscalar(x) else np.asarray(x, dtype=t.dtype)), + reward, + self.reward_spec()) + + if done: + truncated = info.get('TimeLimit.truncated', False) + if truncated: + return dm_env.truncation(reward, observation) + return dm_env.termination(reward, observation) + return dm_env.transition(reward, observation) + + def observation_spec(self) -> types.NestedSpec: + return self._observation_spec + + def action_spec(self) -> types.NestedSpec: + return self._action_spec + + def get_info(self) -> Optional[Dict[str, Any]]: + """Returns the last info returned from env.step(action). + + Returns: + info: dictionary of diagnostic information from the last environment step + """ + return self._last_info + + @property + def environment(self) -> gym.Env: + """Returns the wrapped environment.""" + return self._environment + + def __getattr__(self, name: str): + if name.startswith('__'): + raise AttributeError( + "attempted to get missing private attribute '{}'".format(name)) + return getattr(self._environment, name) + + def close(self): + self._environment.close() + + +def _convert_to_spec(space: gym.Space, + name: Optional[str] = None) -> types.NestedSpec: + """Converts an OpenAI Gym space to a dm_env spec or nested structure of specs. + + Box, MultiBinary and MultiDiscrete Gym spaces are converted to BoundedArray + specs. Discrete OpenAI spaces are converted to DiscreteArray specs. Tuple and + Dict spaces are recursively converted to tuples and dictionaries of specs. + + Args: + space: The Gym space to convert. + name: Optional name to apply to all return spec(s). + + Returns: + A dm_env spec or nested structure of specs, corresponding to the input + space. + """ + if isinstance(space, spaces.Discrete): + return specs.DiscreteArray(num_values=space.n, dtype=space.dtype, name=name) + + elif isinstance(space, spaces.Box): + return specs.BoundedArray( + shape=space.shape, + dtype=space.dtype, + minimum=space.low, + maximum=space.high, + name=name) + + elif isinstance(space, spaces.MultiBinary): + return specs.BoundedArray( + shape=space.shape, + dtype=space.dtype, + minimum=0.0, + maximum=1.0, + name=name) + + elif isinstance(space, spaces.MultiDiscrete): + return specs.BoundedArray( + shape=space.shape, + dtype=space.dtype, + minimum=np.zeros(space.shape), + maximum=space.nvec - 1, + name=name) + + elif isinstance(space, spaces.Tuple): + return tuple(_convert_to_spec(s, name) for s in space.spaces) + + elif isinstance(space, spaces.Dict): + return { + key: _convert_to_spec(value, key) + for key, value in space.spaces.items() + } + + else: + raise ValueError('Unexpected gym space: {}'.format(space)) + + +class GymAtariAdapter(GymWrapper): + """Specialized wrapper exposing a Gym Atari environment. + + This wraps the Gym Atari environment in the same way as GymWrapper, but also + exposes the lives count as an observation. The resuling observations are + a tuple whose first element is the RGB observations and the second is the + lives count. + """ + + def _wrap_observation(self, + observation: types.NestedArray) -> types.NestedArray: + # pytype: disable=attribute-error + return observation, self._environment.ale.lives() + # pytype: enable=attribute-error + + def reset(self) -> dm_env.TimeStep: + """Resets the episode.""" + self._reset_next_step = False + observation = self._environment.reset() + observation = self._wrap_observation(observation) + return dm_env.restart(observation) + + def step(self, action: List[np.ndarray]) -> dm_env.TimeStep: + """Steps the environment.""" + if self._reset_next_step: + return self.reset() + + observation, reward, done, _ = self._environment.step(action[0].item()) + self._reset_next_step = done + + observation = self._wrap_observation(observation) + + if done: + return dm_env.termination(reward, observation) + return dm_env.transition(reward, observation) + + def observation_spec(self) -> types.NestedSpec: + return (self._observation_spec, + specs.Array(shape=(), dtype=np.dtype('float64'), name='lives')) + + def action_spec(self) -> List[specs.BoundedArray]: + return [self._action_spec] # pytype: disable=bad-return-type diff --git a/acme/acme/wrappers/gym_wrapper_test.py b/acme/acme/wrappers/gym_wrapper_test.py new file mode 100644 index 00000000..bc6fdd80 --- /dev/null +++ b/acme/acme/wrappers/gym_wrapper_test.py @@ -0,0 +1,140 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for gym_wrapper.""" + +import unittest + +from dm_env import specs +import numpy as np + +from absl.testing import absltest + +SKIP_GYM_TESTS = False +SKIP_GYM_MESSAGE = 'gym not installed.' +SKIP_ATARI_TESTS = False +SKIP_ATARI_MESSAGE = '' + +try: + # pylint: disable=g-import-not-at-top + from acme.wrappers import gym_wrapper + import gym + # pylint: enable=g-import-not-at-top +except ModuleNotFoundError: + SKIP_GYM_TESTS = True + +try: + import atari_py # pylint: disable=g-import-not-at-top + atari_py.get_game_path('pong') +except ModuleNotFoundError as e: + SKIP_ATARI_TESTS = True + SKIP_ATARI_MESSAGE = str(e) +except Exception as e: # pylint: disable=broad-except + # This exception is raised by atari_py.get_game_path('pong') if the Atari ROM + # file has not been installed. + SKIP_ATARI_TESTS = True + SKIP_ATARI_MESSAGE = str(e) + del atari_py +else: + del atari_py + + +@unittest.skipIf(SKIP_GYM_TESTS, SKIP_GYM_MESSAGE) +class GymWrapperTest(absltest.TestCase): + + def test_gym_cartpole(self): + env = gym_wrapper.GymWrapper(gym.make('CartPole-v0')) + + # Test converted observation spec. + observation_spec: specs.BoundedArray = env.observation_spec() + self.assertEqual(type(observation_spec), specs.BoundedArray) + self.assertEqual(observation_spec.shape, (4,)) + self.assertEqual(observation_spec.minimum.shape, (4,)) + self.assertEqual(observation_spec.maximum.shape, (4,)) + self.assertEqual(observation_spec.dtype, np.dtype('float32')) + + # Test converted action spec. + action_spec: specs.BoundedArray = env.action_spec() + self.assertEqual(type(action_spec), specs.DiscreteArray) + self.assertEqual(action_spec.shape, ()) + self.assertEqual(action_spec.minimum, 0) + self.assertEqual(action_spec.maximum, 1) + self.assertEqual(action_spec.num_values, 2) + self.assertEqual(action_spec.dtype, np.dtype('int64')) + + # Test step. + timestep = env.reset() + self.assertTrue(timestep.first()) + timestep = env.step(1) + self.assertEqual(timestep.reward, 1.0) + self.assertTrue(np.isscalar(timestep.reward)) + self.assertEqual(timestep.observation.shape, (4,)) + env.close() + + def test_early_truncation(self): + # Pendulum has no early termination condition. Recent versions of gym force + # to use v1. We try both in case an earlier version is installed. + try: + gym_env = gym.make('Pendulum-v1') + except: # pylint: disable=bare-except + gym_env = gym.make('Pendulum-v0') + env = gym_wrapper.GymWrapper(gym_env) + ts = env.reset() + while not ts.last(): + ts = env.step(env.action_spec().generate_value()) + self.assertEqual(ts.discount, 1.0) + self.assertTrue(np.isscalar(ts.reward)) + env.close() + + def test_multi_discrete(self): + space = gym.spaces.MultiDiscrete([2, 3]) + spec = gym_wrapper._convert_to_spec(space) + + spec.validate([0, 0]) + spec.validate([1, 2]) + + self.assertRaises(ValueError, spec.validate, [2, 2]) + self.assertRaises(ValueError, spec.validate, [1, 3]) + + +@unittest.skipIf(SKIP_ATARI_TESTS, SKIP_ATARI_MESSAGE) +class AtariGymWrapperTest(absltest.TestCase): + + def test_pong(self): + env = gym.make('PongNoFrameskip-v4', full_action_space=True) + env = gym_wrapper.GymAtariAdapter(env) + + # Test converted observation spec. This should expose (RGB, LIVES). + observation_spec = env.observation_spec() + self.assertEqual(type(observation_spec[0]), specs.BoundedArray) + self.assertEqual(type(observation_spec[1]), specs.Array) + + # Test converted action spec. + action_spec: specs.DiscreteArray = env.action_spec()[0] + self.assertEqual(type(action_spec), specs.DiscreteArray) + self.assertEqual(action_spec.shape, ()) + self.assertEqual(action_spec.minimum, 0) + self.assertEqual(action_spec.maximum, 17) + self.assertEqual(action_spec.num_values, 18) + self.assertEqual(action_spec.dtype, np.dtype('int64')) + + # Test step. + timestep = env.reset() + self.assertTrue(timestep.first()) + _ = env.step([np.array(0)]) + env.close() + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/wrappers/mujoco.py b/acme/acme/wrappers/mujoco.py new file mode 100644 index 00000000..60d1afc1 --- /dev/null +++ b/acme/acme/wrappers/mujoco.py @@ -0,0 +1,50 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An environment wrapper to produce pixel observations from dm_control.""" + +import collections +from acme.wrappers import base +from dm_control.rl import control +from dm_control.suite.wrappers import pixels # type: ignore +import dm_env + + +class MujocoPixelWrapper(base.EnvironmentWrapper): + """Produces pixel observations from Mujoco environment observations.""" + + def __init__(self, + environment: control.Environment, + *, + height: int = 84, + width: int = 84, + camera_id: int = 0): + render_kwargs = {'height': height, 'width': width, 'camera_id': camera_id} + pixel_environment = pixels.Wrapper( + environment, pixels_only=True, render_kwargs=render_kwargs) + super().__init__(pixel_environment) + + def step(self, action) -> dm_env.TimeStep: + return self._convert_timestep(self._environment.step(action)) + + def reset(self) -> dm_env.TimeStep: + return self._convert_timestep(self._environment.reset()) + + def observation_spec(self): + return self._environment.observation_spec()['pixels'] + + def _convert_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: + """Removes the pixel observation's OrderedDict wrapper.""" + observation: collections.OrderedDict = timestep.observation + return timestep._replace(observation=observation['pixels']) diff --git a/acme/acme/wrappers/multiagent_dict_key_wrapper.py b/acme/acme/wrappers/multiagent_dict_key_wrapper.py new file mode 100644 index 00000000..1c15c121 --- /dev/null +++ b/acme/acme/wrappers/multiagent_dict_key_wrapper.py @@ -0,0 +1,87 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multiagent dict-indexed environment wrapped.""" + +from typing import Any, Dict, List, TypeVar, Union +from acme import types + +from acme.wrappers import base +import dm_env + +V = TypeVar('V') + + +class MultiagentDictKeyWrapper(base.EnvironmentWrapper): + """Wrapper that converts list-indexed multiagent environments to dict-indexed. + + Specifically, if the underlying environment observation and actions are: + observation = [observation_agent_0, observation_agent_1, ...] + action = [action_agent_0, action_agent_1, ...] + + They are converted instead to: + observation = {'0': observation_agent_0, '1': observation_agent_1, ...} + action = {'0': action_agent_0, '1': action_agent_1, ...} + + This can be helpful in situations where dict-based structures are natively + supported, whereas lists are not (e.g., in tfds, where ragged observation data + can directly be supported if dicts, but not natively supported as lists). + """ + + def __init__(self, environment: dm_env.Environment): + self._environment = environment + # Convert action and observation specs. + self._action_spec = self._list_to_dict(self._environment.action_spec()) + self._discount_spec = self._list_to_dict(self._environment.discount_spec()) + self._observation_spec = self._list_to_dict( + self._environment.observation_spec()) + self._reward_spec = self._list_to_dict(self._environment.reward_spec()) + + def _list_to_dict(self, data: Union[List[V], V]) -> Union[Dict[str, V], V]: + """Convert list-indexed data to dict-indexed, otherwise passthrough.""" + if isinstance(data, list): + return {str(k): v for k, v in enumerate(data)} + return data + + def _dict_to_list(self, data: Union[Dict[str, V], V]) -> Union[List[V], V]: + """Convert dict-indexed data to list-indexed, otherwise passthrough.""" + if isinstance(data, dict): + return [data[str(i_agent)] + for i_agent in range(self._environment.num_agents)] # pytype: disable=attribute-error + return data + + def _convert_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: + return timestep._replace( + reward=self._list_to_dict(timestep.reward), + discount=self._list_to_dict(timestep.discount), + observation=self._list_to_dict(timestep.observation)) + + def step(self, action: Dict[int, Any]) -> dm_env.TimeStep: + return self._convert_timestep( + self._environment.step(self._dict_to_list(action))) + + def reset(self) -> dm_env.TimeStep: + return self._convert_timestep(self._environment.reset()) + + def action_spec(self) -> types.NestedSpec: # Internal pytype check. + return self._action_spec + + def discount_spec(self) -> types.NestedSpec: # Internal pytype check. + return self._discount_spec + + def observation_spec(self) -> types.NestedSpec: # Internal pytype check. + return self._observation_spec + + def reward_spec(self) -> types.NestedSpec: # Internal pytype check. + return self._reward_spec diff --git a/acme/acme/wrappers/multigrid_wrapper.py b/acme/acme/wrappers/multigrid_wrapper.py new file mode 100644 index 00000000..9c4b2287 --- /dev/null +++ b/acme/acme/wrappers/multigrid_wrapper.py @@ -0,0 +1,314 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wraps a Multigrid multiagent environment to be used as a dm_env.""" + +from typing import Any, Dict, List, Optional +import warnings + +from acme import specs +from acme import types +from acme import wrappers +from acme.multiagent import types as ma_types +from acme.wrappers import multiagent_dict_key_wrapper +import dm_env +import gym +from gym import spaces +import jax +import numpy as np +import tree + +try: + # The following import registers multigrid environments in gym. Do not remove. + # pylint: disable=unused-import, disable=g-import-not-at-top + # pytype: disable=import-error + from social_rl.gym_multigrid import multigrid + # pytype: enable=import-error + # pylint: enable=unused-import, enable=g-import-not-at-top +except ModuleNotFoundError as err: + raise ModuleNotFoundError( + 'The multiagent multigrid environment module could not be found. ' + 'Ensure you have downloaded it from ' + 'https://github.com/google-research/google-research/tree/master/social_rl/gym_multigrid' + ' before running this example.') from err + +# Disables verbose np.bool warnings that occur in multigrid. +warnings.filterwarnings( + action='ignore', + category=DeprecationWarning, + message='`np.bool` is a deprecated alias') + + +class MultigridWrapper(dm_env.Environment): + """Environment wrapper for Multigrid environments. + + Note: the main difference with vanilla GymWrapper is that reward_spec() is + overridden and rewards are cast to np.arrays in step() + """ + + def __init__(self, environment: multigrid.MultiGridEnv): + """Initializes environment. + + Args: + environment: the environment. + """ + self._environment = environment + self._reset_next_step = True + self._last_info = None + self.num_agents = environment.n_agents # pytype: disable=attribute-error + + # Convert action and observation specs. + obs_space = self._environment.observation_space + act_space = self._environment.action_space + self._observation_spec = _convert_to_spec( + obs_space, self.num_agents, name='observation') + self._action_spec = _convert_to_spec( + act_space, self.num_agents, name='action') + + def process_obs(self, observation: types.NestedArray) -> types.NestedArray: + # Convert observations to agent-index-first format + observation = dict_obs_to_list_obs(observation) + + # Assign dtypes to multigrid observations (some of which are lists by + # default, so do not have a precise dtype that matches their observation + # spec. This ensures no replay signature mismatch issues occur). + observation = tree.map_structure(lambda x, t: np.asarray(x, dtype=t.dtype), + observation, self.observation_spec()) + return observation + + def reset(self) -> dm_env.TimeStep: + """Resets the episode.""" + self._reset_next_step = False + observation = self.process_obs(self._environment.reset()) + + # Reset the diagnostic information. + self._last_info = None + return dm_env.restart(observation) + + def step(self, action: types.NestedArray) -> dm_env.TimeStep: + """Steps the environment.""" + if self._reset_next_step: + return self.reset() + + observation, reward, done, info = self._environment.step(action) + observation = self.process_obs(observation) + + self._reset_next_step = done + self._last_info = info + + def _map_reward_spec(x, t): + if np.isscalar(x): + return t.dtype.type(x) + return np.asarray(x, dtype=t.dtype) + + reward = tree.map_structure( + _map_reward_spec, + reward, + self.reward_spec()) + + if done: + truncated = info.get('TimeLimit.truncated', False) + if truncated: + return dm_env.truncation(reward, observation) + return dm_env.termination(reward, observation) + return dm_env.transition(reward, observation) + + def observation_spec(self) -> types.NestedSpec: # Internal pytype check. + return self._observation_spec + + def action_spec(self) -> types.NestedSpec: # Internal pytype check. + return self._action_spec + + def reward_spec(self) -> types.NestedSpec: # Internal pytype check. + return [specs.Array(shape=(), dtype=float, name='rewards') + ] * self._environment.n_agents + + def get_info(self) -> Optional[Dict[str, Any]]: + """Returns the last info returned from env.step(action). + + Returns: + info: dictionary of diagnostic information from the last environment step + """ + return self._last_info + + @property + def environment(self) -> gym.Env: + """Returns the wrapped environment.""" + return self._environment + + def __getattr__(self, name: str) -> Any: + """Returns any other attributes of the underlying environment.""" + return getattr(self._environment, name) + + def close(self): + self._environment.close() + + +def _get_single_agent_spec(spec): + """Returns a single-agent spec from multiagent multigrid spec. + + Primarily used for converting multigrid specs to multiagent Acme specs, + wherein actions and observations specs are expected to be lists (each entry + corresponding to the spec of that particular agent). Note that this function + assumes homogeneous observation / action specs across all agents, which is the + case in multigrid. + + Args: + spec: multigrid environment spec. + """ + def make_single_agent_spec(spec): + if not spec.shape: # Rewards & discounts + shape = () + elif len(spec.shape) == 1: # Actions + shape = () + else: # Observations + shape = spec.shape[1:] + + if isinstance(spec, specs.BoundedArray): + # Bounded rewards and discounts often have no dimensions as they are + # amongst the agents, whereas observations are of shape [num_agents, ...]. + # The following pair of if statements handle both cases accordingly. + minimum = spec.minimum if spec.minimum.ndim == 0 else spec.minimum[0] + maximum = spec.maximum if spec.maximum.ndim == 0 else spec.maximum[0] + return specs.BoundedArray( + shape=shape, + name=spec.name, + minimum=minimum, + maximum=maximum, + dtype=spec.dtype) + elif isinstance(spec, specs.DiscreteArray): + return specs.DiscreteArray( + num_values=spec.num_values, dtype=spec.dtype, name=spec.name) + elif isinstance(spec, specs.Array): + return specs.Array(shape=shape, dtype=spec.dtype, name=spec.name) + else: + raise ValueError(f'Unexpected spec type {type(spec)}.') + + single_agent_spec = jax.tree_map(make_single_agent_spec, spec) + return single_agent_spec + + +def _gym_to_spec(space: gym.Space, + name: Optional[str] = None) -> types.NestedSpec: + """Converts an OpenAI Gym space to a dm_env spec or nested structure of specs. + + Box, MultiBinary and MultiDiscrete Gym spaces are converted to BoundedArray + specs. Discrete OpenAI spaces are converted to DiscreteArray specs. Tuple and + Dict spaces are recursively converted to tuples and dictionaries of specs. + + Args: + space: The Gym space to convert. + name: Optional name to apply to all return spec(s). + + Returns: + A dm_env spec or nested structure of specs, corresponding to the input + space. + """ + if isinstance(space, spaces.Discrete): + return specs.DiscreteArray(num_values=space.n, dtype=space.dtype, name=name) + + elif isinstance(space, spaces.Box): + return specs.BoundedArray( + shape=space.shape, + dtype=space.dtype, + minimum=space.low, + maximum=space.high, + name=name) + + elif isinstance(space, spaces.MultiBinary): + return specs.BoundedArray( + shape=space.shape, + dtype=space.dtype, + minimum=0.0, + maximum=1.0, + name=name) + + elif isinstance(space, spaces.MultiDiscrete): + return specs.BoundedArray( + shape=space.shape, + dtype=space.dtype, + minimum=np.zeros(space.shape), + maximum=space.nvec - 1, + name=name) + + elif isinstance(space, spaces.Tuple): + return tuple(_gym_to_spec(s, name) for s in space.spaces) + + elif isinstance(space, spaces.Dict): + return { + key: _gym_to_spec(value, key) for key, value in space.spaces.items() + } + + else: + raise ValueError('Unexpected gym space: {}'.format(space)) + + +def _convert_to_spec(space: gym.Space, + num_agents: int, + name: Optional[str] = None) -> types.NestedSpec: + """Converts multigrid Gym space to an Acme multiagent spec. + + Args: + space: The Gym space to convert. + num_agents: the number of agents. + name: Optional name to apply to all return spec(s). + + Returns: + A dm_env spec or nested structure of specs, corresponding to the input + space. + """ + # Convert gym specs to acme specs + spec = _gym_to_spec(space, name) + # Then change spec indexing from observation-key-first to agent-index-first + return [_get_single_agent_spec(spec)] * num_agents + + +def dict_obs_to_list_obs( + observation: types.NestedArray +) -> List[Dict[ma_types.AgentID, types.NestedArray]]: + """Returns multigrid observations converted to agent-index-first format. + + By default, multigrid observations are structured as: + observation['image'][agent_index] + observation['direction'][agent_index] + ... + + However, multiagent Acme expects observations with agent indices first: + observation[agent_index]['image'] + observation[agent_index]['direction'] + + This function simply converts multigrid observations to the latter format. + + Args: + observation: + """ + return [dict(zip(observation, v)) for v in zip(*observation.values())] + + +def make_multigrid_environment( + env_name: str = 'MultiGrid-Empty-5x5-v0') -> dm_env.Environment: + """Returns Multigrid Multiagent Gym environment. + + Args: + env_name: name of multigrid task. See social_rl.gym_multigrid.envs for the + available environments. + """ + # Load the gym environment. + env = gym.make(env_name) + + # Make sure the environment obeys the dm_env.Environment interface. + env = MultigridWrapper(env) + env = wrappers.SinglePrecisionWrapper(env) + env = multiagent_dict_key_wrapper.MultiagentDictKeyWrapper(env) + return env diff --git a/acme/acme/wrappers/noop_starts.py b/acme/acme/wrappers/noop_starts.py new file mode 100644 index 00000000..ed23ebf8 --- /dev/null +++ b/acme/acme/wrappers/noop_starts.py @@ -0,0 +1,68 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NoOp Starts wrapper to allow stochastic initial state for deterministic Python environments.""" + +from typing import Optional + +from acme import types +from acme.wrappers import base +import dm_env +import numpy as np + + +class NoopStartsWrapper(base.EnvironmentWrapper): + """Implements random noop starts to episodes. + + This introduces randomness into an otherwise deterministic environment. + + Note that the base environment must support a no-op action and the value + of this action must be known and provided to this wrapper. + """ + + def __init__(self, + environment: dm_env.Environment, + noop_action: types.NestedArray = 0, + noop_max: int = 30, + seed: Optional[int] = None): + """Initializes a `NoopStartsWrapper` wrapper. + + Args: + environment: An environment conforming to the dm_env.Environment + interface. + noop_action: The noop action used to step the environment for random + initialisation. + noop_max: The maximal number of noop actions at the start of an episode. + seed: The random seed used to sample the number of noops. + """ + if noop_max < 0: + raise ValueError( + 'Maximal number of no-ops after reset cannot be negative. ' + f'Received noop_max={noop_max}') + + super().__init__(environment) + self.np_random = np.random.RandomState(seed) + self._noop_max = noop_max + self._noop_action = noop_action + + def reset(self) -> dm_env.TimeStep: + """Resets environment and provides the first timestep.""" + noops = self.np_random.randint(self._noop_max + 1) + timestep = self.environment.reset() + for _ in range(noops): + timestep = self.environment.step(self._noop_action) + if timestep.last(): + timestep = self.environment.reset() + + return timestep._replace(step_type=dm_env.StepType.FIRST) diff --git a/acme/acme/wrappers/noop_starts_test.py b/acme/acme/wrappers/noop_starts_test.py new file mode 100644 index 00000000..74d96e78 --- /dev/null +++ b/acme/acme/wrappers/noop_starts_test.py @@ -0,0 +1,68 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the noop starts wrapper.""" + +from unittest import mock + +from acme import wrappers +from acme.testing import fakes +from dm_env import specs +import numpy as np + +from absl.testing import absltest + + +class NoopStartsTest(absltest.TestCase): + + def test_reset(self): + """Ensure that noop starts `reset` steps the environment multiple times.""" + noop_action = 0 + noop_max = 10 + seed = 24 + + base_env = fakes.DiscreteEnvironment( + action_dtype=np.int64, + obs_dtype=np.int64, + reward_spec=specs.Array(dtype=np.float64, shape=())) + mock_step_fn = mock.MagicMock() + expected_num_step_calls = np.random.RandomState(seed).randint(noop_max + 1) + + with mock.patch.object(base_env, 'step', mock_step_fn): + env = wrappers.NoopStartsWrapper( + base_env, + noop_action=noop_action, + noop_max=noop_max, + seed=seed, + ) + env.reset() + + # Test environment step called with noop action as part of wrapper.reset + mock_step_fn.assert_called_with(noop_action) + self.assertEqual(mock_step_fn.call_count, expected_num_step_calls) + self.assertEqual(mock_step_fn.call_args, ((noop_action,), {})) + + def test_raises_value_error(self): + """Ensure that wrapper raises error if noop_max is <0.""" + base_env = fakes.DiscreteEnvironment( + action_dtype=np.int64, + obs_dtype=np.int64, + reward_spec=specs.Array(dtype=np.float64, shape=())) + + with self.assertRaises(ValueError): + wrappers.NoopStartsWrapper(base_env, noop_action=0, noop_max=-1, seed=24) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/wrappers/observation_action_reward.py b/acme/acme/wrappers/observation_action_reward.py new file mode 100644 index 00000000..2433de14 --- /dev/null +++ b/acme/acme/wrappers/observation_action_reward.py @@ -0,0 +1,62 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A wrapper that puts the previous action and reward into the observation.""" + +from typing import NamedTuple + +from acme import types +from acme.wrappers import base + +import dm_env +import tree + + +class OAR(NamedTuple): + """Container for (Observation, Action, Reward) tuples.""" + observation: types.Nest + action: types.Nest + reward: types.Nest + + +class ObservationActionRewardWrapper(base.EnvironmentWrapper): + """A wrapper that puts the previous action and reward into the observation.""" + + def reset(self) -> dm_env.TimeStep: + # Initialize with zeros of the appropriate shape/dtype. + action = tree.map_structure( + lambda x: x.generate_value(), self._environment.action_spec()) + reward = tree.map_structure( + lambda x: x.generate_value(), self._environment.reward_spec()) + timestep = self._environment.reset() + new_timestep = self._augment_observation(action, reward, timestep) + return new_timestep + + def step(self, action: types.NestedArray) -> dm_env.TimeStep: + timestep = self._environment.step(action) + new_timestep = self._augment_observation(action, timestep.reward, timestep) + return new_timestep + + def _augment_observation(self, action: types.NestedArray, + reward: types.NestedArray, + timestep: dm_env.TimeStep) -> dm_env.TimeStep: + oar = OAR(observation=timestep.observation, + action=action, + reward=reward) + return timestep._replace(observation=oar) + + def observation_spec(self): + return OAR(observation=self._environment.observation_spec(), + action=self.action_spec(), + reward=self.reward_spec()) diff --git a/acme/acme/wrappers/open_spiel_wrapper.py b/acme/acme/wrappers/open_spiel_wrapper.py new file mode 100644 index 00000000..3d6d3230 --- /dev/null +++ b/acme/acme/wrappers/open_spiel_wrapper.py @@ -0,0 +1,147 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wraps an OpenSpiel RL environment to be used as a dm_env environment.""" + +from typing import List, NamedTuple + +from acme import specs +from acme import types +import dm_env +import numpy as np +# pytype: disable=import-error +from open_spiel.python import rl_environment +# pytype: enable=import-error + + +class OLT(NamedTuple): + """Container for (observation, legal_actions, terminal) tuples.""" + observation: types.Nest + legal_actions: types.Nest + terminal: types.Nest + + +class OpenSpielWrapper(dm_env.Environment): + """Environment wrapper for OpenSpiel RL environments.""" + + # Note: we don't inherit from base.EnvironmentWrapper because that class + # assumes that the wrapped environment is a dm_env.Environment. + + def __init__(self, environment: rl_environment.Environment): + self._environment = environment + self._reset_next_step = True + if not environment.is_turn_based: + raise ValueError("Currently only supports turn based games.") + + def reset(self) -> dm_env.TimeStep: + """Resets the episode.""" + self._reset_next_step = False + open_spiel_timestep = self._environment.reset() + observations = self._convert_observation(open_spiel_timestep) + return dm_env.restart(observations) + + def step(self, action: types.NestedArray) -> dm_env.TimeStep: + """Steps the environment.""" + if self._reset_next_step: + return self.reset() + + open_spiel_timestep = self._environment.step(action) + + if open_spiel_timestep.step_type == rl_environment.StepType.LAST: + self._reset_next_step = True + + observations = self._convert_observation(open_spiel_timestep) + rewards = np.asarray(open_spiel_timestep.rewards) + discounts = np.asarray(open_spiel_timestep.discounts) + step_type = open_spiel_timestep.step_type + + if step_type == rl_environment.StepType.FIRST: + step_type = dm_env.StepType.FIRST + elif step_type == rl_environment.StepType.MID: + step_type = dm_env.StepType.MID + elif step_type == rl_environment.StepType.LAST: + step_type = dm_env.StepType.LAST + else: + raise ValueError( + "Did not recognize OpenSpiel StepType: {}".format(step_type)) + + return dm_env.TimeStep(observation=observations, + reward=rewards, + discount=discounts, + step_type=step_type) + + # Convert OpenSpiel observation so it's dm_env compatible. Also, the list + # of legal actions must be converted to a legal actions mask. + def _convert_observation( + self, open_spiel_timestep: rl_environment.TimeStep) -> List[OLT]: + observations = [] + for pid in range(self._environment.num_players): + legals = np.zeros(self._environment.game.num_distinct_actions(), + dtype=np.float32) + legals[open_spiel_timestep.observations["legal_actions"][pid]] = 1.0 + player_observation = OLT(observation=np.asarray( + open_spiel_timestep.observations["info_state"][pid], + dtype=np.float32), + legal_actions=legals, + terminal=np.asarray([open_spiel_timestep.last()], + dtype=np.float32)) + observations.append(player_observation) + return observations + + def observation_spec(self) -> OLT: + # Observation spec depends on whether the OpenSpiel environment is using + # observation/information_state tensors. + if self._environment.use_observation: + return OLT(observation=specs.Array( + (self._environment.game.observation_tensor_size(),), np.float32), + legal_actions=specs.Array( + (self._environment.game.num_distinct_actions(),), + np.float32), + terminal=specs.Array((1,), np.float32)) + else: + return OLT(observation=specs.Array( + (self._environment.game.information_state_tensor_size(),), + np.float32), + legal_actions=specs.Array( + (self._environment.game.num_distinct_actions(),), + np.float32), + terminal=specs.Array((1,), np.float32)) + + def action_spec(self) -> specs.DiscreteArray: + return specs.DiscreteArray(self._environment.game.num_distinct_actions()) + + def reward_spec(self) -> specs.BoundedArray: + return specs.BoundedArray((), + np.float32, + minimum=self._environment.game.min_utility(), + maximum=self._environment.game.max_utility()) + + def discount_spec(self) -> specs.BoundedArray: + return specs.BoundedArray((), np.float32, minimum=0, maximum=1.0) + + @property + def environment(self) -> rl_environment.Environment: + """Returns the wrapped environment.""" + return self._environment + + @property + def current_player(self) -> int: + return self._environment.get_state.current_player() + + def __getattr__(self, name: str): + """Expose any other attributes of the underlying environment.""" + if name.startswith("__"): + raise AttributeError( + "attempted to get missing private attribute '{}'".format(name)) + return getattr(self._environment, name) diff --git a/acme/acme/wrappers/open_spiel_wrapper_test.py b/acme/acme/wrappers/open_spiel_wrapper_test.py new file mode 100644 index 00000000..faaaf899 --- /dev/null +++ b/acme/acme/wrappers/open_spiel_wrapper_test.py @@ -0,0 +1,68 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for open_spiel_wrapper.""" + +import unittest + +from dm_env import specs +import numpy as np + +from absl.testing import absltest + +SKIP_OPEN_SPIEL_TESTS = False +SKIP_OPEN_SPIEL_MESSAGE = 'open_spiel not installed.' + +try: + # pylint: disable=g-import-not-at-top + # pytype: disable=import-error + from acme.wrappers import open_spiel_wrapper + from open_spiel.python import rl_environment + # pytype: enable=import-error +except ModuleNotFoundError: + SKIP_OPEN_SPIEL_TESTS = True + + +@unittest.skipIf(SKIP_OPEN_SPIEL_TESTS, SKIP_OPEN_SPIEL_MESSAGE) +class OpenSpielWrapperTest(absltest.TestCase): + + def test_tic_tac_toe(self): + raw_env = rl_environment.Environment('tic_tac_toe') + env = open_spiel_wrapper.OpenSpielWrapper(raw_env) + + # Test converted observation spec. + observation_spec = env.observation_spec() + self.assertEqual(type(observation_spec), open_spiel_wrapper.OLT) + self.assertEqual(type(observation_spec.observation), specs.Array) + self.assertEqual(type(observation_spec.legal_actions), specs.Array) + self.assertEqual(type(observation_spec.terminal), specs.Array) + + # Test converted action spec. + action_spec: specs.DiscreteArray = env.action_spec() + self.assertEqual(type(action_spec), specs.DiscreteArray) + self.assertEqual(action_spec.shape, ()) + self.assertEqual(action_spec.minimum, 0) + self.assertEqual(action_spec.maximum, 8) + self.assertEqual(action_spec.num_values, 9) + self.assertEqual(action_spec.dtype, np.dtype('int32')) + + # Test step. + timestep = env.reset() + self.assertTrue(timestep.first()) + _ = env.step([0]) + env.close() + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/wrappers/single_precision.py b/acme/acme/wrappers/single_precision.py new file mode 100644 index 00000000..e1b90c7f --- /dev/null +++ b/acme/acme/wrappers/single_precision.py @@ -0,0 +1,85 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Environment wrapper which converts double-to-single precision.""" + +from acme import specs +from acme import types +from acme.wrappers import base + +import dm_env +import numpy as np +import tree + + +class SinglePrecisionWrapper(base.EnvironmentWrapper): + """Wrapper which converts environments from double- to single-precision.""" + + def _convert_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: + return timestep._replace( + reward=_convert_value(timestep.reward), + discount=_convert_value(timestep.discount), + observation=_convert_value(timestep.observation)) + + def step(self, action) -> dm_env.TimeStep: + return self._convert_timestep(self._environment.step(action)) + + def reset(self) -> dm_env.TimeStep: + return self._convert_timestep(self._environment.reset()) + + def action_spec(self): + return _convert_spec(self._environment.action_spec()) + + def discount_spec(self): + return _convert_spec(self._environment.discount_spec()) + + def observation_spec(self): + return _convert_spec(self._environment.observation_spec()) + + def reward_spec(self): + return _convert_spec(self._environment.reward_spec()) + + +def _convert_spec(nested_spec: types.NestedSpec) -> types.NestedSpec: + """Convert a nested spec.""" + + def _convert_single_spec(spec: specs.Array): + """Convert a single spec.""" + if spec.dtype == 'O': + # Pass StringArray objects through unmodified. + return spec + if np.issubdtype(spec.dtype, np.float64): + dtype = np.float32 + elif np.issubdtype(spec.dtype, np.int64): + dtype = np.int32 + else: + dtype = spec.dtype + return spec.replace(dtype=dtype) + + return tree.map_structure(_convert_single_spec, nested_spec) + + +def _convert_value(nested_value: types.Nest) -> types.Nest: + """Convert a nested value given a desired nested spec.""" + + def _convert_single_value(value): + if value is not None: + value = np.array(value, copy=False) + if np.issubdtype(value.dtype, np.float64): + value = np.array(value, copy=False, dtype=np.float32) + elif np.issubdtype(value.dtype, np.int64): + value = np.array(value, copy=False, dtype=np.int32) + return value + + return tree.map_structure(_convert_single_value, nested_value) diff --git a/acme/acme/wrappers/single_precision_test.py b/acme/acme/wrappers/single_precision_test.py new file mode 100644 index 00000000..f99779cb --- /dev/null +++ b/acme/acme/wrappers/single_precision_test.py @@ -0,0 +1,71 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the single precision wrapper.""" + +from acme import wrappers +from acme.testing import fakes +from dm_env import specs +import numpy as np + +from absl.testing import absltest + + +class SinglePrecisionTest(absltest.TestCase): + + def test_continuous(self): + env = wrappers.SinglePrecisionWrapper( + fakes.ContinuousEnvironment( + action_dim=0, dtype=np.float64, reward_dtype=np.float64)) + + self.assertTrue(np.issubdtype(env.observation_spec().dtype, np.float32)) + self.assertTrue(np.issubdtype(env.action_spec().dtype, np.float32)) + self.assertTrue(np.issubdtype(env.reward_spec().dtype, np.float32)) + self.assertTrue(np.issubdtype(env.discount_spec().dtype, np.float32)) + + timestep = env.reset() + self.assertIsNone(timestep.reward) + self.assertIsNone(timestep.discount) + self.assertTrue(np.issubdtype(timestep.observation.dtype, np.float32)) + + timestep = env.step(0.0) + self.assertTrue(np.issubdtype(timestep.reward.dtype, np.float32)) + self.assertTrue(np.issubdtype(timestep.discount.dtype, np.float32)) + self.assertTrue(np.issubdtype(timestep.observation.dtype, np.float32)) + + def test_discrete(self): + env = wrappers.SinglePrecisionWrapper( + fakes.DiscreteEnvironment( + action_dtype=np.int64, + obs_dtype=np.int64, + reward_spec=specs.Array(dtype=np.float64, shape=()))) + + self.assertTrue(np.issubdtype(env.observation_spec().dtype, np.int32)) + self.assertTrue(np.issubdtype(env.action_spec().dtype, np.int32)) + self.assertTrue(np.issubdtype(env.reward_spec().dtype, np.float32)) + self.assertTrue(np.issubdtype(env.discount_spec().dtype, np.float32)) + + timestep = env.reset() + self.assertIsNone(timestep.reward) + self.assertIsNone(timestep.discount) + self.assertTrue(np.issubdtype(timestep.observation.dtype, np.int32)) + + timestep = env.step(0) + self.assertTrue(np.issubdtype(timestep.reward.dtype, np.float32)) + self.assertTrue(np.issubdtype(timestep.discount.dtype, np.float32)) + self.assertTrue(np.issubdtype(timestep.observation.dtype, np.int32)) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/acme/wrappers/step_limit.py b/acme/acme/wrappers/step_limit.py new file mode 100644 index 00000000..8d474f3c --- /dev/null +++ b/acme/acme/wrappers/step_limit.py @@ -0,0 +1,42 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrapper that implements environment step limit.""" + +from typing import Optional +from acme import types +from acme.wrappers import base +import dm_env + + +class StepLimitWrapper(base.EnvironmentWrapper): + """A wrapper which truncates episodes at the specified step limit.""" + + def __init__(self, environment: dm_env.Environment, + step_limit: Optional[int] = None): + super().__init__(environment) + self._step_limit = step_limit + self._elapsed_steps = 0 + + def reset(self) -> dm_env.TimeStep: + self._elapsed_steps = 0 + return self._environment.reset() + + def step(self, action: types.NestedArray) -> dm_env.TimeStep: + timestep = self._environment.step(action) + self._elapsed_steps += 1 + if self._step_limit is not None and self._elapsed_steps >= self._step_limit: + return dm_env.truncation( + timestep.reward, timestep.observation, timestep.discount) + return timestep diff --git a/acme/acme/wrappers/video.py b/acme/acme/wrappers/video.py new file mode 100644 index 00000000..33493140 --- /dev/null +++ b/acme/acme/wrappers/video.py @@ -0,0 +1,237 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Environment wrappers which record videos. + +The code used to generate animations in this wrapper is based on that used in +the `dm_control/tutorial.ipynb` file. +""" + +import os.path +from typing import Callable, Optional, Sequence, Tuple, Union +from acme.utils import paths +from acme.wrappers import base +import dm_env + +import matplotlib +matplotlib.use('Agg') # Switch to headless 'Agg' to inhibit figure rendering. +import matplotlib.animation as anim # pylint: disable=g-import-not-at-top +import matplotlib.pyplot as plt +import numpy as np + +# Internal imports. +# Make sure you have FFMpeg configured. + +def make_animation(frames: Sequence[np.ndarray], frame_rate: float, + figsize: Union[float, Tuple[int, int]]) -> anim.Animation: + """Generates an animation from a stack of frames.""" + + # Set animation characteristics. + if figsize is None: + height, width, _ = frames[0].shape + elif isinstance(figsize, tuple): + height, width = figsize + else: + diagonal = figsize + height, width, _ = frames[0].shape + scale_factor = diagonal / np.sqrt(height**2 + width**2) + width *= scale_factor + height *= scale_factor + + dpi = 70 + interval = int(round(1e3 / frame_rate)) # Time (in ms) between frames. + + # Create and configure the figure. + fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi), dpi=dpi) + ax.set_axis_off() + ax.set_aspect('equal') + ax.set_position([0, 0, 1, 1]) + + # Initialize the first frame. + im = ax.imshow(frames[0]) + + # Create the function that will modify the frame, creating an animation. + def update(frame): + im.set_data(frame) + return [im] + + return anim.FuncAnimation( + fig=fig, + func=update, + frames=frames, + interval=interval, + blit=True, + repeat=False) + + +class VideoWrapper(base.EnvironmentWrapper): + """Wrapper which creates and records videos from generated observations. + + This will limit itself to recording once every `record_every` episodes and + videos will be recorded to the directory `path` + '//videos' where + `path` defaults to '~/acme'. Users can specify the size of the screen by + passing either a tuple giving height and width or a float giving the size + of the diagonal. + """ + + def __init__(self, + environment: dm_env.Environment, + *, + path: str = '~/acme', + filename: str = '', + process_path: Callable[[str, str], str] = paths.process_path, + record_every: int = 100, + frame_rate: int = 30, + figsize: Optional[Union[float, Tuple[int, int]]] = None): + super(VideoWrapper, self).__init__(environment) + self._path = process_path(path, 'videos') + self._filename = filename + self._record_every = record_every + self._frame_rate = frame_rate + self._frames = [] + self._counter = 0 + self._figsize = figsize + + def _render_frame(self, observation): + """Renders a frame from the given environment observation.""" + return observation + + def _write_frames(self): + """Writes frames to video.""" + if self._counter % self._record_every == 0: + path = os.path.join(self._path, + f'{self._filename}_{self._counter:04d}.html') + video = make_animation(self._frames, self._frame_rate, + self._figsize).to_html5_video() + + with open(path, 'w') as f: + f.write(video) + + # Clear the frame buffer whether a video was generated or not. + self._frames = [] + + def _append_frame(self, observation): + """Appends a frame to the sequence of frames.""" + if self._counter % self._record_every == 0: + self._frames.append(self._render_frame(observation)) + + def step(self, action) -> dm_env.TimeStep: + timestep = self.environment.step(action) + self._append_frame(timestep.observation) + return timestep + + def reset(self) -> dm_env.TimeStep: + # If the frame buffer is nonempty, flush it and record video + if self._frames: + self._write_frames() + self._counter += 1 + timestep = self.environment.reset() + self._append_frame(timestep.observation) + return timestep + + def make_html_animation(self): + if self._frames: + return make_animation(self._frames, self._frame_rate, + self._figsize).to_html5_video() + else: + raise ValueError('make_html_animation should be called after running a ' + 'trajectory and before calling reset().') + + def close(self): + if self._frames: + self._write_frames() + self._frames = [] + self.environment.close() + + +class MujocoVideoWrapper(VideoWrapper): + """VideoWrapper which generates videos from a mujoco physics object. + + This passes its keyword arguments into the parent `VideoWrapper` class (refer + here for any default arguments). + """ + + # Note that since we can be given a wrapped mujoco environment we can't give + # the type as dm_control.Environment. + + def __init__(self, + environment: dm_env.Environment, + *, + frame_rate: Optional[int] = None, + camera_id: Optional[int] = 0, + height: int = 240, + width: int = 320, + playback_speed: float = 1., + **kwargs): + + # Check that we have a mujoco environment (or a wrapper thereof). + if not hasattr(environment, 'physics'): + raise ValueError('MujocoVideoWrapper expects an environment which ' + 'exposes a physics attribute corresponding to a MuJoCo ' + 'physics engine') + + # Compute frame rate if not set. + if frame_rate is None: + try: + control_timestep = getattr(environment, 'control_timestep')() + except AttributeError as e: + raise AttributeError('MujocoVideoWrapper expects an environment which ' + 'exposes a control_timestep method, like ' + 'dm_control environments, or frame_rate ' + 'to be specified.') from e + frame_rate = int(round(playback_speed / control_timestep)) + + super().__init__(environment, frame_rate=frame_rate, **kwargs) + self._camera_id = camera_id + self._height = height + self._width = width + + def _render_frame(self, unused_observation): + del unused_observation + + # We've checked above that this attribute should exist. Pytype won't like + # it if we just try and do self.environment.physics, so we use the slightly + # grosser version below. + physics = getattr(self.environment, 'physics') + + if self._camera_id is not None: + frame = physics.render( + camera_id=self._camera_id, height=self._height, width=self._width) + else: + # If camera_id is None, we create a minimal canvas that will accommodate + # physics.model.ncam frames, and render all of them on a grid. + num_cameras = physics.model.ncam + num_columns = int(np.ceil(np.sqrt(num_cameras))) + num_rows = int(np.ceil(float(num_cameras)/num_columns)) + height = self._height + width = self._width + + # Make a black canvas. + frame = np.zeros((num_rows*height, num_columns*width, 3), dtype=np.uint8) + + for col in range(num_columns): + for row in range(num_rows): + + camera_id = row*num_columns + col + + if camera_id >= num_cameras: + break + + subframe = physics.render( + camera_id=camera_id, height=height, width=width) + + # Place the frame in the appropriate rectangle on the pixel canvas. + frame[row*height:(row+1)*height, col*width:(col+1)*width] = subframe + + return frame diff --git a/acme/docs/_static/custom.css b/acme/docs/_static/custom.css new file mode 100644 index 00000000..1927ca58 --- /dev/null +++ b/acme/docs/_static/custom.css @@ -0,0 +1,7 @@ +div.version { + color:#404040 !important; +} + +.wy-side-nav-search { + background-color: #fff; +} diff --git a/acme/docs/conf.py b/acme/docs/conf.py new file mode 100644 index 00000000..5611bf80 --- /dev/null +++ b/acme/docs/conf.py @@ -0,0 +1,41 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sphinx configuration. +""" + +project = 'Acme' +author = 'DeepMind Technologies Limited' +copyright = '2018, DeepMind Technologies Limited' # pylint: disable=redefined-builtin +version = '' +release = '' +master_doc = 'index' + +extensions = [ + 'myst_parser' +] + +html_theme = 'sphinx_rtd_theme' +html_logo = 'imgs/acme.png' +html_theme_options = { + 'logo_only': True, +} +html_css_files = [ + 'custom.css', +] + +templates_path = [] +html_static_path = ['_static'] +exclude_patterns = ['_build', 'requirements.txt'] + diff --git a/acme/docs/faq.md b/acme/docs/faq.md new file mode 100644 index 00000000..ae728374 --- /dev/null +++ b/acme/docs/faq.md @@ -0,0 +1,42 @@ +# FAQ + +## Environments + +- **Does Acme support my environment?** All _agents_ in Acme are designed to + work with environments which implement the + [dm_env environment interface][dm_env]. This interface, however, has been + designed to match general concepts widely in use across the RL research + community. As a result, it should be quite straight-forward to write a + wrapper for other environments in order to make them conform to this + interface. See e.g. the `acme.wrappers.gym_wrapper` module which can be used + to interact with [OpenAI Gym][gym] environments. + + Note: Follow the instructions [here][atari] to install ROMs for Atari + environments. + + Similarly, _learners_ in Acme are designed to consume dataset iterators + (generally `tf.Dataset` instances) which consume either transition tuples or + sequences of state, action, reward, etc. tuples. If your data does not match + these formats it should be relatively straightforward to write an adaptor! + See individual agents for more information on their expected input. + +[dm_env]: https://github.com/deepmind/dm_env +[gym]: https://gym.openai.com/ +[atari]: https://github.com/openai/atari-py#roms + +## TensorFlow agents + +- **How do I debug my TF2 learner?** Debugging TensorFlow code has never been + easier! All our learners’ `_step()` functions are decorated with a + `@tf.function` which can easily be commented out to run them in eager mode. + In this mode, one can easily run through the code (say, via `pdb`) line by + line and examine outputs. Most of the time, if your code works in eager + mode, it will work in graph mode (with the `@tf.function` decorator) but + there are rare exceptions when using exotic ops with unsupported dtypes. + Finally, don’t forget to add the decorator back in or you’ll find your + learner to be a little sluggish! + +## Misc. + +- **How should I spell Acme?** Acme is a proper noun, not an acronym, and + hence should be spelled "Acme" without caps. diff --git a/acme/docs/imgs/acme-notext.png b/acme/docs/imgs/acme-notext.png new file mode 100644 index 00000000..ba748235 Binary files /dev/null and b/acme/docs/imgs/acme-notext.png differ diff --git a/acme/docs/imgs/acme.png b/acme/docs/imgs/acme.png new file mode 100644 index 00000000..7af7c084 Binary files /dev/null and b/acme/docs/imgs/acme.png differ diff --git a/acme/docs/index.rst b/acme/docs/index.rst new file mode 100644 index 00000000..3bf3d2ba --- /dev/null +++ b/acme/docs/index.rst @@ -0,0 +1,26 @@ +Welcome to Acme +--------------- + +Acme is a library of reinforcement learning (RL) building blocks that strives to +expose simple, efficient, and readable agents. These agents first and foremost +serve both as reference implementations as well as providing strong baselines +for algorithm performance. However, the baseline agents exposed by Acme should +also provide enough flexibility and simplicity that they can be used as a +starting block for novel research. Finally, the building blocks of Acme are +designed in such a way that the agents can be written at multiple scales (e.g. +single-stream vs. distributed agents). + +.. toctree:: + :hidden: + :titlesonly: + + self + + +.. toctree:: :caption: Getting started + :titlesonly: + + user/overview + user/agents + user/components + faq diff --git a/acme/docs/requirements.txt b/acme/docs/requirements.txt new file mode 100644 index 00000000..1ce629a3 --- /dev/null +++ b/acme/docs/requirements.txt @@ -0,0 +1,3 @@ +myst-parser +markdown-callouts + diff --git a/acme/docs/user/agents.md b/acme/docs/user/agents.md new file mode 100644 index 00000000..f9463194 --- /dev/null +++ b/acme/docs/user/agents.md @@ -0,0 +1,115 @@ +# Agents + +Acme includes a number of pre-built agents listed below. All are provided as +single-process agents, but we also include a distributed implementation using +[Launchpad](https://github.com/deepmind/launchpad). Distributed agents share +the exact same learning and acting code as their single-process counterparts +and can be executed either on a single machine +(--lp_launch_type=[local_mt|local_mp] command line flag for multi-threaded or +multi-process execution) or multi machine setup on GCP +(--lp_launch_type=vertex_ai). For details please refer to +[Launchpad documentation](https://github.com/deepmind/launchpad/search?q=%22class+LaunchType%22). + +We've listed the agents below in separate sections based on their different +use cases, however these distinction are often subtle. For more information on +each implementation see the relevant agent-specific README. + +## Continuous control + +Acme has long had a focus on continuous control agents (i.e. settings where the +action space consists of a continuous space). The following agents focus on this +setting: + +Agent | Paper | Code +----------------------------------------------------------------- | ------------------------------- | ---- +Deep Deterministic Policy Gradient (DDPG) | Lillicrap et al., 2015 | [![TF]][DDPG_TF2] +Distributed Distributional Deep Determinist (D4PG) | Barth-Maron et al., 2018 | [![TF]][D4PG_TF2] +Maximum a posteriori Policy Optimisation (MPO) | Abdolmaleki et al., 2018 | [![TF]][MPO_TF2] +Distributional Maximum a posteriori Policy Optimisation (DMPO) | - | [![TF]][DMPO_TF2] +Multi-Objective Maximum a posteriori Policy Optimisation (MO-MPO) | Abdolmaleki, Huang et al., 2020 | [![TF]][MOMPO_TF2] + +
+ +## Discrete control + +We also include a number of agents built with discrete action-spaces in mind. +Note that the distinction between these agents and the continuous agents listed +can be somewhat arbitrary. E.g. Impala could be implemented for continuous +action spaces as well, but here we focus on a discrete-action variant. + +Agent | Paper | Code +-------------------------------------------------------- | -------------------------- | ---- +Deep Q-Networks (DQN) | [Horgan et al., 2018] | [![TF]][DQN_TF2] [![JAX]][DQN_JAX] +Importance-Weighted Actor-Learner Architectures (IMPALA) | [Espeholt et al., 2018] | [![TF]][IMPALA_TF2] [![JAX]][IMPALA_JAX] +Recurrent Replay Distributed DQN (R2D2) | [Kapturowski et al., 2019] | [![TF]][R2D2_TF2] + +
+ +## Batch RL + +The structure of Acme also lends itself quite nicely to "learner-only" algorithm +for use in Batch RL (with no environment interactions). Implemented algorithms +include: + +Agent | Paper | Code +--------------------- | ----- | -------------------------------- +Behavior Cloning (BC) | - | [![TF]][BC_TF2] [![JAX]][BC_JAX] + +
+ +## Learning from demonstrations + +Acme also easily allows active data acquisition to be combined with data from +demonstrations. Such algorithms include: + +Agent | Paper | Code +----------------------------------------------------------- | --------------------- | ---- +Deep Q-Learning from Demonstrations (DQfD) | Hester et al., 2017 | [![TF]][DQFD_TF2] +Recurrent Replay Distributed DQN from Demonstratinos (R2D3) | Gulcehre et al., 2020 | [![TF]][R2D3_TF2] + +
+ +## Model-based RL + +Finally, Acme also includes a variant of MCTS which can be used for model-based +RL using a given or learned simulator + +Agent | Paper | Code +------------------------------ | --------------------- | ----------------- +Monte-Carlo Tree Search (MCTS) | [Silver et al., 2018] | [![TF]][MCTS_TF2] + +
+ + + +[TF]: logos/tf-small.png +[JAX]: logos/jax-small.png + + + +[DQN_TF2]: https://github.com/deepmind/acme/blob/master/acme/agents/tf/dqn/ +[IMPALA_TF2]: https://github.com/deepmind/acme/blob/master/acme/agents/tf/impala/ +[R2D2_TF2]: https://github.com/deepmind/acme/blob/master/acme/agents/tf/r2d2/ +[MCTS_TF2]: https://github.com/deepmind/acme/blob/master/acme/agents/tf/mcts/ +[DDPG_TF2]: https://github.com/deepmind/acme/blob/master/acme/agents/tf/ddpg/ +[D4PG_TF2]: https://github.com/deepmind/acme/blob/master/acme/agents/tf/d4pg/ +[MPO_TF2]: https://github.com/deepmind/acme/blob/master/acme/agents/tf/mpo/ +[DMPO_TF2]: https://github.com/deepmind/acme/blob/master/acme/agents/tf/dmpo/ +[MOMPO_TF2]: https://github.com/deepmind/acme/blob/master/acme/agents/tf/mompo/ +[BC_TF2]: https://github.com/deepmind/acme/blob/master/acme/agents/tf/bc/ +[DQFD_TF2]: https://github.com/deepmind/acme/blob/master/acme/agents/tf/dqfd/ +[R2D3_TF2]: https://github.com/deepmind/acme/blob/master/acme/agents/tf/r2d3/ + + + +[DQN_JAX]: https://github.com/deepmind/acme/blob/master/acme/agents/jax/dqn/ +[IMPALA_JAX]: https://github.com/deepmind/acme/blob/master/acme/agents/jax/impala/ +[D4PG_JAX]: https://github.com/deepmind/acme/blob/master/acme/agents/jax/d4pg/ +[BC_JAX]: https://github.com/deepmind/acme/blob/master/acme/agents/jax/bc/ + + + +[Horgan et al., 2018]: https://arxiv.org/abs/1803.00933 +[Silver et al., 2018]: https://science.sciencemag.org/content/362/6419/1140 +[Espeholt et al., 2018]: https://arxiv.org/abs/1802.01561 +[Kapturowski et al., 2019]: https://openreview.net/pdf?id=r1lyTjAqYX diff --git a/acme/docs/user/components.md b/acme/docs/user/components.md new file mode 100644 index 00000000..0f50a9f2 --- /dev/null +++ b/acme/docs/user/components.md @@ -0,0 +1,366 @@ +# Components + +## Environments + +Acme is designed to work with environments which implement the +[dm_env environment interface][dm_env]. This provides a common API for +interacting with an environment in order to take actions and receive +observations. This environment API also provides a standard way for environments +to specify the input and output spaces relevant to that environment via methods +like `environment.action_spec()`. Note that Acme also exposes these spec types +directly via `acme.specs`. However, it is also common for Acme agents to require +a _full environment spec_ which can be obtained by making use of +`acme.make_environment_spec(environment)`. + +[dm_env]: https://github.com/deepmind/dm_env + +Acme also exposes, under `acme.wrappers`, a number of classes which wrap and/or +expose a `dm_env` environment. All such wrappers are of the form: + +```python +environment = Wrapper(raw_environment, ...) +``` + +where additional parameters may be passed to the wrapper to control its behavior +(see individual implementations for more details). Wrappers exposed directly +include + +- `SinglePrecisionWrapper`: converts any double-precision `float` and `int` + components returned by the environment to single-precision. +- `AtariWrapper`: converts a standard ALE Atari environment using a stack of + wrappers corresponding to the modifications used in the + "[Human Level Control Through Deep Reinforcement Learning][nature-atari]" + publication. + +Acme also includes the `acme.wrappers.gym_wrapper` module which can be used to +interact with [OpenAI Gym][gym] environments. This includes a general +`GymWrapper` class as well as `AtariGymWrapper` which exposes a lives count +observation which can optionally be exposed by the `AtariWrapper`. + +[nature-atari]: https://deepmind.com/research/publications/playing-atari-deep-reinforcement-learning +[gym]: https://gym.openai.com/ + +## Networks + +An important building block for any agent implementation consists of the +parameterized functions or networks which are used to construct policies, value +functions, etc. Agents implemented in Acme are built to be as agnostic as +possible to the environment on which they will be applied. As a result they +typically require network(s) which are used to directly interact with the +environment either by consuming observations, producing actions, or both. These +are typically passed directly into the agent at initialization, e.g. + +```python +policy_network = ... +critic_network = ... +agent = MyActorCriticAgent(policy_network, critic_network, ...) +``` + +### Tensorflow and Sonnet + +For TensorFlow agents, networks in Acme are typically implemented using the +[Sonnet][sonnet] neural network library. These network objects take the form of +a `Callable` object which takes a collection of (nested) `tf.Tensor` objects as +input and outputs a collection of (nested) `tf.Tensor` or `tfp.Distribution` +objects. In what follows we use the following aliases. + +```python +import sonnet as snt +import tensorflow as tf +import tensorflow_probability as tfp + +tfd = tfp.distributions +``` + +While custom Sonnet modules can be implemented and used directly, Acme also +provides a number of useful network primitives which are tailored to RL tasks; +these can be imported from `acme.tf.networks`, see [networks] for more examples. +These primitives can be combined using `snt.Sequential`, or `snt.DeepRNN` when +stacking network modules with state. + +When stacking modules it is often, though not always, helpful to distinguish +between what are often called _torso_, _head_, and multiplexer networks. Note +that this categorization is purely pedagogical but has nevertheless proven +useful when discussing network architectures. + +Torsos are usually the first to transform inputs (observations, actions, or a +combination) and produce what is commonly known in the deep learning literature +as an embedding vector. These modules can be stacked so that an embedding is +transformed multiple times before it is fed into a head. + +Let us consider for instance a simple network we use in the Impala agent when +training it on Atari games: + +```python +impala_network = snt.DeepRNN([ + # Torsos. + networks.AtariTorso(), # Default Atari ConvNet offered as convenience. + snt.LSTM(256), # Custom LSTM core. + snt.Linear(512), # Custom perceptron layer before head. + tf.nn.relu, # Seemlessly stack Sonnet modules and TF ops as usual. + # Head producing 18 action logits and a value estimate for the input + # observation. + networks.PolicyValueHead(num_actions=18), +]) +``` + +Heads are networks that consume the embedding vector to produce a desired output +(e.g. action logits or distribution, value estimates, etc). These modules can +also be stacked, which is useful particularly when dealing with stochastic +policies. For example, consider the following stochastic policy used in the MPO +agent, trained on the control suite: + +```python +policy_layer_sizes: Sequence[int] = (256, 256, 256) + +stochastic_policy_network = snt.Sequential([ + # MLP torso with initial layer normalization; activate the final layer since + # it feeds into another module. + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + # Head producing a tfd.Distribution: in this case `num_dimensions` + # independent normal distributions. + networks.MultivariateNormalDiagHead(num_dimensions), + ]) +``` + +This stochastic policy is used internally by the MPO algorithm to compute log +probabilities and Kullback-Leibler (KL) divergences. We can also stack an +additional head that will select the mode of the stochastic policy as a greedy +action: + +```python +greedy_policy_network = snt.Sequential([ + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead(num_dimensions), + networks.StochasticModeHead(), + ]) +``` + +When designing our actor-critic agents for continuous control tasks, we found +one simple module particularly useful: the `CriticMultiplexer`. This callable +Sonnet module takes two inputs, an observation and an action, and concatenates +them along all but the batch dimension, after possibly transforming them if +either (both) `[observation|action]_network` is (are) passed. For example, the +following is the C51 (see [Bellemare et al., 2017]) distributional critic +network adapted for our D4PG experiments: + +```python +critic_layer_sizes: Sequence[int] = (512, 512, 256) + +distributional_critic_network = snt.Sequential([ + # Flattens and concatenates inputs; see `tf2_utils.batch_concat` for more. + networks.CriticMultiplexer(), + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + # Distributional head corresponding to the C51 network. + networks.DiscreteValuedHead(vmin=-150., vmax=150., num_atoms=51), +]) +``` + +Finally, our actor-critic control agents also allow the specification of an +observation network that is shared by the policy and critic. This network embeds +the observations once and uses the transformed input in both the policy and +critic as needed, which saves computation particularly when the transformation +is expensive. This is the case for example when learning from pixels where the +observation network can be a large ResNet. In such cases, the shared visual +network can be specified to any of DDPG, D4PG, MPO, DMPO by simply defining and +passing the following: + +```python +shared_resnet = networks.ResNetTorso() # Default (deep) Impala network. + +agent = dmpo.DMPO( + # Networks defined above. + policy_network=stochastic_policy_network, + critic_network=distributional_critic_network, + # New ResNet visual module, shared by both policy and critic. + observation_network=shared_resnet, + # ... +) +``` + +In this case, the `policy_` and `critic_network` act as heads on top of the +shared visual torso. + +[networks]: https://github.com/deepmind/acme/blob/master/acme/tf/networks/ +[sonnet]: https://github.com/deepmind/sonnet/ + +## Internal components + +Acme also includes a number of components and concepts that are typically +internal to an agent's implementation. These components can, in general, be +ignored if you are only interested in using an Acme agent. However they prove +useful when implementing a novel agent or modifying an existing agent. + +### Losses + +These are some commonly-used loss functions. Note that in general we defer to +[TRFL][trfl] where possible, except in cases for which it does not support +TensorFlow 2. + +[trfl]: https://github.com/deepmind/trfl + +RL-specific losses implemented include: + +- a [distributional TD loss][distributional] for categorical distributions; + see [Bellemare et al., 2017]. +- the Deterministic Policy Gradient [(DPG) loss][dpg]; see + [Silver et al., 2014]. +- the Maximum a posteriori Policy Optimization [(MPO) loss][mpo]; see + [Abdolmaleki et al., 2018]. + +Also implemented (and useful within the losses mentioned above) are: + +- the [Huber loss][huber] for robust regression. + +[distributional]: https://github.com/deepmind/acme/blob/master/acme/tf/losses/distributional.py +[dpg]: https://github.com/deepmind/acme/blob/master/acme/tf/losses/dpg.py +[mpo]: https://github.com/deepmind/acme/blob/master/acme/tf/losses/mpo.py +[huber]: https://github.com/deepmind/acme/blob/master/acme/tf/losses/huber.py +[Abdolmaleki et al., 2018]: https://arxiv.org/abs/1806.06920 +[Bellemare et al., 2017]: https://arxiv.org/abs/1707.06887 +[Silver et al., 2014]: http://proceedings.mlr.press/v32/silver14 + +### Adders + +An `Adder` packs together data to send to the replay buffer, and potentially +does some reductions/transformations to this data in the process. + +All Acme `Adder`s can be interacted through their `add()`, `add_first()`, and +`reset()` methods. + +The `add()` method takes actions, timesteps, and potentially some extras and +adds the `action`, `observation`, `reward`, `discount`, `extra` fields to the +buffer. + +The `add_first()` method takes the first timestep of an episode and adds it to +the buffer, automatically padding the empty `action` `reward` `discount`, and +`extra` fields that don't exist at the first timestep of an episode. + +The `reset` method clears the buffer. + +Example usage of an adder: + +```python +# Reset the environment and add the first observation. +timestep = env.reset() +adder.add_first(timestep) + +while not timestep.last(): + # Generate an action from the policy and step the environment. + action = my_policy(timestep) + timestep = env.step(action) + + # Add the action and the resulting timestep. + adder.add(action, next_timestep=timestep) +``` + +### ReverbAdders + +Acme uses [Reverb](http://github.com/deepmind/reverb) for creating data structures like +*replay buffers* to store RL experiences. + +For convenience, Acme provides several `ReverbAdders` for adding actor +experiences to a Reverb table. The `ReverbAdder`s provided include: + +* `NStepTransitionAdder` takes single steps from an environment/agent loop, + automatically concatenates them into N-step transitions, and adds the + transitions to Reverb for future retrieval. The steps are buffered and then + concatenated into N-step transitions, which are stored in and returned from + replay. + + Where N is 1, the transitions are of the form: + + ``` + (s_t, a_t, r_t, d_t, s_{t+1}, e_t) + ``` + + For N greater than 1, transitions are of the form: + + ``` + (s_t, a_t, R_{t:t+n}, D_{t:t+n}, s_{t+n}, e_t), + ``` + + Transitions can be stored as sequences or episodes. + +* `EpisodeAdder` which adds entire episodes as trajectories of the form: + + ```python + (s_0, a_0, r_0, d_0, e_0, + s_1, a_1, r_1, d_1, e_1, + . + . + . + s_T, a_T, r_T, 0., e_T) + ``` + +* `SequenceAdder` which adds sequences of fixed `sequence_length` n of the + form: + + ```python + (s_0, a_0, r_0, d_0, e_0, + s_1, a_1, r_1, d_1, e_1, + . + . + . + s_n, a_n, r_n, d_n, e_n) + ``` + + sequences can be overlapping (if `period < sequence_length`) or + non-overlapping (if `period >= sequence_length`) + +### [Loggers](https://github.com/deepmind/acme/blob/master/acme/utils/loggers/) + +Acme contains several loggers for writing out data to common places, +based on the abstract `Logger` class, all with `write()` methods.

+NOTE: By default, loggers will immediately output all data passed through `write()` unless given a nonzero value for the `time_delta` argument when constructing a logger representing the number of seconds between logger outputs.
+ +#### [Terminal Logger](https://github.com/deepmind/acme/blob/master/acme/utils/loggers/terminal.py) + +Logs data directly to the terminal.

+Example:
+ +```python +terminal_logger = loggers.TerminalLogger(label='TRAINING',time_delta=5) +terminal_logger.write({'step': 0, 'reward': 0.0}) + +>> TRAINING: step: 0, reward: 0.0 +``` + +#### [CSV Logger](https://github.com/deepmind/acme/blob/master/acme/utils/loggers/csv.py) + +Logs to specified CSV file.

+Example:
+ +```python +csv_logger = loggers.CSVLogger(logdir='logged_data', label='my_csv_file') +csv_logger.write({'step': 0, 'reward': 0.0}) +``` + +### [Tensorflow savers](https://github.com/deepmind/acme/blob/master/acme/tf/savers.py) + +To save trained TensorFlow models, we can *checkpoint* or *snapshot* +them.
+ +Both *checkpointing* and *snapshotting* are ways to save and restore model state + for later use. The difference comes when restoring the checkpoint.
+ +With checkpoints, you have to first re-build the exact graph, then restore the + checkpoint. They are useful to have while running experiments, in case the + experiment gets interrupted/preempted and has to be restored to continue the + experiment run without losing the experiment state.
+ +Snapshots re-build the graph internally, so all you have to do is restore the + snapshot.
+ +Acme provides Checkpointer and Snapshotter classes to checkpoint and snapshot + respectively parts of the model state as desired.
+ +```python + model = snt.Linear(10) + checkpointer = tf2_savers.Checkpointer(objects_to_save={'model': model}) + snapshotter = tf2_savers.Snapshotter(objects_to_save={'model': model}) + for _ in range(100): + # ... + checkpointer.save() + snapshotter.save() +``` diff --git a/acme/docs/user/diagrams/actor_loop.png b/acme/docs/user/diagrams/actor_loop.png new file mode 100644 index 00000000..d2c22291 Binary files /dev/null and b/acme/docs/user/diagrams/actor_loop.png differ diff --git a/acme/docs/user/diagrams/agent_loop.png b/acme/docs/user/diagrams/agent_loop.png new file mode 100644 index 00000000..15ec82a9 Binary files /dev/null and b/acme/docs/user/diagrams/agent_loop.png differ diff --git a/acme/docs/user/diagrams/batch_loop.png b/acme/docs/user/diagrams/batch_loop.png new file mode 100644 index 00000000..49a2bde1 Binary files /dev/null and b/acme/docs/user/diagrams/batch_loop.png differ diff --git a/acme/docs/user/diagrams/distributed_loop.png b/acme/docs/user/diagrams/distributed_loop.png new file mode 100644 index 00000000..351d9d0c Binary files /dev/null and b/acme/docs/user/diagrams/distributed_loop.png differ diff --git a/acme/docs/user/diagrams/environment_loop.png b/acme/docs/user/diagrams/environment_loop.png new file mode 100644 index 00000000..56a69275 Binary files /dev/null and b/acme/docs/user/diagrams/environment_loop.png differ diff --git a/acme/docs/user/logos/jax-small.png b/acme/docs/user/logos/jax-small.png new file mode 100644 index 00000000..ba663d05 Binary files /dev/null and b/acme/docs/user/logos/jax-small.png differ diff --git a/acme/docs/user/logos/tf-small.png b/acme/docs/user/logos/tf-small.png new file mode 100644 index 00000000..85e775f1 Binary files /dev/null and b/acme/docs/user/logos/tf-small.png differ diff --git a/acme/docs/user/overview.md b/acme/docs/user/overview.md new file mode 100644 index 00000000..dbdfcf95 --- /dev/null +++ b/acme/docs/user/overview.md @@ -0,0 +1,81 @@ +# Overview + +The design of Acme attempts to provide multiple points of entry to the RL +problem at differing levels of complexity. The first entry-point — and +easiest way to get started — is just by running one of our +[baseline agents](agents.md). This can be done simply by connecting an agent (or +actor) instance to an environment using an environment loop. This instantiates +the standard mode of interaction with an environment common in RL and +illustrated by the following diagram: + +![Environment loop](diagrams/environment_loop.png) + +This setting should, of course, look familiar to any RL practitioner, and with +this you can be up and running an Acme agent within a few lines of code. +Environments used by Acme are assumed to conform to the +[DeepMind Environment API][dm_env] which provides a simple mechanism to both +reset the environment to some initial state as well as to step the environment +and produce observations. + +[dm_env]: https://github.com/deepmind/dm_env + +Actors in Acme expose three primary methods: `select_action` which returns +actions to be taken, `observe` which records observations from the environment, +and an `update` method. In fact, by making use of these methods, the +`EnvironmentLoop` illustrated above can be roughly approximated by + +```python +while True: + # Make an initial observation. + step = environment.reset() + actor.observe_first(step.observation) + + while not step.last(): + # Evaluate the policy and take a step in the environment. + action = actor.select_action(step.observation) + step = environment.step(action) + + # Make an observation and update the actor. + actor.observe(action, next_step=step) + actor.update() +``` + +> NOTE: Currently in Acme the default method for observing data we make use of +> utilizes of observe/observe_last methods (reverse of the above). This is being +> phased out in favor of the above, which will soon be made the default. + +Internally, agents built using Acme are written with modular _acting_ and +_learning_ components. By acting we refer to the sub-components used to generate +experience and by learning we are referring to the process of training the +relevant action-selection models (typically neural networks). An illustration of +this breakdown of an agent is given below: + +![Agent loop](diagrams/agent_loop.png) + +Superficially this allows us to share the same experience generating code +between multiple agents. More importantly this split greatly simplifies the way +in which distributed agents are constructed. + +Distributed agents are built using all of the same components as their +single-process counterparts, but split so that the components for acting, +learning, evaluation, replay, etc. each run in their own process. An +illustration of this is shown below, and here you can see that it follows the +same template as above, just with many different actors/environments: + +![Distributed loop](diagrams/distributed_loop.png) + +This greatly simplifies the process of designing a novel agent and testing +existing agents where the differences in scale can be roughly ignored. This even +allows for us to scale all the way down to the batch or offline setting wherein +there is no data generation process and only a fixed dataset: + +![Batch loop](diagrams/batch_loop.png) + +Finally, Acme also includes a number of useful utilities that help keep agent +code readable, and that make the process of writing the next agent that much +easier. We provide common tooling for these components ranging from +checkpointing to snapshotting, various forms of logging, and other low-level +computations. For more information on these components, as well as the structure +described above, see our more detailed discussion of Acme +[components](components.md) or take a look at implementations of a specific +[agents](agents.md). diff --git a/acme/examples/README.md b/acme/examples/README.md new file mode 100644 index 00000000..66ec9f12 --- /dev/null +++ b/acme/examples/README.md @@ -0,0 +1,83 @@ +# Examples + +This directory includes a number of working examples of Acme agents. These +examples are not meant to be comprehensive and instead show a number of common +use cases under which Acme agents can be applied. + +Our [quickstart] guide can be used to get running quickly. This +notebook will show how to instantiate a simple agent and run it on an +environment. You can also take a look at our [tutorial], which takes a more +in-depth look at the construction of the D4PG agent. This also highlights the +general structure of most Acme agents which applies more broadly to all agents +implemented in Acme. + +[quickstart]: https://github.com/deepmind/acme/blob/master/examples/quickstart.ipynb +[tutorial]: https://github.com/deepmind/acme/blob/master/examples/tutorial.ipynb + + +## Continuous control + +We include a number of agents running on continuous control tasks. These agents +are representative examples, but any continuous control algorithm implemented in +Acme should be able to be swapped in. + +Note that many of the examples, particularly those based on the DeepMind Control +Suite, will require a [MuJoCo license](https://www.roboti.us/license.html) in +order to run. See our [tutorial] for more details or see refer to the +[dm_control] repository for further information. + +- [D4PG](https://github.com/deepmind/acme/blob/master/examples/baselines/rl_continuous/run_d4pg.py): a deterministic policy gradient (D4PG) agent which includes a determinstic + policy and a distributional critic running on the DeepMind Control Suite or + the [OpenAI Gym]. By default it runs on the "half cheetah" environment from the + OpenAI Gym. +- [MPO](https://github.com/deepmind/acme/blob/master/examples/baselines/rl_continuous/run_mpo.py): a maximum-a-posterior policy optimization agent which combines both a distributional critic and a stochastic policy. + +[dm_control]: https://github.com/deepmind/dm_control +[OpenAI Gym]: https://github.com/openai/gym + + +## Discrete agents (Atari) + +The development of the [Arcade Learning environment] and the coinciding use +of Atari as a benchmark has played a very prominent role in the modern usage and +testing of reinforcement learning algorithms. As a result we've also included +direct examples of prominent discrete-action algorithms implemented in Acme and +running on this environment. + +- [DQN](https://github.com/deepmind/acme/blob/master/examples/baselines/rl_discrete/run_dqn.py): a "classic" benchmark agent for Atari; and + +[Arcade Learning environment]: https://arxiv.org/abs/1207.4708 + + +## Offline agents + +Acme includes examples of offline agents, i.e. agents trained using external +data generated by another agent: + +- [BC](https://github.com/deepmind/acme/blob/master/examples/offline/run_bc.py): a behaviour cloning agent. +- [BC (JAX)](https://github.com/deepmind/acme/blob/master/examples/offline/run_bc_jax.py): a behaviour cloning agent (implemented + in jax). +- [BCQ](https://github.com/deepmind/acme/blob/master/examples/offline/run_bcq.py): an implementation of BCQ. + +Similarly we also include so-called "from demonstration" agents which mix +offline and online data: + +- [DQfD](https://github.com/deepmind/acme/blob/master/examples/offline/run_dqfd.py): the DQfD agent running on hard-exploration + tasks within bsuite (e.g. deep sea) using demonstration data; and + + +## Behaviour Suite + +The [Behaviour Suite for Reinforcement Learning](bsuite) defines a collection +of tasks and environments which collectively investigate core capabilities of RL +algorithms across a number of different axes. The examples we include +show how to run Acme agents on this suite. + +- [DQN](https://github.com/deepmind/acme/blob/master/examples/bsuite/run_dqn.py): an off-policy DQN examples; +- [Impala](https://github.com/deepmind/acme/blob/master/examples/bsuite/run_impala.py): an on-policy Impala agent; and +- [MCTS](https://github.com/deepmind/acme/blob/master/examples/bsuite/run_mcts.py): a model-based agent running on the + task suite using either a simulator of the environment or a learned model. + +For more information see https://github.com/deepmind/bsuite. + +[bsuite]: https://github.com/deepmind/bsuite diff --git a/acme/examples/baselines/rl_continuous/helpers.py b/acme/examples/baselines/rl_continuous/helpers.py new file mode 100644 index 00000000..9c74703c --- /dev/null +++ b/acme/examples/baselines/rl_continuous/helpers.py @@ -0,0 +1,57 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared helpers for rl_continuous experiments.""" + +from acme import wrappers +import dm_env +import gym + + +_VALID_TASK_SUITES = ('gym', 'control') + + +def make_environment(suite: str, task: str) -> dm_env.Environment: + """Makes the requested continuous control environment. + + Args: + suite: One of 'gym' or 'control'. + task: Task to load. If `suite` is 'control', the task must be formatted as + f'{domain_name}:{task_name}' + + Returns: + An environment satisfying the dm_env interface expected by Acme agents. + """ + + if suite not in _VALID_TASK_SUITES: + raise ValueError( + f'Unsupported suite: {suite}. Expected one of {_VALID_TASK_SUITES}') + + if suite == 'gym': + env = gym.make(task) + # Make sure the environment obeys the dm_env.Environment interface. + env = wrappers.GymWrapper(env) + + elif suite == 'control': + # Load dm_suite lazily not require Mujoco license when not using it. + from dm_control import suite as dm_suite # pylint: disable=g-import-not-at-top + domain_name, task_name = task.split(':') + env = dm_suite.load(domain_name, task_name) + env = wrappers.ConcatObservationWrapper(env) + + # Wrap the environment so the expected continuous action spec is [-1, 1]. + # Note: this is a no-op on 'control' tasks. + env = wrappers.CanonicalSpecWrapper(env, clip=True) + env = wrappers.SinglePrecisionWrapper(env) + return env diff --git a/acme/examples/baselines/rl_continuous/run_d4pg.py b/acme/examples/baselines/rl_continuous/run_d4pg.py new file mode 100644 index 00000000..79e7c4e8 --- /dev/null +++ b/acme/examples/baselines/rl_continuous/run_d4pg.py @@ -0,0 +1,85 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example running D4PG on continuous control tasks.""" + +from absl import flags +from acme.agents.jax import d4pg +import helpers +from absl import app +from acme.jax import experiments +from acme.utils import lp_utils +import launchpad as lp + +FLAGS = flags.FLAGS + +flags.DEFINE_bool( + 'run_distributed', False, 'Should an agent be executed in a ' + 'distributed way (the default is a single-threaded agent)') +flags.DEFINE_string('env_name', 'gym:HalfCheetah-v2', 'What environment to run') +flags.DEFINE_integer('seed', 0, 'Random seed.') +flags.DEFINE_integer('num_steps', 1_000_000, 'Number of env steps to run.') +flags.DEFINE_integer('eval_every', 50_000, 'How often to run evaluation.') +flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') + + +def build_experiment_config(): + """Builds D4PG experiment config which can be executed in different ways.""" + + # Create an environment, grab the spec, and use it to create networks. + suite, task = FLAGS.env_name.split(':', 1) + + # Bound of the distributional critic. The reward for control environments is + # normalized, not for gym locomotion environments hence the different scales. + vmax_values = { + 'gym': 1000., + 'control': 150., + } + vmax = vmax_values[suite] + + def network_factory(spec) -> d4pg.D4PGNetworks: + return d4pg.make_networks( + spec, + policy_layer_sizes=(256, 256, 256), + critic_layer_sizes=(256, 256, 256), + vmin=-vmax, + vmax=vmax, + ) + + # Configure the agent. + d4pg_config = d4pg.D4PGConfig(learning_rate=3e-4, sigma=0.2) + + return experiments.ExperimentConfig( + builder=d4pg.D4PGBuilder(d4pg_config), + environment_factory=lambda seed: helpers.make_environment(suite, task), + network_factory=network_factory, + seed=FLAGS.seed, + max_num_actor_steps=FLAGS.num_steps) + + +def main(_): + config = build_experiment_config() + if FLAGS.run_distributed: + program = experiments.make_distributed_experiment( + experiment=config, num_actors=4) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment( + experiment=config, + eval_every=FLAGS.eval_every, + num_eval_episodes=FLAGS.evaluation_episodes) + + +if __name__ == '__main__': + app.run(main) diff --git a/acme/examples/baselines/rl_continuous/run_ppo.py b/acme/examples/baselines/rl_continuous/run_ppo.py new file mode 100644 index 00000000..648c8758 --- /dev/null +++ b/acme/examples/baselines/rl_continuous/run_ppo.py @@ -0,0 +1,68 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example running PPO on continuous control tasks.""" + +from absl import flags +from acme.agents.jax import ppo +import helpers +from absl import app +from acme.jax import experiments +from acme.utils import lp_utils +import launchpad as lp + +FLAGS = flags.FLAGS + +flags.DEFINE_bool( + 'run_distributed', False, 'Should an agent be executed in a ' + 'distributed way (the default is a single-threaded agent)') +flags.DEFINE_string('env_name', 'gym:HalfCheetah-v2', 'What environment to run') +flags.DEFINE_integer('seed', 0, 'Random seed.') +flags.DEFINE_integer('num_steps', 1_000_000, 'Number of env steps to run.') +flags.DEFINE_integer('eval_every', 50_000, 'How often to run evaluation.') +flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') + + +def build_experiment_config(): + """Builds PPO experiment config which can be executed in different ways.""" + # Create an environment, grab the spec, and use it to create networks. + suite, task = FLAGS.env_name.split(':', 1) + + config = ppo.PPOConfig(entropy_cost=0, learning_rate=1e-4) + ppo_builder = ppo.PPOBuilder(config) + + layer_sizes = (256, 256, 256) + return experiments.ExperimentConfig( + builder=ppo_builder, + environment_factory=lambda seed: helpers.make_environment(suite, task), + network_factory=lambda spec: ppo.make_networks(spec, layer_sizes), + seed=FLAGS.seed, + max_num_actor_steps=FLAGS.num_steps) + + +def main(_): + config = build_experiment_config() + if FLAGS.run_distributed: + program = experiments.make_distributed_experiment( + experiment=config, num_actors=4) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment( + experiment=config, + eval_every=FLAGS.eval_every, + num_eval_episodes=FLAGS.evaluation_episodes) + + +if __name__ == '__main__': + app.run(main) diff --git a/acme/examples/baselines/rl_continuous/run_sac.py b/acme/examples/baselines/rl_continuous/run_sac.py new file mode 100644 index 00000000..a858aa3b --- /dev/null +++ b/acme/examples/baselines/rl_continuous/run_sac.py @@ -0,0 +1,86 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example running SAC on continuous control tasks.""" + +from absl import flags +from acme import specs +from acme.agents.jax import normalization +from acme.agents.jax import sac +from acme.agents.jax.sac import builder +import helpers +from absl import app +from acme.jax import experiments +from acme.utils import lp_utils +import launchpad as lp + +FLAGS = flags.FLAGS + +flags.DEFINE_bool( + 'run_distributed', False, 'Should an agent be executed in a ' + 'distributed way (the default is a single-threaded agent)') +flags.DEFINE_string('env_name', 'gym:HalfCheetah-v2', 'What environment to run') +flags.DEFINE_integer('seed', 0, 'Random seed.') +flags.DEFINE_integer('num_steps', 1_000_000, 'Number of env steps to run.') +flags.DEFINE_integer('eval_every', 50_000, 'How often to run evaluation.') +flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') + + +def build_experiment_config(): + """Builds SAC experiment config which can be executed in different ways.""" + # Create an environment, grab the spec, and use it to create networks. + + suite, task = FLAGS.env_name.split(':', 1) + environment = helpers.make_environment(suite, task) + + environment_spec = specs.make_environment_spec(environment) + network_factory = ( + lambda spec: sac.make_networks(spec, hidden_layer_sizes=(256, 256, 256))) + + # Construct the agent. + config = sac.SACConfig( + learning_rate=3e-4, + n_step=2, + target_entropy=sac.target_entropy_from_env_spec(environment_spec)) + sac_builder = builder.SACBuilder(config) + # One batch dimension: [batch_size, ...] + batch_dims = (0,) + sac_builder = normalization.NormalizationBuilder( + sac_builder, + is_sequence_based=False, + batch_dims=batch_dims) + + return experiments.ExperimentConfig( + builder=sac_builder, + environment_factory=lambda seed: helpers.make_environment(suite, task), + network_factory=network_factory, + seed=FLAGS.seed, + max_num_actor_steps=FLAGS.num_steps) + + +def main(_): + config = build_experiment_config() + if FLAGS.run_distributed: + program = experiments.make_distributed_experiment( + experiment=config, num_actors=4) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment( + experiment=config, + eval_every=FLAGS.eval_every, + num_eval_episodes=FLAGS.evaluation_episodes) + + +if __name__ == '__main__': + app.run(main) diff --git a/acme/examples/baselines/rl_continuous/run_td3.py b/acme/examples/baselines/rl_continuous/run_td3.py new file mode 100644 index 00000000..f8a0a776 --- /dev/null +++ b/acme/examples/baselines/rl_continuous/run_td3.py @@ -0,0 +1,76 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example running SAC on continuous control tasks.""" + +from absl import flags +from acme.agents.jax import td3 +import helpers +from absl import app +from acme.jax import experiments +from acme.utils import lp_utils +import launchpad as lp + + +FLAGS = flags.FLAGS + +flags.DEFINE_bool( + 'run_distributed', False, 'Should an agent be executed in a ' + 'distributed way (the default is a single-threaded agent)') +flags.DEFINE_string('env_name', 'gym:HalfCheetah-v2', 'What environment to run') +flags.DEFINE_integer('seed', 0, 'Random seed.') +flags.DEFINE_integer('num_steps', 1_000_000, 'Number of env steps to run.') +flags.DEFINE_integer('eval_every', 50_000, 'How often to run evaluation.') +flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') + + +def build_experiment_config(): + """Builds TD3 experiment config which can be executed in different ways.""" + # Create an environment, grab the spec, and use it to create networks. + + suite, task = FLAGS.env_name.split(':', 1) + network_factory = ( + lambda spec: td3.make_networks(spec, hidden_layer_sizes=(256, 256, 256))) + + # Construct the agent. + config = td3.TD3Config( + policy_learning_rate=3e-4, + critic_learning_rate=3e-4, + ) + td3_builder = td3.TD3Builder(config) + # pylint:disable=g-long-lambda + return experiments.ExperimentConfig( + builder=td3_builder, + environment_factory=lambda seed: helpers.make_environment(suite, task), + network_factory=network_factory, + seed=FLAGS.seed, + max_num_actor_steps=FLAGS.num_steps) + # pylint:enable=g-long-lambda + + +def main(_): + config = build_experiment_config() + if FLAGS.run_distributed: + program = experiments.make_distributed_experiment( + experiment=config, num_actors=4) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment( + experiment=config, + eval_every=FLAGS.eval_every, + num_eval_episodes=FLAGS.evaluation_episodes) + + +if __name__ == '__main__': + app.run(main) diff --git a/acme/examples/baselines/rl_discrete/helpers.py b/acme/examples/baselines/rl_discrete/helpers.py new file mode 100644 index 00000000..1e1a34ae --- /dev/null +++ b/acme/examples/baselines/rl_discrete/helpers.py @@ -0,0 +1,118 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared helpers for different discrete RL experiment flavours.""" + +import functools +import os +from typing import Union + +from absl import flags +from acme import core +from acme import environment_loop +from acme import specs +from acme import wrappers +from acme.agents.jax import builders +from acme.jax import experiments as experiments_lib +from acme.jax import networks as networks_lib +from acme.jax import types as jax_types +from acme.jax import utils +from acme.utils import counting +from acme.utils import experiment_utils +import atari_py # pylint:disable=unused-import +import dm_env +import gym +import haiku as hk + +FLAGS = flags.FLAGS + + +def make_atari_environment( + level: str = 'Pong', + sticky_actions: bool = True, + zero_discount_on_life_loss: bool = False, + oar_wrapper: bool = False, +) -> dm_env.Environment: + """Loads the Atari environment.""" +# Internal logic. + version = 'v0' if sticky_actions else 'v4' + level_name = f'{level}NoFrameskip-{version}' + env = gym.make(level_name, full_action_space=True) + + wrapper_list = [ + wrappers.GymAtariAdapter, + functools.partial( + wrappers.AtariWrapper, + to_float=True, + max_episode_len=108_000, + zero_discount_on_life_loss=zero_discount_on_life_loss, + ), + wrappers.SinglePrecisionWrapper, + ] + + if oar_wrapper: + # E.g. IMPALA and R2D2 use this particular variant. + wrapper_list.append(wrappers.ObservationActionRewardWrapper) + + return wrappers.wrap_all(env, wrapper_list) + + +def make_atari_evaluator_factory( + level_name: str, + network_factory: experiments_lib.NetworkFactory, + agent_builder: Union[builders.ActorLearnerBuilder, builders.OfflineBuilder], +) -> experiments_lib.EvaluatorFactory: + """Returns an Atari evaluator process.""" + + def evaluator_factory( + random_key: jax_types.PRNGKey, + variable_source: core.VariableSource, + counter: counting.Counter, + make_actor: experiments_lib.MakeActorFn, + ) -> environment_loop.EnvironmentLoop: + """The evaluation process.""" + + environment = make_atari_environment( + level_name, + sticky_actions=False, # Turn off sticky actions for evaluation. + oar_wrapper=True) + environment_spec = specs.make_environment_spec(environment) + networks = network_factory(environment_spec) + policy = agent_builder.make_policy( + networks, environment_spec, evaluation=True) + actor = make_actor(random_key, policy, environment_spec, variable_source) + + # Create logger and counter. + counter = counting.Counter(counter, 'evaluator') + logger = experiment_utils.make_experiment_logger('evaluator', 'actor_steps') + + # Create the run loop and return it. + return environment_loop.EnvironmentLoop(environment, actor, counter, logger) + + return evaluator_factory + + +def make_dqn_atari_network( + environment_spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork: + """Creates networks for training DQN on Atari.""" + def network(inputs): + model = hk.Sequential([ + networks_lib.AtariTorso(), + hk.nets.MLP([512, environment_spec.actions.num_values]), + ]) + return model(inputs) + network_hk = hk.without_apply_rng(hk.transform(network)) + obs = utils.add_batch_dim(utils.zeros_like(environment_spec.observations)) + return networks_lib.FeedForwardNetwork( + init=lambda rng: network_hk.init(rng, obs), apply=network_hk.apply) diff --git a/acme/examples/baselines/rl_discrete/run_dqn.py b/acme/examples/baselines/rl_discrete/run_dqn.py new file mode 100644 index 00000000..5e20b09f --- /dev/null +++ b/acme/examples/baselines/rl_discrete/run_dqn.py @@ -0,0 +1,91 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example running DQN on discrete control tasks.""" + +from absl import flags +from acme import specs +from acme.agents.jax import dqn +from acme.agents.jax.dqn import losses +import helpers +from absl import app +from acme.jax import experiments +from acme.utils import lp_utils +import launchpad as lp + + +FLAGS = flags.FLAGS + +flags.DEFINE_bool( + 'run_distributed', False, 'Should an agent be executed in a ' + 'distributed way (the default is a single-threaded agent)') +flags.DEFINE_string('env_name', 'Pong', 'What environment to run') +flags.DEFINE_integer('seed', 0, 'Random seed.') +flags.DEFINE_integer('num_steps', 1_000_000, 'Number of env steps to run.') + + +def build_experiment_config(): + """Builds DQN experiment config which can be executed in different ways.""" + # Create an environment, grab the spec, and use it to create networks. + env_name = FLAGS.env_name + + def env_factory(seed): + del seed + return helpers.make_atari_environment( + level=env_name, sticky_actions=True, zero_discount_on_life_loss=False) + + environment_spec = specs.make_environment_spec(env_factory(0)) + + # Create network + network = helpers.make_dqn_atari_network(environment_spec) + + # Construct the agent. + config = dqn.DQNConfig( + discount=0.99, + learning_rate=5e-5, + n_step=1, + epsilon=0.01, + target_update_period=2000, + min_replay_size=20_000, + max_replay_size=1_000_000, + samples_per_insert=8, + batch_size=32) + loss_fn = losses.QLearning( + discount=config.discount, max_abs_reward=1.) + + dqn_builder = dqn.DQNBuilder(config, loss_fn=loss_fn) + + return experiments.ExperimentConfig( + builder=dqn_builder, + environment_factory=env_factory, + network_factory=lambda spec: network, + evaluator_factories=[], + seed=FLAGS.seed, + max_num_actor_steps=FLAGS.num_steps) + + +def main(_): + config = build_experiment_config() + # Evaluation is disabled for performance reasons. Set `num_eval_episodes` to + # a positive number and remove `evaluator_factories=[]` to enable it. + if FLAGS.run_distributed: + program = experiments.make_distributed_experiment( + experiment=config, num_actors=4) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment(experiment=config, num_eval_episodes=0) + + +if __name__ == '__main__': + app.run(main) diff --git a/acme/examples/baselines/rl_discrete/run_mdqn.py b/acme/examples/baselines/rl_discrete/run_mdqn.py new file mode 100644 index 00000000..ff4a45bc --- /dev/null +++ b/acme/examples/baselines/rl_discrete/run_mdqn.py @@ -0,0 +1,92 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example running Munchausen-DQN on discrete control tasks.""" + +from absl import flags +from acme import specs +from acme.agents.jax import dqn +from acme.agents.jax.dqn import losses +import helpers +from absl import app +from acme.jax import experiments +from acme.utils import lp_utils +import launchpad as lp + + +FLAGS = flags.FLAGS + +flags.DEFINE_bool( + 'run_distributed', False, 'Should an agent be executed in a ' + 'distributed way (the default is a single-threaded agent)') +flags.DEFINE_string('env_name', 'Pong', 'What environment to run') +flags.DEFINE_integer('seed', 0, 'Random seed.') +flags.DEFINE_integer('num_steps', 1_000_000, 'Number of env steps to run.') + + +def build_experiment_config(): + """Builds MDQN experiment config which can be executed in different ways.""" + # Create an environment, grab the spec, and use it to create networks. + env_name = FLAGS.env_name + + def env_factory(seed): + del seed + return helpers.make_atari_environment( + level=env_name, sticky_actions=True, zero_discount_on_life_loss=False) + + environment_spec = specs.make_environment_spec(env_factory(0)) + + # Create network. + network = helpers.make_dqn_atari_network(environment_spec) + + # Construct the agent. + config = dqn.DQNConfig( + discount=0.99, + learning_rate=5e-5, + n_step=1, + epsilon=0.01, + target_update_period=2000, + min_replay_size=20_000, + max_replay_size=1_000_000, + samples_per_insert=8, + batch_size=32) + loss_fn = losses.MunchausenQLearning( + discount=config.discount, max_abs_reward=1., huber_loss_parameter=1., + entropy_temperature=0.03, munchausen_coefficient=0.9) + + dqn_builder = dqn.DQNBuilder(config, loss_fn=loss_fn) + + return experiments.ExperimentConfig( + builder=dqn_builder, + environment_factory=env_factory, + network_factory=lambda spec: network, + evaluator_factories=[], + seed=FLAGS.seed, + max_num_actor_steps=FLAGS.num_steps) + + +def main(_): + config = build_experiment_config() + # Evaluation is disabled for performance reasons. Set `num_eval_episodes` to + # a positive number and remove `evaluator_factories=[]` to enable it. + if FLAGS.run_distributed: + program = experiments.make_distributed_experiment( + experiment=config, num_actors=4) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment(experiment=config, num_eval_episodes=0) + + +if __name__ == '__main__': + app.run(main) diff --git a/acme/examples/bsuite/run_dqn.py b/acme/examples/bsuite/run_dqn.py new file mode 100644 index 00000000..0a570b44 --- /dev/null +++ b/acme/examples/bsuite/run_dqn.py @@ -0,0 +1,60 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example running DQN on BSuite in a single process.""" + +from absl import app +from absl import flags + +import acme +from acme import specs +from acme import wrappers +from acme.agents.tf import dqn + +import bsuite +import sonnet as snt + +# Bsuite flags +flags.DEFINE_string('bsuite_id', 'deep_sea/0', 'Bsuite id.') +flags.DEFINE_string('results_dir', '/tmp/bsuite', 'CSV results directory.') +flags.DEFINE_boolean('overwrite', False, 'Whether to overwrite csv results.') +FLAGS = flags.FLAGS + + +def main(_): + # Create an environment and grab the spec. + raw_environment = bsuite.load_and_record_to_csv( + bsuite_id=FLAGS.bsuite_id, + results_dir=FLAGS.results_dir, + overwrite=FLAGS.overwrite, + ) + environment = wrappers.SinglePrecisionWrapper(raw_environment) + environment_spec = specs.make_environment_spec(environment) + + network = snt.Sequential([ + snt.Flatten(), + snt.nets.MLP([50, 50, environment_spec.actions.num_values]) + ]) + + # Construct the agent. + agent = dqn.DQN( + environment_spec=environment_spec, network=network) + + # Run the environment loop. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=environment.bsuite_num_episodes) # pytype: disable=attribute-error + + +if __name__ == '__main__': + app.run(main) diff --git a/acme/examples/bsuite/run_impala.py b/acme/examples/bsuite/run_impala.py new file mode 100644 index 00000000..c592e604 --- /dev/null +++ b/acme/examples/bsuite/run_impala.py @@ -0,0 +1,69 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runs IMPALA on bsuite locally.""" + +from absl import app +from absl import flags +import acme +from acme import specs +from acme import wrappers +from acme.agents.tf import impala +from acme.tf import networks +import bsuite +import sonnet as snt + +# Bsuite flags +flags.DEFINE_string('bsuite_id', 'deep_sea/0', 'Bsuite id.') +flags.DEFINE_string('results_dir', '/tmp/bsuite', 'CSV results directory.') +flags.DEFINE_boolean('overwrite', False, 'Whether to overwrite csv results.') +FLAGS = flags.FLAGS + + +def make_network(action_spec: specs.DiscreteArray) -> snt.RNNCore: + return snt.DeepRNN([ + snt.Flatten(), + snt.nets.MLP([50, 50]), + snt.LSTM(20), + networks.PolicyValueHead(action_spec.num_values), + ]) + + +def main(_): + # Create an environment and grab the spec. + raw_environment = bsuite.load_and_record_to_csv( + bsuite_id=FLAGS.bsuite_id, + results_dir=FLAGS.results_dir, + overwrite=FLAGS.overwrite, + ) + environment = wrappers.SinglePrecisionWrapper(raw_environment) + environment_spec = specs.make_environment_spec(environment) + + # Create the networks to optimize. + network = make_network(environment_spec.actions) + + agent = impala.IMPALA( + environment_spec=environment_spec, + network=network, + sequence_length=3, + sequence_period=3, + ) + + # Run the environment loop. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=environment.bsuite_num_episodes) # pytype: disable=attribute-error + + +if __name__ == '__main__': + app.run(main) diff --git a/acme/examples/bsuite/run_mcts.py b/acme/examples/bsuite/run_mcts.py new file mode 100644 index 00000000..dc89b88d --- /dev/null +++ b/acme/examples/bsuite/run_mcts.py @@ -0,0 +1,105 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example running MCTS on BSuite in a single process.""" + +from typing import Tuple + +from absl import app +from absl import flags +import acme +from acme import specs +from acme import wrappers +from acme.agents.tf import mcts +from acme.agents.tf.mcts import models +from acme.agents.tf.mcts.models import mlp +from acme.agents.tf.mcts.models import simulator +from acme.tf import networks +import bsuite +from bsuite.logging import csv_logging +import dm_env +import sonnet as snt + +# Bsuite flags +flags.DEFINE_string('bsuite_id', 'deep_sea/0', 'Bsuite id.') +flags.DEFINE_string('results_dir', '/tmp/bsuite', 'CSV results directory.') +flags.DEFINE_boolean('overwrite', False, 'Whether to overwrite csv results.') +# Agent flags +flags.DEFINE_boolean('simulator', True, 'Simulator or learned model?') +FLAGS = flags.FLAGS + + +def make_env_and_model( + bsuite_id: str, + results_dir: str, + overwrite: bool) -> Tuple[dm_env.Environment, models.Model]: + """Create environment and corresponding model (learned or simulator).""" + raw_env = bsuite.load_from_id(bsuite_id) + if FLAGS.simulator: + model = simulator.Simulator(raw_env) # pytype: disable=attribute-error + else: + model = mlp.MLPModel( + specs.make_environment_spec(raw_env), + replay_capacity=1000, + batch_size=16, + hidden_sizes=(50,), + ) + environment = csv_logging.wrap_environment( + raw_env, bsuite_id, results_dir, overwrite) + environment = wrappers.SinglePrecisionWrapper(environment) + + return environment, model + + +def make_network(action_spec: specs.DiscreteArray) -> snt.Module: + return snt.Sequential([ + snt.Flatten(), + snt.nets.MLP([50, 50]), + networks.PolicyValueHead(action_spec.num_values), + ]) + + +def main(_): + # Create an environment and environment model. + environment, model = make_env_and_model( + bsuite_id=FLAGS.bsuite_id, + results_dir=FLAGS.results_dir, + overwrite=FLAGS.overwrite, + ) + environment_spec = specs.make_environment_spec(environment) + + # Create the network and optimizer. + network = make_network(environment_spec.actions) + optimizer = snt.optimizers.Adam(learning_rate=1e-3) + + # Construct the agent. + agent = mcts.MCTS( + environment_spec=environment_spec, + model=model, + network=network, + optimizer=optimizer, + discount=0.99, + replay_capacity=10000, + n_step=1, + batch_size=16, + num_simulations=50, + ) + + # Run the environment loop. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=environment.bsuite_num_episodes) # pytype: disable=attribute-error + + +if __name__ == '__main__': + app.run(main) diff --git a/acme/examples/multiagent/multigrid/helpers.py b/acme/examples/multiagent/multigrid/helpers.py new file mode 100644 index 00000000..6894b562 --- /dev/null +++ b/acme/examples/multiagent/multigrid/helpers.py @@ -0,0 +1,208 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for multigrid environment.""" + +import functools +from typing import Any, Dict, NamedTuple, Sequence + +from acme import specs +from acme.agents.jax import ppo +from acme.agents.jax.multiagent.decentralized import factories +from acme.jax import networks as networks_lib +from acme.jax import utils as acme_jax_utils +from acme.multiagent import types as ma_types +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +import tensorflow_probability + +tfp = tensorflow_probability.substrates.jax +tfd = tfp.distributions + + +class CategoricalParams(NamedTuple): + """Parameters for a categorical distribution.""" + logits: jnp.ndarray + + +def multigrid_obs_preproc(obs: Dict[str, Any], + conv_filters: int = 8, + conv_kernel: int = 3, + scalar_fc: int = 5, + scalar_name: str = 'direction', + scalar_dim: int = 4) -> jnp.ndarray: + """Conducts preprocessing on 'multigrid' environment dict observations. + + The preprocessing applied here is similar to those in: + https://github.com/google-research/google-research/blob/master/social_rl/multiagent_tfagents/multigrid_networks.py + + Args: + obs: multigrid observation dict, which can include observation inputs such + as 'image', 'position', and a custom additional observation (defined by + scalar_name). + conv_filters: Number of convolution filters. + conv_kernel: Size of the convolution kernel. + scalar_fc: Number of neurons in the fully connected layer processing the + scalar input. + scalar_name: a special observation key, which is set to + `direction` in most multigrid environments (and can be overridden here if + otherwise). + scalar_dim: Highest possible value for the scalar input. Used to convert to + one-hot representation. + + Returns: + out: output observation. + """ + + def _cast_and_scale(x, scale_by=10.0): + if isinstance(x, jnp.ndarray): + x = x.astype(jnp.float32) + return x / scale_by + + outputs = [] + + if 'image' in obs.keys(): + image_preproc = hk.Sequential([ + _cast_and_scale, + hk.Conv2D(output_channels=conv_filters, kernel_shape=conv_kernel), + jax.nn.relu, + hk.Flatten() + ]) + outputs.append(image_preproc(obs['image'])) + + if 'position' in obs.keys(): + position_preproc = hk.Sequential([_cast_and_scale, hk.Linear(scalar_fc)]) + outputs.append(position_preproc(obs['position'])) + + if scalar_name in obs.keys(): + direction_preproc = hk.Sequential([ + functools.partial(jax.nn.one_hot, num_classes=scalar_dim), + hk.Flatten(), + hk.Linear(scalar_fc) + ]) + outputs.append(direction_preproc(obs[scalar_name])) + + out = jnp.concatenate(outputs, axis=-1) + return out + + +def make_multigrid_dqn_networks( + environment_spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork: + """Returns DQN networks used by the agent in the multigrid environment.""" + # Check that multigrid environment is defined with discrete actions, 0-indexed + assert np.issubdtype(environment_spec.actions.dtype, np.integer), ( + 'Expected multigrid environment to have discrete actions with int dtype' + f' but environment_spec.actions.dtype == {environment_spec.actions.dtype}' + ) + assert environment_spec.actions.minimum == 0, ( + 'Expected multigrid environment to have 0-indexed action indices, but' + f' environment_spec.actions.minimum == {environment_spec.actions.minimum}' + ) + num_actions = environment_spec.actions.maximum + 1 + + def network(inputs): + model = hk.Sequential([ + hk.Flatten(), + hk.nets.MLP([50, 50, num_actions]), + ]) + processed_inputs = multigrid_obs_preproc(inputs) + return model(processed_inputs) + + network_hk = hk.without_apply_rng(hk.transform(network)) + dummy_obs = acme_jax_utils.add_batch_dim( + acme_jax_utils.zeros_like(environment_spec.observations)) + + return networks_lib.FeedForwardNetwork( + init=lambda rng: network_hk.init(rng, dummy_obs), apply=network_hk.apply) + + +def make_multigrid_ppo_networks( + environment_spec: specs.EnvironmentSpec, + hidden_layer_sizes: Sequence[int] = (64, 64), +) -> ppo.PPONetworks: + """Returns PPO networks used by the agent in the multigrid environments.""" + + # Check that multigrid environment is defined with discrete actions, 0-indexed + assert np.issubdtype(environment_spec.actions.dtype, np.integer), ( + 'Expected multigrid environment to have discrete actions with int dtype' + f' but environment_spec.actions.dtype == {environment_spec.actions.dtype}' + ) + assert environment_spec.actions.minimum == 0, ( + 'Expected multigrid environment to have 0-indexed action indices, but' + f' environment_spec.actions.minimum == {environment_spec.actions.minimum}' + ) + num_actions = environment_spec.actions.maximum + 1 + + def forward_fn(inputs): + processed_inputs = multigrid_obs_preproc(inputs) + trunk = hk.nets.MLP(hidden_layer_sizes, activation=jnp.tanh) + h = trunk(processed_inputs) + logits = hk.Linear(num_actions)(h) + values = hk.Linear(1)(h) + values = jnp.squeeze(values, axis=-1) + return (CategoricalParams(logits=logits), values) + + # Transform into pure functions. + forward_fn = hk.without_apply_rng(hk.transform(forward_fn)) + + dummy_obs = acme_jax_utils.zeros_like(environment_spec.observations) + dummy_obs = acme_jax_utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. + network = networks_lib.FeedForwardNetwork( + lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply) + return make_categorical_ppo_networks(network) # pylint:disable=undefined-variable + + +def make_categorical_ppo_networks( + network: networks_lib.FeedForwardNetwork) -> ppo.PPONetworks: + """Constructs a PPONetworks for Categorical Policy from FeedForwardNetwork. + + Args: + network: a transformed Haiku network (or equivalent in other libraries) that + takes in observations and returns the action distribution and value. + + Returns: + A PPONetworks instance with pure functions wrapping the input network. + """ + + def log_prob(params: CategoricalParams, action): + return tfd.Categorical(logits=params.logits).log_prob(action) + + def entropy(params: CategoricalParams): + return tfd.Categorical(logits=params.logits).entropy() + + def sample(params: CategoricalParams, key: networks_lib.PRNGKey): + return tfd.Categorical(logits=params.logits).sample(seed=key) + + def sample_eval(params: CategoricalParams, key: networks_lib.PRNGKey): + del key + return tfd.Categorical(logits=params.logits).mode() + + return ppo.PPONetworks( + network=network, + log_prob=log_prob, + entropy=entropy, + sample=sample, + sample_eval=sample_eval) + + +def init_default_multigrid_network( + agent_type: str, + agent_spec: specs.EnvironmentSpec) -> ma_types.Networks: + """Returns default networks for multigrid environment.""" + if agent_type == factories.DefaultSupportedAgent.PPO: + return make_multigrid_ppo_networks(agent_spec) + else: + raise ValueError(f'Unsupported agent type: {agent_type}.') diff --git a/acme/examples/multiagent/multigrid/run_multigrid.py b/acme/examples/multiagent/multigrid/run_multigrid.py new file mode 100644 index 00000000..af50132c --- /dev/null +++ b/acme/examples/multiagent/multigrid/run_multigrid.py @@ -0,0 +1,97 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multiagent multigrid training run example.""" +from absl import flags + +import acme +from acme import specs +from acme.agents.jax.multiagent.decentralized import agents +from acme.agents.jax.multiagent.decentralized import factories +from absl import app +import helpers +from acme.utils import loggers +from acme.wrappers import multigrid_wrapper +import jax + +FLAGS = flags.FLAGS +_NUM_STEPS = flags.DEFINE_integer('num_steps', 10000, + 'Number of env steps to run training for.') +_EVAL_EVERY = flags.DEFINE_integer('eval_every', 1000, + 'How often to run evaluation.') +_ENV_NAME = flags.DEFINE_string('env_name', 'MultiGrid-Empty-5x5-v0', + 'What environment to run.') +_BATCH_SIZE = flags.DEFINE_integer('batch_size', 64, 'Batch size.') +_SEED = flags.DEFINE_integer('seed', 0, 'Random seed.') + + +def main(_): + """Runs multigrid experiment.""" + # Training environment + train_env = multigrid_wrapper.make_multigrid_environment(_ENV_NAME.value) + train_environment_spec = specs.make_environment_spec(train_env) + + agent_types = { + str(i): factories.DefaultSupportedAgent.PPO + for i in range(train_env.num_agents) # pytype: disable=attribute-error + } + # Example of how to set custom sub-agent configurations. + ppo_configs = {'unroll_length': 16, 'num_minibatches': 32, 'num_epochs': 10} + config_overrides = { + k: ppo_configs for k, v in agent_types.items() if v == 'ppo' + } + train_agents, eval_policy_networks = agents.init_decentralized_multiagent( + agent_types=agent_types, + environment_spec=train_environment_spec, + seed=_SEED.value, + batch_size=_BATCH_SIZE.value, + init_network_fn=helpers.init_default_multigrid_network, + config_overrides=config_overrides + ) + + train_loop = acme.EnvironmentLoop( + train_env, + train_agents, + label='train_loop', + logger=loggers.TerminalLogger( + label='trainer', time_delta=1.0)) + + # Evaluation environment + eval_env = multigrid_wrapper.make_multigrid_environment(_ENV_NAME.value) + eval_environment_spec = specs.make_environment_spec(eval_env) + eval_actors = train_agents.builder.make_actor( + random_key=jax.random.PRNGKey(_SEED.value), + policy_networks=eval_policy_networks, + environment_spec=eval_environment_spec, + variable_source=train_agents + ) + eval_loop = acme.EnvironmentLoop( + eval_env, + eval_actors, + label='eval_loop', + logger=loggers.TerminalLogger( + label='evaluator', time_delta=1.0)) + + # Run + assert _NUM_STEPS.value % _EVAL_EVERY.value == 0 + for _ in range(_NUM_STEPS.value // _EVAL_EVERY.value): + eval_loop.run(num_episodes=5) + train_loop.run(num_steps=_EVAL_EVERY.value) + eval_loop.run(num_episodes=5) + + return train_agents + +if __name__ == '__main__': + app.run(main) diff --git a/acme/examples/offline/bc_utils.py b/acme/examples/offline/bc_utils.py new file mode 100644 index 00000000..6de047e7 --- /dev/null +++ b/acme/examples/offline/bc_utils.py @@ -0,0 +1,206 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils for running behavioral cloning. +""" +import functools +import operator +from typing import Callable + +from acme import core +from acme import environment_loop +from acme import specs +from acme import types +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import actors +from acme.agents.jax import bc +from acme.agents.tf.dqfd import bsuite_demonstrations +from acme.jax import networks as networks_lib +from acme.jax import types as jax_types +from acme.jax import utils +from acme.jax import variable_utils +from acme.jax.layouts import offline_distributed_layout +from acme.utils import counting +from acme.utils import loggers +from acme.wrappers import single_precision +import bsuite +import dm_env +import haiku as hk +import jax +import jax.numpy as jnp +from jax.scipy import special +import rlax +import tensorflow as tf +import tree + + +def make_network(spec: specs.EnvironmentSpec) -> bc.BCNetworks: + """Creates networks used by the agent.""" + num_actions = spec.actions.num_values + + def actor_fn(obs, is_training=True, key=None): + # is_training and key allows to utilize train/test dependant modules + # like dropout. + del is_training + del key + mlp = hk.Sequential( + [hk.Flatten(), + hk.nets.MLP([64, 64, num_actions])]) + return mlp(obs) + + policy = hk.without_apply_rng(hk.transform(actor_fn)) + + # Create dummy observations to create network parameters. + dummy_obs = utils.zeros_like(spec.observations) + dummy_obs = utils.add_batch_dim(dummy_obs) + + policy_network = bc.BCPolicyNetwork(lambda key: policy.init(key, dummy_obs), + policy.apply) + + def sample_fn(logits: networks_lib.NetworkOutput, + key: jax_types.PRNGKey) -> networks_lib.Action: + return rlax.epsilon_greedy(epsilon=0.0).sample(key, logits) + + def log_prob(logits: networks_lib.Params, + actions: networks_lib.Action) -> networks_lib.LogProb: + logits_actions = jnp.sum( + jax.nn.one_hot(actions, logits.shape[-1]) * logits, axis=-1) + logits_actions = logits_actions - special.logsumexp(logits, axis=-1) + return logits_actions + + return bc.BCNetworks(policy_network, sample_fn, log_prob) + + +def _n_step_transition_from_episode( + observations: types.NestedTensor, + actions: tf.Tensor, + rewards: tf.Tensor, + discounts: tf.Tensor, n_step: int, + additional_discount: float) -> types.Transition: + """Produce Reverb-like N-step transition from a full episode. + + Observations, actions, rewards and discounts have the same length. This + function will ignore the first reward and discount and the last action. + + Args: + observations: [episode_length, ...] Tensor. + actions: [episode_length, ...] Tensor. + rewards: [episode_length] Tensor. + discounts: [episode_length] Tensor. + n_step: number of steps to squash into a single transition. + additional_discount: discount to use for TD updates. + + Returns: + A types.Transition. + """ + + max_index = tf.shape(rewards)[0] - 1 + first = tf.random.uniform( + shape=(), minval=0, maxval=max_index - 1, dtype=tf.int32) + last = tf.minimum(first + n_step, max_index) + + o_t = tree.map_structure(operator.itemgetter(first), observations) + a_t = tree.map_structure(operator.itemgetter(first), actions) + o_tp1 = tree.map_structure(operator.itemgetter(last), observations) + + # 0, 1, ..., n-1. + discount_range = tf.cast(tf.range(last - first), tf.float32) + # 1, g, ..., g^{n-1}. + additional_discounts = tf.pow(additional_discount, discount_range) + # 1, d_t, d_t * d_{t+1}, ..., d_t * ... * d_{t+n-2}. + discounts = tf.concat([[1.], tf.math.cumprod(discounts[first:last - 1])], 0) + # 1, g * d_t, ..., g^{n-1} * d_t * ... * d_{t+n-2}. + discounts *= additional_discounts + # r_t + g * d_t * r_{t+1} + ... + g^{n-1} * d_t * ... * d_{t+n-2} * r_{t+n-1} + # We have to shift rewards by one so last=max_index corresponds to transitions + # that include the last reward. + r_t = tf.reduce_sum(rewards[first + 1:last + 1] * discounts) + + # g^{n-1} * d_{t} * ... * d_{t+n-1}. + d_t = discounts[-1] + + return types.Transition(o_t, a_t, r_t, d_t, o_tp1) + + +def make_environment(training: bool = True): + del training + env = bsuite.load(experiment_name='deep_sea', kwargs={'size': 10}) + return single_precision.SinglePrecisionWrapper(env) + + +def make_demonstrations(env: dm_env.Environment, + batch_size: int) -> tf.data.Dataset: + """Prepare the dataset of demonstrations.""" + batch_dataset = bsuite_demonstrations.make_dataset(env, stochastic=False) + # Combine with demonstration dataset. + transition = functools.partial( + _n_step_transition_from_episode, n_step=1, additional_discount=1.) + + dataset = batch_dataset.map(transition) + + # Batch and prefetch. + dataset = dataset.batch(batch_size, drop_remainder=True) + dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) + + return dataset + + +def make_actor_evaluator( + environment_factory: Callable[[bool], dm_env.Environment], + evaluator_network: actor_core_lib.FeedForwardPolicy, +) -> offline_distributed_layout.EvaluatorFactory: + """Makes an evaluator that runs the agent on the environment. + + Args: + environment_factory: Function that creates a dm_env. + evaluator_network: Network to be use by the actor. + + Returns: + actor_evaluator: Function that returns a Worker that will be executed + by launchpad. + """ + def actor_evaluator( + random_key: networks_lib.PRNGKey, + variable_source: core.VariableSource, + counter: counting.Counter, + ): + """The evaluation process.""" + # Create the actor loading the weights from variable source. + actor_core = actor_core_lib.batched_feed_forward_to_actor_core( + evaluator_network) + # Inference happens on CPU, so it's better to move variables there too. + variable_client = variable_utils.VariableClient(variable_source, 'policy', + device='cpu') + actor = actors.GenericActor( + actor_core, random_key, variable_client, backend='cpu') + + # Logger. + logger = loggers.make_default_logger( + 'evaluator', steps_key='evaluator_steps') + + # Create environment and evaluator networks + environment = environment_factory(False) + + # Create logger and counter. + counter = counting.Counter(counter, 'evaluator') + + # Create the run loop and return it. + return environment_loop.EnvironmentLoop( + environment, + actor, + counter, + logger, + ) + + return actor_evaluator diff --git a/acme/examples/offline/run_bc.py b/acme/examples/offline/run_bc.py new file mode 100644 index 00000000..9d451801 --- /dev/null +++ b/acme/examples/offline/run_bc.py @@ -0,0 +1,195 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An example BC running on BSuite.""" + +import functools +import operator + +from absl import app +from absl import flags +import acme +from acme import specs +from acme import types +from acme.agents.tf import actors +from acme.agents.tf.bc import learning +from acme.agents.tf.dqfd import bsuite_demonstrations +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +from acme.wrappers import single_precision +import bsuite +import reverb +import sonnet as snt +import tensorflow as tf +import tree +import trfl + +# Bsuite flags +flags.DEFINE_string('bsuite_id', 'deep_sea/0', 'Bsuite id.') +flags.DEFINE_string('results_dir', '/tmp/bsuite', 'CSV results directory.') +flags.DEFINE_boolean('overwrite', False, 'Whether to overwrite csv results.') + +# Agent flags +flags.DEFINE_float('learning_rate', 2e-4, 'Learning rate.') +flags.DEFINE_integer('batch_size', 16, 'Batch size.') +flags.DEFINE_float('epsilon', 0., 'Epsilon for the epsilon greedy in the env.') +flags.DEFINE_integer('evaluate_every', 100, 'Evaluation period.') +flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') + +FLAGS = flags.FLAGS + + +def make_policy_network(action_spec: specs.DiscreteArray) -> snt.Module: + return snt.Sequential([ + snt.Flatten(), + snt.nets.MLP([64, 64, action_spec.num_values]), + ]) + + +# TODO(b/152733199): Move this function to acme utils. +def _n_step_transition_from_episode(observations: types.NestedTensor, + actions: tf.Tensor, rewards: tf.Tensor, + discounts: tf.Tensor, n_step: int, + additional_discount: float): + """Produce Reverb-like N-step transition from a full episode. + + Observations, actions, rewards and discounts have the same length. This + function will ignore the first reward and discount and the last action. + + Args: + observations: [L, ...] Tensor. + actions: [L, ...] Tensor. + rewards: [L] Tensor. + discounts: [L] Tensor. + n_step: number of steps to squash into a single transition. + additional_discount: discount to use for TD updates. + + Returns: + (o_t, a_t, r_t, d_t, o_tp1) tuple. + """ + + max_index = tf.shape(rewards)[0] - 1 + first = tf.random.uniform( + shape=(), minval=0, maxval=max_index - 1, dtype=tf.int32) + last = tf.minimum(first + n_step, max_index) + + o_t = tree.map_structure(operator.itemgetter(first), observations) + a_t = tree.map_structure(operator.itemgetter(first), actions) + o_tp1 = tree.map_structure(operator.itemgetter(last), observations) + + # 0, 1, ..., n-1. + discount_range = tf.cast(tf.range(last - first), tf.float32) + # 1, g, ..., g^{n-1}. + additional_discounts = tf.pow(additional_discount, discount_range) + # 1, d_t, d_t * d_{t+1}, ..., d_t * ... * d_{t+n-2}. + discounts = tf.concat([[1.], tf.math.cumprod(discounts[first:last - 1])], 0) + # 1, g * d_t, ..., g^{n-1} * d_t * ... * d_{t+n-2}. + discounts *= additional_discounts + # r_t + g * d_t * r_{t+1} + ... + g^{n-1} * d_t * ... * d_{t+n-2} * r_{t+n-1} + # We have to shift rewards by one so last=max_index corresponds to transitions + # that include the last reward. + r_t = tf.reduce_sum(rewards[first + 1:last + 1] * discounts) + + # g^{n-1} * d_{t} * ... * d_{t+n-1}. + d_t = discounts[-1] + + # Reverb requires every sample to be given a key and priority. + # In the supervised learning case for BC, neither of those will be used. + # We set the key to `0` and the priorities probabilities to `1`, but that + # should not matter much. + key = tf.constant(0, tf.uint64) + probability = tf.constant(1.0, tf.float64) + table_size = tf.constant(1, tf.int64) + priority = tf.constant(1.0, tf.float64) + times_sampled = tf.constant(1, tf.int32) + info = reverb.SampleInfo( + key=key, + probability=probability, + table_size=table_size, + priority=priority, + times_sampled=times_sampled, + ) + + return reverb.ReplaySample(info=info, data=(o_t, a_t, r_t, d_t, o_tp1)) + + +def main(_): + # Create an environment and grab the spec. + raw_environment = bsuite.load_and_record_to_csv( + bsuite_id=FLAGS.bsuite_id, + results_dir=FLAGS.results_dir, + overwrite=FLAGS.overwrite, + ) + environment = single_precision.SinglePrecisionWrapper(raw_environment) + environment_spec = specs.make_environment_spec(environment) + + # Build demonstration dataset. + if hasattr(raw_environment, 'raw_env'): + raw_environment = raw_environment.raw_env + + batch_dataset = bsuite_demonstrations.make_dataset(raw_environment, + stochastic=False) + # Combine with demonstration dataset. + transition = functools.partial( + _n_step_transition_from_episode, n_step=1, additional_discount=1.) + + dataset = batch_dataset.map(transition) + + # Batch and prefetch. + dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True) + dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) + + # Create the networks to optimize. + policy_network = make_policy_network(environment_spec.actions) + + # If the agent is non-autoregressive use epsilon=0 which will be a greedy + # policy. + evaluator_network = snt.Sequential([ + policy_network, + lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(), + ]) + + # Ensure that we create the variables before proceeding (maybe not needed). + tf2_utils.create_variables(policy_network, [environment_spec.observations]) + + counter = counting.Counter() + learner_counter = counting.Counter(counter, prefix='learner') + + # Create the actor which defines how we take actions. + evaluation_network = actors.FeedForwardActor(evaluator_network) + + eval_loop = acme.EnvironmentLoop( + environment=environment, + actor=evaluation_network, + counter=counter, + logger=loggers.TerminalLogger('evaluation', time_delta=1.)) + + # The learner updates the parameters (and initializes them). + learner = learning.BCLearner( + network=policy_network, + learning_rate=FLAGS.learning_rate, + dataset=dataset, + counter=learner_counter) + + # Run the environment loop. + while True: + for _ in range(FLAGS.evaluate_every): + learner.step() + learner_counter.increment(learner_steps=FLAGS.evaluate_every) + eval_loop.run(FLAGS.evaluation_episodes) + + +if __name__ == '__main__': + app.run(main) diff --git a/acme/examples/offline/run_bc_jax.py b/acme/examples/offline/run_bc_jax.py new file mode 100644 index 00000000..f902b23b --- /dev/null +++ b/acme/examples/offline/run_bc_jax.py @@ -0,0 +1,98 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An example BC running on BSuite.""" + +from absl import app +from absl import flags +import acme +from acme import specs +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import actors +from acme.agents.jax import bc +from acme.examples.offline import bc_utils +from acme.jax import utils +from acme.jax import variable_utils +from acme.utils import loggers +import haiku as hk +import jax +import jax.numpy as jnp +import optax +import rlax + +# Agent flags +flags.DEFINE_float('learning_rate', 1e-3, 'Learning rate.') +flags.DEFINE_integer('batch_size', 64, 'Batch size.') +flags.DEFINE_float('evaluation_epsilon', 0., + 'Epsilon for the epsilon greedy in the evaluation agent.') +flags.DEFINE_integer('evaluate_every', 20, 'Evaluation period.') +flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') +flags.DEFINE_integer('seed', 0, 'Random seed for learner and evaluator.') + +FLAGS = flags.FLAGS + + +def main(_): + # Create an environment and grab the spec. + environment = bc_utils.make_environment() + environment_spec = specs.make_environment_spec(environment) + + # Unwrap the environment to get the demonstrations. + dataset = bc_utils.make_demonstrations(environment.environment, + FLAGS.batch_size) + dataset = dataset.as_numpy_iterator() + + # Create the networks to optimize. + bc_networks = bc_utils.make_network(environment_spec) + + key = jax.random.PRNGKey(FLAGS.seed) + key, key1 = jax.random.split(key, 2) + + loss_fn = bc.logp() + + learner = bc.BCLearner( + networks=bc_networks, + random_key=key1, + loss_fn=loss_fn, + optimizer=optax.adam(FLAGS.learning_rate), + prefetching_iterator=utils.sharded_prefetch(dataset), + num_sgd_steps_per_step=1) + + def evaluator_network(params: hk.Params, key: jnp.DeviceArray, + observation: jnp.DeviceArray) -> jnp.DeviceArray: + dist_params = bc_networks.policy_network.apply(params, observation) + return rlax.epsilon_greedy(FLAGS.evaluation_epsilon).sample( + key, dist_params) + + actor_core = actor_core_lib.batched_feed_forward_to_actor_core( + evaluator_network) + variable_client = variable_utils.VariableClient( + learner, 'policy', device='cpu') + evaluator = actors.GenericActor( + actor_core, key, variable_client, backend='cpu') + + eval_loop = acme.EnvironmentLoop( + environment=environment, + actor=evaluator, + logger=loggers.TerminalLogger('evaluation', time_delta=0.)) + + # Run the environment loop. + while True: + for _ in range(FLAGS.evaluate_every): + learner.step() + eval_loop.run(FLAGS.evaluation_episodes) + + +if __name__ == '__main__': + app.run(main) diff --git a/acme/examples/offline/run_bcq.py b/acme/examples/offline/run_bcq.py new file mode 100644 index 00000000..3c2eefa8 --- /dev/null +++ b/acme/examples/offline/run_bcq.py @@ -0,0 +1,144 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Run BCQ offline agent on Atari RL Unplugged datasets. + +Instructions: + +1 - Download dataset: +> gsutil cp gs://rl_unplugged/atari/Pong/run_1-00000-of-00100 \ + /tmp/dataset/Pong/run_1-00000-of-00001 + +2 - Install RL Unplugged dependencies: +https://github.com/deepmind/deepmind-research/tree/master/rl_unplugged#running-the-code + +3 - Download RL Unplugged library: +> git clone https://github.com/deepmind/deepmind-research.git deepmind_research + +4 - Run script: +> python -m run_atari_bcq --dataset_path=/tmp/dataset --game=Pong --run=1 \ + --num_shards=1 +""" + +from absl import app +from absl import flags +import acme +from acme import specs +from acme.agents.tf import actors +from acme.agents.tf import bcq +from acme.tf import networks +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import sonnet as snt +import tensorflow as tf +import trfl + +from deepmind_research.rl_unplugged import atari # type: ignore + +# Atari dataset flags +flags.DEFINE_string('dataset_path', None, 'Dataset path.') +flags.DEFINE_string('game', 'Pong', 'Dataset path.') +flags.DEFINE_integer('run', 1, 'Dataset path.') +flags.DEFINE_integer('num_shards', 100, 'Number of dataset shards.') +flags.DEFINE_integer('batch_size', 16, 'Batch size.') + +# Agent flags +flags.DEFINE_float('bcq_threshold', 0.5, 'BCQ threshold.') +flags.DEFINE_float('learning_rate', 1e-4, 'Learning rate.') +flags.DEFINE_float('discount', 0.99, 'Discount.') +flags.DEFINE_float('importance_sampling_exponent', 0.2, + 'Importance sampling exponent.') +flags.DEFINE_integer('target_update_period', 2500, + ('Number of learner steps to perform before updating' + 'the target networks.')) + +# Evaluation flags. +flags.DEFINE_float('epsilon', 0., 'Epsilon for the epsilon greedy in the env.') +flags.DEFINE_integer('evaluate_every', 100, 'Evaluation period.') +flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') + +FLAGS = flags.FLAGS + + +def make_network(action_spec: specs.DiscreteArray) -> snt.Module: + return snt.Sequential([ + lambda x: tf.image.convert_image_dtype(x, tf.float32), + networks.DQNAtariNetwork(action_spec.num_values) + ]) + + +def main(_): + # Create an environment and grab the spec. + environment = atari.environment(FLAGS.game) + environment_spec = specs.make_environment_spec(environment) + + # Create dataset. + dataset = atari.dataset(path=FLAGS.dataset_path, + game=FLAGS.game, + run=FLAGS.run, + num_shards=FLAGS.num_shards) + # Discard extra inputs + dataset = dataset.map(lambda x: x._replace(data=x.data[:5])) + + # Batch and prefetch. + dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True) + dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) + + # Build network. + g_network = make_network(environment_spec.actions) + q_network = make_network(environment_spec.actions) + network = networks.DiscreteFilteredQNetwork(g_network=g_network, + q_network=q_network, + threshold=FLAGS.bcq_threshold) + tf2_utils.create_variables(network, [environment_spec.observations]) + + evaluator_network = snt.Sequential([ + q_network, + lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(), + ]) + + # Counters. + counter = counting.Counter() + learner_counter = counting.Counter(counter, prefix='learner') + + # Create the actor which defines how we take actions. + evaluation_network = actors.FeedForwardActor(evaluator_network) + + eval_loop = acme.EnvironmentLoop( + environment=environment, + actor=evaluation_network, + counter=counter, + logger=loggers.TerminalLogger('evaluation', time_delta=1.)) + + # The learner updates the parameters (and initializes them). + learner = bcq.DiscreteBCQLearner( + network=network, + dataset=dataset, + learning_rate=FLAGS.learning_rate, + discount=FLAGS.discount, + importance_sampling_exponent=FLAGS.importance_sampling_exponent, + target_update_period=FLAGS.target_update_period, + counter=counter) + + # Run the environment loop. + while True: + for _ in range(FLAGS.evaluate_every): + learner.step() + learner_counter.increment(learner_steps=FLAGS.evaluate_every) + eval_loop.run(FLAGS.evaluation_episodes) + + +if __name__ == '__main__': + app.run(main) diff --git a/acme/examples/offline/run_cql_jax.py b/acme/examples/offline/run_cql_jax.py new file mode 100644 index 00000000..44c2c5a0 --- /dev/null +++ b/acme/examples/offline/run_cql_jax.py @@ -0,0 +1,114 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An example CQL running on locomotion datasets (mujoco) from D4rl.""" + +from absl import app +from absl import flags +import acme +from acme import specs +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import actors +from acme.agents.jax import cql +from acme.datasets import tfds +from acme.examples.offline import helpers as gym_helpers +from acme.jax import variable_utils +from acme.utils import loggers +import haiku as hk +import jax +import jax.numpy as jnp +import optax + +# Agent flags +flags.DEFINE_integer('batch_size', 64, 'Batch size.') +flags.DEFINE_integer('evaluate_every', 20, 'Evaluation period.') +flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') +flags.DEFINE_integer( + 'num_demonstrations', 10, + 'Number of demonstration episodes to load from the dataset. If None, loads the full dataset.' +) +flags.DEFINE_integer('seed', 0, 'Random seed for learner and evaluator.') +# CQL specific flags. +flags.DEFINE_float('policy_learning_rate', 3e-5, 'Policy learning rate.') +flags.DEFINE_float('critic_learning_rate', 3e-4, 'Critic learning rate.') +flags.DEFINE_float('fixed_cql_coefficient', None, + 'Fixed CQL coefficient. If None, an adaptive one is used.') +flags.DEFINE_float('cql_lagrange_threshold', 10., + 'Lagrange threshold for the adaptive CQL coefficient.') +# Environment flags. +flags.DEFINE_string('env_name', 'HalfCheetah-v2', + 'Gym mujoco environment name.') +flags.DEFINE_string( + 'dataset_name', 'd4rl_mujoco_halfcheetah/v2-medium', + 'D4rl dataset name. Can be any locomotion dataset from ' + 'https://www.tensorflow.org/datasets/catalog/overview#d4rl.') + +FLAGS = flags.FLAGS + + +def main(_): + key = jax.random.PRNGKey(FLAGS.seed) + key_demonstrations, key_learner = jax.random.split(key, 2) + + # Create an environment and grab the spec. + environment = gym_helpers.make_environment(task=FLAGS.env_name) + environment_spec = specs.make_environment_spec(environment) + + # Get a demonstrations dataset. + transitions_iterator = tfds.get_tfds_dataset(FLAGS.dataset_name, + FLAGS.num_demonstrations) + demonstrations = tfds.JaxInMemoryRandomSampleIterator( + transitions_iterator, key=key_demonstrations, batch_size=FLAGS.batch_size) + + # Create the networks to optimize. + networks = cql.make_networks(environment_spec) + + # Create the learner. + learner = cql.CQLLearner( + batch_size=FLAGS.batch_size, + networks=networks, + random_key=key_learner, + policy_optimizer=optax.adam(FLAGS.policy_learning_rate), + critic_optimizer=optax.adam(FLAGS.critic_learning_rate), + fixed_cql_coefficient=FLAGS.fixed_cql_coefficient, + cql_lagrange_threshold=FLAGS.cql_lagrange_threshold, + demonstrations=demonstrations, + num_sgd_steps_per_step=1) + + def evaluator_network(params: hk.Params, key: jnp.DeviceArray, + observation: jnp.DeviceArray) -> jnp.DeviceArray: + dist_params = networks.policy_network.apply(params, observation) + return networks.sample_eval(dist_params, key) + + actor_core = actor_core_lib.batched_feed_forward_to_actor_core( + evaluator_network) + variable_client = variable_utils.VariableClient( + learner, 'policy', device='cpu') + evaluator = actors.GenericActor( + actor_core, key, variable_client, backend='cpu') + + eval_loop = acme.EnvironmentLoop( + environment=environment, + actor=evaluator, + logger=loggers.TerminalLogger('evaluation', time_delta=0.)) + + # Run the environment loop. + while True: + for _ in range(FLAGS.evaluate_every): + learner.step() + eval_loop.run(FLAGS.evaluation_episodes) + + +if __name__ == '__main__': + app.run(main) diff --git a/acme/examples/offline/run_crr_jax.py b/acme/examples/offline/run_crr_jax.py new file mode 100644 index 00000000..8893be6a --- /dev/null +++ b/acme/examples/offline/run_crr_jax.py @@ -0,0 +1,137 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An example CRR running on locomotion datasets (mujoco) from D4rl.""" + +from absl import app +from absl import flags +import acme +from acme import specs +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import actors +from acme.agents.jax import crr +from acme.datasets import tfds +from acme.examples.offline import helpers as gym_helpers +from acme.jax import variable_utils +from acme.types import Transition +from acme.utils import loggers +import haiku as hk +import jax +import jax.numpy as jnp +import optax +import rlds + +# Agent flags +flags.DEFINE_integer('batch_size', 64, 'Batch size.') +flags.DEFINE_integer('evaluate_every', 20, 'Evaluation period.') +flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') +flags.DEFINE_integer( + 'num_demonstrations', 10, + 'Number of demonstration episodes to load from the dataset. If None, loads the full dataset.' +) +flags.DEFINE_integer('seed', 0, 'Random seed for learner and evaluator.') +# CQL specific flags. +flags.DEFINE_float('policy_learning_rate', 3e-5, 'Policy learning rate.') +flags.DEFINE_float('critic_learning_rate', 3e-4, 'Critic learning rate.') +flags.DEFINE_float('discount', 0.99, 'Discount.') +flags.DEFINE_integer('target_update_period', 100, 'Target update periode.') +flags.DEFINE_integer('grad_updates_per_batch', 1, 'Grad updates per batch.') +flags.DEFINE_bool( + 'use_sarsa_target', True, + 'Compute on-policy target using iterator actions rather than sampled ' + 'actions.' +) +# Environment flags. +flags.DEFINE_string('env_name', 'HalfCheetah-v2', + 'Gym mujoco environment name.') +flags.DEFINE_string( + 'dataset_name', 'd4rl_mujoco_halfcheetah/v2-medium', + 'D4rl dataset name. Can be any locomotion dataset from ' + 'https://www.tensorflow.org/datasets/catalog/overview#d4rl.') + +FLAGS = flags.FLAGS + + +def _add_next_action_extras(double_transitions: Transition) -> Transition: + return Transition( + observation=double_transitions.observation[0], + action=double_transitions.action[0], + reward=double_transitions.reward[0], + discount=double_transitions.discount[0], + next_observation=double_transitions.next_observation[0], + extras={'next_action': double_transitions.action[1]}) + + +def main(_): + key = jax.random.PRNGKey(FLAGS.seed) + key_demonstrations, key_learner = jax.random.split(key, 2) + + # Create an environment and grab the spec. + environment = gym_helpers.make_environment(task=FLAGS.env_name) + environment_spec = specs.make_environment_spec(environment) + + # Get a demonstrations dataset with next_actions extra. + transitions = tfds.get_tfds_dataset( + FLAGS.dataset_name, FLAGS.num_demonstrations) + double_transitions = rlds.transformations.batch( + transitions, size=2, shift=1, drop_remainder=True) + transitions = double_transitions.map(_add_next_action_extras) + demonstrations = tfds.JaxInMemoryRandomSampleIterator( + transitions, key=key_demonstrations, batch_size=FLAGS.batch_size) + + # Create the networks to optimize. + networks = crr.make_networks(environment_spec) + + # CRR policy loss function. + policy_loss_coeff_fn = crr.policy_loss_coeff_advantage_exp + + # Create the learner. + learner = crr.CRRLearner( + networks=networks, + random_key=key_learner, + discount=FLAGS.discount, + target_update_period=FLAGS.target_update_period, + policy_loss_coeff_fn=policy_loss_coeff_fn, + iterator=demonstrations, + policy_optimizer=optax.adam(FLAGS.policy_learning_rate), + critic_optimizer=optax.adam(FLAGS.critic_learning_rate), + grad_updates_per_batch=FLAGS.grad_updates_per_batch, + use_sarsa_target=FLAGS.use_sarsa_target) + + def evaluator_network(params: hk.Params, key: jnp.DeviceArray, + observation: jnp.DeviceArray) -> jnp.DeviceArray: + dist_params = networks.policy_network.apply(params, observation) + return networks.sample_eval(dist_params, key) + + actor_core = actor_core_lib.batched_feed_forward_to_actor_core( + evaluator_network) + variable_client = variable_utils.VariableClient( + learner, 'policy', device='cpu') + evaluator = actors.GenericActor( + actor_core, key, variable_client, backend='cpu') + + eval_loop = acme.EnvironmentLoop( + environment=environment, + actor=evaluator, + logger=loggers.TerminalLogger('evaluation', time_delta=0.)) + + # Run the environment loop. + while True: + for _ in range(FLAGS.evaluate_every): + learner.step() + eval_loop.run(FLAGS.evaluation_episodes) + + +if __name__ == '__main__': + app.run(main) diff --git a/acme/examples/offline/run_dqfd.py b/acme/examples/offline/run_dqfd.py new file mode 100644 index 00000000..d5b3a1d0 --- /dev/null +++ b/acme/examples/offline/run_dqfd.py @@ -0,0 +1,83 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example running DQfD on BSuite in a single process. +""" + +from absl import app +from absl import flags + +import acme +from acme import specs +from acme import wrappers +from acme.agents.tf import dqfd +from acme.agents.tf.dqfd import bsuite_demonstrations + +import bsuite +import sonnet as snt + + +# Bsuite flags +flags.DEFINE_string('bsuite_id', 'deep_sea/0', 'Bsuite id.') +flags.DEFINE_string('results_dir', '/tmp/bsuite', 'CSV results directory.') +flags.DEFINE_boolean('overwrite', False, 'Whether to overwrite csv results.') + +# Agent flags +flags.DEFINE_float('demonstration_ratio', 0.5, + ('Proportion of demonstration transitions in the replay ' + 'buffer.')) +flags.DEFINE_integer('n_step', 5, + ('Number of steps to squash into a single transition.')) +flags.DEFINE_float('samples_per_insert', 8, + ('Number of samples to take from replay for every insert ' + 'that is made.')) +flags.DEFINE_float('learning_rate', 1e-3, 'Learning rate.') + +FLAGS = flags.FLAGS + + +def make_network(action_spec: specs.DiscreteArray) -> snt.Module: + return snt.Sequential([ + snt.Flatten(), + snt.nets.MLP([50, 50, action_spec.num_values]), + ]) + + +def main(_): + # Create an environment and grab the spec. + raw_environment = bsuite.load_and_record_to_csv( + bsuite_id=FLAGS.bsuite_id, + results_dir=FLAGS.results_dir, + overwrite=FLAGS.overwrite, + ) + environment = wrappers.SinglePrecisionWrapper(raw_environment) + environment_spec = specs.make_environment_spec(environment) + + # Construct the agent. + agent = dqfd.DQfD( + environment_spec=environment_spec, + network=make_network(environment_spec.actions), + demonstration_dataset=bsuite_demonstrations.make_dataset( + raw_environment, stochastic=False), + demonstration_ratio=FLAGS.demonstration_ratio, + samples_per_insert=FLAGS.samples_per_insert, + learning_rate=FLAGS.learning_rate) + + # Run the environment loop. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=environment.bsuite_num_episodes) # pytype: disable=attribute-error + + +if __name__ == '__main__': + app.run(main) diff --git a/acme/examples/offline/run_mbop_jax.py b/acme/examples/offline/run_mbop_jax.py new file mode 100644 index 00000000..62a58cfd --- /dev/null +++ b/acme/examples/offline/run_mbop_jax.py @@ -0,0 +1,135 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An example running MBOP on D4RL dataset.""" + +import functools + +from absl import app +from absl import flags +import acme +from acme import specs +from acme.agents.jax import mbop +from acme.datasets import tfds +from acme.examples.offline import helpers as gym_helpers +from acme.jax import running_statistics +from acme.utils import loggers +import jax +import optax +import tensorflow_datasets + +# Training flags. +_NUM_NETWORKS = flags.DEFINE_integer('num_networks', 10, + 'Number of ensemble networks.') +_LEARNING_RATE = flags.DEFINE_float('learning_rate', 1e-3, 'Learning rate.') +_BATCH_SIZE = flags.DEFINE_integer('batch_size', 64, 'Batch size.') +_HIDDEN_LAYER_SIZES = flags.DEFINE_multi_integer('hidden_layer_sizes', [64, 64], + 'Sizes of the hidden layers.') +_NUM_SGD_STEPS_PER_STEP = flags.DEFINE_integer( + 'num_sgd_steps_per_step', 1, + 'Denotes how many gradient updates perform per one learner step.') +_NUM_NORMALIZATION_BATCHES = flags.DEFINE_integer( + 'num_normalization_batches', 50, + 'Number of batches used for calculating the normalization statistics.') +_EVALUATE_EVERY = flags.DEFINE_integer('evaluate_every', 20, + 'Evaluation period.') +_EVALUATION_EPISODES = flags.DEFINE_integer('evaluation_episodes', 10, + 'Evaluation episodes.') +_SEED = flags.DEFINE_integer('seed', 0, + 'Random seed for learner and evaluator.') + +# Environment flags. +_ENV_NAME = flags.DEFINE_string('env_name', 'HalfCheetah-v2', + 'Gym mujoco environment name.') +_DATASET_NAME = flags.DEFINE_string( + 'dataset_name', 'd4rl_mujoco_halfcheetah/v2-medium', + 'D4rl dataset name. Can be any locomotion dataset from ' + 'https://www.tensorflow.org/datasets/catalog/overview#d4rl.') + + +def main(_): + # Create an environment and grab the spec. + environment = gym_helpers.make_environment(task=_ENV_NAME.value) + spec = specs.make_environment_spec(environment) + + key = jax.random.PRNGKey(_SEED.value) + key, dataset_key, evaluator_key = jax.random.split(key, 3) + + # Load the dataset. + dataset = tensorflow_datasets.load(_DATASET_NAME.value)['train'] + # Unwrap the environment to get the demonstrations. + dataset = mbop.episodes_to_timestep_batched_transitions( + dataset, return_horizon=10) + dataset = tfds.JaxInMemoryRandomSampleIterator( + dataset, key=dataset_key, batch_size=_BATCH_SIZE.value) + + # Apply normalization to the dataset. + mean_std = mbop.get_normalization_stats(dataset, + _NUM_NORMALIZATION_BATCHES.value) + apply_normalization = jax.jit( + functools.partial(running_statistics.normalize, mean_std=mean_std)) + dataset = (apply_normalization(sample) for sample in dataset) + + # Create the networks. + networks = mbop.make_networks( + spec, hidden_layer_sizes=tuple(_HIDDEN_LAYER_SIZES.value)) + + # Use the default losses. + losses = mbop.MBOPLosses() + + def logger_fn(label: str, steps_key: str): + return loggers.make_default_logger(label, steps_key=steps_key) + + def make_learner(name, logger_fn, counter, rng_key, dataset, network, loss): + return mbop.make_ensemble_regressor_learner( + name, + _NUM_NETWORKS.value, + logger_fn, + counter, + rng_key, + dataset, + network, + loss, + optax.adam(_LEARNING_RATE.value), + _NUM_SGD_STEPS_PER_STEP.value, + ) + + learner = mbop.MBOPLearner(networks, losses, dataset, key, logger_fn, + functools.partial(make_learner, 'world_model'), + functools.partial(make_learner, 'policy_prior'), + functools.partial(make_learner, 'n_step_return')) + + planning_config = mbop.MPPIConfig() + + assert planning_config.n_trajectories % _NUM_NETWORKS.value == 0, ( + 'Number of trajectories must be a multiple of the number of networks.') + + actor_core = mbop.make_ensemble_actor_core( + networks, planning_config, spec, mean_std, use_round_robin=False) + evaluator = mbop.make_actor(actor_core, evaluator_key, learner) + + eval_loop = acme.EnvironmentLoop( + environment=environment, + actor=evaluator, + logger=loggers.TerminalLogger('evaluation', time_delta=0.)) + + # Train the agent. + while True: + for _ in range(_EVALUATE_EVERY.value): + learner.step() + eval_loop.run(_EVALUATION_EPISODES.value) + + +if __name__ == '__main__': + app.run(main) diff --git a/acme/examples/offline/run_offline_td3_jax.py b/acme/examples/offline/run_offline_td3_jax.py new file mode 100644 index 00000000..ea68534a --- /dev/null +++ b/acme/examples/offline/run_offline_td3_jax.py @@ -0,0 +1,144 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An example offline TD3 running on locomotion datasets (mujoco) from D4rl.""" + +from absl import app +from absl import flags +import acme +from acme import specs +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import actors +from acme.agents.jax import td3 +from acme.datasets import tfds +from acme.examples.offline import helpers as gym_helpers +from acme.jax import variable_utils +from acme.types import Transition +from acme.utils import loggers +import haiku as hk +import jax +import jax.numpy as jnp +import optax +import reverb +import rlds +import tensorflow as tf +import tree + +# Agent flags +flags.DEFINE_integer('batch_size', 64, 'Batch size.') +flags.DEFINE_integer('evaluate_every', 20, 'Evaluation period.') +flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') +flags.DEFINE_integer( + 'num_demonstrations', 10, + 'Number of demonstration episodes to load from the dataset. If None, loads the full dataset.' +) +flags.DEFINE_integer('seed', 0, 'Random seed for learner and evaluator.') +# TD3 specific flags. +flags.DEFINE_float('discount', 0.99, 'Discount.') +flags.DEFINE_float('policy_learning_rate', 3e-4, 'Policy learning rate.') +flags.DEFINE_float('critic_learning_rate', 3e-4, 'Critic learning rate.') +flags.DEFINE_float('bc_alpha', 2.5, + 'Add a bc regularization term to the policy loss.' + 'If set to None, TD3 is run without bc regularisation.') +flags.DEFINE_bool( + 'use_sarsa_target', True, + 'Compute on-policy target using iterator actions rather than sampled ' + 'actions.' +) +# Environment flags. +flags.DEFINE_string('env_name', 'HalfCheetah-v2', + 'Gym mujoco environment name.') +flags.DEFINE_string( + 'dataset_name', 'd4rl_mujoco_halfcheetah/v2-medium', + 'D4rl dataset name. Can be any locomotion dataset from ' + 'https://www.tensorflow.org/datasets/catalog/overview#d4rl.') + +FLAGS = flags.FLAGS + + +def _add_next_action_extras(double_transitions: Transition + ) -> reverb.ReplaySample: + # As TD3 is online by default, it expects an iterator over replay samples. + info = tree.map_structure(lambda dtype: tf.ones([], dtype), + reverb.SampleInfo.tf_dtypes()) + return reverb.ReplaySample( + info=info, + data=Transition( + observation=double_transitions.observation[0], + action=double_transitions.action[0], + reward=double_transitions.reward[0], + discount=double_transitions.discount[0], + next_observation=double_transitions.next_observation[0], + extras={'next_action': double_transitions.action[1]})) + + +def main(_): + key = jax.random.PRNGKey(FLAGS.seed) + key_demonstrations, key_learner = jax.random.split(key, 2) + + # Create an environment and grab the spec. + environment = gym_helpers.make_environment(task=FLAGS.env_name) + environment_spec = specs.make_environment_spec(environment) + + # Get a demonstrations dataset with next_actions extra. + transitions = tfds.get_tfds_dataset( + FLAGS.dataset_name, FLAGS.num_demonstrations) + double_transitions = rlds.transformations.batch( + transitions, size=2, shift=1, drop_remainder=True) + transitions = double_transitions.map(_add_next_action_extras) + demonstrations = tfds.JaxInMemoryRandomSampleIterator( + transitions, key=key_demonstrations, batch_size=FLAGS.batch_size) + + # Create the networks to optimize. + networks = td3.make_networks(environment_spec) + + # Create the learner. + learner = td3.TD3Learner( + networks=networks, + random_key=key_learner, + discount=FLAGS.discount, + iterator=demonstrations, + policy_optimizer=optax.adam(FLAGS.policy_learning_rate), + critic_optimizer=optax.adam(FLAGS.critic_learning_rate), + twin_critic_optimizer=optax.adam(FLAGS.critic_learning_rate), + use_sarsa_target=FLAGS.use_sarsa_target, + bc_alpha=FLAGS.bc_alpha, + num_sgd_steps_per_step=1) + + def evaluator_network(params: hk.Params, key: jnp.DeviceArray, + observation: jnp.DeviceArray) -> jnp.DeviceArray: + del key + return networks.policy_network.apply(params, observation) + + actor_core = actor_core_lib.batched_feed_forward_to_actor_core( + evaluator_network) + variable_client = variable_utils.VariableClient( + learner, 'policy', device='cpu') + evaluator = actors.GenericActor( + actor_core, key, variable_client, backend='cpu') + + eval_loop = acme.EnvironmentLoop( + environment=environment, + actor=evaluator, + logger=loggers.TerminalLogger('evaluation', time_delta=0.)) + + # Run the environment loop. + while True: + for _ in range(FLAGS.evaluate_every): + learner.step() + eval_loop.run(FLAGS.evaluation_episodes) + + +if __name__ == '__main__': + app.run(main) diff --git a/acme/examples/open_spiel/run_dqn.py b/acme/examples/open_spiel/run_dqn.py new file mode 100644 index 00000000..5264b356 --- /dev/null +++ b/acme/examples/open_spiel/run_dqn.py @@ -0,0 +1,75 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example running DQN on OpenSpiel game in a single process.""" + +from absl import app +from absl import flags + +import acme +from acme import wrappers +from acme.agents.tf import dqn +from acme.environment_loops import open_spiel_environment_loop +from acme.tf.networks import legal_actions +from acme.wrappers import open_spiel_wrapper +import sonnet as snt + +from open_spiel.python import rl_environment + +flags.DEFINE_string('game', 'tic_tac_toe', 'Name of the game') +flags.DEFINE_integer('num_players', None, 'Number of players') + +FLAGS = flags.FLAGS + + +def main(_): + # Create an environment and grab the spec. + env_configs = {'players': FLAGS.num_players} if FLAGS.num_players else {} + raw_environment = rl_environment.Environment(FLAGS.game, **env_configs) + + environment = open_spiel_wrapper.OpenSpielWrapper(raw_environment) + environment = wrappers.SinglePrecisionWrapper(environment) # type: open_spiel_wrapper.OpenSpielWrapper # pytype: disable=annotation-type-mismatch + environment_spec = acme.make_environment_spec(environment) + + # Build the networks. + networks = [] + policy_networks = [] + for _ in range(environment.num_players): + network = legal_actions.MaskedSequential([ + snt.Flatten(), + snt.nets.MLP([50, 50, environment_spec.actions.num_values]) + ]) + policy_network = snt.Sequential( + [network, + legal_actions.EpsilonGreedy(epsilon=0.1, threshold=-1e8)]) + networks.append(network) + policy_networks.append(policy_network) + + # Construct the agents. + agents = [] + + for network, policy_network in zip(networks, policy_networks): + agents.append( + dqn.DQN(environment_spec=environment_spec, + network=network, + policy_network=policy_network)) + + # Run the environment loop. + loop = open_spiel_environment_loop.OpenSpielEnvironmentLoop( + environment, agents) + loop.run(num_episodes=100000) + + +if __name__ == '__main__': + app.run(main) diff --git a/acme/examples/quickstart.ipynb b/acme/examples/quickstart.ipynb new file mode 100644 index 00000000..87ffda30 --- /dev/null +++ b/acme/examples/quickstart.ipynb @@ -0,0 +1,468 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "ULdrhOaVbsdO" + }, + "source": [ + "# Acme: Quickstart\n", + "## Guide to installing Acme and training your first D4PG agent.\n", + "# \u003ca href=\"https://colab.research.google.com/github/deepmind/acme/blob/master/examples/quickstart.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ogw2P040-F5P" + }, + "source": [ + "## Select your environment library\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "xflSXJPS8Qpm" + }, + "outputs": [], + "source": [ + "environment_library = 'gym' # @param ['dm_control', 'gym']" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xaJxoatMhJ71" + }, + "source": [ + "## Installation" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ovuCuHCC78Zu" + }, + "source": [ + "### Install Acme" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "both", + "id": "KH3O0zcXUeun" + }, + "outputs": [], + "source": [ + "!pip install dm-acme\n", + "!pip install dm-acme[reverb]\n", + "!pip install dm-acme[tf]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VEEj3Qw60y73" + }, + "source": [ + "### Install the environment library" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "IbZxYDxzoz5R" + }, + "outputs": [], + "source": [ + "if environment_library == 'dm_control':\n", + " import distutils.util\n", + " import subprocess\n", + " if subprocess.run('nvidia-smi').returncode:\n", + " raise RuntimeError(\n", + " 'Cannot communicate with GPU. '\n", + " 'Make sure you are using a GPU Colab runtime. '\n", + " 'Go to the Runtime menu and select Choose runtime type.')\n", + "\n", + " mujoco_dir = \"$HOME/.mujoco\"\n", + "\n", + " print('Installing OpenGL dependencies...')\n", + " !apt-get update -qq\n", + " !apt-get install -qq -y --no-install-recommends libglew2.0 \u003e /dev/null\n", + "\n", + " print('Downloading MuJoCo...')\n", + " BASE_URL = 'https://github.com/deepmind/mujoco/releases/download'\n", + " MUJOCO_VERSION = '2.1.1'\n", + " MUJOCO_ARCHIVE = (\n", + " f'mujoco-{MUJOCO_VERSION}-{distutils.util.get_platform()}.tar.gz')\n", + " !wget -q \"{BASE_URL}/{MUJOCO_VERSION}/{MUJOCO_ARCHIVE}\"\n", + " !wget -q \"{BASE_URL}/{MUJOCO_VERSION}/{MUJOCO_ARCHIVE}.sha256\"\n", + " check_result = !shasum -c \"{MUJOCO_ARCHIVE}.sha256\"\n", + " if _exit_code:\n", + " raise RuntimeError(\n", + " 'Downloaded MuJoCo archive is corrupted (checksum mismatch)')\n", + "\n", + " print('Unpacking MuJoCo...')\n", + " MUJOCO_DIR = '$HOME/.mujoco'\n", + " !mkdir -p \"{MUJOCO_DIR}\"\n", + " !tar -zxf {MUJOCO_ARCHIVE} -C \"{MUJOCO_DIR}\"\n", + "\n", + " # Configure dm_control to use the EGL rendering backend (requires GPU)\n", + " %env MUJOCO_GL=egl\n", + "\n", + " print('Installing dm_control...')\n", + " # Version 0.0.416848645 is the first one to support MuJoCo 2.1.1.\n", + " !pip install -q dm_control\u003e=0.0.416848645\n", + "\n", + " print('Checking that the dm_control installation succeeded...')\n", + " try:\n", + " from dm_control import suite\n", + " env = suite.load('cartpole', 'swingup')\n", + " pixels = env.physics.render()\n", + " except Exception as e:\n", + " raise e from RuntimeError(\n", + " 'Something went wrong during installation. Check the shell output above '\n", + " 'for more information.\\n'\n", + " 'If using a hosted Colab runtime, make sure you enable GPU acceleration '\n", + " 'by going to the Runtime menu and selecting \"Choose runtime type\".')\n", + " else:\n", + " del suite, env, pixels\n", + "\n", + " !echo Installed dm_control $(pip show dm_control | grep -Po \"(?\u003c=Version: ).+\")\n", + "\n", + "elif environment_library == 'gym':\n", + " !pip install gym" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Cl8eyWblSs-z" + }, + "source": [ + "### Install visualization packages" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "aSM7KHDFSsQS" + }, + "outputs": [], + "source": [ + "!sudo apt-get install -y xvfb ffmpeg\n", + "!pip install imageio\n", + "!pip install PILLOW\n", + "!pip install pyvirtualdisplay" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "c-H2d6UZi7Sf" + }, + "source": [ + "## Import Modules" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "both", + "colab": {}, + "colab_type": "code", + "id": "HJ74Id-8MERq" + }, + "outputs": [], + "source": [ + "import IPython\n", + "\n", + "from acme import environment_loop\n", + "from acme import specs\n", + "from acme import wrappers\n", + "from acme.agents.tf import d4pg\n", + "from acme.tf import networks\n", + "from acme.tf import utils as tf2_utils\n", + "from acme.utils import loggers\n", + "import numpy as np\n", + "import sonnet as snt\n", + "\n", + "# Import the selected environment lib\n", + "if environment_library == 'dm_control':\n", + " from dm_control import suite\n", + "elif environment_library == 'gym':\n", + " import gym\n", + "\n", + "# Imports required for visualization\n", + "import pyvirtualdisplay\n", + "import imageio\n", + "import base64\n", + "\n", + "# Set up a virtual display for rendering.\n", + "display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "I6KuVGSk4uc9" + }, + "source": [ + "## Load an environment\n", + "\n", + "We can now load an environment. In what follows we'll create an environment and grab the environment's specifications." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "both", + "colab": {}, + "colab_type": "code", + "id": "4PVlHtGF5yzt" + }, + "outputs": [], + "source": [ + "if environment_library == 'dm_control':\n", + " environment = suite.load('cartpole', 'balance')\n", + " \n", + "elif environment_library == 'gym':\n", + " environment = gym.make('MountainCarContinuous-v0')\n", + " environment = wrappers.GymWrapper(environment) # To dm_env interface.\n", + "\n", + "else:\n", + " raise ValueError(\n", + " \"Unknown environment library: {};\".format(environment_library) +\n", + " \"choose among ['dm_control', 'gym'].\")\n", + "\n", + "# Make sure the environment outputs single-precision floats.\n", + "environment = wrappers.SinglePrecisionWrapper(environment)\n", + "\n", + "# Grab the spec of the environment.\n", + "environment_spec = specs.make_environment_spec(environment)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BukOfOsmtSQn" + }, + "source": [ + " ## Create a D4PG agent" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "3Jcjk1w6oHVX" + }, + "outputs": [], + "source": [ + "#@title Build agent networks\n", + "\n", + "# Get total number of action dimensions from action spec.\n", + "num_dimensions = np.prod(environment_spec.actions.shape, dtype=int)\n", + "\n", + "# Create the shared observation network; here simply a state-less operation.\n", + "observation_network = tf2_utils.batch_concat\n", + "\n", + "# Create the deterministic policy network.\n", + "policy_network = snt.Sequential([\n", + " networks.LayerNormMLP((256, 256, 256), activate_final=True),\n", + " networks.NearZeroInitializedLinear(num_dimensions),\n", + " networks.TanhToSpec(environment_spec.actions),\n", + "])\n", + "\n", + "# Create the distributional critic network.\n", + "critic_network = snt.Sequential([\n", + " # The multiplexer concatenates the observations/actions.\n", + " networks.CriticMultiplexer(),\n", + " networks.LayerNormMLP((512, 512, 256), activate_final=True),\n", + " networks.DiscreteValuedHead(vmin=-150., vmax=150., num_atoms=51),\n", + "])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "9CD2sNK-oA9S" + }, + "outputs": [], + "source": [ + "# Create a logger for the agent and environment loop.\n", + "agent_logger = loggers.TerminalLogger(label='agent', time_delta=10.)\n", + "env_loop_logger = loggers.TerminalLogger(label='env_loop', time_delta=10.)\n", + "\n", + "# Create the D4PG agent.\n", + "agent = d4pg.D4PG(\n", + " environment_spec=environment_spec,\n", + " policy_network=policy_network,\n", + " critic_network=critic_network,\n", + " observation_network=observation_network,\n", + " sigma=1.0,\n", + " logger=agent_logger,\n", + " checkpoint=False\n", + ")\n", + "\n", + "# Create an loop connecting this agent to the environment created above.\n", + "env_loop = environment_loop.EnvironmentLoop(\n", + " environment, agent, logger=env_loop_logger)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "oKeGQxzitXYC" + }, + "source": [ + "## Run a training loop" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "VWZd5N-Qoz82" + }, + "outputs": [], + "source": [ + "# Run a `num_episodes` training episodes.\n", + "# Rerun this cell until the agent has learned the given task.\n", + "env_loop.run(num_episodes=100)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Do57Ql4ZsWDu" + }, + "source": [ + "## Visualize an evaluation loop\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BJXkfg6LSZ-h" + }, + "source": [ + "### Helper functions for rendering and vizualization" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "OIJRbtAlxQVu" + }, + "outputs": [], + "source": [ + "# Create a simple helper function to render a frame from the current state of\n", + "# the environment.\n", + "if environment_library == 'dm_control':\n", + " def render(env):\n", + " return env.physics.render(camera_id=0)\n", + "elif environment_library == 'gym':\n", + " def render(env):\n", + " return env.environment.render(mode='rgb_array')\n", + "else:\n", + " raise ValueError(\n", + " \"Unknown environment library: {};\".format(environment_library) +\n", + " \"choose among ['dm_control', 'gym'].\")\n", + "\n", + "def display_video(frames, filename='temp.mp4'):\n", + " \"\"\"Save and display video.\"\"\"\n", + "\n", + " # Write video\n", + " with imageio.get_writer(filename, fps=60) as video:\n", + " for frame in frames:\n", + " video.append_data(frame)\n", + "\n", + " # Read video and display the video\n", + " video = open(filename, 'rb').read()\n", + " b64_video = base64.b64encode(video)\n", + " video_tag = ('