Skip to content

Commit

Permalink
Refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
Demirrr committed Jan 12, 2023
1 parent 9c248c0 commit c197f42
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 20 deletions.
3 changes: 1 addition & 2 deletions core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import os
import pandas as pd


class AccumulateEpochLossCallback(AbstractCallback):
def __init__(self, path: str):
super().__init__()
Expand Down Expand Up @@ -41,7 +40,7 @@ def on_fit_start(self, trainer, pl_module):
print(pl_module)
print(pl_module.summarize())
print(pl_module.selected_optimizer)
print("\nTraining is starting...")
print(f"\nTraining is starting {datetime.datetime.now()}...")

def on_fit_end(self, trainer, pl_module):
training_time = time.time() - self.start_time
Expand Down
22 changes: 14 additions & 8 deletions core/dataset_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,6 @@ def __init__(self, train_set: np.ndarray, num_entities, num_relations, neg_sampl
label_smoothing_rate: float = 0.0):
super().__init__()
assert isinstance(train_set, np.ndarray)
# https://pytorch.org/docs/stable/data.html#multi-process-data-loading
# TLDL; replace Python objects with non-refcounted representations such as Pandas, Numpy or PyArrow objects
self.train_data = train_set
self.num_entities = num_entities
self.num_relations = num_relations
Expand All @@ -276,8 +274,18 @@ def __init__(self, train_set: np.ndarray, num_entities, num_relations, neg_sampl
print('Constructing training data...')
store = mapping_from_first_two_cols_to_third(train_set)
self.train_data = torch.IntTensor(list(store.keys()))
self.train_target = list(store.values())
# https://pytorch.org/docs/stable/data.html#multi-process-data-loading
# TLDL; replace Python objects with non-refcounted representations such as Pandas, Numpy or PyArrow objects
# Unsure whether a list of numpy arrays are non-refcounted
self.train_target = list([np.array(i) for i in store.values()])
del store
# @TODO: Investigate reference counts of using list of numpy arrays.
#import sys
#import gc
# print(sys.getrefcount(self.train_target))
# print(sys.getrefcount(self.train_target[0]))
# print(gc.get_referrers(self.train_target))
# print(gc.get_referrers(self.train_target[0]))

def __len__(self):
assert len(self.train_data) == len(self.train_target)
Expand All @@ -287,7 +295,7 @@ def __getitem__(self, idx):
# (1) Get i.th unique (head,relation) pair.
x = self.train_data[idx]
# (2) Get tail entities given (1).
positives_idx = np.array(self.train_target[idx])
positives_idx = self.train_target[idx]
num_positives = len(positives_idx)
# (3) Do we need to subsample (2) to create training data points of same size.
if num_positives < self.neg_sample_ratio:
Expand All @@ -296,16 +304,14 @@ def __getitem__(self, idx):
# (3.2) Generate more negative entities
negative_idx = torch.randint(low=0,
high=self.num_entities,
size=(self.neg_sample_ratio + self.neg_sample_ratio - num_positives,),
dtype=torch.int32)
size=(self.neg_sample_ratio + self.neg_sample_ratio - num_positives,))
else:
# (3.1) Subsample positives without replacement.
positives_idx = torch.IntTensor(np.random.choice(positives_idx, size=self.neg_sample_ratio, replace=False))
# (3.2) Generate random entities.
negative_idx = torch.randint(low=0,
high=self.num_entities,
size=(self.neg_sample_ratio,),
dtype=torch.int32)
size=(self.neg_sample_ratio,))
# (5) Create selected indexes.
y_idx = torch.cat((positives_idx, negative_idx), 0)
# (6) Create binary labels.
Expand Down
39 changes: 29 additions & 10 deletions core/executer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
import datetime

import numpy as np
import torch
import torch.nn.functional as F
from pytorch_lightning import seed_everything

from core.knowledge_graph import KG
from core.models.base_model import BaseKGE
from core.evaluator import Evaluator
# Avoid
from core.static_funcs import *
from core.static_preprocess_funcs import preprocesses_input_args
from core.sanity_checkers import *
Expand Down Expand Up @@ -100,14 +99,25 @@ def save_trained_model(self) -> None:

def start(self) -> dict:
"""
(1) Data Preparation:
(1.1) Read, Preprocess Index, Serialize.
(1.2) Load a data that has been in (1.1).
(2) Train & Eval
(3) Save the model
(4) Return a report of the training
Start training
# (1) Loading the Data
# (2) Create an evaluator object.
# (3) Create a trainer object.
# (4) Start the training
# (5) Store trained model.
# (6) Eval model if required.
Parameter
---------
Returns
-------
A dict containing information about the training and/or evaluation
"""
start_time = time.time()
print(f"Start time:{datetime.datetime.now()}")
# (1) Loading the Data
# Load the indexed data from disk or read a raw data from disk.
self.load_indexed_data() if self.is_continual_training else self.read_preprocess_index_serialize_data()
Expand All @@ -134,7 +144,12 @@ def start(self) -> dict:


class ContinuousExecute(Execute):
""" Continue training a pretrained KGE model """
""" A subclass of Execute Class for retraining
(1) Loading & Preprocessing & Serializing input data.
(2) Training & Validation & Testing
(3) Storing all necessary info
"""

def __init__(self, args):
assert os.path.exists(args.path_experiment_folder)
Expand All @@ -158,10 +173,13 @@ def __init__(self, args):
previous_args.full_storage_path = previous_args.path_experiment_folder
print('ContinuousExecute starting...')
print(previous_args)
# TODO: can we remove continuous_training from Execute ?
super().__init__(previous_args, continuous_training=True)

def continual_start(self) -> dict:
"""
Start Continual Training
(1) Initialize training.
(2) Start continual training.
(3) Save trained model.
Expand All @@ -171,7 +189,8 @@ def continual_start(self) -> dict:
Returns
-------
report:dict
A dict containing information about the training and/or evaluation
"""
# (1)
self.trainer = DICE_Trainer(args=self.args, is_continual_training=True,
Expand Down

0 comments on commit c197f42

Please sign in to comment.