Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

typing: add initial types #488

Merged
merged 2 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Source = "https://github.com/sloria/TextBlob"
[project.optional-dependencies]
docs = ["sphinx==8.1.3", "sphinx-issues==5.0.0", "PyYAML==6.0.2"]
tests = ["pytest", "numpy"]
dev = ["textblob[tests]", "tox", "pre-commit>=3.5,<5.0"]
dev = ["textblob[tests]", "tox", "pre-commit>=3.5,<5.0", "pyright", "ruff"]

[build-system]
requires = ["flit_core<4"]
Expand Down Expand Up @@ -86,6 +86,7 @@ select = [
"I", # isort
"UP", # pyupgrade
"W", # pycodestyle warning
"TC", # flake8-typechecking
]

[tool.ruff.lint.per-file-ignores]
Expand All @@ -96,3 +97,6 @@ markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"numpy: marks tests that require numpy",
]

[tool.pyright]
include = ["src/**", "tests/**"]
23 changes: 11 additions & 12 deletions src/textblob/_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def keys(self):
def values(self):
return self._lazy("values")

def update(self, *args):
def update(self, *args, **kwargs):
return self._lazy("update", *args)

def pop(self, *args):
Expand Down Expand Up @@ -324,10 +324,10 @@ def penntreebank2universal(token, tag):
("cry", -1.00): set((":'(", ":'''(", ";'(")),
}

