Skip to content

Commit

Permalink
Add demos
Browse files Browse the repository at this point in the history
  • Loading branch information
LucidCodeAI committed Jun 15, 2024
1 parent 0d5ed12 commit 6388344
Show file tree
Hide file tree
Showing 10 changed files with 264 additions and 174 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ venv.bak/
/data
/assets
/test/data
*.zip
/raw

#WandB
/notebooks/wandb
Expand Down
24 changes: 19 additions & 5 deletions Demo-notebooks/demo.ipynb

Large diffs are not rendered by default.

Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
License Information. The HCC-TACE-Seg collection is distributed under the CC BY 4.0 at https://creativecommons.org/licenses/by/4.0/ By downloading the data, you agree to abide by terms of this license.
Data Usage Policy

Any user accessing TCIA data must agree to:
- Not use the requested datasets, either alone or in concert with any other information, to identify or contact individual participants from whom data and/or samples were collected and follow all other conditions specified in the TCIA Site Disclaimer. Approved Users also agree not to generate and use information (e.g., facial images or comparable representations) in a manner that could allow the identities of research participants to be readily ascertained. These provisions do not apply to research investigators operating with specific IRB approval, pursuant to 45 CFR 46, to contact individuals within datasets or to obtain and use identifying information under an IRB-approved research protocol. All investigators including any Approved User conducting “human subjects research” within the scope of 45 CFR 46 must comply with the requirements contained therein.

- Acknowledge in all oral or written presentations, disclosures, or publications the specific dataset(s) or applicable accession number(s) and the NIH-designated data repositories through which the investigator accessed any data. Citation guidelines for doing this are outlined below.

- If you are considering mirroring a copy of our publicly available datasets or providing direct access to any of the TCIA data via another tool or website using the REST API (https://wiki.cancerimagingarchive.net/x/NIIiAQ) please review our Data Analysis Centers (DACs) page (https://wiki.cancerimagingarchive.net/x/x49XAQ) for more information. DACs must provide attribution and links back to this TCIA data use policy and must require downstream users to do the same.

The summary page for every TCIA dataset includes a Citations & Data Usage Policy tab. Please consult the Citation & Data Usage Policy for each Collection before using them.
- Most data are freely available to browse, download, and use for commercial, scientific and educational purposes as outlined in the Creative Commons Attribution 3.0 Unported License or the Creative Commons Attribution 4.0 International License. In rare circumstances commercial use may be prohibited using Attribution-NonCommercial 3.0 Unported (CC BY-NC 3.0) or Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0).

- Most data are immediately accessible and do not require account registration. A small subset of collections do require registration and special permission to gain access. Refer to the "Access" column on https://www.cancerimagingarchive.net/collections/ for more details.
142 changes: 50 additions & 92 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,59 +6,44 @@
import streamlit as st
import torch

# Now you can import from the submodule
from diffdrr.data import read
from diffdrr.visualization import plot_drr
from drr import create_drr
from model import TACEnet
from training import loadData, sampleVolume
from demo import demonstration, get_demo_data, load_model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set the page configuration
st.set_page_config(page_title="DRR Enhancement Model Demo", layout="wide")
st.set_page_config(page_title="Tacetastic", layout="wide")


# Caching the model loading to avoid reloading the model on each interaction
@st.cache_resource(show_spinner="Loading the model...")
def load_model():
model = TACEnet()
model.load_state_dict(
torch.load("../models/TACEnet_vessel_enhancement_deformations_30052024.pth")
)
model.eval()
return model
def cached_load_model(deformation, device):
return load_model(deformation, device)


