This is the official repo for the paper <Video-TransUNet: Temporally Blended Vision Transformer for CT VFSS Instance Segmentation>.
Chengxi Zeng, Xinyu Yang, Majid Mirmehdi, Alberto M Gambaruto and Tilo Burghardt
SPIE Internation Conference on Machine Vision
Please also see our latest update using Swintransformer .
<Video-SwinUNet: Spatio-temporal Deep Learning Framework for VFSS Instance Segmentation>.
IEEE International Conference on Image Processing
Github
We propose Video-TransUNet, a deep architecture for instance segmentation in medical CT videos constructed by integrating temporal feature blending into the TransUNet deep learning framework. In particular, our approach amalgamates strong frame representation via a ResNet CNN backbone, multi-frame feature blending via a Temporal Context Module (TCM), non-local attention via a Vision Transformer, and reconstructive capabilities for multiple targets via a UNet-based convolutional-deconvolutional architecture with multiple heads. We show that this new network design can significantly outperform other state-of-the-art systems when tested on the segmentation of bolus and pharynx/larynx in Videofluoroscopic Swallowing Study (VFSS) CT sequences. On our VFSS2022 dataset it achieves a dice coefficient of
(a) Multi-frame ResNet-50-based feature extractor; (b) Temporal Context Module for temporal feature blending across frames; (c) Vision Transformer Block for non-local attention-based learning of multi-frame encoded input; (d) Cascaded expansive decoder with skip connections as used in original UNet architectures, however, here with multiple prediction heads co-learning the two instances of clinical interest.
Based on four sample frames (top) we show for TransUNet and our model boundary segmentations (lower rows) and GradCam output (upper rows) highlighting where models are paying attention to. Results for the bolus and pharynx are next to each other left and right, respectively, for every sample image. Note the much more target-focused results of our model.
torch == 1.10.1
torchvision
torchsummary
numpy == 1.21.5
scipy
skimage
matplotlib
PIL
mmcv == 1.5.0
Medpy
- R50-ViT-B_16
- ViT-B_16
- ViT-L_16 ... Get models in this link
Our data ethics approval only grants usage and showing on paper, not yet support full release.
To fully utlise the Temporal Blending feature of the model, sequential image data should be converted to numpy arrays and concated in the format of [T, H, W]
for BW data and [T, C, H, W]
for colored data.
A small batch size is recommanded as the size of the data and nature of TCM components.
Train:
python train.py --dataset Synapse --vit_name R50-ViT-B_16
Test:
python test.py --dataset Synapse --vit_name R50-ViT-B_16
Vision Transformer
TransUNet
TCM
@misc{https://doi.org/10.48550/arxiv.2208.08315,
doi = {10.48550/ARXIV.2208.08315},
url = {https://arxiv.org/abs/2208.08315},
author = {Zeng, Chengxi and Yang, Xinyu and Mirmehdi, Majid and Gambaruto, Alberto M and Burghardt, Tilo},
keywords = {Image and Video Processing (eess.IV), Computer Vision and Pattern Recognition (cs.CV), FOS: Electrical engineering, electronic engineering, information engineering, FOS: Electrical engineering, electronic engineering, information engineering, FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Video-TransUNet: Temporally Blended Vision Transformer for CT VFSS Instance Segmentation},
publisher = {arXiv},
year = {2022},
copyright = {Creative Commons Attribution Non Commercial Share Alike 4.0 International}
}