Skip to content

Commit

Permalink
refactor: move dataloader creation out of analysis function (ensure c…
Browse files Browse the repository at this point in the history
…onsistency due to rand perm in prepare_dataset)
  • Loading branch information
fabioseel committed Dec 16, 2024
1 parent 6f5861e commit 8fb3892
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 20 deletions.
17 changes: 4 additions & 13 deletions retinal_rl/analysis/channel_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,13 @@ class HistogramAnalysis:

def spectral_analysis(
device: torch.device,
imageset: Imageset,
dataloader: DataLoader[tuple[Tensor, Tensor, int]],
brain: Brain,
max_sample_size: int = 0,
) -> dict[str, SpectralAnalysis]:
brain.eval()
brain.to(device)
_, cnn_layers = get_cnn_circuit(brain)

# Prepare dataset
dataloader = _prepare_dataset(imageset, max_sample_size)

# Initialize results
results: dict[str, SpectralAnalysis] = {}

Expand All @@ -73,17 +69,13 @@ def spectral_analysis(

def histogram_analysis(
device: torch.device,
imageset: Imageset,
dataloader: DataLoader[tuple[Tensor, Tensor, int]],
brain: Brain,
max_sample_size: int = 0,
) -> dict[str, HistogramAnalysis]:
brain.eval()
brain.to(device)
_, cnn_layers = get_cnn_circuit(brain)

# Prepare dataset
dataloader = _prepare_dataset(imageset, max_sample_size) # TODO: Move outside?

# Initialize results
results: dict[str, HistogramAnalysis] = {}

Expand All @@ -101,7 +93,7 @@ def histogram_analysis(
return results


def _prepare_dataset(
def prepare_dataset(
imageset: Imageset, max_sample_size: int = 0
) -> DataLoader[tuple[Tensor, Tensor, int]]:
"""Prepare dataset and dataloader for analysis."""
Expand Down Expand Up @@ -394,9 +386,8 @@ def _plot_receptive_fields(ax: Axes, rf: FloatArray):


def analyze_input(
device: torch.device, imageset: Imageset, max_sample_size: int
device: torch.device, dataloader: DataLoader[tuple[Tensor, Tensor, int]]
) -> tuple[SpectralAnalysis, HistogramAnalysis]:
dataloader = _prepare_dataset(imageset, max_sample_size)
spectral_result = _layer_spectral_analysis(device, dataloader, nn.Identity())
histogram_result = _layer_pixel_histograms(device, dataloader, nn.Identity())
return spectral_result, histogram_result
Expand Down
13 changes: 6 additions & 7 deletions runner/frameworks/classification/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,10 @@ def analyze(
log.save_dict(cfg.analyses_dir / f"receptive_fields_epoch_{epoch}.json", rf_result)

if cfg.channel_analysis:
spectral_result = channel_ana.spectral_analysis(
device, test_set, brain, cfg.plot_sample_size
)
histogram_result = channel_ana.histogram_analysis(
device, test_set, brain, cfg.plot_sample_size
)
# Prepare dataset
dataloader = channel_ana.prepare_dataset(test_set, cfg.plot_sample_size)
spectral_result = channel_ana.spectral_analysis(device, dataloader, brain)
histogram_result = channel_ana.histogram_analysis(device, dataloader, brain)
channel_ana.plot(
log,
rf_result,
Expand Down Expand Up @@ -142,8 +140,9 @@ def _extended_initialization_plots(
if channel_analysis:
# Input 'rfs' is just the colors
rf_result = np.eye(input_shape[0])[:, :, np.newaxis, np.newaxis]
dataloader = channel_ana.prepare_dataset(train_set, max_sample_size)
spectral_result, histogram_result = channel_ana.analyze_input(
device, train_set, max_sample_size
device, dataloader
)
channel_ana.input_plot(
log,
Expand Down

0 comments on commit 8fb3892

Please sign in to comment.