Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor train #28

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open

Refactor train #28

wants to merge 49 commits into from

Conversation

TibbersHao
Copy link
Member

This PR works on the issue #26 to revamp the current code structure to be more unit test friendly and modular. The work is 80% finished and pytests are ready to be executed.

Major Feature Upgrade

  1. Updated training script which utilizes small functions for each step and common functions have been saved as utility functions.
  2. Updated docstrings for each function.
  3. Each function has been accompanied by one or several pytests
  4. Tests for model building and training

Minor Feature Upgrade

  1. Updated QLTY version

Known Issue

  1. The model loading function which leverages DLSIA's pre-built loading function apparently has mismatched model weights in state_dict( ) comparing to the trained model. Cause unknown and investigation still under the way.
  2. Package compatibility. At this moment the pydantic version has been temporary degraded to 1.10.15 in order to mitigate the imcompatibility of an older version of tiled[all] in dev requirements. Updating tiled version will require a rewrite of the array space allocation function using DataSource from tiled, and this will be in the scope of a separate PR.

To Do's

  1. Adding quick inference step right after training
  2. Breaking done the long crop_seg_save( ) function used in inference script to manageable small functions
  3. Adding pytests for tiled array allocation function
  4. Finishing inference script
  5. Format cleaning
  6. Deleting no longer used functions and merge seg_utils.py into utils.py for clarity

@TibbersHao TibbersHao added bug Something isn't working documentation Improvements or additions to documentation enhancement New feature or request dependencies Pull requests that update a dependency file labels Jul 24, 2024
@dylanmcreynolds
Copy link
Member

I ran pulled this down and ran pytest. I see three pytest errors. The build is also failing with a flake8 error. You should run both flake8 . and pytest before committing. Since this project uses pre-commit, you can also do the following:

pre-commit install

Then, if you want to test things out without committing, run:

pre-commit run --all-files

This is the equivalent of doing what pre-commit does when you commit, which includes running flake8 and black.

One other tip...after you run, black will reformat your code. You can then run the following to add those changes to your next commit:

git add .

@TibbersHao
Copy link
Member Author

I ran pulled this down and ran pytest. I see three pytest errors. The build is also failing with a flake8 error. You should run both flake8 . and pytest before committing. Since this project uses pre-commit, you can also do the following:

pre-commit install

Then, if you want to test things out without committing, run:

pre-commit run --all-files

This is the equivalent of doing what pre-commit does when you commit, which includes running flake8 and black.

One other tip...after you run, black will reformat your code. You can then run the following to add those changes to your next commit:

git add .

Sounds good. Initially I was planning to run the pre-commit check and fix all format related changes when I wrap up the inference part, but I will include these changes with my next commit.

On the pytest side, could you send me the errors you got from pytest? I was expecting only one known error but you got three instead.

@TibbersHao TibbersHao assigned TibbersHao and unassigned Wiebke Jul 29, 2024
Copy link
Member

@dylanmcreynolds dylanmcreynolds left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work. I added a few comments.

yield parameters_dict

@pytest.fixture
def io_parameters(parameters_dict):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we only want one fixture called parameters_dict, and remove io_parameters, network_name, model_parameters.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my first iteration I only have one fixture for all 3 sets of the parameters, then as the development evolving I started to see different scenarios where only one or two sets are used at once, and I made this work-around to reduce the redundancy in other parts of the code. But I am happy to change back to a single fixture.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you can get it easily with just the parameters_dict fixture, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is one more step needed. The parameters_dict is just the dictionary read from the yaml file, and it has not been validated by pydantic.
io_parameters, network_name, model_parameters = validate_parameters(parameters_dict)



@pytest.fixture
def data_tensor(normed_data):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is only being used in one test, it not not need to be a fixture. Just call the one line in your test

 data_tensor = torch.from_numpy(normed_data)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, modified in the next commit


@pytest.fixture
def mask_tensor(mask_array):
mask_tensor = torch.from_numpy(mask_array)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does not need to be a fixture if only used once.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified in the next commit

