Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 5, 2024
1 parent 72ae995 commit d0932e9
Show file tree
Hide file tree
Showing 15 changed files with 71 additions and 151 deletions.
20 changes: 12 additions & 8 deletions nemo_text_processing/text_normalization/hi/taggers/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@ def __init__(self, cardinal: GraphFst, decimal: GraphFst):
unit = pynutil.insert("units: \"") + unit_graph + pynutil.insert("\" ")

# Handling symbols like x, X, *, -
symbol_graph = pynini.string_map([
("x", "बाई"),
("X", "बाई"),
("*", "बाई"),
# ("-", "से")
])
symbol_graph = pynini.string_map(
[
("x", "बाई"),
("X", "बाई"),
("*", "बाई"),
# ("-", "से")
]
)

graph_measurements = (
pynutil.insert("decimal { ")
Expand All @@ -76,15 +78,17 @@ def __init__(self, cardinal: GraphFst, decimal: GraphFst):
+ unit
)

# Handling cardinal clubbed with symbol as single token
# Handling cardinal clubbed with symbol as single token
graph_measurements |= (
pynutil.insert("cardinal { ")
+ optional_graph_negative
+ pynutil.insert("integer: \"")
+ cardinal_graph
+ pynutil.insert("\"")
+ pynutil.insert(" }")
+ pynutil.insert(" units: \"") + symbol_graph + pynutil.insert("\" ")
+ pynutil.insert(" units: \"")
+ symbol_graph
+ pynutil.insert("\" ")
+ pynutil.insert("} }")
+ insert_space
+ pynutil.insert("tokens { cardinal { ")
Expand Down
31 changes: 6 additions & 25 deletions nemo_text_processing/text_normalization/hi/taggers/money.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
import pynini
from pynini.lib import pynutil

from nemo_text_processing.text_normalization.hi.graph_utils import (
GraphFst,
insert_space,
)
from nemo_text_processing.text_normalization.hi.graph_utils import GraphFst, insert_space
from nemo_text_processing.text_normalization.hi.utils import get_abs_path

currency_graph = pynini.string_file(get_abs_path("data/money/currency.tsv"))
Expand All @@ -44,30 +41,14 @@ def __init__(self, cardinal: GraphFst):

cardinal_graph = cardinal.final_graph

currency_major = (
pynutil.insert('currency_maj: "') + currency_graph + pynutil.insert('"')
)
integer = (
pynutil.insert('integer_part: "') + cardinal_graph + pynutil.insert('"')
)
fraction = (
pynutil.insert('fractional_part: "') + cardinal_graph + pynutil.insert('"')
)
currency_minor = (
pynutil.insert('currency_min: "')
+ pynutil.insert("centiles")
+ pynutil.insert('"')
)
currency_major = pynutil.insert('currency_maj: "') + currency_graph + pynutil.insert('"')
integer = pynutil.insert('integer_part: "') + cardinal_graph + pynutil.insert('"')
fraction = pynutil.insert('fractional_part: "') + cardinal_graph + pynutil.insert('"')
currency_minor = pynutil.insert('currency_min: "') + pynutil.insert("centiles") + pynutil.insert('"')

graph_major_only = currency_major + insert_space + integer
graph_major_and_minor = (
currency_major
+ insert_space
+ integer
+ pynini.cross(".", " ")
+ fraction
+ insert_space
+ currency_minor
currency_major + insert_space + integer + pynini.cross(".", " ") + fraction + insert_space + currency_minor
)

graph_currencies = graph_major_only | graph_major_and_minor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,10 @@
from nemo_text_processing.text_normalization.hi.taggers.fraction import FractionFst
from nemo_text_processing.text_normalization.hi.taggers.measure import MeasureFst
from nemo_text_processing.text_normalization.hi.taggers.money import MoneyFst
from nemo_text_processing.text_normalization.hi.taggers.punctuation import (
PunctuationFst,
)
from nemo_text_processing.text_normalization.hi.taggers.punctuation import PunctuationFst
from nemo_text_processing.text_normalization.hi.taggers.time import TimeFst
from nemo_text_processing.text_normalization.hi.taggers.word import WordFst
from nemo_text_processing.text_normalization.hi.taggers.whitelist import WhiteListFst
from nemo_text_processing.text_normalization.hi.taggers.word import WordFst


class ClassifyFst(GraphFst):
Expand All @@ -63,17 +61,14 @@ def __init__(
overwrite_cache: bool = False,
whitelist: str = None,
):
super().__init__(
name="tokenize_and_classify", kind="classify", deterministic=deterministic
)
super().__init__(name="tokenize_and_classify", kind="classify", deterministic=deterministic)

far_file = None
if cache_dir is not None and cache_dir != "None":
os.makedirs(cache_dir, exist_ok=True)
whitelist_file = os.path.basename(whitelist) if whitelist else ""
far_file = os.path.join(
cache_dir,
f"hi_tn_{deterministic}_deterministic_{input_case}_{whitelist_file}_tokenize.far",
cache_dir, f"hi_tn_{deterministic}_deterministic_{input_case}_{whitelist_file}_tokenize.far",
)
if not overwrite_cache and far_file and os.path.exists(far_file):
self.fst = pynini.Far(far_file, mode="r")["tokenize_and_classify"]
Expand All @@ -84,66 +79,48 @@ def __init__(
start_time = time.time()
cardinal = CardinalFst(deterministic=deterministic)
cardinal_graph = cardinal.fst
logging.debug(
f"cardinal: {time.time() - start_time: .2f}s -- {cardinal_graph.num_states()} nodes"
)
logging.debug(f"cardinal: {time.time() - start_time: .2f}s -- {cardinal_graph.num_states()} nodes")

start_time = time.time()
decimal = DecimalFst(cardinal=cardinal, deterministic=deterministic)
decimal_graph = decimal.fst
logging.debug(
f"decimal: {time.time() - start_time: .2f}s -- {decimal_graph.num_states()} nodes"
)
logging.debug(f"decimal: {time.time() - start_time: .2f}s -- {decimal_graph.num_states()} nodes")

start_time = time.time()
fraction = FractionFst(cardinal=cardinal, deterministic=deterministic)
fraction_graph = fraction.fst
logging.debug(
f"fraction: {time.time() - start_time: .2f}s -- {fraction_graph.num_states()} nodes"
)
logging.debug(f"fraction: {time.time() - start_time: .2f}s -- {fraction_graph.num_states()} nodes")

start_time = time.time()
date = DateFst(cardinal=cardinal)
date_graph = date.fst
logging.debug(
f"date: {time.time() - start_time: .2f}s -- {date_graph.num_states()} nodes"
)
logging.debug(f"date: {time.time() - start_time: .2f}s -- {date_graph.num_states()} nodes")

start_time = time.time()
timefst = TimeFst()
time_graph = timefst.fst
logging.debug(
f"time: {time.time() - start_time: .2f}s -- {time_graph.num_states()} nodes"
)
logging.debug(f"time: {time.time() - start_time: .2f}s -- {time_graph.num_states()} nodes")

start_time = time.time()
measure = MeasureFst(cardinal=cardinal, decimal=decimal)
measure_graph = measure.fst
logging.debug(
f"measure: {time.time() - start_time: .2f}s -- {measure_graph.num_states()} nodes"
)
logging.debug(f"measure: {time.time() - start_time: .2f}s -- {measure_graph.num_states()} nodes")

start_time = time.time()
money = MoneyFst(cardinal=cardinal)
money_graph = money.fst
logging.debug(
f"money: {time.time() - start_time: .2f}s -- {money_graph.num_states()} nodes"
)
logging.debug(f"money: {time.time() - start_time: .2f}s -- {money_graph.num_states()} nodes")

start_time = time.time()
whitelist_graph = WhiteListFst(
input_case=input_case, deterministic=deterministic, input_file=whitelist
).fst
logging.debug(
f"whitelist: {time.time() - start_time: .2f}s -- {whitelist_graph.num_states()} nodes"
)
logging.debug(f"whitelist: {time.time() - start_time: .2f}s -- {whitelist_graph.num_states()} nodes")

start_time = time.time()
punctuation = PunctuationFst(deterministic=deterministic)
punct_graph = punctuation.fst
logging.debug(
f"punct: {time.time() - start_time: .2f}s -- {punct_graph.num_states()} nodes"
)
logging.debug(f"punct: {time.time() - start_time: .2f}s -- {punct_graph.num_states()} nodes")

classify = (
pynutil.add_weight(whitelist_graph, 1.01)
Expand All @@ -157,18 +134,10 @@ def __init__(
)

start_time = time.time()
word_graph = WordFst(
punctuation=punctuation, deterministic=deterministic
).fst
logging.debug(
f"word: {time.time() - start_time: .2f}s -- {word_graph.num_states()} nodes"
)
word_graph = WordFst(punctuation=punctuation, deterministic=deterministic).fst
logging.debug(f"word: {time.time() - start_time: .2f}s -- {word_graph.num_states()} nodes")

punct = (
pynutil.insert("tokens { ")
+ pynutil.add_weight(punct_graph, weight=2.1)
+ pynutil.insert(" }")
)
punct = pynutil.insert("tokens { ") + pynutil.add_weight(punct_graph, weight=2.1) + pynutil.insert(" }")
punct = pynini.closure(
pynini.compose(pynini.closure(NEMO_WHITE_SPACE, 1), delete_extra_space)
| (pynutil.insert(" ") + punct),
Expand All @@ -178,16 +147,12 @@ def __init__(
classify |= pynutil.add_weight(word_graph, 100)
token = pynutil.insert("tokens { ") + classify + pynutil.insert(" }")
token_plus_punct = (
pynini.closure(punct + pynutil.insert(" "))
+ token
+ pynini.closure(pynutil.insert(" ") + punct)
pynini.closure(punct + pynutil.insert(" ")) + token + pynini.closure(pynutil.insert(" ") + punct)
)

graph = token_plus_punct + pynini.closure(
(
pynini.compose(
pynini.closure(NEMO_WHITE_SPACE, 1), delete_extra_space
)
pynini.compose(pynini.closure(NEMO_WHITE_SPACE, 1), delete_extra_space)
| (pynutil.insert(" ") + punct + pynutil.insert(" "))
)
+ token_plus_punct
Expand Down
14 changes: 1 addition & 13 deletions nemo_text_processing/text_normalization/hi/taggers/word.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pynini
from pynini.lib import pynutil

from nemo_text_processing.text_normalization.hi.graph_utils import (
MIN_NEG_WEIGHT,
NEMO_NOT_SPACE,
GraphFst,
convert_space,
)
from nemo_text_processing.text_normalization.hi.taggers.punctuation import PunctuationFst


import pynini
from pynini.lib import pynutil

Expand Down Expand Up @@ -63,7 +51,7 @@ def __init__(self, punctuation: PunctuationFst, deterministic: bool = True):
punct = punctuation.graph
default_graph = pynini.closure(pynini.difference(NEMO_NOT_SPACE, punct.project("input")), 1)
symbols_to_exclude = (pynini.union("$", "€", "₩", "£", "¥", "#", "%") | punct).optimize()

# Use HINDI_CHAR in the graph
graph = pynini.closure(pynini.difference(HINDI_CHAR, symbols_to_exclude), 1)
graph = pynutil.add_weight(graph, MIN_NEG_WEIGHT) | default_graph
Expand Down
44 changes: 8 additions & 36 deletions nemo_text_processing/text_normalization/hi/verbalizers/money.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,8 @@
import pynini
from pynini.lib import pynutil

from nemo_text_processing.text_normalization.hi.graph_utils import (
NEMO_NOT_QUOTE,
NEMO_SPACE,
GraphFst,
)
from nemo_text_processing.text_normalization.hi.data.money.major_minor_currencies import (
major_minor_currencies,
)
from nemo_text_processing.text_normalization.hi.data.money.major_minor_currencies import major_minor_currencies
from nemo_text_processing.text_normalization.hi.graph_utils import NEMO_NOT_QUOTE, NEMO_SPACE, GraphFst


class MoneyFst(GraphFst):
Expand All @@ -42,22 +36,12 @@ class MoneyFst(GraphFst):
def __init__(self):
super().__init__(name="money", kind="verbalize")

currency_major = (
pynutil.delete('currency_maj: "')
+ pynini.closure(NEMO_NOT_QUOTE, 1)
+ pynutil.delete('"')
)
currency_major = pynutil.delete('currency_maj: "') + pynini.closure(NEMO_NOT_QUOTE, 1) + pynutil.delete('"')

integer_part = (
pynutil.delete('integer_part: "')
+ pynini.closure(NEMO_NOT_QUOTE, 1)
+ pynutil.delete('"')
)
integer_part = pynutil.delete('integer_part: "') + pynini.closure(NEMO_NOT_QUOTE, 1) + pynutil.delete('"')

fractional_part = (
pynutil.delete('fractional_part: "')
+ pynini.closure(NEMO_NOT_QUOTE, 1)
+ pynutil.delete('"')
pynutil.delete('fractional_part: "') + pynini.closure(NEMO_NOT_QUOTE, 1) + pynutil.delete('"')
)

# Handles major denominations only
Expand All @@ -71,16 +55,8 @@ def __init__(self):

# Logic for handling minor denominations
for major, minor in major_minor_currencies.items():
graph_major = (
pynutil.delete('currency_maj: "')
+ pynini.accep(major)
+ pynutil.delete('"')
)
graph_minor = (
pynutil.delete('currency_min: "')
+ pynini.cross("centiles", minor)
+ pynutil.delete('"')
)
graph_major = pynutil.delete('currency_maj: "') + pynini.accep(major) + pynutil.delete('"')
graph_minor = pynutil.delete('currency_min: "') + pynini.cross("centiles", minor) + pynutil.delete('"')
graph_major_minor_partial = (
integer_part
+ pynini.accep(NEMO_SPACE)
Expand Down Expand Up @@ -108,11 +84,7 @@ def __init__(self):
graph_major_minor = pynini.union(*major_minor_graphs)
graph_minor_only = pynini.union(*minor_graphs)

graph = (
graph_major_only
| graph_major_minor
| pynutil.add_weight(graph_minor_only, -0.1)
)
graph = graph_major_only | graph_major_minor | pynutil.add_weight(graph_minor_only, -0.1)

delete_tokens = self.delete_tokens(graph)
self.fst = delete_tokens.optimize()
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
from nemo_text_processing.text_normalization.hi.verbalizers.measure import MeasureFst
from nemo_text_processing.text_normalization.hi.verbalizers.money import MoneyFst
from nemo_text_processing.text_normalization.hi.verbalizers.time import TimeFst
from nemo_text_processing.text_normalization.hi.verbalizers.whitelist import (
WhiteListFst,
)
from nemo_text_processing.text_normalization.hi.verbalizers.whitelist import WhiteListFst


class VerbalizeFst(GraphFst):
Expand All @@ -37,9 +35,7 @@ class VerbalizeFst(GraphFst):
"""

def __init__(self, deterministic: bool = True):
super().__init__(
name="verbalize", kind="verbalize", deterministic=deterministic
)
super().__init__(name="verbalize", kind="verbalize", deterministic=deterministic)

cardinal = CardinalFst(deterministic=deterministic)
cardinal_graph = cardinal.fst
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def __init__(self, deterministic: bool = True):
chars = pynini.closure(NEMO_CHAR - " ", 1)
punct = pynini.union("!", "?", ".", ",", "-", ":", ";", "।") # Add other punctuation marks as needed
char = pynutil.delete("name:") + delete_space + pynutil.delete("\"") + chars + pynutil.delete("\"")

# Ensure no spaces around punctuation
graph = char + pynini.closure(delete_space + punct, 0, 1)

# Explicitly remove spaces before punctuation
remove_space_before_punct = pynini.cdrewrite(pynini.cross(" ", ""), "", punct, NEMO_SIGMA)
graph = graph @ remove_space_before_punct
Expand Down
4 changes: 3 additions & 1 deletion tests/nemo_text_processing/hi/test_cardinal.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@


class TestCardinal:
normalizer = Normalizer(input_case='cased', lang='hi', cache_dir=CACHE_DIR, overwrite_cache=False, post_process=False)
normalizer = Normalizer(
input_case='cased', lang='hi', cache_dir=CACHE_DIR, overwrite_cache=False, post_process=False
)
inverse_normalizer = InverseNormalizer(lang='hi', cache_dir=CACHE_DIR, overwrite_cache=False)

@parameterized.expand(parse_test_case_file('hi/data_text_normalization/test_cases_cardinal.txt'))
Expand Down
Loading

0 comments on commit d0932e9

Please sign in to comment.