Skip to content

Commit

Permalink
Merge pull request #541 from axif0/all_in_one
Browse files Browse the repository at this point in the history
Update get cmd for 'all' when retrieving all data-types.
  • Loading branch information
andrewtavis authored Jan 7, 2025
2 parents e523d81 + 248a56f commit 8073210
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 72 deletions.
20 changes: 10 additions & 10 deletions src/scribe_data/cli/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def prompt_user_download_all():
elif data_type:
if prompt_user_download_all():
parse_wd_lexeme_dump(
language=None,
language="all",
wikidata_dump_type=["form"],
data_types=[data_type],
type_output_dir=output_dir,
Expand Down Expand Up @@ -208,35 +208,35 @@ def prompt_user_download_all():
f"Updating data for language(s): {language.title()}; data type(s): {data_type.capitalize()}"
)
existing_files = list(Path(output_dir).glob(f"{language}/{data_type}.json"))
if existing_files:
if existing_files and not overwrite:
print(
f"Existing file(s) found for {language.title()} and {data_type.capitalize()} in the {output_dir} directory."
)
for idx, file in enumerate(existing_files, start=1):
print(f"{idx}. {file.name}")

print("\nChoose an option:")
print("1. Overwrite existing data (press 'o')")
print("2. Skip process (press anything else)")
user_choice = input("Enter your choice: ").strip().lower()
user_choice = questionary.confirm(
"Overwrite existing data?", default=False
).ask()

if user_choice == "o":
if user_choice:
print("Overwrite chosen. Removing existing files...")
for file in existing_files:
file.unlink()
if file.exists(): # check if the file exists before unlinking
file.unlink()
else:
print(f"Skipping update for {language.title()} {data_type}.")
return {"success": False, "skipped": True}

