-
Notifications
You must be signed in to change notification settings - Fork 69
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Builder to gen structures for machine learning
- Loading branch information
shyamd
committed
Dec 16, 2017
1 parent
fe608c3
commit a747a0f
Showing
2 changed files
with
197 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
|
||
from itertools import chain | ||
from pymatgen import Structure | ||
from pymatgen.entries.computed_entries import ComputedStructureEntry | ||
from emmet.vasp.builders.task_tagger import task_type | ||
from maggma.builder import Builder | ||
from pydash.objects import get | ||
|
||
__author__ = "Shyam Dwaraknath <[email protected]>" | ||
|
||
|
||
class MLStructuresBuilder(Builder): | ||
|
||
def __init__(self, tasks, ml_strucs, task_types = ("Structure Optimization","Static"),query={}, **kwargs): | ||
""" | ||
Creates a collection of structures, energies, forces, and stresses for machine learning efforts | ||
Args: | ||
tasks (Store): Store of task documents | ||
ml_strucs (Store): Store of materials documents to generate | ||
tasK_types (list): list of substrings for task_types to process | ||
""" | ||
|
||
self.tasks = tasks | ||
self.ml_strucs = ml_strucs | ||
self.task_types = task_types | ||
self.query = query | ||
super().__init__(sources=[tasks], | ||
targets=[ml_strucs], | ||
**kwargs) | ||
|
||
def get_items(self): | ||
""" | ||
Gets all items to process into materials documents | ||
Returns: | ||
generator or list relevant tasks and materials to process into materials documents | ||
""" | ||
|
||
self.logger.info("Machine Learning Structure Database Builder Started") | ||
self.logger.info("Setting indexes") | ||
self.ensure_indexes() | ||
|
||
# Get all processed tasks: | ||
q = dict(self.query) | ||
q["state"] = "successful" | ||
q["calcs_reversed"] = {"$exists": 1} | ||
|
||
all_tasks = set(self.tasks.distinct("task_id", q)) | ||
processed_tasks = set(self.ml_strucs.distinct("task_id")) | ||
to_process_tasks = all_tasks - processed_tasks | ||
|
||
self.logger.info( | ||
"Found {} tasks to extract information from".format(len(to_process_tasks))) | ||
|
||
for t_id in to_process_tasks: | ||
task = self.tasks.query_one(criteria={"task_id": t_id}) | ||
yield task | ||
|
||
def process_item(self, task): | ||
""" | ||
Process the tasks into a list of materials | ||
Args: | ||
task [dict] : a task doc | ||
Returns: | ||
list of C | ||
""" | ||
|
||
t_type = task_type(get(task, 'input.incar')) | ||
entries = [] | ||
|
||
if not any([t in t_type for t in self.task_types]): | ||
return [] | ||
|
||
is_hubbard = get(task, "input.is_hubbard", False) | ||
hubbards = get(task, "input.hubbards", []) | ||
i = 0 | ||
|
||
for calc in task.get("calcs_reversed", []): | ||
|
||
parameters = {"is_hubbard": is_hubbard, | ||
"hubbards": hubbards, | ||
"potcar_spec": get(calc, "input.potcar_spec", []), | ||
"run_type": calc.get("run_type", "GGA") | ||
} | ||
|
||
for step_num, step in enumerate(get(calc, "output.ionic_steps")): | ||
struc = Structure.from_dict(step.get("structure")) | ||
forces = calc.get("forces", []) | ||
if forces: | ||
struc.add_site_property("forces", forces) | ||
stress = calc.get("stress", None) | ||
data = {"stress": stress} if stress else {} | ||
data["step"] =step_num | ||
c = ComputedStructureEntry(structure=struc, | ||
correction=0, | ||
energy=step.get("e_wo_entrp"), | ||
parameters=parameters, | ||
entry_id="{}-{}".format(task[self.tasks.key],i), | ||
data=data) | ||
i += 1 | ||
|
||
d = c.as_dict() | ||
d["chemsys"] = '-'.join( | ||
sorted(set([e.symbol for e in struc.composition.elements]))) | ||
d["task_type"] = task_type(get(calc, 'input.incar')) | ||
d["calc_name"] = get(calc, "task.name") | ||
d["task_id"] = task[self.tasks.key] | ||
|
||
entries.append(d) | ||
|
||
return entries | ||
|
||
def update_targets(self, items): | ||
""" | ||
Inserts the new entires into the task_types collection | ||
Args: | ||
items ([([dict],[int])]): A list of tuples of materials to update and the corresponding processed task_ids | ||
""" | ||
items = [i for i in filter(None, chain.from_iterable(items))] | ||
|
||
if len(items) > 0: | ||
self.logger.info("Updating {} entries".format(len(items))) | ||
self.ml_strucs.update(docs=items) | ||
else: | ||
self.logger.info("No items to update") | ||
|
||
def ensure_indexes(self): | ||
""" | ||
Ensures indexes on the tasks and materials collections | ||
:return: | ||
""" | ||
|
||
# Basic search index for tasks | ||
self.ml_strucs.ensure_index("entry_id") | ||
self.ml_strucs.ensure_index("chemsys") | ||
self.ml_strucs.ensure_index(self.ml_strucs.lu_field) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import unittest | ||
from itertools import chain | ||
from pydash.objects import get | ||
from maggma.stores import MongoStore | ||
from maggma.runner import Runner | ||
from emmet.vasp.builders.task_tagger import task_type | ||
from emmet.vasp.builders.tests.test_builders import BuilderTest | ||
from emmet.vasp.builders.ml_structures import MLStructuresBuilder | ||
|
||
|
||
__author__ = "Shyam Dwaraknath" | ||
__email__ = "[email protected]" | ||
|
||
|
||
class TestMaterials(BuilderTest): | ||
|
||
def setUp(self): | ||
self.ml_strucs = MongoStore("emmet_test", "ml_strucs",key="entry_id") | ||
self.ml_strucs.connect() | ||
|
||
self.ml_strucs.collection.drop() | ||
self.mlbuilder = MLStructuresBuilder(self.tasks, self.ml_strucs, task_types = ("Structure Optimization","Static")) | ||
|
||
|
||
def test_get_items(self): | ||
to_process = list(self.mlbuilder.get_items()) | ||
to_process_forms = {task["formula_pretty"] for task in to_process} | ||
|
||
self.assertEqual(len(to_process), 197) | ||
self.assertEqual(len(to_process_forms), 12) | ||
self.assertTrue("Sr" in to_process_forms) | ||
self.assertTrue("Hf" in to_process_forms) | ||
self.assertTrue("O2" in to_process_forms) | ||
self.assertFalse("H" in to_process_forms) | ||
|
||
def test_process_item(self): | ||
for task in self.tasks.query(): | ||
ml_strucs = self.mlbuilder.process_item(task) | ||
t_type = task_type(get(task, 'input.incar')) | ||
if not any([t in t_type for t in self.mlbuilder.task_types]): | ||
self.assertEqual(len(ml_strucs),0) | ||
else: | ||
self.assertEqual(len(ml_strucs), sum([len(t["output"]["ionic_steps"]) for t in task["calcs_reversed"]])) | ||
|
||
|
||
def test_update_targets(self): | ||
for task in self.tasks.query(): | ||
ml_strucs = self.mlbuilder.process_item(task) | ||
self.mlbuilder.update_targets([ml_strucs]) | ||
self.assertEqual(len(self.ml_strucs.distinct("task_id")), 102) | ||
self.assertEqual(len(list(self.ml_strucs.query())), 1012) | ||
|
||
def tearDown(self): | ||
self.ml_strucs.collection.drop() | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |
a747a0f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't realize you'd written this! Looks useful :)
For use in an API, not sure if it might be useful to unpack the Structure dict into lists of positions/elements/forces since this likely how a ML algorithm would ingest them -- as far as I can see, if you wanted to query this ML doc right now to get e.g. forces, you'd have to either reconstruct the ComputedStructureEntry/Structure object to get the site properties, or do a complex aggregation. Being able to query to get vectors of forces etc might be better. Happy to implement if you agree.
Also need to note if it's a spin polarized calc or not.
a747a0f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was just an initial scaffold for anyone to build on. Go ahead and make a PR.