@TibbersHao
Copy link
Member Author

@dylanmcreynolds @Wiebke the latest commit covered 2 upgrades for the pytest:

  • Expanded positive test cases to cover all 4 available model options
  • Added negative test cases for bad model name
  • Use of session scope temporary directory files to handle model writing, with automatic cleaning afterwards

Please give it another check to see if you would like to make any comments, I believe this PR should be reviewed and merged before more features to be added, as it's already too big atm.

Comment on lines 5 to 7
data_tiled_uri: http://0.0.0.0:8888/api/v1/metadata/reconstruction/rec20190524_085542_clay_testZMQ_8bit/20190524_085542_clay_testZMQ_
data_tiled_api_key:
mask_tiled_uri: https://tiled-seg.als.lbl.gov/api/v1/metadata/mlex_store/mlex_segm_user/rec20190524_085542_clay_testZMQ_8bit/a558f7b1773855b9453bf6c91079699d
mask_tiled_uri: http://0.0.0.0:8888/api/v1/metadata/mlex_store/user1/rec20190524_085542_clay_testZMQ_8bit/pytest
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note, with the updated Tiled server on SPIN, the clay dataset is now under: https://tiled-seg.als.lbl.gov/api/v1/metadata/reconstruction/als_20190524_085542_clay/20190524_085542_clay_testZMQ_ and an example mask is at https://tiled-seg.als.lbl.gov/api/v1/metadata/mlex_store/highres_spin/als_20190524_085542_clay/c35880e427239fb216ce73f3974ffe5e

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, captured and updated.

@TibbersHao
Copy link
Member Author

As requested by @Wiebke , a change log to capture the difference for testing pipeline connections:

The Revamped version:

Training

  1. provide I/O and model parameters in yaml, following examples from the "example_yamls" folder. Note the structure and entries of the yaml file remains the same as the current version in main branch.
  2. Activate conda environment.
  3. run the training script: python src/train.py <yaml file path>, command remains the same as the current version.
  4. The training script will create a model saving directory at models_dir from the yaml file, then train the model and save it to the directory along with metrics. This part remains the same as the current version.
  5. Once the training from last step is finished, train.py will kick off a quick inference using the trained model on annotated images, and save the result in the seg_tiled_uri indicated in the yaml file.

Inference

  1. provide I/O and model parameters in yaml, following examples from the "example_yamls" folder. Note the structure and entries of the yaml file remains the same as the current version in main branch. mask_tiled_uri and mask_tiled_api_key are optional for this case.
  2. Activate conda environment.
  3. run the full inference script: python src/segment.py <yaml file path>, command remains the same as the current version.
  4. The inference script only covers full inference now.
  5. Result will be saved in the seg_tiled_uri indicated in the yaml file. This remains the same as the current version.

Wiebke added 3 commits August 30, 2024 16:20
This is obsolete code that is never called, as `using_qlty` is set to `False` by default and and is also set to `False` in the only place where it is specified.
Obsolete code that is never applied.
Copy link
Member

@Wiebke Wiebke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your hard work on this and providing the high-level description. 

I was actually mostly looking for information on what needs to change in other repositories. I gather quick inference does not rely on uid_retrieve being set to anything, and we now need to reduce the two function calls here, to a single one, i.e. use only segmentation.py#L77-L95

Overall, this refactoring is great. The modularity of functions is greatly improved, making the code more organized and easier to navigate.
I have a few suggestions and observations in regards to the refactoring and testing, some of which I think are important to incorporate under this PR. 

