Skip to content

Commit

Permalink
Merge pull request #7 from spaceml-org/sdoml_latents_support
Browse files Browse the repository at this point in the history
add sdoml support
  • Loading branch information
Sceki authored Dec 19, 2024
2 parents c48ac6e + 328e780 commit 84c2869
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 2 deletions.
2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ dependencies:
- wandb
- netcdf4
- cftime
- huggingface_hub
- h5py
- pip:
- nrlmsise00
- spaceweather
Expand Down
16 changes: 15 additions & 1 deletion karman/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def __init__(
soho_resolution=60, # 1 hour
lag_minutes_nrlmsise00=2*24*60,
nrlmsise00_resolution=60, # 1 hour
lag_minutes_sdoml_latents=2 * 24 * 60,
sdoml_latents_resolution=12,
features_to_exclude_thermo=[
"all__dates_datetime__",
"tudelft_thermo__satellite__",
Expand All @@ -42,6 +44,7 @@ def __init__(
features_to_exclude_soho=['all__dates_datetime__',
'source__gaps_flag__'],
features_to_exclude_nrlmsise00=['all__dates_datetime__'],
features_to_exclude_sdoml_latents=['all__dates_datetime__'],
min_date=pd.to_datetime("2000-07-29 00:59:47"),
max_date=pd.to_datetime("2024-05-31 23:59:32"),
max_altitude=np.inf,
Expand All @@ -58,6 +61,7 @@ def __init__(
goes_1405nm_path=None,#"../data/goes_data/goes_1405nm_sw.csv"
soho_path=None,#"../data/soho_data/soho_data.csv"
nrlmsise00_path=None,#"../data/nrlmsise00_data/nrlmsise00_time_series.csv"
sdoml_latents_path=None,#"../data/sdoml_latents/sdofm_nvae_embeddings_pca_50.csv"
torch_type=torch.float32,
target_type="log_density",
exclude_mask='exclude_mask.pk'
Expand Down Expand Up @@ -104,6 +108,7 @@ def __init__(
self.features_to_exclude_omni_magnetic_field = features_to_exclude_omni_magnetic_field
self.features_to_exclude_goes = features_to_exclude_goes
self.features_to_exclude_soho = features_to_exclude_soho
self.features_to_exclude_sdoml_latents = features_to_exclude_sdoml_latents
self.features_to_exclude_nrlmsise00 = features_to_exclude_nrlmsise00

self.min_date = min_date
Expand Down Expand Up @@ -228,6 +233,15 @@ def __init__(
soho_resolution,
self.features_to_exclude_soho,
)
if sdoml_latents_path is not None:
print("Loading SDO-FM Latents.")
self._add_time_series_data(
"sdoml_latents",
sdoml_latents_path,
lag_minutes_sdoml_latents,
sdoml_latents_resolution,
self.features_to_exclude_sdoml_latents,
)
print("Creating thermospheric density dataset")
self.data_thermo = {}
self.data_thermo["data"] = pd.read_csv(self.thermo_path)
Expand Down Expand Up @@ -451,7 +465,7 @@ def _add_time_series_data(
"""
# Data loading:
self.time_series_data[data_name] = {}
if data_name in ["omni_indices", "omni_solar_wind", "omni_magnetic_field","goes_256nm","goes_284nm","goes_304nm","goes_1175nm","goes_1216nm","goes_1335nm","goes_1405nm","soho"]:
if data_name in ["omni_indices", "omni_solar_wind", "omni_magnetic_field","goes_256nm","goes_284nm","goes_304nm","goes_1175nm","goes_1216nm","goes_1335nm","goes_1405nm","soho","sdoml_latents"]:
self.time_series_data[data_name]["data"] = pd.read_csv(data_path)
# we now index the data by the datetime column, and sort it by the index. The reason is that it is then easier to resample
self.time_series_data[data_name]["data"].index = pd.to_datetime(self.time_series_data[data_name]["data"]["all__dates_datetime__"])
Expand Down
42 changes: 42 additions & 0 deletions scripts/input_data_prep/download_sdoml_latents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
import argparse
import time
import sys

from pyfiglet import Figlet
from termcolor import colored

from huggingface_hub import login
from huggingface_hub import hf_hub_download

def download_sdoml_latents():
print('SDO-FM Latents Data Downloading')
f = Figlet(font='5lineoblique')
print(colored(f.renderText('KARMAN 2.0'), 'red'))
f = Figlet(font='digital')
print(colored(f.renderText("Downloading SDO-FM Latents Data"), 'blue'))
#print(colored(f'Version {karman.__version__}\n','blue'))

parser = argparse.ArgumentParser(description='SDO-FM Latents Data Downloading', formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('--token', type=str, help='Hugging face login token')
parser.add_argument('--sdoml_latents_data_dir',type=str, default='../../data/sdoml_latents', help='Path where to store the SDO-FM latents data')
opt = parser.parse_args()

login(token=opt.token)

repo_id = "SpaceML/SDO-FM"
filename = "sdofm_nvae_embeddings.h5"
if not os.path.exists(opt.sdoml_latents_data_dir):
os.makedirs(opt.sdoml_latents_data_dir)
#download the file
file_path = hf_hub_download(repo_id=repo_id, filename=filename, local_dir=opt.sdoml_latents_data_dir)
print(f"Downloaded file at: {file_path}")



if __name__ == "__main__":
time_start = time.time()
download_sdoml_latents()
print('\nTotal duration: {}'.format(time.time() - time_start))
sys.exit(0)
69 changes: 69 additions & 0 deletions scripts/input_data_prep/process_sdoml_latents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import numpy as np
import os
import argparse
import datetime
import h5py
import numpy as np
import pandas as pd

import time
import sys

from io import StringIO
from tqdm import tqdm
from pyfiglet import Figlet
from termcolor import colored
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

def process_sdoml_latents():
print('SDO-FM Latents Data Processing')
f = Figlet(font='5lineoblique')
print(colored(f.renderText('KARMAN 2.0'), 'red'))
f = Figlet(font='digital')
print(colored(f.renderText("SDO-FM Latents Data Processing"), 'blue'))
#print(colored(f'Version {karman.__version__}\n','blue'))

parser = argparse.ArgumentParser(description='SDO-FM Latents Data Processing', formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('--sdoml_latents_dir', type=str, default='../../data/sdoml_latents_data_dir', help='SDO-FM latents data directory: this will be used also to store the processed data.')
parser.add_argument('--pca_components', type=int, default=50, help='PCA components to reduce the dimensionality of the latents data.')

opt = parser.parse_args()

print('Reading SDOM Latents Data')
#we start by loading the sdo-fm latents:
with h5py.File('../../data/sdoml_latents/sdofm_nvae_embeddings.h5', 'r') as f:
data_tmp = {key: f[key][:] for key in f.keys()}

#create the datetime column
datetime = pd.to_datetime({
'year': data_tmp['year'],
'month': data_tmp['month'],
'day': data_tmp['day'],
'hour': data_tmp['hour'],
'minute': data_tmp['minute']
})
print("Done, now reducing the dimensionality via PCA")

scaler = StandardScaler()
data_scaled = scaler.fit_transform(data_tmp['latent'])

pca = PCA(n_components=opt.pca_components)
data_pca = pca.fit_transform(data_scaled)

print("Done, now saving the PCA latents")
#let's create the dataframe to store the PCA latents
df={}
df['all__dates_datetime__']=datetime
for i in range(data_pca.shape[1]):
df[f'sdofm__latent_{i}__']=data_pca[:,i]
df=pd.DataFrame(df)
df.to_csv(f'../../data/sdoml_latents/sdofm_nvae_embeddings_pca_{opt.pca_components}.csv',index=False)


if __name__ == "__main__":
time_start = time.time()
process_sdoml_latents()
print('\nTotal duration: {}'.format(time.time() - time_start))
sys.exit(0)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def read_package_variable(key):
packages=find_packages(),
url='https://github.com/spaceml-org/karman',
install_requires=['numpy', 'torch','matplotlib','scikit-learn','pandas','tables','tqdm','pyfiglet>=0.8.0','termcolor','wandb','pyatmos','spaceweather','nrlmsise00','tft-torch'],
extras_require={'dev': ['pytest', 'coverage', 'pytest-xdist','netcdf4','cftime','flake8']},
extras_require={'dev': ['pytest', 'coverage', 'pytest-xdist','netcdf4','cftime','flake8','huggingface_hub', 'h5py']},
classifiers=['License :: OSI Approved :: GNU General Public License v3 (GPLv3)', 'Programming Language :: Python :: 3'],
include_package_data=True,
package_data={
Expand Down

0 comments on commit 84c2869

Please sign in to comment.