Skip to content

Commit

Permalink
Builder to gen structures for machine learning
Browse files Browse the repository at this point in the history
  • Loading branch information
shyamd committed Dec 16, 2017
1 parent fe608c3 commit a747a0f
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 0 deletions.
140 changes: 140 additions & 0 deletions emmet/vasp/builders/ml_structures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@

from itertools import chain
from pymatgen import Structure
from pymatgen.entries.computed_entries import ComputedStructureEntry
from emmet.vasp.builders.task_tagger import task_type
from maggma.builder import Builder
from pydash.objects import get

__author__ = "Shyam Dwaraknath <[email protected]>"


class MLStructuresBuilder(Builder):

def __init__(self, tasks, ml_strucs, task_types = ("Structure Optimization","Static"),query={}, **kwargs):
"""
Creates a collection of structures, energies, forces, and stresses for machine learning efforts
Args:
tasks (Store): Store of task documents
ml_strucs (Store): Store of materials documents to generate
tasK_types (list): list of substrings for task_types to process
"""

self.tasks = tasks
self.ml_strucs = ml_strucs
self.task_types = task_types
self.query = query
super().__init__(sources=[tasks],
targets=[ml_strucs],
**kwargs)

def get_items(self):
"""
Gets all items to process into materials documents
Returns:
generator or list relevant tasks and materials to process into materials documents
"""

self.logger.info("Machine Learning Structure Database Builder Started")
self.logger.info("Setting indexes")
self.ensure_indexes()

# Get all processed tasks:
q = dict(self.query)
q["state"] = "successful"
q["calcs_reversed"] = {"$exists": 1}

all_tasks = set(self.tasks.distinct("task_id", q))
processed_tasks = set(self.ml_strucs.distinct("task_id"))
to_process_tasks = all_tasks - processed_tasks

self.logger.info(
"Found {} tasks to extract information from".format(len(to_process_tasks)))

for t_id in to_process_tasks:
task = self.tasks.query_one(criteria={"task_id": t_id})
yield task

def process_item(self, task):
"""
Process the tasks into a list of materials
Args:
task [dict] : a task doc
Returns:
list of C
"""

t_type = task_type(get(task, 'input.incar'))
entries = []

if not any([t in t_type for t in self.task_types]):
return []

is_hubbard = get(task, "input.is_hubbard", False)
hubbards = get(task, "input.hubbards", [])
i = 0

for calc in task.get("calcs_reversed", []):

parameters = {"is_hubbard": is_hubbard,
"hubbards": hubbards,
"potcar_spec": get(calc, "input.potcar_spec", []),
"run_type": calc.get("run_type", "GGA")
}

for step_num, step in enumerate(get(calc, "output.ionic_steps")):
struc = Structure.from_dict(step.get("structure"))
forces = calc.get("forces", [])
if forces:
struc.add_site_property("forces", forces)
stress = calc.get("stress", None)
data = {"stress": stress} if stress else {}
data["step"] =step_num
c = ComputedStructureEntry(structure=struc,
correction=0,
energy=step.get("e_wo_entrp"),
parameters=parameters,
entry_id="{}-{}".format(task[self.tasks.key],i),
data=data)
i += 1

d = c.as_dict()
d["chemsys"] = '-'.join(
sorted(set([e.symbol for e in struc.composition.elements])))
d["task_type"] = task_type(get(calc, 'input.incar'))
d["calc_name"] = get(calc, "task.name")
d["task_id"] = task[self.tasks.key]

entries.append(d)

return entries

def update_targets(self, items):
"""
Inserts the new entires into the task_types collection
Args:
items ([([dict],[int])]): A list of tuples of materials to update and the corresponding processed task_ids
"""
items = [i for i in filter(None, chain.from_iterable(items))]

if len(items) > 0:
self.logger.info("Updating {} entries".format(len(items)))
self.ml_strucs.update(docs=items)
else:
self.logger.info("No items to update")

def ensure_indexes(self):
"""
Ensures indexes on the tasks and materials collections
:return:
"""

# Basic search index for tasks
self.ml_strucs.ensure_index("entry_id")
self.ml_strucs.ensure_index("chemsys")
self.ml_strucs.ensure_index(self.ml_strucs.lu_field)
57 changes: 57 additions & 0 deletions emmet/vasp/builders/tests/test_ml_structure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import unittest
from itertools import chain
from pydash.objects import get
from maggma.stores import MongoStore
from maggma.runner import Runner
from emmet.vasp.builders.task_tagger import task_type
from emmet.vasp.builders.tests.test_builders import BuilderTest
from emmet.vasp.builders.ml_structures import MLStructuresBuilder


__author__ = "Shyam Dwaraknath"
__email__ = "[email protected]"


class TestMaterials(BuilderTest):

def setUp(self):
self.ml_strucs = MongoStore("emmet_test", "ml_strucs",key="entry_id")
self.ml_strucs.connect()

self.ml_strucs.collection.drop()
self.mlbuilder = MLStructuresBuilder(self.tasks, self.ml_strucs, task_types = ("Structure Optimization","Static"))


def test_get_items(self):
to_process = list(self.mlbuilder.get_items())
to_process_forms = {task["formula_pretty"] for task in to_process}

self.assertEqual(len(to_process), 197)
self.assertEqual(len(to_process_forms), 12)
self.assertTrue("Sr" in to_process_forms)
self.assertTrue("Hf" in to_process_forms)
self.assertTrue("O2" in to_process_forms)
self.assertFalse("H" in to_process_forms)

def test_process_item(self):
for task in self.tasks.query():
ml_strucs = self.mlbuilder.process_item(task)
t_type = task_type(get(task, 'input.incar'))
if not any([t in t_type for t in self.mlbuilder.task_types]):
self.assertEqual(len(ml_strucs),0)
else:
self.assertEqual(len(ml_strucs), sum([len(t["output"]["ionic_steps"]) for t in task["calcs_reversed"]]))


def test_update_targets(self):
for task in self.tasks.query():
ml_strucs = self.mlbuilder.process_item(task)
self.mlbuilder.update_targets([ml_strucs])
self.assertEqual(len(self.ml_strucs.distinct("task_id")), 102)
self.assertEqual(len(list(self.ml_strucs.query())), 1012)

def tearDown(self):
self.ml_strucs.collection.drop()

if __name__ == "__main__":
unittest.main()

2 comments on commit a747a0f

@mkhorton
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't realize you'd written this! Looks useful :)

For use in an API, not sure if it might be useful to unpack the Structure dict into lists of positions/elements/forces since this likely how a ML algorithm would ingest them -- as far as I can see, if you wanted to query this ML doc right now to get e.g. forces, you'd have to either reconstruct the ComputedStructureEntry/Structure object to get the site properties, or do a complex aggregation. Being able to query to get vectors of forces etc might be better. Happy to implement if you agree.

Also need to note if it's a spin polarized calc or not.

@shyamd
Copy link
Contributor

@shyamd shyamd commented on a747a0f Jan 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just an initial scaffold for anyone to build on. Go ahead and make a PR.

Please sign in to comment.