Skip to content

Commit

Permalink
add window operator
Browse files Browse the repository at this point in the history
  • Loading branch information
christinadionysio committed Feb 7, 2025
1 parent 4dd5ab3 commit 0c4b074
Show file tree
Hide file tree
Showing 24 changed files with 508 additions and 182 deletions.
15 changes: 5 additions & 10 deletions src/main/python/systemds/scuro/dataloader/audio_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,18 @@

import librosa
from systemds.scuro.dataloader.base_loader import BaseLoader
from systemds.scuro.utils.schema_helpers import create_timestamps
from systemds.scuro.modality.type import ModalityType


class AudioLoader(BaseLoader):
def __init__(
self,
source_path: str,
indices: List[str],
chunk_size: Optional[int] = None,
self, source_path: str, indices: List[str], chunk_size: Optional[int] = None
):
super().__init__(source_path, indices, chunk_size)
super().__init__(source_path, indices, chunk_size, ModalityType.AUDIO)

def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
self.file_sanity_check(file)
audio, sr = librosa.load(file)
self.metadata[file] = {"sample_rate": sr, "length": audio.shape[0]}
self.metadata[file]["timestamp"] = create_timestamps(
self.metadata[file]["sample_rate"], self.metadata[file]["length"]
)
self.metadata[file] = self.modality_type.create_audio_metadata(sr, audio)

self.data.append(audio)
7 changes: 6 additions & 1 deletion src/main/python/systemds/scuro/dataloader/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@