RE_EMOTICONS = [
TEMP_RE_EMOTICONS = [
r" ?".join([re.escape(each) for each in e]) for v in EMOTICONS.values() for e in v
]
RE_EMOTICONS = re.compile(r"(%s)($|\s)" % "|".join(RE_EMOTICONS))
RE_EMOTICONS = re.compile(r"(%s)($|\s)" % "|".join(TEMP_RE_EMOTICONS))

# Handle sarcasm punctuation (!).
RE_SARCASM = re.compile(r"\( ?\! ?\)")
Expand Down Expand Up @@ -490,9 +490,9 @@ class Lexicon(lazydict):
def __init__(
self,
path="",
morphology=None,
context=None,
entities=None,
morphology="",
context="",
entities="",
NNP="NNP",
language=None,
):
Expand Down Expand Up @@ -724,7 +724,7 @@ def apply(self, tokens):
t[i] = [t[i][0], r[1]]
return t[len(o) : -len(o)]

def insert(self, i, tag1, tag2, cmd="prevtag", x=None, y=None):
def insert(self, i, tag1, tag2, cmd="prevtag", x=None, y=None, *args):
"""Inserts a new rule that updates words with tag1 to tag2,
given constraints x and y, e.g., Context.append("TO < NN", "VB")
"""
Expand All @@ -739,7 +739,7 @@ def insert(self, i, tag1, tag2, cmd="prevtag", x=None, y=None):
def append(self, *args, **kwargs):
self.insert(len(self) - 1, *args, **kwargs)

def extend(self, rules=None):
def extend(self, rules=None, *args):
if rules is None:
rules = []
for r in rules:
Expand Down Expand Up @@ -1570,9 +1570,8 @@ def parse(

TOKENS = "tokens"


class TaggedString(str):
def __new__(self, string, tags=None, language=None):
def __new__(cls, string, tags=None, language=None):
"""Unicode string with tags and language attributes.
For example: TaggedString("cat/NN/NP", tags=["word", "pos", "chunk"]).
"""
Expand All @@ -1588,7 +1587,7 @@ def __new__(self, string, tags=None, language=None):
for s in string
]
string = "\n".join(" ".join("/".join(token) for token in s) for s in string)
s = str.__new__(self, string)
s = str.__new__(cls, string)
s.tags = list(tags)
s.language = language
return s
Expand Down Expand Up @@ -1634,7 +1633,7 @@ def language(self):
return self._language

@classmethod
def train(self, s, path="spelling.txt"):
def train(cls, s, path="spelling.txt"):
"""Counts the words in the given string and saves the probabilities at the given path.
This can be used to generate a new model for the Spelling() constructor.
"""
Expand Down
30 changes: 19 additions & 11 deletions src/textblob/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@
All base classes are defined in the same module, ``textblob.base``.
"""

from __future__ import annotations

from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING

import nltk

if TYPE_CHECKING:
from typing import Any, AnyStr

##### POS TAGGERS #####


Expand All @@ -19,11 +25,11 @@ class BaseTagger(metaclass=ABCMeta):
"""

@abstractmethod
def tag(self, text, tokenize=True):
def tag(self, text: str, tokenize=True) -> list[tuple[str, str]]:
"""Return a list of tuples of the form (word, tag)
for a given set of text or BaseBlob instance.
"""
return
...


##### NOUN PHRASE EXTRACTORS #####
Expand All @@ -36,29 +42,29 @@ class BaseNPExtractor(metaclass=ABCMeta):
"""

@abstractmethod
def extract(self, text):
def extract(self, text: str) -> list[str]:
"""Return a list of noun phrases (strings) for a body of text."""
return
...


##### TOKENIZERS #####


class BaseTokenizer(nltk.tokenize.api.TokenizerI, metaclass=ABCMeta):
class BaseTokenizer(nltk.tokenize.api.TokenizerI, metaclass=ABCMeta): # pyright: ignore
"""Abstract base class from which all Tokenizer classes inherit.
Descendant classes must implement a ``tokenize(text)`` method
that returns a list of noun phrases as strings.
"""

@abstractmethod
def tokenize(self, text):
def tokenize(self, text: str) -> list[str]:
"""Return a list of tokens (strings) for a body of text.

:rtype: list
"""
return
...

def itokenize(self, text, *args, **kwargs):
def itokenize(self, text: str, *args, **kwargs):
"""Return a generator that generates tokens "on-demand".

.. versionadded:: 0.6.0
Expand All @@ -81,6 +87,8 @@ class BaseSentimentAnalyzer(metaclass=ABCMeta):
results of analysis.
"""

_trained: bool

kind = DISCRETE

def __init__(self):
Expand All @@ -91,7 +99,7 @@ def train(self):
self._trained = True

@abstractmethod
def analyze(self, text):
def analyze(self, text) -> Any:
"""Return the result of of analysis. Typically returns either a
tuple, float, or dictionary.
"""
Expand All @@ -111,6 +119,6 @@ class BaseParser(metaclass=ABCMeta):
"""

@abstractmethod
def parse(self, text):
def parse(self, text: AnyStr):
"""Parses the text."""
return
...
8 changes: 4 additions & 4 deletions src/textblob/blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def lemmatize(self, pos=None):
lemmatizer = nltk.stem.WordNetLemmatizer()
return lemmatizer.lemmatize(self.string, tag)

PorterStemmer = nltk.stem.porter.PorterStemmer()
LancasterStemmer = nltk.stem.lancaster.LancasterStemmer()
SnowballStemmer = nltk.stem.snowball.SnowballStemmer("english")
PorterStemmer = nltk.stem.PorterStemmer()
LancasterStemmer = nltk.stem.LancasterStemmer()
SnowballStemmer = nltk.stem.SnowballStemmer("english")

# added 'stemmer' on lines of lemmatizer
# based on nltk
Expand Down Expand Up @@ -308,7 +308,7 @@ def _initialize_models(
obj.tokenizer = _validated_param(
tokenizer,
"tokenizer",
base_class=(BaseTokenizer, nltk.tokenize.api.TokenizerI),
base_class=(BaseTokenizer, nltk.tokenize.api.TokenizerI), # pyright: ignore
default=BaseBlob.tokenizer,
base_class_name="BaseTokenizer",
)
Expand Down
4 changes: 2 additions & 2 deletions src/textblob/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,8 +510,8 @@ def update(


class MaxEntClassifier(NLTKClassifier):
__doc__ = nltk.classify.maxent.MaxentClassifier.__doc__
nltk_class = nltk.classify.maxent.MaxentClassifier
__doc__ = nltk.classify.MaxentClassifier.__doc__
nltk_class = nltk.classify.MaxentClassifier

def prob_classify(self, text):
"""Return the label probability distribution for classifying a string
Expand Down
13 changes: 12 additions & 1 deletion src/textblob/decorators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
"""Custom decorators."""

from __future__ import annotations

from functools import wraps
from typing import TYPE_CHECKING

from textblob.exceptions import MissingCorpusError

if TYPE_CHECKING:
from collections.abc import Callable
from typing import TypeVar

ReturnType = TypeVar("ReturnType")


class cached_property:
"""A property that is only computed once per instance and then replaces
Expand All @@ -24,7 +33,9 @@ def __get__(self, obj, cls):
return value


def requires_nltk_corpus(func):
def requires_nltk_corpus(
func: Callable[..., ReturnType],
) -> Callable[..., ReturnType]:
"""Wraps a function that requires an NLTK corpus. If the corpus isn't found,
raise a :exc:`MissingCorpusError`.
"""
Expand Down
Loading
Loading