Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update maze-dataset dep and poetry lockfile #213

Merged
merged 12 commits into from
May 14, 2024
2 changes: 1 addition & 1 deletion maze_transformer/evaluation/baseline_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
get_origin_tokens,
get_path_tokens,
get_target_tokens,
strings_to_coords,
)
from maze_dataset.tokenization.util import strings_to_coords
from transformer_lens import HookedTransformer

from maze_transformer.training.config import ConfigHolder
Expand Down
6 changes: 3 additions & 3 deletions maze_transformer/evaluation/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
)
from maze_dataset.tokenization import MazeTokenizer
from maze_dataset.tokenization.token_utils import (
WhenMissing,
get_context_tokens,
get_path_tokens,
remove_padding_from_token_str,
strings_to_coords,
)
from maze_dataset.tokenization.util import strings_to_coords
from maze_dataset.utils import WhenMissing

# muutils
from muutils.mlutils import chunks
Expand Down Expand Up @@ -143,7 +143,7 @@ def predict_maze_paths(
smart_max_new_tokens
), "if max_new_tokens is None, smart_max_new_tokens must be True"

maze_tokenizer: MazeTokenizer = model.config.maze_tokenizer
maze_tokenizer: MazeTokenizer = model.tokenizer._maze_tokenizer

contexts_lists: list[list[str]] = [
get_context_tokens(tokens) for tokens in tokens_batch
Expand Down
2 changes: 1 addition & 1 deletion maze_transformer/evaluation/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def plot_predicted_paths(
if n_mazes is None:
n_mazes = len(dataset)

dataset_tokens = dataset.as_tokens(model.config.maze_tokenizer)[:n_mazes]
dataset_tokens = dataset.as_tokens(model.tokenizer._maze_tokenizer)[:n_mazes]

# predict
predictions: list[list[str | tuple[int, int]]] = predict_maze_paths(
Expand Down
2 changes: 1 addition & 1 deletion maze_transformer/mechinterp/plot_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from maze_dataset.plotting.plot_tokens import plot_colored_text
from maze_dataset.plotting.print_tokens import color_tokens_cmap
from maze_dataset.tokenization import MazeTokenizer
from maze_dataset.tokenization.token_utils import coord_str_to_tuple_noneable
from maze_dataset.tokenization.util import coord_str_to_tuple_noneable

# Utilities
from muutils.json_serialize import SerializableDataclass, serializable_dataclass
Expand Down
2 changes: 1 addition & 1 deletion maze_transformer/mechinterp/residual_stream_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# maze_dataset
from maze_dataset.constants import _SPECIAL_TOKENS_ABBREVIATIONS
from maze_dataset.tokenization import MazeTokenizer
from maze_dataset.tokenization.token_utils import strings_to_coords
from maze_dataset.tokenization.util import strings_to_coords

# scipy
from scipy.spatial.distance import pdist, squareform
Expand Down
26 changes: 23 additions & 3 deletions maze_transformer/training/train_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import typing
import warnings
from pathlib import Path
from typing import Union

Expand Down Expand Up @@ -122,9 +123,28 @@ def train_model(
f"passed dataset has different config than cfg.dataset_cfg, but allow_dataset_override is True, so using passed dataset"
)
else:
raise ValueError(
f"dataset has different config than cfg.dataset_cfg, and allow_dataset_override is False"
)
datasets_cfg_diff: dict = dataset.cfg.diff(cfg.dataset_cfg)
if datasets_cfg_diff == {
"applied_filters": {
"self": [
{
"name": "collect_generation_meta",
"args": (),
"kwargs": {},
}
],
"other": [],
}
}:
warnings.warn(
f"dataset has different config than cfg.dataset_cfg, but the only difference is in applied_filters, so using passed dataset. This is due to fast dataset loading collecting generation metadata for performance reasons"
)

else:
raise ValueError(
f"dataset has different config than cfg.dataset_cfg, and allow_dataset_override is False",
f"{datasets_cfg_diff = }",
)

logger.progress(f"finished getting training dataset with {len(dataset)} samples")
# validation dataset, if applicable
Expand Down
16 changes: 8 additions & 8 deletions notebooks/residual_stream_decoding.ipynb

Large diffs are not rendered by default.

197 changes: 87 additions & 110 deletions notebooks/train_model.ipynb

Large diffs are not rendered by default.

2,724 changes: 1,367 additions & 1,357 deletions poetry.lock

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@ repository = "https://github.com/understanding-search/maze-transformer"
[tool.poetry.dependencies]
python = ">=3.10,<3.13"
# dataset
maze-dataset = "^0.4.5"
maze-dataset = "^0.5.2"
# transformers
torch = ">=1.13.1"
transformer-lens = "1.14.0"
transformer-lens = "^1.14.0"
transformers = ">=4.34" # Dependency in transformer-lens 1.14.0
# utils
muutils = "^0.5.5"
zanj = "^0.2.0"
wandb = "^0.13.5" # note: TransformerLens forces us to use 0.13.5
# wandb = "^0.13.5" # note: TransformerLens forces us to use 0.13.5
wandb = "^0.17.0"
fire = "^0.5.0"
typing-extensions = "^4.8.0"
# plotting
Expand Down
Loading