class BaseLoader(ABC):
def __init__(
self, source_path: str, indices: List[str], chunk_size: Optional[int] = None
self,
source_path: str,
indices: List[str],
chunk_size: Optional[int] = None,
modality_type=None,
):
"""
Base class to load raw data for a given list of indices and stores them in the data object
Expand All @@ -40,6 +44,7 @@ def __init__(
) # TODO: check what the index should be for storing the metadata (file_name, counter, ...)
self.source_path = source_path
self.indices = indices
self.modality_type = modality_type
self._next_chunk = 0
self._num_chunks = 1
self._chunk_size = None
Expand Down
7 changes: 5 additions & 2 deletions src/main/python/systemds/scuro/dataloader/text_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# -------------------------------------------------------------
from systemds.scuro.dataloader.base_loader import BaseLoader
from typing import Optional, Pattern, List, Union
from systemds.scuro.modality.type import ModalityType
import re


Expand All @@ -31,7 +32,7 @@ def __init__(
chunk_size: Optional[int] = None,
prefix: Optional[Pattern[str]] = None,
):
super().__init__(source_path, indices, chunk_size)
super().__init__(source_path, indices, chunk_size, ModalityType.TEXT)
self.prefix = prefix

def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
Expand All @@ -41,5 +42,7 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
if self.prefix:
line = re.sub(self.prefix, "", line)
line = line.replace("\n", "")
self.metadata[file] = {"length": len(line.split())}
self.metadata[file] = self.modality_type.create_text_metadata(
len(line.split()), line
)
self.data.append(line)
20 changes: 9 additions & 11 deletions src/main/python/systemds/scuro/dataloader/video_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import numpy as np

from systemds.scuro.dataloader.base_loader import BaseLoader
from systemds.scuro.utils.schema_helpers import create_timestamps
import cv2
from systemds.scuro.modality.type import ModalityType


class VideoLoader(BaseLoader):
Expand All @@ -34,7 +34,7 @@ def __init__(
indices: List[str],
chunk_size: Optional[int] = None,
):
super().__init__(source_path, indices, chunk_size)
super().__init__(source_path, indices, chunk_size, ModalityType.VIDEO)

def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
self.file_sanity_check(file)
Expand All @@ -43,16 +43,14 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
if not cap.isOpened():
raise f"Could not read video at path: {file}"

self.metadata[file] = {
"fps": cap.get(cv2.CAP_PROP_FPS),
"length": int(cap.get(cv2.CAP_PROP_FRAME_COUNT)),
"width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
"height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
"num_channels": 3,
}
fps = cap.get(cv2.CAP_PROP_FPS)
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
num_channels = 3

self.metadata[file]["timestamp"] = create_timestamps(
self.metadata[file]["fps"], self.metadata[file]["length"]
self.metadata[file] = self.modality_type.create_video_metadata(
fps, length, width, height, num_channels
)

frames = []
Expand Down
33 changes: 16 additions & 17 deletions src/main/python/systemds/scuro/modality/joined.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def execute(self, starting_idx=0):
self.joined_right.data[i - starting_idx].append([])
right = np.array([])
if self.condition.join_type == "<":
while c < len(idx_2) and idx_2[c] < nextIdx[j]:
while c < len(idx_2) - 1 and idx_2[c] < nextIdx[j]:
if right.size == 0:
right = self.right_modality.data[i][c]
if right.ndim == 1:
Expand All @@ -125,7 +125,7 @@ def execute(self, starting_idx=0):
)
c = c + 1
else:
while c < len(idx_2) and idx_2[c] <= idx_1[j]:
while c < len(idx_2) - 1 and idx_2[c] <= idx_1[j]:
if idx_2[c] == idx_1[j]:
right.append(self.right_modality.data[i][c])
c = c + 1
Expand All @@ -141,18 +141,17 @@ def execute(self, starting_idx=0):

self.joined_right.data[i - starting_idx][j] = right

def apply_representation(self, representation, aggregation):
def apply_representation(self, representation, aggregation=None):
self.aggregation = aggregation
if self.chunked_execution:
return self._handle_chunked_execution(representation)
elif self.left_type.__name__.__contains__("Unimodal"):
self.left_modality.extract_raw_data()
if self.left_type == self.right_type:
self.right_modality.extract_raw_data()
elif self.right_type.__name__.__contains__("Unimodal"):
self.right_modality.extract_raw_data()
# elif self.left_type.__name__.__contains__("Unimodal"):
# self.left_modality.extract_raw_data()
# if self.left_type == self.right_type:
# self.right_modality.extract_raw_data()
# elif self.right_type.__name__.__contains__("Unimodal") and not self.right_modality.has_data():
# self.right_modality.extract_raw_data()

self.execute()
left_transformed = self._apply_representation(
self.left_modality, representation
)
Expand Down Expand Up @@ -263,12 +262,12 @@ def _apply_representation_chunked(

def _apply_representation(self, modality, representation):
transformed = representation.transform(modality)
if self.aggregation:
aggregated_data_left = self.aggregation.window(transformed)
transformed = Modality(
transformed.modality_type,
transformed.metadata,
)
transformed.data = aggregated_data_left
# if self.aggregation:
# aggregated_data_left = self.aggregation.execute(transformed)
# transformed = Modality(
# transformed.modality_type,
# transformed.metadata,
# )
# transformed.data = aggregated_data_left

return transformed
7 changes: 7 additions & 0 deletions src/main/python/systemds/scuro/modality/joined_transformed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from systemds.scuro.modality.modality import Modality
from systemds.scuro.representations.utils import pad_sequences
from systemds.scuro.representations.window import WindowAggregation


class JoinedTransformedModality(Modality):
Expand Down Expand Up @@ -68,3 +69,9 @@ def combine(self, fusion_method):
self.data[i] = np.array(r)
self.data = pad_sequences(self.data)
return self

def window(self, window_size, aggregation):
w = WindowAggregation(window_size, aggregation)
self.left_modality.data = w.execute(self.left_modality)
self.right_modality.data = w.execute(self.right_modality)
return self
80 changes: 42 additions & 38 deletions src/main/python/systemds/scuro/modality/modality.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,41 @@
# under the License.
#
# -------------------------------------------------------------
from copy import deepcopy
from typing import List

import numpy as np

from systemds.scuro.modality.type import ModalityType
from systemds.scuro.modality.type import ModalityType, DataLayout


class Modality:

def __init__(self, modalityType: ModalityType, metadata=None):
def __init__(self, modalityType: ModalityType, metadata={}):
"""
Parent class of the different Modalities (unimodal & multimodal)
:param modality_type: Type of the modality
"""
self.modality_type = modalityType
self.schema = modalityType.get_schema()
self.metadata = metadata
self.data = []
self.data_type = None
self.cost = None
self.shape = None
self.dataIndex = None
self.metadata = metadata

@property
def data(self):
return self._data

@data.setter
def data(self, value):
"""
This method ensures that the data layout in the metadata is updated when the data changes
"""
self._data = value
self.update_metadata()

def get_modality_names(self) -> List[str]:
"""
Expand All @@ -50,10 +63,23 @@ def get_modality_names(self) -> List[str]:
]

def copy_from_instance(self):
"""
Create a copy of the modality instance
"""
return type(self)(self.modality_type, self.metadata)

def update_metadata(self):
md_copy = self.metadata
"""
Updates the metadata of the modality (i.e.: updates timestamps)
"""
if (
not self.has_metadata()
or not self.has_data()
or len(self.data) < len(self.metadata)
):
return

md_copy = deepcopy(self.metadata)
self.metadata = {}
for i, (md_k, md_v) in enumerate(md_copy.items()):
updated_md = self.modality_type.update_metadata(md_v, self.data[i])
Expand All @@ -63,6 +89,10 @@ def get_metadata_at_position(self, position: int):
return self.metadata[self.dataIndex][position]

def flatten(self):
"""
Flattens modality data by row-wise concatenation
Prerequisite for some ML-models
"""
for num_instance, instance in enumerate(self.data):
if type(instance) is np.ndarray:
self.data[num_instance] = instance.flatten()
Expand All @@ -75,39 +105,13 @@ def flatten(self):
return self

def get_data_layout(self):
if not self.data:
return self.data

if isinstance(self.data[0], list):
return "list_of_lists_of_numpy_array"
elif isinstance(self.data[0], np.ndarray):
return "list_of_numpy_array"

def get_data_shape(self):
layout = self.get_data_layout()
if not layout:
return None

if layout == "list_of_lists_of_numpy_array":
return self.data[0][0].shape
elif layout == "list_of_numpy_array":
return self.data[0].shape

def get_data_dtype(self):
layout = self.get_data_layout()
if not layout:
return None

if layout == "list_of_lists_of_numpy_array":
return self.data[0][0].dtype
elif layout == "list_of_numpy_array":
return self.data[0].dtype

def update_data_layout(self):
if not self.data:
return
if self.has_metadata():
return list(self.metadata.values())[0]["data_layout"]["representation"]

return None

self.schema["data_layout"]["representation"] = self.get_data_layout()
def has_data(self):
return self.data is not None and len(self.data) != 0

self.shape = self.get_data_shape()
self.schema["data_layout"]["type"] = self.get_data_dtype()
def has_metadata(self):
return self.metadata is not None and self.metadata != {}
11 changes: 6 additions & 5 deletions src/main/python/systemds/scuro/modality/transformed.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from systemds.scuro.modality.joined import JoinedModality
from systemds.scuro.modality.modality import Modality
from systemds.scuro.representations.aggregate import Aggregation
from systemds.scuro.representations.window import WindowAggregation


Expand All @@ -36,7 +37,6 @@ def __init__(self, modality_type, transformation, metadata):
"""
super().__init__(modality_type, metadata)
self.transformation = transformation
self.data = []

def copy_from_instance(self):
return type(self)(self.modality_type, self.transformation, self.metadata)
Expand All @@ -46,7 +46,7 @@ def join(self, right, join_condition):
if type(right).__name__.__contains__("Unimodal"):
if right.data_loader.chunk_size:
chunked_execution = True
elif right.data is None or len(right.data) == 0:
elif not right.has_data():
right.extract_raw_data()

joined_modality = JoinedModality(
Expand All @@ -59,15 +59,16 @@ def join(self, right, join_condition):

if not chunked_execution:
joined_modality.execute(0)
joined_modality.joined_right.update_metadata()

return joined_modality

def window(self, windowSize, aggregationFunction, fieldName=None):
def window(self, windowSize, aggregation):
transformed_modality = TransformedModality(
self.modality_type, "window", self.metadata
)
w = WindowAggregation(windowSize, aggregationFunction)
transformed_modality.data = w.window(self)
w = WindowAggregation(windowSize, Aggregation(aggregation))
transformed_modality.data = w.execute(self)

return transformed_modality

Expand Down
Loading

0 comments on commit 0c4b074

Please sign in to comment.