TiledDataSet
TiledDataSet was created as a subclass of torch.utils.data.Dataset with the intention of making it easier to construct data loaders. Since then the code has evolved. Neither the transform, nor the qlty-patching functionionality are still in use, and instead the class is used for retrieval of frames only. In some places where the class functionality could be used, Tiled clients are accessed directly. I think now is the time to remove any obsolete code and make sure the class reflects how we actually use it. While the new parameters make it very clear in which "mode" the class is currently used, they are in principle already captured by the presence of a mask. I suggest either simplify this class (remove all qlty,  or creating dedicated classes for the different use cases.

validate_parameters
This function passes input parameters to the respective correct parameter class and otherwise mostly relies on Pydantic's type validation. I think we are missing out on the opportunity to have all parameters validated after this function call. In the current version, we still cannot trust that all input parameters are actually valid as some of the validation happens only when networks are built (for activation, convolution, ...) or while training (for the criterion).  I suggest restructuring the parameter/network classes such that: After they are instantiated, all parameters are actually validated (making use of: Pydantics validators). I would further connect the build_network methods to the respective classes. We could make use of the Factory pattern to eliminate some lengthy if/then/else-statements based on the network name. When upgrading Tiled, we would need to follow the Pydantic migration guide to replace obsolete decorators

utils.py
While the modularity of functions has greatly improved, many functions have been moved to utils, even though they are closely connected to specific steps or concepts. I suggest organizing them differently, perhaps grouping them according to suitable topics, maybe model_io, data_io, network, ...
 
Testing:
The previously mentioned disconnect between fixtures and test cases remains. This leads to test cases to be less legible, and points to the test cases overall having some interdependency. I am concerned that the pattern of placing the actual functions to be tested in fixtures, may in the future just cause fixture creation to fail, rather than causing test cases to fail. I think the high-level concern here is that the test setup essentially does end-to-end testing, but in steps. The numbering of the test files further implies some dependency that I would recommend to steer away from. For example, any training or inference steps that require a TiledDataSet could be made independent of Tiled by mocking the class.
You have clearly already put a lot of thought into the test setup and it is a challenging aspect of this project, but please revisit this, taking into account the importance of independent isolated testing.
Additionally, the inclusion of the one set of bad parameters causes some complicated logic which is difficult to trace back due to the issue above. In functions that are intended for testing parameter validation testing, there is a check if a parameter is of type AssertionError, but within the test case it is not immediately clear why the passed parameter (a fixture) would be of that type. Consider creating a dedicated test case for failure, rather than skipping this case in all other tests. 
Also note that within the Github action, testing causes 206 warnings (some of which have been addressed in Tiled, #bluesky/tiled/676), some related to the test setup, and there are some important aspects that are not being tested (e.g. ensure_parent_containers). 
Furthermore, the testing and example YAMLs do not fully reflect how training and inference are run from our frontend application. This means issues like validate_parameters only working if a mask_uri is provided—which is not required in inference—aren’t being caught.
Finally (and this may be one of the more controversial points of this review), I do not think all tests included right now are really necessary, e.g. functions that are essentially just getters (passing parameters to a parameter class and checking if they are still the same) or are close to trivial (find_device, or create-directory).

Performance:
This is more for a future PR, perhaps in collaboration with @Giselleu

  • Inference normalizes data on cpu before moving the data to device. Should this be the transform that was removed from TiledDataSet?
  • segment_single_frame calls .eval().to(device) on the given network. I assume that later calls thus no longer cause a move of network weights to GPU? segment_single_frame moves and concatenates on CPU. This will cause stitching to happen on the CPU as well.
  • Training loads all data and masks into a numpy array each. Is there a limit for the number of slices this is feasible with? 

Final things I noticed while taking a look at the overall current state, not necessarily this PR:

  • MSDNets have a final_layer parameter which we pass final_layer=nn.Softmax(dim=1) to, but all networks apply a final_layer within segment_single_frame. Does that mean we pass MSDNet results through a final_layer twice? Could this be another reason (aside from setting the number of layers being set too low) for our consistent sub-par MSDNet performance? (Typically everything is segmented as a single class). @phzwart
  • The functions for partial inference and full inference appear to have a large amount of overlap and only differ in the retrieval of network and device as well as the qlty setup which has already happened for partial inference. 
  • build_network indicates that we support different data dimensionalities, but I do not see that reflected in testing.
  • The ensemble build methods make use of a lot of default parameters that we do not expose to the user. Is this intended? If the function you crafted here to setup networks does not exist in dlsia itself, should this functionality be moved there?
  • CrossEntropyLoss is the only loss tested for, but we have exposed alternative losses in the interface. However, specifying any loss without the parameters weights, ignore_index or size_average will cause errors. Should we support only CrossEntropyLoss?
  • As of mlexchange/mlex_highres_segmentation@f64ad39 mask indices are written as a list of integers rather than a list of strings.

I suggest we meet offline to discuss further and come up with a plan to tackle the observations above and/or record issues for future work.

@TibbersHao
Copy link
Member Author

TibbersHao commented Sep 5, 2024

Thanks for the feedback. Here is a summary for the plans to move this PR forward:

TiledDataset

  • Trim unnecessary and no longer used codes (eg. qlty cropping) to make it more clear.
  • In order to address different needs for tiled clients during training / quick inference / full inference, @Wiebke will make a PR to this branch to separate out TiledDataset class to 2 classes for specific use cases.
  • To accommodate that, changes in pipeline scripts are needed.

validate_parameters

  • For parameters that need to be converted before use (eg: criterion, loss, weights), use pydantic pre-validator to handle the conversion within the class.
  • For parameters that are not independent to each other (eg: weights should have length of num_of_classes), explore pydantic root validators to check for the internal relationship.
  • For validate_parameters function in utils.py Move assertion check of mask_uri and bad model names to the IOParameters pydantic class in parameters.py.
  • Break IOParameters to 2 classes, one for training and one for full inference.
  • In validate_parameters function, add an argument to check whether the session is training or full inference, then call each io parameter class respectively.

utils.py

  • revisit the grouping of functions to make it more logical in a mental way. For example init_tiled_dataset should be moved to tiled_dataset.py and load_model should be moved to segment.py
  • Modify normalization function to make it perform frame by frame normalization when a stack of array is provided, to keep it consistent for training and inference. An issue Inconsistency of normalization between training and inference  #29 is created to address this.
  • For both inferences, group common parts into a new function, put this function in segment.py and call this function to reduce code redundancy.

build_network

  • for MSDNet, check if final layer needs to be defined during the model building, as this may cause execution of double softmax layers when combined with the later scripts for prediction. Set this to None and check if pipeline still running properly.
  • Currently the pipeline only supports cross entropy loss, an issue Expand criterion pool to support various types of losses #30 is created to address the need of additional loss options and make it compatible inside the build_criterion function currently sitting in train.py.
  • Regarding the data shape and dimensions, write docstrings to document different cases and corresponding expectations.
  • Refactor current function oriented script to object oriented way, and creating / loading models will be achieved by different attributes. (low priority)
  • For ensemble methods, check if there are codes that can be directly called from dlsia. An issue Trimming of Ensemble Network Code #31 is created to address this.

Testing

  • Modify test functions to get rid of usage of calling functions and get outputs within the fixture, instead define test parameters for each tests(with parameterization), call function within the test function and do assertion check. No need to use same sets of parameters for the entire testing pipeline.
  • For testing of load_model within test_inference.py, to lift the dependency between training tests and inference tests, provide a set of training parameters and run the train_model function as the first step, then test loading.
  • Add tests to address different dimension of data_shape for build_network.

Performance

  • Use a fairly large dataset, write a series of tests to check if there will be a limit for the mask and data loading into a single array. In this case artificial masks created by randomization is sufficient.

@phzwart
Copy link
Collaborator

phzwart commented Sep 6, 2024 via email

Wiebke and others added 9 commits September 5, 2024 18:32
New structure easily enables inference on a subset of the data.
Note: `partial_inference` function currently has no test.
Circumvents `AttributeError: 'PatchedStreamingResponse' object has no attribute 'background'`, updating Tiled to `v0.1.0b8` will resolve this too.
…-dataset

Refactor `TiledDataset` into `TiledDataset` and `TiledMaskedDataset`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working dependencies Pull requests that update a dependency file documentation Improvements or additions to documentation enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants