From c197f42a32adcc0e7399ad0dddd35cc89964303c Mon Sep 17 00:00:00 2001 From: Caglar Demir Date: Thu, 12 Jan 2023 11:40:21 +0100 Subject: [PATCH] Refactoring. --- core/callbacks.py | 3 +-- core/dataset_classes.py | 22 ++++++++++++++-------- core/executer.py | 39 +++++++++++++++++++++++++++++---------- 3 files changed, 44 insertions(+), 20 deletions(-) diff --git a/core/callbacks.py b/core/callbacks.py index 382189c9..0fa2756c 100644 --- a/core/callbacks.py +++ b/core/callbacks.py @@ -8,7 +8,6 @@ import os import pandas as pd - class AccumulateEpochLossCallback(AbstractCallback): def __init__(self, path: str): super().__init__() @@ -41,7 +40,7 @@ def on_fit_start(self, trainer, pl_module): print(pl_module) print(pl_module.summarize()) print(pl_module.selected_optimizer) - print("\nTraining is starting...") + print(f"\nTraining is starting {datetime.datetime.now()}...") def on_fit_end(self, trainer, pl_module): training_time = time.time() - self.start_time diff --git a/core/dataset_classes.py b/core/dataset_classes.py index 29009140..362c9f53 100644 --- a/core/dataset_classes.py +++ b/core/dataset_classes.py @@ -260,8 +260,6 @@ def __init__(self, train_set: np.ndarray, num_entities, num_relations, neg_sampl label_smoothing_rate: float = 0.0): super().__init__() assert isinstance(train_set, np.ndarray) - # https://pytorch.org/docs/stable/data.html#multi-process-data-loading - # TLDL; replace Python objects with non-refcounted representations such as Pandas, Numpy or PyArrow objects self.train_data = train_set self.num_entities = num_entities self.num_relations = num_relations @@ -276,8 +274,18 @@ def __init__(self, train_set: np.ndarray, num_entities, num_relations, neg_sampl print('Constructing training data...') store = mapping_from_first_two_cols_to_third(train_set) self.train_data = torch.IntTensor(list(store.keys())) - self.train_target = list(store.values()) + # https://pytorch.org/docs/stable/data.html#multi-process-data-loading + # TLDL; replace Python objects with non-refcounted representations such as Pandas, Numpy or PyArrow objects + # Unsure whether a list of numpy arrays are non-refcounted + self.train_target = list([np.array(i) for i in store.values()]) del store + # @TODO: Investigate reference counts of using list of numpy arrays. + #import sys + #import gc + # print(sys.getrefcount(self.train_target)) + # print(sys.getrefcount(self.train_target[0])) + # print(gc.get_referrers(self.train_target)) + # print(gc.get_referrers(self.train_target[0])) def __len__(self): assert len(self.train_data) == len(self.train_target) @@ -287,7 +295,7 @@ def __getitem__(self, idx): # (1) Get i.th unique (head,relation) pair. x = self.train_data[idx] # (2) Get tail entities given (1). - positives_idx = np.array(self.train_target[idx]) + positives_idx = self.train_target[idx] num_positives = len(positives_idx) # (3) Do we need to subsample (2) to create training data points of same size. if num_positives < self.neg_sample_ratio: @@ -296,16 +304,14 @@ def __getitem__(self, idx): # (3.2) Generate more negative entities negative_idx = torch.randint(low=0, high=self.num_entities, - size=(self.neg_sample_ratio + self.neg_sample_ratio - num_positives,), - dtype=torch.int32) + size=(self.neg_sample_ratio + self.neg_sample_ratio - num_positives,)) else: # (3.1) Subsample positives without replacement. positives_idx = torch.IntTensor(np.random.choice(positives_idx, size=self.neg_sample_ratio, replace=False)) # (3.2) Generate random entities. negative_idx = torch.randint(low=0, high=self.num_entities, - size=(self.neg_sample_ratio,), - dtype=torch.int32) + size=(self.neg_sample_ratio,)) # (5) Create selected indexes. y_idx = torch.cat((positives_idx, negative_idx), 0) # (6) Create binary labels. diff --git a/core/executer.py b/core/executer.py index 3396c575..05f8481a 100644 --- a/core/executer.py +++ b/core/executer.py @@ -7,13 +7,12 @@ import datetime import numpy as np -import torch -import torch.nn.functional as F from pytorch_lightning import seed_everything from core.knowledge_graph import KG from core.models.base_model import BaseKGE from core.evaluator import Evaluator +# Avoid from core.static_funcs import * from core.static_preprocess_funcs import preprocesses_input_args from core.sanity_checkers import * @@ -100,14 +99,25 @@ def save_trained_model(self) -> None: def start(self) -> dict: """ - (1) Data Preparation: - (1.1) Read, Preprocess Index, Serialize. - (1.2) Load a data that has been in (1.1). - (2) Train & Eval - (3) Save the model - (4) Return a report of the training + Start training + + # (1) Loading the Data + # (2) Create an evaluator object. + # (3) Create a trainer object. + # (4) Start the training + # (5) Store trained model. + # (6) Eval model if required. + + Parameter + --------- + + Returns + ------- + A dict containing information about the training and/or evaluation + """ start_time = time.time() + print(f"Start time:{datetime.datetime.now()}") # (1) Loading the Data # Load the indexed data from disk or read a raw data from disk. self.load_indexed_data() if self.is_continual_training else self.read_preprocess_index_serialize_data() @@ -134,7 +144,12 @@ def start(self) -> dict: class ContinuousExecute(Execute): - """ Continue training a pretrained KGE model """ + """ A subclass of Execute Class for retraining + + (1) Loading & Preprocessing & Serializing input data. + (2) Training & Validation & Testing + (3) Storing all necessary info + """ def __init__(self, args): assert os.path.exists(args.path_experiment_folder) @@ -158,10 +173,13 @@ def __init__(self, args): previous_args.full_storage_path = previous_args.path_experiment_folder print('ContinuousExecute starting...') print(previous_args) + # TODO: can we remove continuous_training from Execute ? super().__init__(previous_args, continuous_training=True) def continual_start(self) -> dict: """ + Start Continual Training + (1) Initialize training. (2) Start continual training. (3) Save trained model. @@ -171,7 +189,8 @@ def continual_start(self) -> dict: Returns ------- - report:dict + A dict containing information about the training and/or evaluation + """ # (1) self.trainer = DICE_Trainer(args=self.args, is_continual_training=True,