Skip to content

Commit

Permalink
added blank features to the batch
Browse files Browse the repository at this point in the history
  • Loading branch information
TjarkMiener committed Dec 2, 2024
1 parent a0e82b2 commit 76bf4d4
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 57 deletions.
48 changes: 40 additions & 8 deletions dl1_data_handler/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class DLDataLoader(Sequence):
Size of the batch to load the data.
random_seed : int, optional
Whether to shuffle the data after each epoch with a provided random seed.
sort_by_intensity : bool, optional
Whether to sort the events based on the hillas intensity for stereo analysis.
stack_telescope_images : bool, optional
Whether to stack the telescope images for stereo analysis.
Methods:
--------
Expand All @@ -41,6 +45,8 @@ def __init__(
tasks,
batch_size=64,
random_seed=None,
sort_by_intensity=False,
stack_telescope_images=False,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -51,19 +57,35 @@ def __init__(
self.batch_size = batch_size
self.random_seed = random_seed
self.on_epoch_end()
self.stack_telescope_images = stack_telescope_images
self.sort_by_intensity = sort_by_intensity

# Get the input shape for the convolutional neural network
self.image_shape = self.DLDataReader.image_mappers[self.DLDataReader.cam_name].image_shape
if self.DLDataReader.__class__.__name__ == "DLImageReader":
self.channel_shape = len(self.DLDataReader.channels)
elif self.DLDataReader.__class__.__name__ == "DLWaveformReader":
self.channel_shape = self.DLDataReader.sequence_length

self.input_shape = (
self.image_shape,
self.image_shape,
self.channel_shape,
)

if self.DLDataReader.mode == "mono":
self.image_shape = self.DLDataReader.image_mappers[self.DLDataReader.cam_name].image_shape
self.input_shape = (
len(self.DLDataReader.tel_ids),
self.image_shape,
self.image_shape,
self.channel_shape,
)
elif self.DLDataReader.mode == "stereo":
if self.stack_telescope_images:
# Reshape inputs into proper dimensions for the stereo analysis with stacked images
self.input_shape = (
self.image_shape,
self.image_shape,
len(self.DLDataReader.tel_ids) * self.channel_shape,
)
else:
self.input_shape = (110, 110, 2)


def __len__(self):
"""
Expand Down Expand Up @@ -108,13 +130,23 @@ def __getitem__(self, index):
# Generate indices of the batch
batch_indices = self.indices[index * self.batch_size : (index + 1) * self.batch_size]
if self.DLDataReader.mode == "mono":
features, batch = self.DLDataReader.mono_batch_generation(
batch = self.DLDataReader.mono_batch_generation(
batch_indices=batch_indices,
)
features = {"input": batch["features"].data}
elif self.DLDataReader.mode == "stereo":
features, batch = self.DLDataReader.stereo_batch_generation(
batch = self.DLDataReader.stereo_batch_generation(
batch_indices=batch_indices,
)
batch_grouped = batch.group_by(["obs_id", "event_id", "tel_type_id"])
# Sort events based on their telescope types by the hillas intensity in a given batch if requested
if self.sort_by_intensity:
for batch_grouped in batch_grouped.groups:
batch_grouped.sort(["hillas_intensity"], reverse=True)
print(batch_grouped)

if self.stack_telescope_images:
print(features)
# Generate the labels for each task
labels = {}
if "type" in self.tasks:
Expand Down
93 changes: 44 additions & 49 deletions dl1_data_handler/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,9 @@ def _multiplicity_cut_subarray(table, key_colnames):
# Constrcut the example identifiers for all files
self.example_identifiers = vstack(example_identifiers)
self.example_identifiers.sort(["obs_id", "event_id", "tel_id", "tel_type_id"])
self.example_identifiers_grouped = self.example_identifiers.group_by(
["obs_id", "event_id"]
)
# Unique example identifiers by events
self.unique_example_identifiers = unique(
self.example_identifiers, keys=["obs_id", "event_id"]
Expand Down Expand Up @@ -676,9 +679,7 @@ def _get_parameters(self, batch, dl1b_parameter_list) -> np.array:
dl1b_parameters.append([np.stack(parameters)])
return np.array(dl1b_parameters)

def mono_batch_generation(
self, batch_indices, dl1b_parameter_list=None
) -> (dict, Table):
def mono_batch_generation(self, batch_indices) -> (dict, Table):
"""
Generate a batch of events for mono mode.
Expand Down Expand Up @@ -713,19 +714,11 @@ def mono_batch_generation(
)
# Retrieve the batch from the example identifiers via indexing
batch = self.example_identifiers.loc[batch_indices]
# Retrieve the features from child classes
features = self._get_features(batch)
# Retrieve the dl1b parameters if requested
if dl1b_parameter_list is not None:
features["parameters"] = self._get_parameters(
batch,
dl1b_parameter_list,
)
return features, batch
# Append the features from child classes to the batch
batch = self._append_features(batch)
return batch

def stereo_batch_generation(
self, batch_indices, dl1b_parameter_list=None
) -> (dict, Table):
def stereo_batch_generation(self, batch_indices) -> (dict, Table):
"""
Generate a batch of events for stereo mode.
Expand Down Expand Up @@ -759,27 +752,29 @@ def stereo_batch_generation(
# Need this PR https://github.com/astropy/astropy/pull/15826
# waiting astropy v7.0.0
# Once available, the batch_generation can be shared with "mono"
example_identifiers_grouped = self.example_identifiers.group_by(
["obs_id", "event_id"]
)
batch = example_identifiers_grouped.groups[batch_indices]
# Sort events based on their telescope types by the hillas intensity in a given batch
batch.sort(
["obs_id", "event_id", "tel_type_id", "hillas_intensity"], reverse=True
)
batch.sort(["obs_id", "event_id", "tel_type_id"])
# Retrieve the features from child classes
features = self._get_features(batch)
# Retrieve the dl1b parameters if requested
if dl1b_parameter_list is not None:
features["parameters"] = self._get_parameters(
batch,
dl1b_parameter_list,
)
return features, batch
batch = self.example_identifiers_grouped.groups[batch_indices]
# Append the features from child classes to the batch
batch = self._append_features(batch)
# Add blank features for missing telescopes in the batch
batch_grouped = batch.group_by(["obs_id", "event_id"])
for batch_grouped in batch_grouped.groups:
for tel_type_id, tel_type in enumerate(self.selected_telescopes):
for tel_id in self.selected_telescopes[tel_type]:
# Check if the telescope is missing in the batch
if tel_id not in batch_grouped["tel_id"]:
blank_features = batch_grouped.copy()[0]
blank_features["table_index"] = -1
blank_features["tel_type_id"] = tel_type_id
blank_features["tel_id"] = tel_id
blank_features["hillas_intensity"] = 0.0
blank_features["features"] = np.zeros_like(blank_features["features"])
batch.add_row(blank_features)
# Sort the batch with the new rows of blank features
batch.sort(["obs_id", "event_id", "tel_type_id", "tel_id"])
return batch

@abstractmethod
def _get_features(self, batch) -> dict:
def _append_features(self, batch) -> dict:
pass


Expand Down Expand Up @@ -941,17 +936,17 @@ def __init__(
"CTAFIELD_4_TRANSFORM_OFFSET"
]

def _get_features(self, batch) -> dict:
def _append_features(self, batch) -> dict:
"""
Retrieve images of a given batch as features.
Append images to a given batch as features.
This method processes a batch of events to retrieve images as input features for the neural networks.
This method processes a batch of events to append images as input features for the neural networks.
It reads the image data from the specified files, applies any necessary transformations, and maps
the images using the appropriate ``ImageMapper``.
Parameters
----------
batch : Table
batch : astropy.table.Table
A table containing information at minimum the following columns:
- "file_index": List of indices corresponding to the files.
- "table_index": List of indices corresponding to the event tables.
Expand All @@ -960,9 +955,8 @@ def _get_features(self, batch) -> dict:
Returns
-------
dict
A dictionary containing the extracted features with the key ``input``,
which maps to a numpy array of the processed images.
batch : astropy.table.Table
The input batch with the appended processed images as features.
"""
images = []
for file_idx, table_idx, tel_type_id, tel_id in zip(
Expand All @@ -989,7 +983,8 @@ def _get_features(self, batch) -> dict:
images.append(self.image_mappers[camera_type].map_image(unmapped_image))
else:
images.append(unmapped_image)
return {"input": np.array(images)}
batch.add_column(images, name="features", index=7)
return batch


def get_unmapped_waveform(
Expand Down Expand Up @@ -1192,11 +1187,11 @@ def __init__(
"CTAFIELD_5_TRANSFORM_OFFSET"
]

def _get_features(self, batch) -> dict:
def _append_features(self, batch) -> dict:
"""
Retrieve waveforms of a given batch as features.
Append waveforms to a given batch as features.
This method processes a batch of events to retrieve waveforms as input features for the neural networks.
This method processes a batch of events to append waveforms as input features for the neural networks.
It reads the waveform data from the specified files, applies any necessary transformations, and maps
the waveforms using the appropriate ``ImageMapper``.
Expand All @@ -1211,9 +1206,8 @@ def _get_features(self, batch) -> dict:
Returns
-------
dict
A dictionary containing the extracted features with the key ``input``,
which maps to a numpy array of the processed waveforms.
batch : astropy.table.Table
The input batch with the appended processed waveforms as features.
"""
waveforms = []
for file_idx, table_idx, tel_type_id, tel_id in zip(
Expand Down Expand Up @@ -1252,4 +1246,5 @@ def _get_features(self, batch) -> dict:
waveforms.append(self.image_mappers[camera_type].map_image(unmapped_waveform))
else:
waveforms.append(unmapped_waveform)
return {"input": np.array(waveforms)}
batch.add_column(waveforms, name="features", index=7)
return batch

0 comments on commit 76bf4d4

Please sign in to comment.