Skip to content

Commit

Permalink
dropping setter interaction and casting to string
Browse files Browse the repository at this point in the history
  • Loading branch information
rhysdg committed Jul 9, 2024
1 parent fb465b0 commit 5c57af9
Show file tree
Hide file tree
Showing 3 changed files with 3,917 additions and 33 deletions.
16 changes: 12 additions & 4 deletions clip/utils/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,10 @@ def __init__(self, **kwargs):

# 4. If some of the special tokens are not part of the vocab, we add them, at the end.
# the order of addition is the same as self.SPECIAL_TOKENS_ATTRIBUTES following `tokenizers`
#Adding str(token) to resolve AddedToken unshashable type

self._add_tokens(
[token for token in self.all_special_tokens_extended if token not in self._added_tokens_encoder],
[token for token in self.all_special_tokens_extended if str(token) not in self._added_tokens_encoder],
special_tokens=True,
)

Expand Down Expand Up @@ -552,8 +554,13 @@ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_to
elif special_tokens:
# doing token.special=True changes the normalization! will fix in rust
# this is important and the only reason why the AddedTokens in each class are normalized by default
token.__setstate__({"special": True, "normalized": token.normalized})
if token in self._added_tokens_decoder:
#token.__setstate__({"special": True, "normalized": token.normalized})
#token.__setstate__({"special": True, "normalized": token.normalized})
token.special = True
token.normalized = token.normalized

#resolving unhashable type AddedToke wiht str(token)
if str(token) in self._added_tokens_decoder:
continue
if not token.special and token.normalized and getattr(self, "do_lower_case", False):
# Normalize if requested
Expand All @@ -576,9 +583,10 @@ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_to
self._update_trie()
return added_tokens

#Adding str(token) to resolve AddedToken unshashable type
def _update_trie(self, unique_no_split_tokens: Optional[str] = []):
for token in self._added_tokens_decoder.values():
if token not in self.tokens_trie._tokens:
if str(token) not in self.tokens_trie._tokens:
self.tokens_trie.add(token.content)
for token in unique_no_split_tokens:
if token not in self.tokens_trie._tokens:
Expand Down
58 changes: 29 additions & 29 deletions clip/utils/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,42 +68,42 @@
if is_flax_available():
import jax.numpy as jnp # noqa: F401

if is_tokenizers_available():
from tokenizers import AddedToken
from tokenizers import Encoding as EncodingFast
else:
#if is_tokenizers_available():
#from tokenizers import AddedToken
#from tokenizers import Encoding as EncodingFast
#else:

@dataclass(frozen=False, eq=True)
class AddedToken:
"""
AddedToken represents a token to be added to a Tokenizer An AddedToken can have special options defining the
way it should behave.
@dataclass(frozen=False, eq=True)
class AddedToken:
"""
AddedToken represents a token to be added to a Tokenizer An AddedToken can have special options defining the
way it should behave.
The `normalized` will default to `not special` if it is not specified, similarly to the definition in
`tokenizers`.
"""
The `normalized` will default to `not special` if it is not specified, similarly to the definition in
`tokenizers`.
"""

def __init__(
self, content: str, single_word=False, lstrip=False, rstrip=False, special=False, normalized=None
):
self.content = content
self.single_word = single_word
self.lstrip = lstrip
self.rstrip = rstrip
self.special = special
self.normalized = normalized if normalized is not None else not special
def __init__(
self, content: str, single_word=False, lstrip=False, rstrip=False, special=False, normalized=None
):
self.content = content
self.single_word = single_word
self.lstrip = lstrip
self.rstrip = rstrip
self.special = special
self.normalized = normalized if normalized is not None else not special

def __getstate__(self):
return self.__dict__
def __getstate__(self):
return self.__dict__

def __str__(self):
return self.content
def __str__(self):
return self.content

@dataclass
class EncodingFast:
"""This is dummy class because without the `tokenizers` library we don't have these objects anyway"""
@dataclass
class EncodingFast:
"""This is dummy class because without the `tokenizers` library we don't have these objects anyway"""

pass
pass


logger = logging.get_logger(__name__)
Expand Down
Loading

0 comments on commit 5c57af9

Please sign in to comment.