Skip to content

Commit

Permalink
set log quietness based on supplied verbosity
Browse files Browse the repository at this point in the history
  • Loading branch information
zouharvi committed Dec 5, 2024
1 parent 4b43b91 commit 4eb1d09
Showing 1 changed file with 7 additions and 12 deletions.
19 changes: 7 additions & 12 deletions py_irt/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,11 @@
from sklearn.feature_extraction.text import CountVectorizer


# These imports are necessary to have @register run
# This import is necessary to have @register run
# pylint: disable=unused-import
from py_irt.models import (
abstract_model,
one_param_logistic,
two_param_logistic,
three_param_logistic,
four_param_logistic,
multidim_2pl,
amortized_1pl,
)
import py_irt.models

from py_irt.models import abstract_model
from py_irt.io import safe_file, write_json
from py_irt.dataset import Dataset
from py_irt.initializers import INITIALIZERS, IrtInitializer
Expand Down Expand Up @@ -81,6 +75,7 @@ def __init__(
self._pyro_model = None
self._pyro_guide = None
self._verbose = verbose
console.quiet = not self._verbose
self.best_params = None
if dataset is None:
self._dataset = Dataset.from_jsonlines(data_path, amortized=self.amortized)
Expand Down Expand Up @@ -179,7 +174,6 @@ def train(self, *, epochs: Optional[int] = None, device: str = "cpu") -> None:
responses = torch.tensor(
self._dataset.observations, dtype=torch.float, device=device
)
print(subjects.size(), items.size())
# Don't take a step here, just make sure params are initialized
# so that initializers can modify the params
_ = self._pyro_model(subjects, items, responses)
Expand All @@ -195,7 +189,8 @@ def train(self, *, epochs: Optional[int] = None, device: str = "cpu") -> None:
loss = float("inf")
best_loss = loss
current_lr = self._config.lr
with Live(table) as live:
with Live(table if self._verbose else None) as live:
live.console.quiet = not self._verbose
live.console.print(f"Training Pyro IRT Model for {epochs} epochs")
for epoch in range(epochs):
loss = svi.step(subjects, items, responses)
Expand Down

0 comments on commit 4eb1d09

Please sign in to comment.