# Caching the DRR generation to avoid redundant computations
@st.cache_data(show_spinner="Generating DRR...")
def cached_create_drr(_train_loader, contrast_value, height=256, width=256, rotation=0):
_volume, _target = sampleVolume(_train_loader, contrast_value=contrast_value)
subject = read(
tensor=_volume[0], label_tensor=_target[0], bone_attenuation_multiplier=5.0
)
return (
create_drr(
subject,
device="cpu",
height=height,
width=width,
mask_to_channels=False,
rotations=torch.tensor([[rotation, 0.0, 0.0]]),
),
_volume,
_target,
@st.cache_data(show_spinner="Collecting the data...")
def cached_get_demo_data():
return get_demo_data()


@st.cache_data(show_spinner="Generating the results...")
def cached_demonstration(
_volumes, _targets, _model, rotation, ef, deformation, initial_contrast, device
):
# Your existing code to generate the figure
fig = demonstration(
_volumes, _targets, _model, rotation, ef, deformation, initial_contrast, device
)
return fig


@st.cache_data
def cached_loadData():
return loadData()
st.session_state.deformation_checkbox = False

# Load the model
model = cached_load_model(
deformation=st.session_state.deformation_checkbox, device=device
)

@st.cache_data
def cached_sampleVolume(_train_loader, contrast_value):
return sampleVolume(_train_loader, contrast_value=contrast_value)
# Load the demo data
volumes, targets = cached_get_demo_data()


# Show dialog
Expand All @@ -83,8 +68,6 @@ def terms_conditions():
if "terms_conditions" not in st.session_state:
terms_conditions()

st.logo("https://avatars.githubusercontent.com/u/8323854?s=200&v=4")

# Display the DRR and enhanced DRR in the main layout
st.title("DRR Enhancement Model Demo")

Expand All @@ -108,63 +91,38 @@ def terms_conditions():
)

st.slider(
label="Select initial contrast",
min_value=0,
max_value=4000,
value=0,
step=500,
label="Select contrast agent reduction ratio",
min_value=0.0,
max_value=1.0,
value=1.0,
step=0.1,
key="contrast_slider",
help="Select the intital contrast",
help="Select the contrast reduction ratio for the vessels",
on_change=None,
label_visibility="visible",
)

"---"

# Load your model
model = load_model()
model.to(device)

# Load CT data
train_loader, val_loader = cached_loadData()
# volume, target = cached_sampleVolume(
# train_loader, contrast_value=st.session_state.contrast_slider
# )

# Initialize the DRR module for generating synthetic X-rays
drr, volume, target = cached_create_drr(
train_loader,
contrast_value=st.session_state.contrast_slider,
rotation=st.session_state.rotation_slider,
)
st.checkbox(
label="Enable deformation",
value=st.session_state.deformation_checkbox,
key="deformation_checkbox",
help="Enable deformation of the subject",
)

"---"

col1, col2, col3 = st.columns(spec=[0.4, 0.2, 0.4])
if st.sidebar.button("Generate"):
# Assuming demonstration now returns a figure directly
fig = cached_demonstration(
_volumes=volumes,
_targets=targets,
_model=model,
rotation=st.session_state.rotation_slider,
ef=st.session_state.contrast_slider,
deformation=st.session_state.deformation_checkbox,
initial_contrast=4000,
device=device,
)

with col1:
st.header("Original DRR")
axs = plot_drr(drr, ticks=False)
fig = axs[0].figure
fig.set_size_inches(2, 2) # Adjust the figure size
# Display the figure in Streamlit's main view
st.pyplot(fig)

with col3:
st.header("Enhanced DRR")
enhanced_placeholder = st.empty() # Placeholder for enhanced image

with col2:
# Add an "Enhance" button
if st.button("Enhance"):
with st.spinner("Enhancing DRR..."):
prediction, latent = model(
volume[0].unsqueeze(0).to(device), drr.to(device)
)
axs = plot_drr(prediction, ticks=False)
fig = axs[0].figure
fig.set_size_inches(2, 2) # Adjust the figure size
enhanced_placeholder.pyplot(fig)

col4, col5, col6 = st.columns(3)
col4.metric("Temperature", "70 °F", "1.2 °F")
col5.metric("Wind", "9 mph", "-8%")
col6.metric("Humidity", "86%", "4%")
31 changes: 23 additions & 8 deletions src/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def load_model(deformation, device):
)
model = TACEnet().to(device)
model.load_state_dict(torch.load(model_path))
model.eval()
return model


Expand Down Expand Up @@ -52,26 +53,37 @@ def generate_drr(subject, rotation, ef, device):
return drr_combined, drr_vessels.to(device)


def demonstration(rotation, ef, deformation=False, initial_contrast=4000, device="cpu"):
transform = get_transforms(
resize_shape=[512, 512, 96], contrast_value=initial_contrast
)
def get_demo_data():
transform = get_transforms(resize_shape=[512, 512, 96])
train_ds, _ = get_datasets(
root_dir="../data081",
collection="HCC-TACE-Seg",
seg_type="SEG",
transform=transform,
download=False,
download_len=1,
val_frac=0.2,
val_frac=0.0,
seed=42,
)
train_loader, _ = get_dataloaders(train_ds, _, batch_size=1)
batch = next(iter(train_loader))
volumes, targets = batch["image"], batch["seg"]
return volumes, targets


