Skip to content

Commit

Permalink
simplify Tasks class
Browse files Browse the repository at this point in the history
  • Loading branch information
ungarj committed Dec 5, 2023
1 parent 6f560e0 commit 16cd59d
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 100 deletions.
3 changes: 2 additions & 1 deletion mapchete/processing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,13 +676,14 @@ def _task_batches(
if tile:
tile = process.config.process_pyramid.tile(*tile)

# first, materialize tile task batches
# first, materialize tile task batches to determine process AOI
tile_task_batches = _tile_task_batches(
process=process,
zoom=zoom,
tile=tile,
profilers=profilers,
)

# create processing AOI (i.e. processing area without overviews)
if process.config.preprocessing_tasks().values():
zoom_aois = []
Expand Down
145 changes: 55 additions & 90 deletions mapchete/processing/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,65 +515,79 @@ def _validate(self, items: Iterator[TileTask]) -> Iterator[TileTask]:


class Tasks:
_len: int = None
_task_batches_generator: Iterator[Union[TaskBatch, TileTaskBatch]]
_preprocessing_batches: List[TaskBatch]
_tile_batches: List[TileTaskBatch]
preprocessing_batches: List[TaskBatch]
tile_batches: List[TileTaskBatch]
materialized: bool = False

def __init__(
self,
task_batches_generator: Iterator[Union[TaskBatch, TileTaskBatch]],
task_batches: Iterator[Union[TaskBatch, TileTaskBatch]],
):
self._task_batches_generator = task_batches_generator

def __len__(self):
# TODO: maybe make explicit that Task.materialize() has to be run first
return sum([len(batch) for batch in self._batches_generator()])

def _batches_generator(self):
for phase in (self.preprocessing_batches, self.tile_batches):
for batch in phase:
yield batch

def materialize(self):
if self.materialized:
return
logger.debug("materializing task batches ...")
with Timer() as tt:
self._preprocessing_batches = []
self._tile_batches = []
for batch in self._task_batches_generator:
self.preprocessing_batches = []
self.tile_batches = []
for batch in task_batches:
if isinstance(batch, TileTaskBatch):
self._tile_batches.append(batch)
self.tile_batches.append(batch)
else:
self._preprocessing_batches.append(batch)
self.materialized = True
self.preprocessing_batches.append(batch)
logger.debug("task batches materialized in %s", tt)

@property
def preprocessing_batches(self) -> List[TaskBatch]:
self.materialize()
return self._preprocessing_batches
def __len__(self):
return sum([len(batch) for batch in self])

@property
def tile_batches(self) -> List[TileTaskBatch]:
self.materialize()
return self._tile_batches
def __iter__(self) -> Union[TaskBatch, TileTaskBatch]:
for phase in (self.preprocessing_batches, self.tile_batches):
for batch in phase:
yield batch

def to_dask_graph(
self,
preprocessing_task_wrapper: Optional[Callable] = None,
tile_task_wrapper: Optional[Callable] = None,
) -> List[Union[Delayed, DelayedLeaf]]:
"""Return task graph to use with dask Executor."""
return to_dask_collection(
self._batches_generator(),
preprocessing_task_wrapper=preprocessing_task_wrapper,
tile_task_wrapper=tile_task_wrapper,
)
tasks = {}
previous_batch = None
for batch in self:
logger.debug("converting batch %s", batch)

if isinstance(batch, TileTaskBatch):
task_func = tile_task_wrapper or batch.func
else:
task_func = preprocessing_task_wrapper or batch.func

if previous_batch:
logger.debug("previous batch had %s tasks", len(previous_batch))

for task in batch.values():
if previous_batch:
dependencies = {
child.id: tasks[child]
for child in previous_batch.intersection(task)
}
logger.debug(
"found %s dependencies from last batch for task %s",
len(dependencies),
task,
)
else:
dependencies = {}

tasks[task] = delayed(
task_func,
pure=True,
name=f"{task.id}",
traverse=len(dependencies) > 0,
)(
task,
dependencies=dependencies,
**batch.fkwargs,
dask_key_name=f"{task.result_key_name}",
)

previous_batch = batch

return list(tasks.values())

def to_batch(self) -> Iterator[Task]:
"""Return all tasks as one batch."""
Expand All @@ -583,53 +597,4 @@ def to_batch(self) -> Iterator[Task]:

def to_batches(self) -> Iterator[Iterator[Task]]:
"""Return batches of tasks."""
return list(self._batches_generator())


def to_dask_collection(
batches: Iterator[Union[TaskBatch, TileTaskBatch]],
preprocessing_task_wrapper: Optional[Callable] = None,
tile_task_wrapper: Optional[Callable] = None,
) -> List[Union[Delayed, DelayedLeaf]]:
tasks = {}
previous_batch = None
for batch in batches:
logger.debug("converting batch %s", batch)

if batch.id == "preprocessing_tasks":
task_func = preprocessing_task_wrapper or batch.func
else:
task_func = tile_task_wrapper or batch.func

if previous_batch:
logger.debug("previous batch had %s tasks", len(previous_batch))

for task in batch.values():
if previous_batch:
dependencies = {
child.id: tasks[child]
for child in previous_batch.intersection(task)
}
logger.debug(
"found %s dependencies from last batch for task %s",
len(dependencies),
task,
)
else:
dependencies = {}

tasks[task] = delayed(
task_func,
pure=True,
name=f"{task.id}",
traverse=len(dependencies) > 0,
)(
task,
dependencies=dependencies,
**batch.fkwargs,
dask_key_name=f"{task.result_key_name}",
)

previous_batch = batch

return list(tasks.values())
return list(self)
11 changes: 2 additions & 9 deletions test/test_processing_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,7 @@

from mapchete.errors import NoTaskGeometry
from mapchete.executor import Executor
from mapchete.processing.tasks import (
Task,
TaskBatch,
Tasks,
TileTask,
TileTaskBatch,
to_dask_collection,
)
from mapchete.processing.tasks import Task, TaskBatch, Tasks, TileTask, TileTaskBatch
from mapchete.testing import ProcessFixture
from mapchete.tile import BufferedTilePyramid

Expand Down Expand Up @@ -128,7 +121,7 @@ def test_task_batches_to_dask_graph(dem_to_hillshade):
)
for zoom in dem_to_hillshade.mp().config.zoom_levels.descending()
)
collection = to_dask_collection((preprocessing_batch, *zoom_batches))
collection = Tasks((preprocessing_batch, *zoom_batches)).to_dask_graph()
import dask

dask.compute(collection)
Expand Down

0 comments on commit 16cd59d

Please sign in to comment.