-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbase.py
56 lines (42 loc) · 1.6 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""src/models/base.py
Base classes for models.
"""
import abc
import logging
import os
from datetime import datetime
from typing import Any, Dict, Tuple
import pandas as pd
from src.datasets import Dataset
class TrajectoryModel(abc.ABC):
"""A base class for trajectory machine learning models."""
def __init__(self, dataset: Dataset):
self.dataset = dataset
self.vocab_sizes = dataset.get_vocab_sizes()
@abc.abstractmethod
def train(self, optimizer, epochs: int, batch_size: int, **kwargs):
"""Train the model."""
raise NotImplementedError
@abc.abstractmethod
def predict(self, df: pd.DataFrame):
"""Use the model to predict (or generate) given new input data."""
raise NotImplementedError
@abc.abstractmethod
def save(self, save_path: os.PathLike):
"""Serialize the model to a checkpoint on disk."""
@abc.abstractclassmethod
def restore(cls, save_path: os.PathLike):
"""Restore the model from a checkpoint on disk."""
def __repr__(self):
return type(self).__name__
def log_start(log: logging.Logger, exp_name: str, **hparams: Dict[str, Any]):
"""Write a log entry for training experiment start."""
start_time = datetime.now()
log.info(f"Running experiment {exp_name} with hparams: {hparams}")
return start_time
def log_end(log: logging.Logger, exp_name: str, start_time: datetime):
"""Write a log entry for training experiment end."""
end_time = datetime.now()
duration = end_time - start_time
log.info(f"Experiment {exp_name} finished in {duration}")
return duration