def demonstration(
volumes,
targets,
model,
rotation,
ef,
deformation=False,
initial_contrast=4000,
device="cpu",
):

volumes = add_vessel_contrast(volumes, targets, contrast_value=initial_contrast)

model = load_model(deformation, device)
subject = read(
tensor=volumes[0],
label_tensor=targets[0],
Expand All @@ -83,5 +95,8 @@ def demonstration(rotation, ef, deformation=False, initial_contrast=4000, device
subject = apply_deformation(subject)
drr_combined, drr_target = generate_drr(subject, rotation, ef, device)

prediction, latent = model(targets.to(device), drr_combined)
plot_results(drr_combined, drr_target, prediction, latent, vmax=25)
with torch.no_grad(): # Disable gradient computation
prediction, latent = model(targets.to(device), drr_combined)
return plot_results(
drr_combined.cpu(), drr_target.cpu(), prediction.cpu(), latent.cpu(), vmax=25
)
91 changes: 91 additions & 0 deletions src/pages/Surgeon Demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import random
import time

import streamlit as st

st.set_page_config(page_title="Surgeon Demo", page_icon="👋")

# Mock Patient Details
patient_info = {
"Name": "John Doe",
"Age": 34,
"Gender": "Male",
"MRN": "123456789",
"Admission Date": "2023-04-01",
"Diagnosis": "Appendicitis",
"Surgeon Assigned": "Dr. Jane Smith",
"Procedure": "TACE",
}

# Vital Signs (Mock, could be dynamically updated)
vital_signs = {
"SpO2": "98%",
"Blood Pressure": "120/80 mmHg",
"Respiratory Rate": "16 breaths/min",
}

# Display Patient Information in Sidebar
st.sidebar.title("Patient Information")
for key, value in patient_info.items():
st.sidebar.text(f"{key}: {value}")

st.sidebar.title("Vital Signs")
for key, value in vital_signs.items():
st.sidebar.text(f"{key}: {value}")

# Initialize session state variables if they don't exist
if "rotation" not in st.session_state:
st.session_state.rotation = -45 # Start at 0 degrees
if "rotation_direction" not in st.session_state:
st.session_state.rotation_direction = 1 # Start moving towards -45
if "confidence" not in st.session_state:
st.session_state.confidence = 98 # Start with high confidence
if "spo2" not in st.session_state:
st.session_state.spo2 = random.randint(
80, 88
) # Random SpO2 value between 94% and 98%

st.image(r"D:\Programming\AI\5ARIP10-ITP-T3G3\src\pages\xray_demo.gif", width=400)


# Placeholder for dynamic metrics
col1, col2, col3 = st.columns(3)
rotation_placeholder = col1.empty()
confidence_placeholder = col2.empty()
spo2_placeholder = col3.empty()

while True:
# Update rotation
if st.session_state.rotation_direction == 1:
if st.session_state.rotation + 10 <= 45:
st.session_state.rotation += 10
else:
st.session_state.rotation = 45 # Correct to max 45 if overshooting
st.session_state.rotation_direction = -1
else:
if st.session_state.rotation - 10 >= -45:
st.session_state.rotation -= 10
else:
st.session_state.rotation = -45 # Correct to min -45 if overshooting
st.session_state.rotation_direction = 1

# Update confidence based on rotation
angle_from_zero = abs(st.session_state.rotation)
st.session_state.confidence = round(
98 - (15 * (angle_from_zero / 45)), 2
) # Decreases from 98% to 65%

# Randomly vary SpO2 value
st.session_state.spo2 = random.randint(80, 88)

# Display updated metrics
rotation_placeholder.metric(
"Rotation", f"{st.session_state.rotation} °C", delta=None
)
confidence_placeholder.metric(
"Confidence", f"{st.session_state.confidence}%", delta=None
)
spo2_placeholder.metric("Heart Rate", f"{st.session_state.spo2} BPM", delta=None)

# Sleep for the duration of the GIF's rotation cycle or any desired update interval
time.sleep(0.157)
Binary file added src/pages/xray_demo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 6388344

Please sign in to comment.