Skip to content

Commit

Permalink
improve IO-performance
Browse files Browse the repository at this point in the history
  • Loading branch information
mieskolainen committed Jul 31, 2024
1 parent 004c308 commit 27afae5
Showing 1 changed file with 122 additions and 8 deletions.
130 changes: 122 additions & 8 deletions icenet/tools/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@
import socket
import copy
import glob
import time
from tqdm import tqdm
import matplotlib.pyplot as plt

import multiprocessing
from concurrent.futures import ThreadPoolExecutor, as_completed

from importlib import import_module
import os
import copy
Expand Down Expand Up @@ -234,7 +238,7 @@ def read_config(config_path='configs/xyz/', runmode='all'):
# Finally create the hash
args['__hash_genesis__'] = io.make_hash_sha256_object(hash_args)

print(f'.read_config: Generated config hashes', 'magenta')
print(f'Generated config hashes', 'magenta')
print(f'[__hash_genesis__] : {args["__hash_genesis__"]} ', 'magenta')

# -------------------------------------------------------------------
Expand Down Expand Up @@ -419,6 +423,38 @@ def generic_flow(rootname, func_loader, func_factor):

return args, runmode

# -------------------------------------------------------------------

def concatenate_data(data):
"""
Helper function to concatenate arrays
"""
X_all, Y_all, W_all = [], [], []

print('Concatenating arrays ...')
tic = time.time()

for X_, Y_, W_ in tqdm(data):
X_all.append(X_)
Y_all.append(Y_)
W_all.append(W_)
X = np.concatenate(X_all, axis=0) # awkward casts this
Y = np.concatenate(Y_all, axis=0)
W = np.concatenate(W_all, axis=0)

toc = time.time() - tic
print(f'Concatenation took {toc:0.2f} sec')

return X, Y, W

def load_file_wrapper(index, filepath):
"""
Helper function
"""
with open(filepath, 'rb') as handle:
return index, pickle.load(handle)

# -------------------------------------------------------------------

@iceprint.icelog(LOGGER)
def read_data(args, func_loader, runmode):
Expand Down Expand Up @@ -454,14 +490,19 @@ def get_chunk_ind(N):
ids = data['ids']
info = data['info']

print(f'Saving to path: "{cache_directory}"', 'yellow')
C = get_chunk_ind(N=len(X))
print(f'Saving {len(C)} pickle files to path: "{cache_directory}"', 'yellow')

tic = time.time()

for i in tqdm(range(len(C))):
with open(f'{cache_directory}/output_{i}.pkl', 'wb') as handle:
pickle.dump([X[C[i][0]:C[i][-1]], Y[C[i][0]:C[i][-1]], W[C[i][0]:C[i][-1]], ids, info, args], \
handle, protocol=pickle.HIGHEST_PROTOCOL)

toc = time.time() - tic
print(f'Saving took {toc:0.2f} sec')

gc.collect()
io.showmem()

Expand All @@ -475,9 +516,57 @@ def get_chunk_ind(N):
print(f'"genesis" already done and the cache files are ready.', 'green')
return

## New version

"""
Using ThreadPool, not fully parallel because of GIL (Global Interpreter Lock), but
should keep memory in control (vs. ProcessPool uses processes, but memory can be a problem)
"""

num_cpus = args['num_cpus']
max_workers = multiprocessing.cpu_count() // 2 if num_cpus == 0 else num_cpus

files = os.listdir(cache_directory)
sorted_files = sorted(files, key=lambda x: int(os.path.splitext(x)[0].split('_')[1]))

filepaths = [os.path.join(cache_directory, f) for f in sorted_files]
num_files = len(filepaths)

print(f'Loading {num_files} pickle files from path: "{cache_directory}"')
print('')
print(sorted_files)
print('')

data = [None] * num_files

tic = time.time()

with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_index = {executor.submit(load_file_wrapper, i, fp): i for i, fp in enumerate(filepaths)}
for future in tqdm(as_completed(future_to_index), total=num_files):
try:
index, (X_, Y_, W_, ids, info, genesis_args) = future.result()
data[index] = (X_, Y_, W_)
except Exception as e:
msg = f'Error loading file: {filepaths[future_to_index[future]]} -- {e}'
raise Exception(msg)

finally:
del future # Ensure the future is deleted to free memory

toc = time.time() - tic
print(f'Loading took {toc:0.2f} sec')

X, Y, W = concatenate_data(data)
gc.collect() # Call garbage collection once after the loop

"""
## Old version
num_files = io.count_files_in_dir(cache_directory)
print(f'Loading from path: "{cache_directory}"', 'yellow')
tic = time.time()
for i in tqdm(range(num_files)):
with open(f'{cache_directory}/output_{i}.pkl', 'rb') as handle:
Expand All @@ -491,6 +580,9 @@ def get_chunk_ind(N):
X,Y,W = copy.deepcopy(X_), copy.deepcopy(Y_), copy.deepcopy(W_)
gc.collect() # important!
toc = time.time() - tic
print(f'Took {toc:0.2f} sec')
"""

print('[done]')

Expand All @@ -516,11 +608,14 @@ def read_data_processed(args, func_loader, func_factor, mvavars, runmode):
data = read_data(args=args, func_loader=func_loader, runmode=runmode)

with open(cache_filename, 'wb') as handle:
print(f'Saving <DATA> to a file: "{cache_filename}"', 'yellow')
print(f'Saving <DATA> to a pickle file: "{cache_filename}"', 'yellow')

# Disable garbage collector for speed
gc.disable()
tic = time.time()
pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
toc = time.time() - tic
print(f'Saving took {toc:0.2f} sec')
gc.enable()

# Save args
Expand All @@ -529,11 +624,14 @@ def read_data_processed(args, func_loader, func_factor, mvavars, runmode):

else:
with open(cache_filename, 'rb') as handle:
print(f'Loading <DATA> from a file: "{cache_filename}"', 'yellow')
print(f'Loading <DATA> from a pickle file: "{cache_filename}"', 'yellow')

# Disable garbage collector for speed
gc.disable()
tic = time.time()
data = pickle.load(handle)
toc = time.time() - tic
print(f'Loading took {toc:0.2f} sec')
gc.enable()

io.showmem()
Expand All @@ -552,11 +650,14 @@ def read_data_processed(args, func_loader, func_factor, mvavars, runmode):
processed_data = process_data(args=args, predata=data, func_factor=func_factor, mvavars=mvavars, runmode=runmode)

with open(cache_filename, 'wb') as handle:
print(f'Saving <PROCESSED DATA> to a file: "{cache_filename}"', 'yellow')
print(f'Saving <PROCESSED DATA> to a pickle file: "{cache_filename}"', 'yellow')

# Disable garbage collector for speed
gc.disable()
tic = time.time()
pickle.dump(processed_data, handle, protocol=pickle.HIGHEST_PROTOCOL)
toc = time.time() - tic
print(f'Saving took {toc:0.2f} sec')
gc.enable()

# Save args
Expand All @@ -565,11 +666,14 @@ def read_data_processed(args, func_loader, func_factor, mvavars, runmode):

else:
with open(cache_filename, 'rb') as handle:
print(f'Loading <PROCESSED DATA> from a file: "{cache_filename}"', 'yellow')
print(f'Loading <PROCESSED DATA> from a pickle file: "{cache_filename}"', 'yellow')

# Disable garbage collector for speed
gc.disable()
tic = time.time()
processed_data = pickle.load(handle)
toc = time.time() - tic
print(f'Loading took {toc:0.2f} sec')
gc.enable()

gc.collect()
Expand Down Expand Up @@ -681,10 +785,17 @@ def process_data(args, predata, func_factor, mvavars, runmode):
pickle.dump(pdf, open(fmodel, 'wb'), protocol=pickle.HIGHEST_PROTOCOL)

# Compute different data representations
print(f'Compute representations [common.func_factor]', 'green')
print(f'Compute representations [func_factor]', 'green')

tic = time.time()
output['trn'] = func_factor(x=trn.x, y=trn.y, w=trn.w, ids=trn.ids, args=args)
toc = time.time() - tic
print(f'Representations [trn] took {toc:0.2f} sec')

tic = time.time()
output['val'] = func_factor(x=val.x, y=val.y, w=val.w, ids=val.ids, args=args)
toc = time.time() - tic
print(f'Representations [val] took {toc:0.2f} sec')

## Imputate
if args['imputation_param']['active']:
Expand Down Expand Up @@ -727,7 +838,10 @@ def process_data(args, predata, func_factor, mvavars, runmode):
# Compute different data representations
print(f'Compute representations [common.func_factor]', 'green')

tic = time.time()
output['tst'] = func_factor(x=tst.x, y=tst.y, w=tst.w, ids=tst.ids, args=args)
toc = time.time() - tic
print(f'Representations [tst] took {toc:0.2f} sec')

## Imputate
if args['imputation_param']['active']:
Expand Down Expand Up @@ -1134,7 +1248,7 @@ def evaluate_models(data=None, info=None, args=None):
Saves evaluation plots to the disk
"""

print(f'Evaluation models ...', 'yellow')
print(f'Evaluating models ...', 'yellow')
print('')

# -----------------------------
Expand Down

0 comments on commit 27afae5

Please sign in to comment.