-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Labeling-tasks integration into fractal + (Torch) memory errors #109
Comments
As of 3c1c431, there are a few working examples of this integration, see for instance A parameter file looks like this:
(note: the structure got somewhat similar to the idea of a pipeline JSON file, but this is not part of that work - which will only take place in the incoming server-based version) These examples include the whole set of current tasks (including illumination correction, per-FOV labeling, MIP creation, per-well MIP labeling). Note that the whole-well labeling only takes place on MIP images (this is hardcoded, at the moment). For the per-FOV labeling in the multi-well case, parsl opens as many GPU jobs as possible (ideally as many as the number of wells). In the running folder, there are some log files named like |
Let's run a couple more tests, including those with larger (9x8) wells, before closing this issue. |
EDIT: what I wrote above is likely true, but in the specific example I was looking at ( |
The following Fractal examples ran through
A run with 10 5x5 wells (and 19 Z planes) lead to a memory error, when run with |
There's still something wrong with the 10-wells example, which successfully completes only 8 out of 10 labeling executions:
The parsl error is of the
which points at a very explicit virtual-memory error (reaching the 64 G available). To be verified. |
Addendum to the last comment: |
One more piece of information, trying to pinpoint the source of the memory error when per-FOV labeling 10 wells. To test the multi-well case and (temporarily) avoid memory errors, I ran a workflow where per-FOV segmentation of 10 5x5 wells takes place at level 1. This is allowed by the new ROI-based labeling (see fractal-analytics-platform/fractal-tasks-core#19). I also reduced Here are the CPU and memory traces (the CPU trace is there just as a reference, as it is quite trivial): The first two blocks of execution (up to time ~1800 s) correspond to other (non-labeling) tasks, and then there are 10 blocks of per-FOV labeling (one for each well). The annoying feature is the build up of memory usage between the 1st and 3rd segmentation tasks, from ~5 G to ~16 G. This is unexpected, as each task should use approximately the same memory (notice that number of Z planes is constant across wells, in this dataset). Thus the actual question becomes something like: |
Quick update after discussing with @mfranzon We tend to think that this is related to cellpose (and probably torch) not freeing up memory, rather than parsl. Possibly related: |
Sounds like a good assumption to test, because we've only seen this issue come up with labeling jobs after all :) |
Quick update: For debugging, we now import the Cellpose model within the use_gpu = core.use_gpu()
model = models.Cellpose(gpu=use_gpu, model_type=model_type) is always re-executed. device = torch.device('cuda:' + str(gpu_number))
_ = torch.zeros([1, 2, 3]).to(device) It would be great to have a torch.reset() function (we are not the only ones who think something similar), but at the moment we only have This update points towards a likely problem, but not towards an obvious solution. Notice that we now tend to exclude that the problem is due to some attributes of Cellpose model being kept in memory, since the model is now always re-initialized within |
Hmm, and with this new model initialization within each segment_FOV call, we still run into the same memory issue? Also, is a torch.reset() something we could safely do if multiple jobs run on parallel on the same GPU? |
Our understanding (by now) is that the Cellpose model initialization is irrelevant, and the problem comes from torch.
This is not really an issue, since torch does not have a |
Can we reproduce this in a synthetic setup by e.g. just running 2 or 10 cellpose models sequentially in a single script? Could be that it's a new issue or related to a specific cellpose/torch/gpu. But I've ran scripts that used cellpose for model inference in the past on 1000s of 3D images, each of them close to the memory limit. And that was all sequentially on a single GPU. So would be kind of surprised if there is a very general issue with cellpose or torch memory handling... |
Yes. By now I'm starting with the original Here is the script that runs four wells sequentially: import sys
import time
import os
import shutil
from fractal.tasks.image_labeling import image_labeling
import gc
zarrurl = "/data/active/fractal/tests/Temporary_data_UZH_4_well_2x2_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/"
wells = ["B/03/0/", "B/05/0/", "C/04/0/", "D/05/0/"]
print(wells)
print()
for well in wells:
label_dir = zarrurl + well + "labels"
if os.path.isdir(label_dir):
shutil.rmtree(label_dir)
print("Cleared label folders")
print()
with open("times.dat", "w") as f:
t0 = time.perf_counter()
for well in wells:
print(well)
sys.stdout.flush()
image_labeling(
zarrurl + well,
coarsening_xy=2,
labeling_level=2,
labeling_channel="A01_C01",
chl_list=["A01_C01"],
num_threads=4,
relabeling=1,
)
t1 = time.perf_counter()
f.write(f"{t1-t0}\n")
f.flush()
gc.collect() And here is the memory trace (also using https://bbengfort.github.io/2020/07/read-mprofile-into-pandas), where something is clearly accumulating between different |
By the way: this last comment confirms that parsl is not involved in this memory issue (since the script was executed directly via SLURM, with parsl never appearing). |
Nice synthetic test and great to know it's not a parsl issue. But can we also push this test over the edge into a memory error? (e.g. by using level 0 or 1)? Reason I'm asking: if torch does some optimization in the background and decides when to clear what memory, I wouldn't care about it, as long as it does it well enough to avoid out-of-memory errors. So I'm not sure how concerning the fact that some memory is accumulating between runs is, unless it is not freed up when needed. What is your expectation here: Does torch do any fancy optimization for when to free up which memory? Because if it was an actual memory leak, I would expect linear accumulation. And for my understanding: we are always talking about CPU memory here, right? Or do we go into GPU memory being an issue? |
Sure, let's try. My guess is that this is exactly the memory error of #109 (comment), but let's check it explicitly.
It's clearly not linear, see the saturation to a plateau in the comment above: https://user-images.githubusercontent.com/3862206/181180402-409d000f-972c-4e8f-85eb-f05b0f44904b.png.
Yes, this is all standard CPU memory. GPU memory errors may appear, but (in our experience) only when running many (e.g. 10) simultaneous Cellpose calculations on the same GPU. |
Hmm, ok. Let's see if that test can explicitly push it over the limit. Also, I wonder whether there are some torch parameters that would tell it about available memory. It does seem to handle memory cleanup sometimes, but maybe it's optimized for classical systems where it could go into swap a little bit? Maybe there are torch parameters we could set to make it more aggressive in CPU memory cleanup? |
The only related option we could find is no_grad, but this introduces an explicit change of the function and we are not able to say whether it modifies Cellpose behavior. Is there a trivial answer? (btw, we have not tested it yet) Other than that, we'd be glad to test other relevant torch options, if you discover any. |
Ok, just googling a bit. Have you had a look at setting e.g. testing having this included would be interesting:
|
I confirm what I said: the example we are looking at does yield the expected memory error (AKA there is no smart optimization of memory usage by torch, but rather a memory accumulation as several Here is the memory trace (details in the plot title), and the memory error appears during processing of the third well. Just as a reference, the detailed traceback is B/09/0/
B/11/0/
C/08/0/
Traceback (most recent call last):
File "Many_segmentations.py", line 24, in <module>
image_labeling(
File "/net/nfs4/pelkmanslab-fileserver-common/data/homes/fractal/mwe_fractal/fractal/tasks/image_labeling.py", line 291, in image_labeling
write_pyramid(
File "/net/nfs4/pelkmanslab-fileserver-common/data/homes/fractal/mwe_fractal/fractal/tasks/lib_pyramid_creation.py", line 69, in write_pyramid
level0 = to_zarr_custom(
File "/net/nfs4/pelkmanslab-fileserver-common/data/homes/fractal/mwe_fractal/fractal/tasks/lib_to_zarr_custom.py", line 64, in to_zarr_custom
output = array.to_zarr(
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/array/core.py", line 2828, in to_zarr
return to_zarr(self, *args, **kwargs)
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/array/core.py", line 3591, in to_zarr
return arr.store(z, lock=False, compute=compute, return_stored=return_stored)
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/array/core.py", line 1752, in store
r = store([self], [target], **kwargs)
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/array/core.py", line 1214, in store
store_dlyds = persist(*store_dlyds, **kwargs)
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/base.py", line 904, in persist
results = schedule(dsk, keys, **kwargs)
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/threaded.py", line 89, in get
results = get_async(
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/local.py", line 511, in get_async
raise_exception(exc, tb)
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/local.py", line 319, in reraise
raise exc
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/local.py", line 224, in execute_task
result = _execute_task(task, data)
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/core.py", line 119, in _execute_task
return func(*(_execute_task(a, cache) for a in args))
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/dask/utils.py", line 41, in apply
return func(*args, **kwargs)
File "/net/nfs4/pelkmanslab-fileserver-common/data/homes/fractal/mwe_fractal/fractal/tasks/image_labeling.py", line 57, in segment_FOV
mask, flows, styles, diams = model.eval(
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/cellpose/models.py", line 227, in eval
masks, flows, styles = self.cp.eval(x,
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/cellpose/models.py", line 536, in eval
masks, styles, dP, cellprob, p = self._run_cp(x,
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/cellpose/models.py", line 625, in _run_cp
masks, p = dynamics.compute_masks(dP, cellprob, niter=niter,
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/cellpose/dynamics.py", line 718, in compute_masks
mask = get_masks(p, iscell=cp_mask)
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/cellpose/dynamics.py", line 636, in get_masks
h,_ = np.histogramdd(tuple(pflows), bins=edges)
File "<__array_function__ internals>", line 180, in histogramdd
File "/data/homes/fractal/.conda/envs/fractal/lib/python3.8/site-packages/numpy/lib/histograms.py", line 1101, in histogramdd
hist = hist.astype(float, casting='safe')
numpy.core._exceptions.MemoryError: Unable to allocate 2.60 GiB for an array with shape (61, 2202, 2602) and data type float64
mprof: Sampling memory every 0.1s
running new process but that's not relevant as it could happen on any other line of Cellpose code. |
Thanks for the links, but @mfranzon noticed that this should be already fixed in pytorch>1.5 (and we are on And indeed by adding these two lines we find a behavior which is fully compatible with #109 (comment) (notice that reproducing the memory error each time takes too long, so we test on the smaller dataset). |
Hmm, thanks for the thorough checks! Does it also happen when calling the pure cellpose model inference on e.g. some synthetic test data (without our |
WARNING: Let's focus only on the case with 10 5x5 wells (and constant number of Z levels). |
We are trying to at least remove the |
Ah, but wouldn't that potentially explain the out of memory problem? Do we know whether we run out of memory in cases with "too many" Z planes (=> where the user should choose a lower pyramid level)? |
The memory error is for a dataset with constant number of Z planes (10), and it's not affected by this last remark.
We are now trying with the 10-5x5 dataset, working at the per-well level in 3D and at pyramid level 3 (this corresponds to calling Cellpose on |
We tried to reduce complexity as far as possible, while keeping the problematic behavior there. We prepared a mock of the
Code: With delayed functionimport os
import shutil
import sys
import itertools
import numpy as np
import time
import dask
import dask.array as da
from cellpose import core
from cellpose import models
from concurrent.futures import ThreadPoolExecutor
def fun(FOV_column, model):
t1_start = time.perf_counter()
print("START: shape =", FOV_column.shape)
sys.stdout.flush()
mask, flows, styles, diams = model.eval(
FOV_column,
channels=[0, 0],
do_3D=True,
net_avg=False,
augment=False,
diameter=(80.0 / 2**labeling_level),
anisotropy=6.0,
cellprob_threshold=0.0,
)
t1 = time.perf_counter()
print(f"END: I found {np.max(mask)} labels, in {t1-t1_start:.3f} seconds")
sys.stdout.flush()
return mask
def image_labeling(well, labeling_level=None, labeling_channel=None, num_threads=None):
print(well)
use_gpu = core.use_gpu()
print("use_gpu:", use_gpu)
model = models.Cellpose(gpu=use_gpu, model_type="nuclei")
# Load full-well data
column = da.from_zarr(zarrurl + well + f"{labeling_level}/")[labeling_channel]
output = da.empty(column.shape, chunks=column.chunks, dtype=column.dtype)
delayed_fun = dask.delayed(fun)
# Select a single FOV
for ind_FOV in itertools.product(range(2), repeat=2):
# Define FOV indices
ix, iy = ind_FOV
size_x = 2560 // 2 ** labeling_level
size_y = 2160 // 2 ** labeling_level
start_x = size_x * ix
end_x = size_x * (ix + 1)
start_y = size_y * iy
end_y = size_y * (iy + 1)
# Select input and assign output
FOV_column = column[:, start_y:end_y, start_x:end_x]
FOV_mask = delayed_fun(FOV_column, model)
output[:, start_y:end_y, start_x:end_x] = da.from_delayed(FOV_mask,
shape=FOV_column.shape,
dtype=FOV_column.dtype)
# Remove output file, if needed
outzarr = f"/tmp/{well}.zarr"
if os.path.isdir(outzarr):
shutil.rmtree(outzarr)
# Write output (--> trigger execution of delayed functions)
with dask.config.set(pool=ThreadPoolExecutor(num_threads)):
output.to_zarr(outzarr)
print()
sys.stdout.flush()
root = "/data/active/fractal/tests/"
zarrurl = root + "Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/"
wells = ["B/09/0/", "B/11/0/", "C/08/0/", "C/10/0/", "D/09/0/", "D/11/0/", "E/08/0/", "E/10/0/", "F/09/0/", "F/11/0/"]
f = open("times.dat", "w")
t0 = time.perf_counter()
num_threads = 2
labeling_level = 0
labeling_channel = 0
for well in wells:
image_labeling(well,
labeling_level=labeling_level,
labeling_channel=labeling_channel,
num_threads=num_threads)
t1 = time.perf_counter()
f.write(f"{t1-t0}\n")
f.flush()
f.close() With sequential functionsimport os
import shutil
import sys
import itertools
import numpy as np
import time
import dask.array as da
from cellpose import core
from cellpose import models
def fun(FOV_column, model):
t1_start = time.perf_counter()
print("START: shape =", FOV_column.shape)
sys.stdout.flush()
mask, flows, styles, diams = model.eval(
FOV_column,
channels=[0, 0],
do_3D=True,
net_avg=False,
augment=False,
diameter=(80.0 / 2**labeling_level),
anisotropy=6.0,
cellprob_threshold=0.0,
)
t1 = time.perf_counter()
print(f"END: I found {np.max(mask)} labels, in {t1-t1_start:.3f} seconds")
sys.stdout.flush()
return mask
def image_labeling(well, labeling_level=None, labeling_channel=None):
print(well)
use_gpu = core.use_gpu()
print("use_gpu:", use_gpu)
model = models.Cellpose(gpu=use_gpu, model_type="nuclei")
# Load full-well data
column = da.from_zarr(zarrurl + well + f"{labeling_level}/")[labeling_channel]
output = da.empty(column.shape, chunks=column.chunks, dtype=column.dtype)
# Select a single FOV
for ind_FOV in itertools.product(range(2), repeat=2):
# Define FOV indices
ix, iy = ind_FOV
size_x = 2560 // 2 ** labeling_level
size_y = 2160 // 2 ** labeling_level
start_x = size_x * ix
end_x = size_x * (ix + 1)
start_y = size_y * iy
end_y = size_y * (iy + 1)
# Select input and assign output
FOV_column = column[:, start_y:end_y, start_x:end_x]
FOV_mask = fun(FOV_column, model)
output[:, start_y:end_y, start_x:end_x] = FOV_mask
# Remove output file, if needed
outzarr = f"/tmp/{well}_clean.zarr"
if os.path.isdir(outzarr):
shutil.rmtree(outzarr)
# Write output (--> trigger execution of delayed functions)
output.to_zarr(outzarr)
print()
sys.stdout.flush()
root = "/data/active/fractal/tests/"
zarrurl = root + "Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/"
wells = ["B/09/0/", "B/11/0/", "C/08/0/", "C/10/0/", "D/09/0/", "D/11/0/", "E/08/0/", "E/10/0/", "F/09/0/", "F/11/0/"]
f = open("times.dat", "w")
t0 = time.perf_counter()
labeling_level = 0
labeling_channel = 0
for well in wells:
image_labeling(well,
labeling_level=labeling_level,
labeling_channel=labeling_channel,
)
t1 = time.perf_counter()
f.write(f"{t1-t0}\n")
f.flush()
f.close() We consider the usual 10-5x5 dataset, but we only segment a 2x2 subset of each 5x5 well. The memory trace of these two runs is below. Black lines in the figure are rolling averages, as a guide to the eye.
These examples seem robust, but we noticed that working with artificial data (AKA A final detail, the size of datasets for different wells (after selecting level 0 and channel 0) is very homogeneous: $ du -sh -L /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/*/*/*/0/0 | sort -k2
2.6G /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/B/09/0/0/0
2.6G /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/B/11/0/0/0
2.6G /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/C/08/0/0/0
2.6G /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/C/10/0/0/0
2.5G /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/D/09/0/0/0
2.5G /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/D/11/0/0/0
2.6G /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/E/08/0/0/0
2.5G /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/E/10/0/0/0
2.5G /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/F/09/0/0/0
2.5G /data/active/fractal/tests/Temporary_data_UZH_10_well_5x5_sites/20200812-CardiomyocyteDifferentiation14-Cycle1.zarr/F/11/0/0/0
|
I just updated the figure in the previous comment. The memory accumulation along the 10-wells example is clear even in the case where cellpose is called sequentially (orange line), without any use of |
Comment by me and @mfranzon: |
Wow, yeah! I am still surprised with the issue, but now that it doesn't seem to be a dask issue: If we can generate a test case of this happening, I'd be very much in favor of opening a cellpose issue with this! And then that's probably a good point to stop our digging into it. |
For a simple test case, could you just load the same image repeatedly (e.g. even just as aPNG image using imageio, doesn't need to be Zarr if that makes the test complicated) and loop cellpose over it in a basic for loop, e.g. not even saving the results. |
The discussion should probably continue in MouseLand/cellpose#539. |
Great that we have this escalated now. I'd say work on the segmentation milestone is done then from our side. Let's follow the cellpose discussion to see if there is a good workaround and otherwise think broader about handling libraries with potential memory issues :) |
The labeling task is part of fractal, and remaining issues are unrelated to this one. |
Integrate
image_labeling
andimage_labeling_whole_well
in fractal (for the current working version, not for the incoming server-based one).The text was updated successfully, but these errors were encountered: