Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYSTEMDS-3830] Add join operator to Scuro #2220

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ jobs:
torch \
librosa \
h5py \
nltk \
gensim \
black

black \
opt-einsum

- name: Build Python Package
run: |
cd src/main/python
Expand Down
9 changes: 7 additions & 2 deletions src/main/python/systemds/scuro/dataloader/audio_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
# under the License.
#
# -------------------------------------------------------------
from typing import List, Optional
from typing import List, Optional, Union

import librosa
from systemds.scuro.dataloader.base_loader import BaseLoader
from systemds.scuro.utils.schema_helpers import create_timestamps


class AudioLoader(BaseLoader):
Expand All @@ -33,7 +34,11 @@ def __init__(
):
super().__init__(source_path, indices, chunk_size)

def extract(self, file: str):
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
self.file_sanity_check(file)
audio, sr = librosa.load(file)
self.metadata[file] = {"sample_rate": sr, "length": audio.shape[0]}
self.metadata[file]["timestamp"] = create_timestamps(
self.metadata[file]["sample_rate"], self.metadata[file]["length"]
)
self.data.append(audio)
55 changes: 46 additions & 9 deletions src/main/python/systemds/scuro/dataloader/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,32 +35,68 @@ def __init__(
(otherwise please provide your own Dataloader that knows about the file name convention)
"""
self.data = []
self.metadata = (
{}
) # TODO: check what the index should be for storing the metadata (file_name, counter, ...)
self.source_path = source_path
self.indices = indices
self.chunk_size = chunk_size
self.next_chunk = 0
self._next_chunk = 0
self._num_chunks = 1
self._chunk_size = None

if self.chunk_size:
self.num_chunks = int(len(self.indices) / self.chunk_size)
if chunk_size:
self.chunk_size = chunk_size

@property
def chunk_size(self):
return self._chunk_size

@chunk_size.setter
def chunk_size(self, value):
self._chunk_size = value
self._num_chunks = int(len(self.indices) / self._chunk_size)

@property
def num_chunks(self):
return self._num_chunks

@property
def next_chunk(self):
return self._next_chunk

def load(self):
"""
Takes care of loading the raw data either chunk wise (if chunk size is defined) or all at once
"""
if self.chunk_size:
if self._chunk_size:
return self._load_next_chunk()

return self._load(self.indices)

def update_chunk_sizes(self, other):
if not self._chunk_size and not other.chunk_size:
return

if (
self._chunk_size
and not other.chunk_size
or self._chunk_size < other.chunk_size
):
other.chunk_size = self.chunk_size
else:
self.chunk_size = other.chunk_size

def _load_next_chunk(self):
"""
Loads the next chunk of data
"""
self.data = []
next_chunk_indices = self.indices[
self.next_chunk * self.chunk_size : (self.next_chunk + 1) * self.chunk_size
self._next_chunk
* self._chunk_size : (self._next_chunk + 1)
* self._chunk_size
]
self.next_chunk += 1
self._next_chunk += 1
return self._load(next_chunk_indices)

def _load(self, indices: List[str]):
Expand All @@ -73,13 +109,14 @@ def _load(self, indices: List[str]):
else:
self.extract(self.source_path, indices)

return self.data
return self.data, self.metadata

@abstractmethod
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
pass

def file_sanity_check(self, file):
@staticmethod
def file_sanity_check(file):
"""
Checks if the file can be found is not empty
"""
Expand Down
6 changes: 3 additions & 3 deletions src/main/python/systemds/scuro/dataloader/json_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import json

from systemds.scuro.dataloader.base_loader import BaseLoader
from typing import Optional, List
from typing import Optional, List, Union


class JSONLoader(BaseLoader):
Expand All @@ -35,9 +35,9 @@ def __init__(
super().__init__(source_path, indices, chunk_size)
self.field = field

def extract(self, file: str, indices: List[str]):
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
self.file_sanity_check(file)
with open(file) as f:
json_file = json.load(f)
for idx in indices:
for idx in index:
self.data.append(json_file[idx][self.field])
5 changes: 3 additions & 2 deletions src/main/python/systemds/scuro/dataloader/text_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#
# -------------------------------------------------------------
from systemds.scuro.dataloader.base_loader import BaseLoader
from typing import Optional, Pattern, List
from typing import Optional, Pattern, List, Union
import re


Expand All @@ -34,11 +34,12 @@ def __init__(
super().__init__(source_path, indices, chunk_size)
self.prefix = prefix

def extract(self, file: str):
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
self.file_sanity_check(file)
with open(file) as text_file:
for i, line in enumerate(text_file):
if self.prefix:
line = re.sub(self.prefix, "", line)
line = line.replace("\n", "")
self.metadata[file] = {"length": len(line.split())}
self.data.append(line)
21 changes: 19 additions & 2 deletions src/main/python/systemds/scuro/dataloader/video_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
# under the License.
#
# -------------------------------------------------------------
from typing import List, Optional
from typing import List, Optional, Union

import numpy as np

from systemds.scuro.dataloader.base_loader import BaseLoader
from systemds.scuro.utils.schema_helpers import create_timestamps
import cv2


Expand All @@ -35,9 +36,25 @@ def __init__(
):
super().__init__(source_path, indices, chunk_size)

def extract(self, file: str):
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
self.file_sanity_check(file)
cap = cv2.VideoCapture(file)

if not cap.isOpened():
raise f"Could not read video at path: {file}"

self.metadata[file] = {
"fps": cap.get(cv2.CAP_PROP_FPS),
"length": int(cap.get(cv2.CAP_PROP_FRAME_COUNT)),
"width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
"height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
"num_channels": 3,
}

self.metadata[file]["timestamp"] = create_timestamps(
self.metadata[file]["fps"], self.metadata[file]["length"]
)

frames = []
while cap.isOpened():
ret, frame = cap.read()
Expand Down
Loading
Loading