query_result = query_data(
query_data(
languages=[language_or_sub_language],
data_type=data_types,
output_dir=output_dir,
overwrite=overwrite,
interactive=interactive,
)

if not all_bool and not query_result.get("skipped", False):
if not all_bool:
print(f"Updated data was saved in: {Path(output_dir).resolve()}.")

else:
Expand Down
17 changes: 12 additions & 5 deletions src/scribe_data/cli/total.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,22 @@ def get_total_lexemes(language, data_type, do_print=True):
str
A formatted string indicating the language, data type and total number of lexemes, if found.
"""

if language is not None and language.startswith("Q") and language[1:].isdigit():
language_qid = language
if (
language is not None
and (language.startswith("Q") or language.startswith("q"))
and language[1:].isdigit()
):
language_qid = language.capitalize()

else:
language_qid = get_qid_by_input(language)

if data_type is not None and data_type.startswith("Q") and data_type[1:].isdigit():
data_type_qid = data_type
if (
data_type is not None
and (data_type.startswith("Q") or data_type.startswith("q"))
and data_type[1:].isdigit()
):
data_type_qid = data_type.capitalize()

else:
data_type_qid = get_qid_by_input(data_type)
Expand Down
10 changes: 5 additions & 5 deletions src/scribe_data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,8 @@ def _load_json(package_path: str, file_name: str) -> Any:
-------
A python entity representing the JSON content.
"""
with (
resources.files(package_path)
.joinpath(file_name)
.open(encoding="utf-8") as in_stream
):
json_file = resources.files(package_path).joinpath(file_name)
with json_file.open(encoding="utf-8") as in_stream:
return json.load(in_stream)


Expand Down Expand Up @@ -547,6 +544,9 @@ def format_sublanguage_name(lang, language_metadata=_languages):
> format_sublanguage_name("english", language_metadata)
'English'
"""
if (lang.startswith("Q") or lang.startswith("q")) and lang[1:].isdigit():
return lang

for main_lang, lang_data in language_metadata.items():
# If it's not a sub-language, return the original name.
if main_lang == lang:
Expand Down
7 changes: 0 additions & 7 deletions src/scribe_data/wikidata/query_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ def query_data(

# MARK: Run Queries

# Run queries and format data.
for q in tqdm(
queries_to_run,
desc="Data updated",
Expand Down Expand Up @@ -216,7 +215,6 @@ def query_data(

else:
print(f"Skipping update for {lang.title()} {target_type}.")
return {"success": False, "skipped": True}

print(f"Querying and formatting {lang.title()} {target_type}")

Expand Down Expand Up @@ -342,8 +340,3 @@ def query_data(
print(
f"Successfully queried and formatted data for {lang.title()} {target_type}."
)
return {"success": True, "skipped": False}


# if __name__ == "__main__":
# query_data()
153 changes: 110 additions & 43 deletions tests/cli/test_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,37 +62,46 @@ def test_invalid_arguments(self):

# MARK: All Data

# @patch("scribe_data.cli.get.query_data")
# @patch("scribe_data.cli.get.prompt_user_download_all", return_value=False)
# def test_get_all_data_types_for_language(self, mock_prompt, mock_query_data):
# """
# Test retrieving all data types for a specific language.

# Ensures that `query_data` is called properly when `--all` flag is used with a language.
# """
# get_data(all_bool=True, language="English")
# mock_query_data.assert_called_once_with(
# languages=["English"],
# data_type=None,
# output_dir="scribe_data_json_export",
# overwrite=False,
# )

# @patch("scribe_data.cli.get.query_data")
# @patch("scribe_data.cli.get.prompt_user_download_all", return_value=False)
# def test_get_all_languages_for_data_type(self, mock_prompt, mock_query_data):
# """
# Test retrieving all languages for a specific data type.

# Ensures that `query_data` is called properly when `--all` flag is used with a data type.
# """
# get_data(all_bool=True, data_type="nouns")
# mock_query_data.assert_called_once_with(
# languages=None,
# data_type=["nouns"],
# output_dir="scribe_data_json_export",
# overwrite=False,
# )
@patch("scribe_data.cli.get.query_data")
@patch("scribe_data.cli.get.parse_wd_lexeme_dump")
@patch("scribe_data.cli.get.questionary.confirm")
def test_get_all_data_types_for_language_user_says_yes(
self, mock_questionary_confirm, mock_parse, mock_query_data
):
"""
Test the behavior when the user agrees to query Wikidata directly.
This test checks that `parse_wd_lexeme_dump` is called with the correct parameters
when the user confirms they want to query Wikidata.
"""
mock_questionary_confirm.return_value.ask.return_value = True

get_data(all_bool=True, language="English")

mock_parse.assert_called_once_with(
language="English",
wikidata_dump_type=["form"],
data_types=None, # because data_types = [data_type] if provided else None
type_output_dir="scribe_data_json_export", # default for JSON
)
mock_query_data.assert_not_called()

@patch("scribe_data.cli.get.parse_wd_lexeme_dump")
def test_get_all_languages_and_data_types(self, mock_parse):
"""
Test retrieving all languages for a specific data type.
Ensures that `query_data` is called properly when `--all` flag is used with a data type.
"""
get_data(all_bool=True)

mock_parse.assert_called_once_with(
language="all",
wikidata_dump_type=["form", "translations"],
data_types="all",
type_output_dir="scribe_data_json_export",
wikidata_dump_path=None,
)

# MARK: Language and Data Type

Expand Down Expand Up @@ -208,45 +217,103 @@ def test_get_data_with_overwrite_false(self, mock_query_data):
interactive=False,
)

# MARK : User Chooses to skip
# MARK: User Chooses Skip

@patch("scribe_data.cli.get.query_data")
@patch("scribe_data.cli.get.Path.glob")
@patch("builtins.input", return_value="s")
def test_user_skips_existing_file(self, mock_input, mock_glob, mock_query_data):
@patch(
"scribe_data.cli.get.Path.glob",
return_value=[Path("./test_output/English/nouns.json")],
)
@patch("scribe_data.cli.get.questionary.confirm")
def test_user_skips_existing_file(
self, mock_questionary_confirm, mock_glob, mock_query_data
):
"""
Test the behavior when the user chooses to skip an existing file.
Ensures that the file is not overwritten and the function returns the correct result.
"""
mock_glob.return_value = [Path("./test_output/English/nouns.json")]
mock_questionary_confirm.return_value.ask.return_value = False
result = get_data(
language="English", data_type="nouns", output_dir="./test_output"
)

# Validate the skip result.
self.assertEqual(result, {"success": False, "skipped": True})
mock_query_data.assert_not_called()

# MARK : User Chooses to overwrite
# MARK: User Chooses Overwrite

@patch("scribe_data.cli.get.query_data")
@patch("scribe_data.cli.get.Path.glob")
@patch("builtins.input", return_value="o")
@patch("scribe_data.cli.get.Path.unlink")
@patch(
"scribe_data.cli.get.Path.glob",
return_value=[Path("./test_output/English/nouns.json")],
)
@patch("scribe_data.cli.get.questionary.confirm")
def test_user_overwrites_existing_file(
self, mock_unlink, mock_input, mock_glob, mock_query_data
self, mock_questionary_confirm, mock_glob, mock_query_data
):
"""
Test the behavior when the user chooses to overwrite an existing file.
Ensures that the file is overwritten and the function returns the correct result.
"""
mock_glob.return_value = [Path("./test_output/English/nouns.json")]
mock_questionary_confirm.return_value.ask.return_value = True
get_data(language="English", data_type="nouns", output_dir="./test_output")
mock_unlink.assert_called_once_with()

mock_query_data.assert_called_once_with(
languages=["English"],
data_type=["nouns"],
output_dir="./test_output",
overwrite=False,
interactive=False,
)

# MARK: Translations

@patch("scribe_data.cli.get.parse_wd_lexeme_dump")
def test_get_translations_no_language_specified(self, mock_parse):
"""
Test behavior when no language is specified for 'translations'.
Expect language="all".
"""
get_data(data_type="translations")
mock_parse.assert_called_once_with(
language="all",
wikidata_dump_type=["translations"],
type_output_dir="scribe_data_json_export", # default output dir for JSON
wikidata_dump_path=None,
)

@patch("scribe_data.cli.get.parse_wd_lexeme_dump")
def test_get_translations_with_specific_language(self, mock_parse):
"""
Test behavior when a specific language is provided for 'translations'.
Expect parse_wd_lexeme_dump to be called with that language.
"""
get_data(
language="Spanish", data_type="translations", output_dir="./test_output"
)
mock_parse.assert_called_once_with(
language="Spanish",
wikidata_dump_type=["translations"],
type_output_dir="./test_output",
wikidata_dump_path=None,
)

@patch("scribe_data.cli.get.parse_wd_lexeme_dump")
def test_get_translations_with_dump(self, mock_parse):
"""
Test behavior when a Wikidata dump path is specified for 'translations'.
Even with a language, it should call parse_wd_lexeme_dump
passing that dump path.
"""
get_data(
language="German", data_type="translations", wikidata_dump="./wikidump.json"
)
mock_parse.assert_called_once_with(
language="German",
wikidata_dump_type=["translations"],
type_output_dir="scribe_data_json_export", # default for JSON
wikidata_dump_path="./wikidump.json",
)
15 changes: 13 additions & 2 deletions tests/load/test_update_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,22 @@ def test_format_sublanguage_name_positive(lang, expected_output):
assert utils.format_sublanguage_name(lang) == expected_output


@pytest.mark.parametrize(
"lang, expected_output",
[
("Q42", "Q42"), # test that any QID is returned
("Q1860", "Q1860"),
],
)
def test_format_sublanguage_name_qid_positive(lang, expected_output):
assert utils.format_sublanguage_name(lang) == expected_output


def test_format_sublanguage_name_negative():
with pytest.raises(ValueError) as excp:
_ = utils.format_sublanguage_name("Silence")
_ = utils.format_sublanguage_name("Newspeak")

assert str(excp.value) == "Silence is not a valid language or sub-language."
assert str(excp.value) == "Newspeak is not a valid language or sub-language."


def test_list_all_languages():
Expand Down

0 comments on commit 8073210

Please sign in to comment.