Skip to content

Commit

Permalink
Handle special tokens like other vocabularies.
Browse files Browse the repository at this point in the history
  • Loading branch information
sychen52 committed Dec 3, 2024
1 parent aee0021 commit 81f3e7b
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
38 changes: 28 additions & 10 deletions axlearn/experiments/text/gpt/vocabulary_fuji_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,19 @@ def pad_id(self) -> int:
for token in ("<|pad_id|>", "<|finetune_right_pad_id|>"):
if token in self.vocab:
return self.vocab[token]
else:
raise ValueError("Unable to infer pad token.")
raise ValueError("Unable to infer pad token.")

@property
def eos_id(self) -> Optional[int]:
if "<|end_of_text|>" in self.vocab:
return self.vocab["<|end_of_text|>"]
else:
raise NotImplementedError()
raise ValueError("Unable to infer eos token.")

@property
def bos_id(self) -> Optional[int]:
if "<|begin_of_text|>" in self.vocab:
return self.vocab["<|begin_of_text|>"]
raise ValueError("Unable to infer eos token.")

def _encode_tf(self, s: tf.Tensor) -> tf.Tensor:
"""Encodes a string to token IDs.
Expand All @@ -144,8 +148,12 @@ def _encode_tf(self, s: tf.Tensor) -> tf.Tensor:

def helper(s):
s = s.numpy()
res = self._tokenizer.encode_batch([item.decode("utf-8") for item in s])
return tf.ragged.constant([r.ids for r in res])
res = self._tokenizer.encode_batch(
[item.decode("utf-8") for item in s], add_special_tokens=True
)
# The return does not include EOS, but we need to remove BOS.
res = [item.ids[1:] if item.ids[0] == self.bos_id else item.ids for item in res]
return tf.ragged.constant(res, dtype=tf.int32)

ret = tf.py_function(
helper, inp=[s], Tout=tf.RaggedTensorSpec([None, None], dtype=tf.int32)
Expand All @@ -158,12 +166,17 @@ def helper(s):
def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor:
"""Detokenizes int32 batched Tensor."""
need_unpack = False
if ids.ndim == 1:
if len(ids.shape) == 1:
ids = tf.reshape(ids, (1, -1))
need_unpack = True

def helper(ids):
s = self._tokenizer.decode_batch(ids.numpy().tolist(), skip_special_tokens=True)
ids = [ids[i].numpy() for i in range(ids.shape[0])]
ids = [
item[(item != self.bos_id) & (item != self.eos_id) & (item != self.pad_id)]
for item in ids
]
s = self._tokenizer.decode_batch(ids, skip_special_tokens=False)
return tf.convert_to_tensor(s, dtype=tf.string)

ret = tf.py_function(helper, inp=[ids], Tout=tf.string)
Expand All @@ -186,11 +199,16 @@ def encode_tf(self, s: tf.Tensor) -> tf.Tensor:

def encode(self, s: str) -> list[int]:
"""Tokenizes string to an int sequence, without adding EOS."""
return self._tokenizer.encode(s).ids
ret = self._tokenizer.encode(s, add_special_tokens=True).ids
# The return does not include EOS, but we need to remove BOS.
return ret[1:] if ret[0] == self.bos_id else ret

def _decode(self, ids: Union[list[int], tuple[int]]) -> str:
"""Detokenizes int32 iterable to a string."""
return self._tokenizer.decode(ids)
# remove BOS, EOS and PAD.
ids = np.array(ids)
ids = ids[(ids != self.bos_id) & (ids != self.eos_id) & (ids != self.pad_id)]
return self._tokenizer.decode(ids, skip_special_tokens=False)

def decode(self, ids: Union[list[int], tuple[int], jax.Array, np.ndarray]) -> str:
"""Detokenizes int32 iterable to a string, up through first EOS."""
Expand Down
6 changes: 2 additions & 4 deletions axlearn/experiments/text/gpt/vocabulary_fuji_v3_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_encode_tf_and_decode_tf(self):
def test_tokenize_example(self):
vocab = self.vocab_cfg.instantiate()
newlines_replaced_with = "<n>"
newlines_replaced_with_id = vocab.encode(newlines_replaced_with)[1:] # remove bos token
newlines_replaced_with_id = vocab.encode(newlines_replaced_with)

# Test tokenize_example replaces newlines.
tokens = input_text.tokenize_example(
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_fake_text_lm_training_data(
processor=config_for_function(text_to_lm_training_input).set(
vocab_cfg=self.vocab_cfg,
max_len=max_len,
replace_newlines_with="<n>",
replace_newlines_with="\n",
window_size=window_size,
max_padding_fraction=max_padding_fraction,
shuffle_buffer_size=shuffle_buffer_size,
Expand Down Expand Up @@ -193,9 +193,7 @@ def test_eval_lm_processor_single_example(self, text, index_key):

input_ids, target_labels = example["input_ids"].numpy(), example["target_labels"].numpy()
self.assertEqual(128001, input_ids[0]) # EOS
self.assertEqual(128000, input_ids[1]) # BOS
non_padded_length = (target_labels == 128004).argmax()
self.assertEqual(128000, target_labels[0]) # BOS at start.
self.assertEqual(128001, target_labels[non_padded_length - 1]) # EOS.
# The inputs should be one-off the labels.
self.assertNestedAllClose(
Expand Down

0 comments on commit 81f3e7b

Please sign in to comment.