diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 87f9a7a..4133e1d 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -21,7 +21,7 @@ jobs: run-tests: runs-on: ubuntu-latest env: - PYTHON_VERSION: "3.9" + PYTHON_VERSION: "3.10" NO_CACHE: ${{ github.event.inputs.no-cache || 'false' }} steps: - name: Checkout @@ -53,7 +53,7 @@ jobs: lint_check: runs-on: ubuntu-latest env: - PYTHON_VERSION: "3.9" + PYTHON_VERSION: "3.10" steps: - uses: actions/checkout@v3 diff --git a/.markdownlint.yaml b/.markdownlint.yaml index d6aa660..f52f5b5 100644 --- a/.markdownlint.yaml +++ b/.markdownlint.yaml @@ -1,3 +1,4 @@ MD013: false MD040: false -MD025: false \ No newline at end of file +MD025: false +MD028: false \ No newline at end of file diff --git a/README.md b/README.md index 8069f69..0fb5271 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ This is an easy-to-use evaluation script designed for JMTEB evaluation. +JMTEB leaderboard is [here](leaderboard.md). A guidance for submission is coming soon. + ## Quick start ```bash @@ -38,4 +40,40 @@ poetry run python -m jmteb \ ``` > [!NOTE] -> Some tasks (e.g., AmazonReviewClassification in classification, JAQKET and Mr.TyDi-ja in retrieval, esci in reranking) are time-consuming and memory-consuming. Heavy retrieval tasks take hours to encode the large corpus, and use much memory for the storage of such vectors. If you want to exclude them, add `--eval_exclude "['amazon_review_classification', 'mrtydi', 'jaqket', 'esci']"`. +> Some tasks (e.g., AmazonReviewClassification in classification, JAQKET and Mr.TyDi-ja in retrieval, esci in reranking) are time-consuming and memory-consuming. Heavy retrieval tasks take hours to encode the large corpus, and use much memory for the storage of such vectors. If you want to exclude them, add `--eval_exclude "['amazon_review_classification', 'mrtydi', 'jaqket', 'esci']"`. Similarly, you can also use `--eval_include` to include only evaluation datasets you want. + +> [!NOTE] +> If you want to log model predictions to further analyze the performance of your model, you may want to use `--log_predictions true` to enable all evaluators to log predictions. It is also available to set whether to log in the config of evaluators. + +## Multi-GPU support + +There are two ways to enable multi-GPU evaluation. + +* New class `DPSentenceBertEmbedder` ([here](src/jmteb/embedders/data_parallel_sbert_embedder.py)). + +```bash +poetry run python -m jmteb \ + --evaluators "src/configs/tasks/jsts.jsonnet" \ + --embedder DPSentenceBertEmbedder \ + --embedder.model_name_or_path "" \ + --save_dir "output/" +``` + +* With `torchrun`, multi-GPU in [`TransformersEmbedder`](src/jmteb/embedders/transformers_embedder.py) is available. For example, + +```bash +MODEL_NAME= +MODEL_KWARGS="\{\'torch_dtype\':\'torch.bfloat16\'\}" +torchrun \ + --nproc_per_node=$GPUS_PER_NODE --nnodes=1 \ + src/jmteb/__main__.py --embedder TransformersEmbedder \ + --embedder.model_name_or_path ${MODEL_NAME} \ + --embedder.pooling_mode cls \ + --embedder.batch_size 4096 \ + --embedder.model_kwargs ${MODEL_KWARGS} \ + --embedder.max_seq_length 512 \ + --save_dir "output/${MODEL_NAME}" \ + --evaluators src/jmteb/configs/jmteb.jsonnet +``` + +Note that the batch size here is global batch size (`per_device_batch_size` × `n_gpu`). diff --git a/docs/results/MU-Kindai/Japanese-DiffCSE-BERT-base/summary.json b/docs/results/MU-Kindai/Japanese-DiffCSE-BERT-base/summary.json new file mode 100644 index 0000000..1b99a44 --- /dev/null +++ b/docs/results/MU-Kindai/Japanese-DiffCSE-BERT-base/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.7809527709426081 + }, + "amazon_review_classification": { + "macro_f1": 0.5155899232320224 + }, + "massive_intent_classification": { + "macro_f1": 0.7879373479249787 + }, + "massive_scenario_classification": { + "macro_f1": 0.8662625888023707 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.9095168116460639 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.42314124780036416 + }, + "jaqket": { + "ndcg@10": 0.36199154051747723 + }, + "mrtydi": { + "ndcg@10": 0.07810683176415421 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.6077212544951452 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.6433890489201118 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.39317174536190913 + } + }, + "STS": { + "jsick": { + "spearman": 0.754165277432144 + }, + "jsts": { + "spearman": 0.7558202366183716 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.4966545453348478 + }, + "mewsc16": { + "v_measure_score": 0.3877356318022785 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.6237623762376237 + } + } +} \ No newline at end of file diff --git a/docs/results/MU-Kindai/Japanese-MixCSE-BERT-base/summary.json b/docs/results/MU-Kindai/Japanese-MixCSE-BERT-base/summary.json new file mode 100644 index 0000000..ea227c2 --- /dev/null +++ b/docs/results/MU-Kindai/Japanese-MixCSE-BERT-base/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.776174162517931 + }, + "amazon_review_classification": { + "macro_f1": 0.5085781180553806 + }, + "massive_intent_classification": { + "macro_f1": 0.7718541530739129 + }, + "massive_scenario_classification": { + "macro_f1": 0.8592571786794985 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.9100551950168166 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.42368135774043536 + }, + "jaqket": { + "ndcg@10": 0.37721850397542034 + }, + "mrtydi": { + "ndcg@10": 0.07878085186566607 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.636999375405723 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.6413498649875696 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.397250919496823 + } + }, + "STS": { + "jsick": { + "spearman": 0.7756925231422259 + }, + "jsts": { + "spearman": 0.7652968548841591 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.5262387436934941 + }, + "mewsc16": { + "v_measure_score": 0.37277574537292835 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.623321554770318 + } + } +} \ No newline at end of file diff --git a/docs/results/MU-Kindai/Japanese-SimCSE-BERT-base-sup/summary.json b/docs/results/MU-Kindai/Japanese-SimCSE-BERT-base-sup/summary.json new file mode 100644 index 0000000..dbed068 --- /dev/null +++ b/docs/results/MU-Kindai/Japanese-SimCSE-BERT-base-sup/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.7619809437515043 + }, + "amazon_review_classification": { + "macro_f1": 0.5205592432502059 + }, + "massive_intent_classification": { + "macro_f1": 0.7789367871593064 + }, + "massive_scenario_classification": { + "macro_f1": 0.8490320705866646 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.9065584234991577 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.4411487123884245 + }, + "jaqket": { + "ndcg@10": 0.39613283459361814 + }, + "mrtydi": { + "ndcg@10": 0.08154879873415645 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.6276035246534508 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.5838785018803183 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.3489329387182086 + } + }, + "STS": { + "jsick": { + "spearman": 0.7463567093877269 + }, + "jsts": { + "spearman": 0.7468283806971927 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.41041888940251137 + }, + "mewsc16": { + "v_measure_score": 0.45175891401665724 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.6236711552090717 + } + } +} \ No newline at end of file diff --git a/docs/results/MU-Kindai/Japanese-SimCSE-BERT-base-unsup/summary.json b/docs/results/MU-Kindai/Japanese-SimCSE-BERT-base-unsup/summary.json new file mode 100644 index 0000000..9528312 --- /dev/null +++ b/docs/results/MU-Kindai/Japanese-SimCSE-BERT-base-unsup/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.7619809437515043 + }, + "amazon_review_classification": { + "macro_f1": 0.5152108946679324 + }, + "massive_intent_classification": { + "macro_f1": 0.7895128475562229 + }, + "massive_scenario_classification": { + "macro_f1": 0.865430249169577 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.9115815294581953 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.47387768939865055 + }, + "jaqket": { + "ndcg@10": 0.3956683977353904 + }, + "mrtydi": { + "ndcg@10": 0.1144234568266308 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.6416096544574569 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.7023477497744102 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.4536720868647063 + } + }, + "STS": { + "jsick": { + "spearman": 0.781770693640686 + }, + "jsts": { + "spearman": 0.7680617109850311 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.5301620892693397 + }, + "mewsc16": { + "v_measure_score": 0.4034776723308173 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.6238078417520311 + } + } +} \ No newline at end of file diff --git a/docs/results/MU-Kindai/Japanese-SimCSE-BERT-large-sup/summary.json b/docs/results/MU-Kindai/Japanese-SimCSE-BERT-large-sup/summary.json new file mode 100644 index 0000000..b36686c --- /dev/null +++ b/docs/results/MU-Kindai/Japanese-SimCSE-BERT-large-sup/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.7725250131648236 + }, + "amazon_review_classification": { + "macro_f1": 0.5341627023771393 + }, + "massive_intent_classification": { + "macro_f1": 0.7682863192709365 + }, + "massive_scenario_classification": { + "macro_f1": 0.8639396658321546 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.9094717381883379 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.47038430326303626 + }, + "jaqket": { + "ndcg@10": 0.44101304795602897 + }, + "mrtydi": { + "ndcg@10": 0.11429128335865787 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.43434267808785576 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.6240651697600803 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.3651687833824759 + } + }, + "STS": { + "jsick": { + "spearman": 0.787528927058734 + }, + "jsts": { + "spearman": 0.7781413957931619 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.48448646364489634 + }, + "mewsc16": { + "v_measure_score": 0.43168522818790694 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.6235418875927891 + } + } +} \ No newline at end of file diff --git a/docs/results/MU-Kindai/Japanese-SimCSE-BERT-large-unsup/summary.json b/docs/results/MU-Kindai/Japanese-SimCSE-BERT-large-unsup/summary.json new file mode 100644 index 0000000..f620d50 --- /dev/null +++ b/docs/results/MU-Kindai/Japanese-SimCSE-BERT-large-unsup/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.7635642561809131 + }, + "amazon_review_classification": { + "macro_f1": 0.5275222511867922 + }, + "massive_intent_classification": { + "macro_f1": 0.7688060073049678 + }, + "massive_scenario_classification": { + "macro_f1": 0.8651446837233107 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.9129851570116734 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.5014367709991477 + }, + "jaqket": { + "ndcg@10": 0.4583812630740073 + }, + "mrtydi": { + "ndcg@10": 0.13003320802922363 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.5508587506679636 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.7497069192695408 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.4524300499843447 + } + }, + "STS": { + "jsick": { + "spearman": 0.7984403024596518 + }, + "jsts": { + "spearman": 0.7813685476201204 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.5319881995988209 + }, + "mewsc16": { + "v_measure_score": 0.4330807170988368 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.6226614895870103 + } + } +} \ No newline at end of file diff --git a/docs/results/OpenAI/text-embedding-3-large/summary.json b/docs/results/OpenAI/text-embedding-3-large/summary.json new file mode 100644 index 0000000..46af0c5 --- /dev/null +++ b/docs/results/OpenAI/text-embedding-3-large/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.7789727938896414 + }, + "amazon_review_classification": { + "macro_f1": 0.6043632319384946 + }, + "massive_intent_classification": { + "macro_f1": 0.8090871295952566 + }, + "massive_scenario_classification": { + "macro_f1": 0.9108443051510002 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.9358042266852659 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.7240937077183436 + }, + "jaqket": { + "ndcg@10": 0.48208863565793814 + }, + "mrtydi": { + "ndcg@10": 0.3488438390945784 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.9932811349540317 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.9655113335080678 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.9547126796600445 + } + }, + "STS": { + "jsick": { + "spearman": 0.8126909906411093 + }, + "jsts": { + "spearman": 0.8376863979620452 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.05018478985401151 + }, + "mewsc16": { + "v_measure_score": 0.4955424351458981 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.6234502302515055 + } + } +} \ No newline at end of file diff --git a/docs/results/OpenAI/text-embedding-3-small/summary.json b/docs/results/OpenAI/text-embedding-3-small/summary.json new file mode 100644 index 0000000..74cee2e --- /dev/null +++ b/docs/results/OpenAI/text-embedding-3-small/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.7000818608185178 + }, + "amazon_review_classification": { + "macro_f1": 0.5592259673654241 + }, + "massive_intent_classification": { + "macro_f1": 0.7766119663088307 + }, + "massive_scenario_classification": { + "macro_f1": 0.8866536867311439 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.9291728102678644 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.640150048193537 + }, + "jaqket": { + "ndcg@10": 0.3394304922804131 + }, + "mrtydi": { + "ndcg@10": 0.2002984123046011 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.9846617848570168 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.9170440283351765 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.9017272741306225 + } + }, + "STS": { + "jsick": { + "spearman": 0.8083062989093882 + }, + "jsts": { + "spearman": 0.7808357024283473 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.051323988942160705 + }, + "mewsc16": { + "v_measure_score": 0.4755374215259236 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.6227417640807651 + } + } +} \ No newline at end of file diff --git a/docs/results/OpenAI/text-embedding-ada-002/summary.json b/docs/results/OpenAI/text-embedding-ada-002/summary.json new file mode 100644 index 0000000..8c7a548 --- /dev/null +++ b/docs/results/OpenAI/text-embedding-ada-002/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.6441904761904762 + }, + "amazon_review_classification": { + "macro_f1": 0.5312953134953877 + }, + "massive_intent_classification": { + "macro_f1": 0.7457150118928685 + }, + "massive_scenario_classification": { + "macro_f1": 0.8689044829586676 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.9303611831749345 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.6102270226904314 + }, + "jaqket": { + "ndcg@10": 0.4256467956806472 + }, + "mrtydi": { + "ndcg@10": 0.1450739420851161 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.9499224324391132 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.9123300358752942 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.8197798210453923 + } + }, + "STS": { + "jsick": { + "spearman": 0.7909435250482901 + }, + "jsts": { + "spearman": 0.7894052744557472 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.060252212362740365 + }, + "mewsc16": { + "v_measure_score": 0.4691938182964486 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.6239830208701805 + } + } +} \ No newline at end of file diff --git a/docs/results/cl-nagoya/sup-simcse-ja-base/summary.json b/docs/results/cl-nagoya/sup-simcse-ja-base/summary.json new file mode 100644 index 0000000..42cc5ff --- /dev/null +++ b/docs/results/cl-nagoya/sup-simcse-ja-base/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.7234436301724776 + }, + "amazon_review_classification": { + "macro_f1": 0.5441445333270086 + }, + "massive_intent_classification": { + "macro_f1": 0.7951973953020242 + }, + "massive_scenario_classification": { + "macro_f1": 0.8760200177186923 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.9183455876236017 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.5161990612242935 + }, + "jaqket": { + "ndcg@10": 0.5024513438428565 + }, + "mrtydi": { + "ndcg@10": 0.13976323269046823 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.6807886421530585 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.6570889175649209 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.48219159577174137 + } + }, + "STS": { + "jsick": { + "spearman": 0.8282816229512862 + }, + "jsts": { + "spearman": 0.8127259236647225 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.5266774168531417 + }, + "mewsc16": { + "v_measure_score": 0.5091016872016825 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.6256665481692143 + } + } +} \ No newline at end of file diff --git a/docs/results/cl-nagoya/sup-simcse-ja-large/summary.json b/docs/results/cl-nagoya/sup-simcse-ja-large/summary.json new file mode 100644 index 0000000..a2d8924 --- /dev/null +++ b/docs/results/cl-nagoya/sup-simcse-ja-large/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.7321444865928852 + }, + "amazon_review_classification": { + "macro_f1": 0.5475800661400465 + }, + "massive_intent_classification": { + "macro_f1": 0.7922802742146243 + }, + "massive_scenario_classification": { + "macro_f1": 0.8772172454209797 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.9148471751378899 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.4683673504170269 + }, + "jaqket": { + "ndcg@10": 0.39878189118804513 + }, + "mrtydi": { + "ndcg@10": 0.11834919561027905 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.634254459552888 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.37927566884615427 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.25787534957423713 + } + }, + "STS": { + "jsick": { + "spearman": 0.837959537101532 + }, + "jsts": { + "spearman": 0.825691902117111 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.5074967876488787 + }, + "mewsc16": { + "v_measure_score": 0.503782014677764 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.6250885896527285 + } + } +} \ No newline at end of file diff --git a/docs/results/cl-nagoya/unsup-simcse-ja-base/summary.json b/docs/results/cl-nagoya/unsup-simcse-ja-base/summary.json new file mode 100644 index 0000000..3863c9e --- /dev/null +++ b/docs/results/cl-nagoya/unsup-simcse-ja-base/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.7330185800774036 + }, + "amazon_review_classification": { + "macro_f1": 0.5392887528271114 + }, + "massive_intent_classification": { + "macro_f1": 0.7907120296283751 + }, + "massive_scenario_classification": { + "macro_f1": 0.8597097942715117 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.9115668272308735 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.46003459081522513 + }, + "jaqket": { + "ndcg@10": 0.3945725593125862 + }, + "mrtydi": { + "ndcg@10": 0.055507775092798486 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.6025847751308843 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.5562839869857912 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.3449181162324482 + } + }, + "STS": { + "jsick": { + "spearman": 0.7849379492955117 + }, + "jsts": { + "spearman": 0.7894946592483818 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.5223347838445698 + }, + "mewsc16": { + "v_measure_score": 0.37310458219601117 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.624424778761062 + } + } +} \ No newline at end of file diff --git a/docs/results/cl-nagoya/unsup-simcse-ja-large/summary.json b/docs/results/cl-nagoya/unsup-simcse-ja-large/summary.json new file mode 100644 index 0000000..d37618a --- /dev/null +++ b/docs/results/cl-nagoya/unsup-simcse-ja-large/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.767905114979583 + }, + "amazon_review_classification": { + "macro_f1": 0.5537089641846143 + }, + "massive_intent_classification": { + "macro_f1": 0.7912698845073401 + }, + "massive_scenario_classification": { + "macro_f1": 0.8736185210672394 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.9095494729022622 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.4509073581555124 + }, + "jaqket": { + "ndcg@10": 0.34595043675331943 + }, + "mrtydi": { + "ndcg@10": 0.05750859876901772 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.550742021417855 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.6307172007359215 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.39612451822677164 + } + }, + "STS": { + "jsick": { + "spearman": 0.8014979086154339 + }, + "jsts": { + "spearman": 0.8097685749017456 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.5090447587797094 + }, + "mewsc16": { + "v_measure_score": 0.4591920015613856 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.6248671625929861 + } + } +} \ No newline at end of file diff --git a/docs/results/colorfulscoop/sbert-base-ja/summary.json b/docs/results/colorfulscoop/sbert-base-ja/summary.json new file mode 100644 index 0000000..2a08044 --- /dev/null +++ b/docs/results/colorfulscoop/sbert-base-ja/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.7221023294352484 + }, + "amazon_review_classification": { + "macro_f1": 0.47952384496155054 + }, + "massive_intent_classification": { + "macro_f1": 0.725195343788811 + }, + "massive_scenario_classification": { + "macro_f1": 0.836177960542408 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.8997301146575819 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.21501915127957166 + }, + "jaqket": { + "ndcg@10": 0.13161989528541293 + }, + "mrtydi": { + "ndcg@10": 0.00436010196904899 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.2878020264605714 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.22397059858982324 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.12815871897103842 + } + }, + "STS": { + "jsick": { + "spearman": 0.6659298300713198 + }, + "jsts": { + "spearman": 0.7423952309826243 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.4298579019834722 + }, + "mewsc16": { + "v_measure_score": 0.46641671645082333 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.6231013776050865 + } + } +} \ No newline at end of file diff --git a/docs/results/intfloat/multilingual-e5-base/summary.json b/docs/results/intfloat/multilingual-e5-base/summary.json new file mode 100644 index 0000000..96f9640 --- /dev/null +++ b/docs/results/intfloat/multilingual-e5-base/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.6367079139150691 + }, + "amazon_review_classification": { + "macro_f1": 0.5424265794470897 + }, + "massive_intent_classification": { + "macro_f1": 0.7277503514873049 + }, + "massive_scenario_classification": { + "macro_f1": 0.8652828949015864 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.9285060467194839 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.6534478396845428 + }, + "jaqket": { + "ndcg@10": 0.5067444792013236 + }, + "mrtydi": { + "ndcg@10": 0.3837652120001251 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.8709767034225332 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.9473129303429082 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.7304538728893641 + } + }, + "STS": { + "jsick": { + "spearman": 0.8128058660848744 + }, + "jsts": { + "spearman": 0.7839196475937381 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.5502694126615243 + }, + "mewsc16": { + "v_measure_score": 0.41494514000218946 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.6226482073127441 + } + } +} \ No newline at end of file diff --git a/docs/results/intfloat/multilingual-e5-large/summary.json b/docs/results/intfloat/multilingual-e5-large/summary.json new file mode 100644 index 0000000..a28c470 --- /dev/null +++ b/docs/results/intfloat/multilingual-e5-large/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.706580687830688 + }, + "amazon_review_classification": { + "macro_f1": 0.5653992303516462 + }, + "massive_intent_classification": { + "macro_f1": 0.7577710251429624 + }, + "massive_scenario_classification": { + "macro_f1": 0.8859090262583831 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.9296254722183955 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.7030214336558751 + }, + "jaqket": { + "ndcg@10": 0.5878065301444064 + }, + "mrtydi": { + "ndcg@10": 0.4363167873386172 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.8600225120389309 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.9469712765040588 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.7248023877969718 + } + }, + "STS": { + "jsick": { + "spearman": 0.7840335060728089 + }, + "jsts": { + "spearman": 0.8098724997856234 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.5713023706914878 + }, + "mewsc16": { + "v_measure_score": 0.4534484706354193 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.621496984746364 + } + } +} \ No newline at end of file diff --git a/docs/results/intfloat/multilingual-e5-small/summary.json b/docs/results/intfloat/multilingual-e5-small/summary.json new file mode 100644 index 0000000..99a4423 --- /dev/null +++ b/docs/results/intfloat/multilingual-e5-small/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.6214130966524566 + }, + "amazon_review_classification": { + "macro_f1": 0.5127428912860463 + }, + "massive_intent_classification": { + "macro_f1": 0.7085230519111091 + }, + "massive_scenario_classification": { + "macro_f1": 0.8622036829599259 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.9303349187158247 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.6411252958220891 + }, + "jaqket": { + "ndcg@10": 0.49966509556428645 + }, + "mrtydi": { + "ndcg@10": 0.36054822913647616 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.8520749151982298 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.9526123412781002 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.729906931983999 + } + }, + "STS": { + "jsick": { + "spearman": 0.8150271836013705 + }, + "jsts": { + "spearman": 0.786450077409501 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.5470075389200084 + }, + "mewsc16": { + "v_measure_score": 0.391226933590049 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.6219382321618744 + } + } +} \ No newline at end of file diff --git a/docs/results/oshizo/sbert-jsnli-luke-japanese-base-lite/summary.json b/docs/results/oshizo/sbert-jsnli-luke-japanese-base-lite/summary.json new file mode 100644 index 0000000..6b7309a --- /dev/null +++ b/docs/results/oshizo/sbert-jsnli-luke-japanese-base-lite/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.7994675369288904 + }, + "amazon_review_classification": { + "macro_f1": 0.5748206591211895 + }, + "massive_intent_classification": { + "macro_f1": 0.8025949222725076 + }, + "massive_scenario_classification": { + "macro_f1": 0.8875250742566655 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.9156331205981866 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.519938655947725 + }, + "jaqket": { + "ndcg@10": 0.4206746951743811 + }, + "mrtydi": { + "ndcg@10": 0.10116108109776817 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.4930421996747514 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.719369187830078 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.3258568875005778 + } + }, + "STS": { + "jsick": { + "spearman": 0.7211422898060521 + }, + "jsts": { + "spearman": 0.8109305772255819 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.4677177349822789 + }, + "mewsc16": { + "v_measure_score": 0.5389209739242912 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.6237623762376237 + } + } +} \ No newline at end of file diff --git a/docs/results/pkshatech/GLuCoSE-base-ja/summary.json b/docs/results/pkshatech/GLuCoSE-base-ja/summary.json new file mode 100644 index 0000000..9048691 --- /dev/null +++ b/docs/results/pkshatech/GLuCoSE-base-ja/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.8243606275521169 + }, + "amazon_review_classification": { + "macro_f1": 0.580654308041878 + }, + "massive_intent_classification": { + "macro_f1": 0.7885427536904928 + }, + "massive_scenario_classification": { + "macro_f1": 0.8794225134482166 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.9190289767663239 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.6387979415478197 + }, + "jaqket": { + "ndcg@10": 0.3981609655991592 + }, + "mrtydi": { + "ndcg@10": 0.30281316435910444 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.7825765249971093 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.8206371528870603 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.5982476164344701 + } + }, + "STS": { + "jsick": { + "spearman": 0.7496711324072552 + }, + "jsts": { + "spearman": 0.824592262812859 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.49890886040948096 + }, + "mewsc16": { + "v_measure_score": 0.49676862904881375 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.663883089770355 + } + } +} \ No newline at end of file diff --git a/docs/results/pkshatech/simcse-ja-bert-base-clcmlp/summary.json b/docs/results/pkshatech/simcse-ja-bert-base-clcmlp/summary.json new file mode 100644 index 0000000..cc9f179 --- /dev/null +++ b/docs/results/pkshatech/simcse-ja-bert-base-clcmlp/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.6748573563374541 + }, + "amazon_review_classification": { + "macro_f1": 0.5084883283463678 + }, + "massive_intent_classification": { + "macro_f1": 0.7967050091211104 + }, + "massive_scenario_classification": { + "macro_f1": 0.871999260591497 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.914930352019688 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.41496851385134836 + }, + "jaqket": { + "ndcg@10": 0.46003031782136106 + }, + "mrtydi": { + "ndcg@10": 0.1019130492122431 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.4014036990267884 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.5962532652358485 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.2452584471710635 + } + }, + "STS": { + "jsick": { + "spearman": 0.7307715649457595 + }, + "jsts": { + "spearman": 0.8052279921326252 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.4476707933600858 + }, + "mewsc16": { + "v_measure_score": 0.5029508725037098 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.6239830208701805 + } + } +} \ No newline at end of file diff --git a/docs/results/sentence-transformers/LaBSE/summary.json b/docs/results/sentence-transformers/LaBSE/summary.json new file mode 100644 index 0000000..de8fd21 --- /dev/null +++ b/docs/results/sentence-transformers/LaBSE/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.7361214773958769 + }, + "amazon_review_classification": { + "macro_f1": 0.516957890685124 + }, + "massive_intent_classification": { + "macro_f1": 0.7698802987251081 + }, + "massive_scenario_classification": { + "macro_f1": 0.8835366493433755 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.9162507647227857 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.4310160105414995 + }, + "jaqket": { + "ndcg@10": 0.34245849139132745 + }, + "mrtydi": { + "ndcg@10": 0.04238747941951049 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.48918127058907085 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.7513086500303519 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.35089108319096984 + } + }, + "STS": { + "jsick": { + "spearman": 0.7698905918950973 + }, + "jsts": { + "spearman": 0.7612337568248777 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.4829337123233023 + }, + "mewsc16": { + "v_measure_score": 0.41471299546625956 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.623321554770318 + } + } +} \ No newline at end of file diff --git a/docs/results/sentence-transformers/stsb-xlm-r-multilingual/summary.json b/docs/results/sentence-transformers/stsb-xlm-r-multilingual/summary.json new file mode 100644 index 0000000..12f71a2 --- /dev/null +++ b/docs/results/sentence-transformers/stsb-xlm-r-multilingual/summary.json @@ -0,0 +1,62 @@ +{ + "Classification": { + "amazon_counterfactual_classification": { + "macro_f1": 0.7565022696601644 + }, + "amazon_review_classification": { + "macro_f1": 0.5131771609073525 + }, + "massive_intent_classification": { + "macro_f1": 0.7427818411370812 + }, + "massive_scenario_classification": { + "macro_f1": 0.8609512679368835 + } + }, + "Reranking": { + "esci": { + "ndcg@10": 0.901984958764163 + } + }, + "Retrieval": { + "jagovfaqs_22k": { + "ndcg@10": 0.2511106863952595 + }, + "jaqket": { + "ndcg@10": 0.21606007987072834 + }, + "mrtydi": { + "ndcg@10": 0.027590779174942116 + }, + "nlp_journal_abs_intro": { + "ndcg@10": 0.2848558252647936 + }, + "nlp_journal_title_abs": { + "ndcg@10": 0.3646520309406354 + }, + "nlp_journal_title_intro": { + "ndcg@10": 0.11545016260271045 + } + }, + "STS": { + "jsick": { + "spearman": 0.7236409557069434 + }, + "jsts": { + "spearman": 0.7843597058304203 + } + }, + "Clustering": { + "livedoor_news": { + "v_measure_score": 0.24487129939212224 + }, + "mewsc16": { + "v_measure_score": 0.304278393205056 + } + }, + "PairClassification": { + "paws_x_ja": { + "binary_f1": 0.6219686162624821 + } + } +} \ No newline at end of file diff --git a/leaderboard.md b/leaderboard.md new file mode 100644 index 0000000..107f2e9 --- /dev/null +++ b/leaderboard.md @@ -0,0 +1,188 @@ +# Leaderboard +This leaderboard shows the results stored under `docs/results`. The scores are all multiplied by 100. + +## Summary + +The summary shows the average scores within each task. + +| Model | Avg. | Retrieval | STS | Classification | Reranking | Clustering | PairClassification | +|:----------------------------------------------|:----------|:------------|:----------|:-----------------|:------------|:-------------|:---------------------| +| intfloat/multilingual-e5-large | **71.65** | 70.98 | 79.70 | 72.89 | 92.96 | 51.24 | 62.15 | +| pkshatech/GLuCoSE-base-ja | 70.44 | 59.02 | 78.71 | 76.82 | 91.90 | 49.78 | **66.39** | +| intfloat/multilingual-e5-base | 70.12 | 68.21 | 79.84 | 69.30 | 92.85 | 48.26 | 62.26 | +| OpenAI/text-embedding-3-large | 69.63 | **74.48** | 82.52 | **77.58** | **93.58** | 27.29 | 62.35 | +| intfloat/multilingual-e5-small | 69.52 | 67.27 | 80.07 | 67.62 | 93.03 | 46.91 | 62.19 | +| cl-nagoya/sup-simcse-ja-base | 68.56 | 49.64 | 82.05 | 73.47 | 91.83 | **51.79** | 62.57 | +| MU-Kindai/Japanese-SimCSE-BERT-large-unsup | 66.89 | 47.38 | 78.99 | 73.13 | 91.30 | 48.25 | 62.27 | +| oshizo/sbert-jsnli-luke-japanese-base-lite | 66.75 | 43.00 | 76.60 | 76.61 | 91.56 | 50.33 | 62.38 | +| OpenAI/text-embedding-3-small | 66.74 | 66.39 | 79.46 | 73.06 | 92.92 | 26.34 | 62.27 | +| cl-nagoya/sup-simcse-ja-large | 66.51 | 37.62 | **83.18** | 73.73 | 91.48 | 50.56 | 62.51 | +| cl-nagoya/unsup-simcse-ja-large | 66.27 | 40.53 | 80.56 | 74.66 | 90.95 | 48.41 | 62.49 | +| MU-Kindai/Japanese-SimCSE-BERT-base-unsup | 66.23 | 46.36 | 77.49 | 73.30 | 91.16 | 46.68 | 62.38 | +| OpenAI/text-embedding-ada-002 | 65.84 | 64.38 | 79.02 | 69.75 | 93.04 | 26.47 | 62.40 | +| MU-Kindai/Japanese-SimCSE-BERT-large-sup | 65.28 | 40.82 | 78.28 | 73.47 | 90.95 | 45.81 | 62.35 | +| MU-Kindai/Japanese-MixCSE-BERT-base | 65.14 | 42.59 | 77.05 | 72.90 | 91.01 | 44.95 | 62.33 | +| cl-nagoya/unsup-simcse-ja-base | 65.07 | 40.23 | 78.72 | 73.07 | 91.16 | 44.77 | 62.44 | +| MU-Kindai/Japanese-DiffCSE-BERT-base | 64.77 | 41.79 | 75.50 | 73.77 | 90.95 | 44.22 | 62.38 | +| sentence-transformers/LaBSE | 64.70 | 40.12 | 76.56 | 72.66 | 91.63 | 44.88 | 62.33 | +| pkshatech/simcse-ja-bert-base-clcmlp | 64.42 | 37.00 | 76.80 | 71.30 | 91.49 | 47.53 | 62.40 | +| MU-Kindai/Japanese-SimCSE-BERT-base-sup | 64.15 | 41.32 | 74.66 | 72.76 | 90.66 | 43.11 | 62.37 | +| colorfulscoop/sbert-base-ja | 58.85 | 16.52 | 70.42 | 69.07 | 89.97 | 44.81 | 62.31 | +| sentence-transformers/stsb-xlm-r-multilingual | 58.01 | 21.00 | 75.40 | 71.84 | 90.20 | 27.46 | 62.20 | + +## Retrieval +| Model | Avg. | jagovfaqs_22k
(ndcg@10) | jaqket
(ndcg@10) | mrtydi
(ndcg@10) | nlp_journal_abs_intro
(ndcg@10) | nlp_journal_title_abs
(ndcg@10) | nlp_journal_title_intro
(ndcg@10) | +|:----------------------------------------------|:----------|:-----------------------------|:----------------------|:----------------------|:-------------------------------------|:-------------------------------------|:---------------------------------------| +| OpenAI/text-embedding-3-large | **74.48** | **72.41** | 48.21 | 34.88 | **99.33** | **96.55** | **95.47** | +| intfloat/multilingual-e5-large | 70.98 | 70.30 | **58.78** | **43.63** | 86.00 | 94.70 | 72.48 | +| intfloat/multilingual-e5-base | 68.21 | 65.34 | 50.67 | 38.38 | 87.10 | 94.73 | 73.05 | +| intfloat/multilingual-e5-small | 67.27 | 64.11 | 49.97 | 36.05 | 85.21 | 95.26 | 72.99 | +| OpenAI/text-embedding-3-small | 66.39 | 64.02 | 33.94 | 20.03 | 98.47 | 91.70 | 90.17 | +| OpenAI/text-embedding-ada-002 | 64.38 | 61.02 | 42.56 | 14.51 | 94.99 | 91.23 | 81.98 | +| pkshatech/GLuCoSE-base-ja | 59.02 | 63.88 | 39.82 | 30.28 | 78.26 | 82.06 | 59.82 | +| cl-nagoya/sup-simcse-ja-base | 49.64 | 51.62 | 50.25 | 13.98 | 68.08 | 65.71 | 48.22 | +| MU-Kindai/Japanese-SimCSE-BERT-large-unsup | 47.38 | 50.14 | 45.84 | 13.00 | 55.09 | 74.97 | 45.24 | +| MU-Kindai/Japanese-SimCSE-BERT-base-unsup | 46.36 | 47.39 | 39.57 | 11.44 | 64.16 | 70.23 | 45.37 | +| oshizo/sbert-jsnli-luke-japanese-base-lite | 43.00 | 51.99 | 42.07 | 10.12 | 49.30 | 71.94 | 32.59 | +| MU-Kindai/Japanese-MixCSE-BERT-base | 42.59 | 42.37 | 37.72 | 7.88 | 63.70 | 64.13 | 39.73 | +| MU-Kindai/Japanese-DiffCSE-BERT-base | 41.79 | 42.31 | 36.20 | 7.81 | 60.77 | 64.34 | 39.32 | +| MU-Kindai/Japanese-SimCSE-BERT-base-sup | 41.32 | 44.11 | 39.61 | 8.15 | 62.76 | 58.39 | 34.89 | +| MU-Kindai/Japanese-SimCSE-BERT-large-sup | 40.82 | 47.04 | 44.10 | 11.43 | 43.43 | 62.41 | 36.52 | +| cl-nagoya/unsup-simcse-ja-large | 40.53 | 45.09 | 34.60 | 5.75 | 55.07 | 63.07 | 39.61 | +| cl-nagoya/unsup-simcse-ja-base | 40.23 | 46.00 | 39.46 | 5.55 | 60.26 | 55.63 | 34.49 | +| sentence-transformers/LaBSE | 40.12 | 43.10 | 34.25 | 4.24 | 48.92 | 75.13 | 35.09 | +| cl-nagoya/sup-simcse-ja-large | 37.62 | 46.84 | 39.88 | 11.83 | 63.43 | 37.93 | 25.79 | +| pkshatech/simcse-ja-bert-base-clcmlp | 37.00 | 41.50 | 46.00 | 10.19 | 40.14 | 59.63 | 24.53 | +| sentence-transformers/stsb-xlm-r-multilingual | 21.00 | 25.11 | 21.61 | 2.76 | 28.49 | 36.47 | 11.55 | +| colorfulscoop/sbert-base-ja | 16.52 | 21.50 | 13.16 | 0.44 | 28.78 | 22.40 | 12.82 | + +## STS +| Model | Avg. | jsick
(spearman) | jsts
(spearman) | +|:----------------------------------------------|:----------|:----------------------|:---------------------| +| cl-nagoya/sup-simcse-ja-large | **83.18** | **83.80** | 82.57 | +| OpenAI/text-embedding-3-large | 82.52 | 81.27 | **83.77** | +| cl-nagoya/sup-simcse-ja-base | 82.05 | 82.83 | 81.27 | +| cl-nagoya/unsup-simcse-ja-large | 80.56 | 80.15 | 80.98 | +| intfloat/multilingual-e5-small | 80.07 | 81.50 | 78.65 | +| intfloat/multilingual-e5-base | 79.84 | 81.28 | 78.39 | +| intfloat/multilingual-e5-large | 79.70 | 78.40 | 80.99 | +| OpenAI/text-embedding-3-small | 79.46 | 80.83 | 78.08 | +| OpenAI/text-embedding-ada-002 | 79.02 | 79.09 | 78.94 | +| MU-Kindai/Japanese-SimCSE-BERT-large-unsup | 78.99 | 79.84 | 78.14 | +| cl-nagoya/unsup-simcse-ja-base | 78.72 | 78.49 | 78.95 | +| pkshatech/GLuCoSE-base-ja | 78.71 | 74.97 | 82.46 | +| MU-Kindai/Japanese-SimCSE-BERT-large-sup | 78.28 | 78.75 | 77.81 | +| MU-Kindai/Japanese-SimCSE-BERT-base-unsup | 77.49 | 78.18 | 76.81 | +| MU-Kindai/Japanese-MixCSE-BERT-base | 77.05 | 77.57 | 76.53 | +| pkshatech/simcse-ja-bert-base-clcmlp | 76.80 | 73.08 | 80.52 | +| oshizo/sbert-jsnli-luke-japanese-base-lite | 76.60 | 72.11 | 81.09 | +| sentence-transformers/LaBSE | 76.56 | 76.99 | 76.12 | +| MU-Kindai/Japanese-DiffCSE-BERT-base | 75.50 | 75.42 | 75.58 | +| sentence-transformers/stsb-xlm-r-multilingual | 75.40 | 72.36 | 78.44 | +| MU-Kindai/Japanese-SimCSE-BERT-base-sup | 74.66 | 74.64 | 74.68 | +| colorfulscoop/sbert-base-ja | 70.42 | 66.59 | 74.24 | + +## Classification +| Model | Avg. | amazon_counterfactual
(macro_f1) | amazon_review
(macro_f1) | massive_intent
(macro_f1) | massive_scenario
(macro_f1) | +|:----------------------------------------------|:----------|:--------------------------------------|:------------------------------|:-------------------------------|:---------------------------------| +| OpenAI/text-embedding-3-large | **77.58** | 77.90 | **60.44** | **80.91** | **91.08** | +| pkshatech/GLuCoSE-base-ja | 76.82 | **82.44** | 58.07 | 78.85 | 87.94 | +| oshizo/sbert-jsnli-luke-japanese-base-lite | 76.61 | 79.95 | 57.48 | 80.26 | 88.75 | +| cl-nagoya/unsup-simcse-ja-large | 74.66 | 76.79 | 55.37 | 79.13 | 87.36 | +| MU-Kindai/Japanese-DiffCSE-BERT-base | 73.77 | 78.10 | 51.56 | 78.79 | 86.63 | +| cl-nagoya/sup-simcse-ja-large | 73.73 | 73.21 | 54.76 | 79.23 | 87.72 | +| MU-Kindai/Japanese-SimCSE-BERT-large-sup | 73.47 | 77.25 | 53.42 | 76.83 | 86.39 | +| cl-nagoya/sup-simcse-ja-base | 73.47 | 72.34 | 54.41 | 79.52 | 87.60 | +| MU-Kindai/Japanese-SimCSE-BERT-base-unsup | 73.30 | 76.20 | 51.52 | 78.95 | 86.54 | +| MU-Kindai/Japanese-SimCSE-BERT-large-unsup | 73.13 | 76.36 | 52.75 | 76.88 | 86.51 | +| cl-nagoya/unsup-simcse-ja-base | 73.07 | 73.30 | 53.93 | 79.07 | 85.97 | +| OpenAI/text-embedding-3-small | 73.06 | 70.01 | 55.92 | 77.66 | 88.67 | +| MU-Kindai/Japanese-MixCSE-BERT-base | 72.90 | 77.62 | 50.86 | 77.19 | 85.93 | +| intfloat/multilingual-e5-large | 72.89 | 70.66 | 56.54 | 75.78 | 88.59 | +| MU-Kindai/Japanese-SimCSE-BERT-base-sup | 72.76 | 76.20 | 52.06 | 77.89 | 84.90 | +| sentence-transformers/LaBSE | 72.66 | 73.61 | 51.70 | 76.99 | 88.35 | +| sentence-transformers/stsb-xlm-r-multilingual | 71.84 | 75.65 | 51.32 | 74.28 | 86.10 | +| pkshatech/simcse-ja-bert-base-clcmlp | 71.30 | 67.49 | 50.85 | 79.67 | 87.20 | +| OpenAI/text-embedding-ada-002 | 69.75 | 64.42 | 53.13 | 74.57 | 86.89 | +| intfloat/multilingual-e5-base | 69.30 | 63.67 | 54.24 | 72.78 | 86.53 | +| colorfulscoop/sbert-base-ja | 69.07 | 72.21 | 47.95 | 72.52 | 83.62 | +| intfloat/multilingual-e5-small | 67.62 | 62.14 | 51.27 | 70.85 | 86.22 | + +## Reranking +| Model | Avg. | esci
(ndcg@10) | +|:----------------------------------------------|:----------|:--------------------| +| OpenAI/text-embedding-3-large | **93.58** | **93.58** | +| OpenAI/text-embedding-ada-002 | 93.04 | 93.04 | +| intfloat/multilingual-e5-small | 93.03 | 93.03 | +| intfloat/multilingual-e5-large | 92.96 | 92.96 | +| OpenAI/text-embedding-3-small | 92.92 | 92.92 | +| intfloat/multilingual-e5-base | 92.85 | 92.85 | +| pkshatech/GLuCoSE-base-ja | 91.90 | 91.90 | +| cl-nagoya/sup-simcse-ja-base | 91.83 | 91.83 | +| sentence-transformers/LaBSE | 91.63 | 91.63 | +| oshizo/sbert-jsnli-luke-japanese-base-lite | 91.56 | 91.56 | +| pkshatech/simcse-ja-bert-base-clcmlp | 91.49 | 91.49 | +| cl-nagoya/sup-simcse-ja-large | 91.48 | 91.48 | +| MU-Kindai/Japanese-SimCSE-BERT-large-unsup | 91.30 | 91.30 | +| MU-Kindai/Japanese-SimCSE-BERT-base-unsup | 91.16 | 91.16 | +| cl-nagoya/unsup-simcse-ja-base | 91.16 | 91.16 | +| MU-Kindai/Japanese-MixCSE-BERT-base | 91.01 | 91.01 | +| cl-nagoya/unsup-simcse-ja-large | 90.95 | 90.95 | +| MU-Kindai/Japanese-DiffCSE-BERT-base | 90.95 | 90.95 | +| MU-Kindai/Japanese-SimCSE-BERT-large-sup | 90.95 | 90.95 | +| MU-Kindai/Japanese-SimCSE-BERT-base-sup | 90.66 | 90.66 | +| sentence-transformers/stsb-xlm-r-multilingual | 90.20 | 90.20 | +| colorfulscoop/sbert-base-ja | 89.97 | 89.97 | + +## Clustering +| Model | Avg. | livedoor_news
(v_measure_score) | mewsc16
(v_measure_score) | +|:----------------------------------------------|:----------|:-------------------------------------|:-------------------------------| +| cl-nagoya/sup-simcse-ja-base | **51.79** | 52.67 | 50.91 | +| intfloat/multilingual-e5-large | 51.24 | **57.13** | 45.34 | +| cl-nagoya/sup-simcse-ja-large | 50.56 | 50.75 | 50.38 | +| oshizo/sbert-jsnli-luke-japanese-base-lite | 50.33 | 46.77 | **53.89** | +| pkshatech/GLuCoSE-base-ja | 49.78 | 49.89 | 49.68 | +| cl-nagoya/unsup-simcse-ja-large | 48.41 | 50.90 | 45.92 | +| intfloat/multilingual-e5-base | 48.26 | 55.03 | 41.49 | +| MU-Kindai/Japanese-SimCSE-BERT-large-unsup | 48.25 | 53.20 | 43.31 | +| pkshatech/simcse-ja-bert-base-clcmlp | 47.53 | 44.77 | 50.30 | +| intfloat/multilingual-e5-small | 46.91 | 54.70 | 39.12 | +| MU-Kindai/Japanese-SimCSE-BERT-base-unsup | 46.68 | 53.02 | 40.35 | +| MU-Kindai/Japanese-SimCSE-BERT-large-sup | 45.81 | 48.45 | 43.17 | +| MU-Kindai/Japanese-MixCSE-BERT-base | 44.95 | 52.62 | 37.28 | +| sentence-transformers/LaBSE | 44.88 | 48.29 | 41.47 | +| colorfulscoop/sbert-base-ja | 44.81 | 42.99 | 46.64 | +| cl-nagoya/unsup-simcse-ja-base | 44.77 | 52.23 | 37.31 | +| MU-Kindai/Japanese-DiffCSE-BERT-base | 44.22 | 49.67 | 38.77 | +| MU-Kindai/Japanese-SimCSE-BERT-base-sup | 43.11 | 41.04 | 45.18 | +| sentence-transformers/stsb-xlm-r-multilingual | 27.46 | 24.49 | 30.43 | +| OpenAI/text-embedding-3-large | 27.29 | 5.02 | 49.55 | +| OpenAI/text-embedding-ada-002 | 26.47 | 6.03 | 46.92 | +| OpenAI/text-embedding-3-small | 26.34 | 5.13 | 47.55 | + +## PairClassification +| Model | Avg. | paws_x_ja
(binary_f1) | +|:----------------------------------------------|:----------|:---------------------------| +| pkshatech/GLuCoSE-base-ja | **66.39** | **66.39** | +| cl-nagoya/sup-simcse-ja-base | 62.57 | 62.57 | +| cl-nagoya/sup-simcse-ja-large | 62.51 | 62.51 | +| cl-nagoya/unsup-simcse-ja-large | 62.49 | 62.49 | +| cl-nagoya/unsup-simcse-ja-base | 62.44 | 62.44 | +| pkshatech/simcse-ja-bert-base-clcmlp | 62.40 | 62.40 | +| OpenAI/text-embedding-ada-002 | 62.40 | 62.40 | +| MU-Kindai/Japanese-SimCSE-BERT-base-unsup | 62.38 | 62.38 | +| oshizo/sbert-jsnli-luke-japanese-base-lite | 62.38 | 62.38 | +| MU-Kindai/Japanese-DiffCSE-BERT-base | 62.38 | 62.38 | +| MU-Kindai/Japanese-SimCSE-BERT-base-sup | 62.37 | 62.37 | +| MU-Kindai/Japanese-SimCSE-BERT-large-sup | 62.35 | 62.35 | +| OpenAI/text-embedding-3-large | 62.35 | 62.35 | +| MU-Kindai/Japanese-MixCSE-BERT-base | 62.33 | 62.33 | +| sentence-transformers/LaBSE | 62.33 | 62.33 | +| colorfulscoop/sbert-base-ja | 62.31 | 62.31 | +| OpenAI/text-embedding-3-small | 62.27 | 62.27 | +| MU-Kindai/Japanese-SimCSE-BERT-large-unsup | 62.27 | 62.27 | +| intfloat/multilingual-e5-base | 62.26 | 62.26 | +| sentence-transformers/stsb-xlm-r-multilingual | 62.20 | 62.20 | +| intfloat/multilingual-e5-small | 62.19 | 62.19 | +| intfloat/multilingual-e5-large | 62.15 | 62.15 | + diff --git a/make_leaderboard.py b/make_leaderboard.py new file mode 100644 index 0000000..ff3a330 --- /dev/null +++ b/make_leaderboard.py @@ -0,0 +1,103 @@ +import json +from collections import defaultdict +from pathlib import Path + +from tabulate import tabulate + +dataset_name_aliases = { + "amazon_counterfactual_classification": "amazon_counterfactual", + "amazon_review_classification": "amazon_review", + "massive_intent_classification": "massive_intent", + "massive_scenario_classification": "massive_scenario", +} + +TASK_ORDER = ["Retrieval", "STS", "Classification", "Reranking", "Clustering", "PairClassification"] +SUMMARY_KEY = "Summary" + +""" +Collects the results from the results folder. +""" +# {task_name: {model_signature: {(dataset_name, metric_name): score}}} +all_results: dict[str, dict[str, dict[str, float]]] = defaultdict(lambda: defaultdict(dict)) +for summary_file in Path("docs/results").rglob("summary.json"): + if not summary_file.exists(): + continue + + with open(summary_file) as f: + summary = json.load(f) + + org_name = summary_file.parent.parent.name + model_name = summary_file.parent.name + model_signature = f"{org_name}/{model_name}" + + for task_name, task_results in summary.items(): + task_results_formatted: dict[str, float] = {} + task_scores: list[float] = [] + for dataset_name, metric_dict in task_results.items(): + metric_name, score = next(iter(metric_dict.items())) + dataset_name = dataset_name_aliases.get(dataset_name, dataset_name) + task_results_formatted[f"{dataset_name}
({metric_name})"] = score + task_scores.append(score) + all_results[task_name][model_signature] = task_results_formatted + all_results[SUMMARY_KEY][model_signature][task_name] = sum(task_scores) / len(task_scores) + +""" +Creates markdown tables for each task. +""" + + +def format_score(score: float) -> str: + return f"{score * 100:.2f}" + + +AVG_COLUMN_NAME = "Avg." +markdown_tables: dict[str, str] = {} +for task_name, task_results in all_results.items(): + # format to markdown table + dataset_keys = list(task_results[next(iter(task_results))].keys()) + if task_name == SUMMARY_KEY: + dataset_keys = TASK_ORDER + + header = ["Model", AVG_COLUMN_NAME, *dataset_keys] + table_list: list[list[str | float]] = [] + for model_signature, dataset_scores in task_results.items(): + model_scores = [dataset_scores[k] for k in dataset_keys] + average_score = sum(model_scores) / len(model_scores) + table_list.append([model_signature, average_score, *model_scores]) + + # sort by the average score + avg_idx = header.index(AVG_COLUMN_NAME) + table_list.sort(key=lambda x: x[avg_idx], reverse=True) + + # make the highest score in each dataset bold + for dataset_name in [AVG_COLUMN_NAME, *dataset_keys]: + task_idx = header.index(dataset_name) + max_score = max(row[task_idx] for row in table_list) + for row in table_list: + if row[task_idx] == max_score: + row[task_idx] = f"**{format_score(row[task_idx])}**" + else: + row[task_idx] = format_score(row[task_idx]) + + # add header + table_list.insert(0, ["Model", AVG_COLUMN_NAME, *dataset_keys]) + markdown_table = tabulate(table_list, headers="firstrow", tablefmt="pipe") + markdown_tables[task_name] = markdown_table + +""" +Dump the markdown tables to a file. +""" +with open("leaderboard.md", "w") as f: + f.write("# Leaderboard\n") + f.write( + "This leaderboard shows the results stored under `docs/results`. The scores are all multiplied by 100.\n\n" + ) + for task_name in [SUMMARY_KEY, *TASK_ORDER]: + markdown_table = markdown_tables[task_name] + f.write(f"## {task_name}\n") + + if task_name == SUMMARY_KEY: + f.write("\nThe summary shows the average scores within each task.\n\n") + + f.write(markdown_table) + f.write("\n\n") diff --git a/poetry.lock b/poetry.lock index 8937812..40fbe9f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,36 @@ # This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +[[package]] +name = "accelerate" +version = "0.31.0" +description = "Accelerate" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "accelerate-0.31.0-py3-none-any.whl", hash = "sha256:0fc608dc49584f64d04711a39711d73cb0ad4ef3d21cddee7ef2216e29471144"}, + {file = "accelerate-0.31.0.tar.gz", hash = "sha256:b5199865b26106ccf9205acacbe8e4b3b428ad585e7c472d6a46f6fb75b6c176"}, +] + +[package.dependencies] +huggingface-hub = "*" +numpy = ">=1.17" +packaging = ">=20.0" +psutil = "*" +pyyaml = "*" +safetensors = ">=0.3.1" +torch = ">=1.10.0" + +[package.extras] +deepspeed = ["deepspeed (<=0.14.0)"] +dev = ["bitsandbytes", "black (>=23.1,<24.0)", "datasets", "diffusers", "evaluate", "hf-doc-builder (>=0.3.0)", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "rich", "ruff (>=0.2.1,<0.3.0)", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] +quality = ["black (>=23.1,<24.0)", "hf-doc-builder (>=0.3.0)", "ruff (>=0.2.1,<0.3.0)"] +rich = ["rich"] +sagemaker = ["sagemaker"] +test-dev = ["bitsandbytes", "datasets", "diffusers", "evaluate", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] +test-prod = ["parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist"] +test-trackers = ["comet-ml", "dvclive", "tensorboard", "wandb"] +testing = ["bitsandbytes", "datasets", "diffusers", "evaluate", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] + [[package]] name = "aiohttp" version = "3.9.5" @@ -343,35 +374,6 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} -[[package]] -name = "cmake" -version = "3.29.3" -description = "CMake is an open-source, cross-platform family of tools designed to build, test and package software" -optional = false -python-versions = ">=3.7" -files = [ - {file = "cmake-3.29.3-py3-none-macosx_10_10_universal2.macosx_10_10_x86_64.macosx_11_0_arm64.macosx_11_0_universal2.whl", hash = "sha256:355f515826023338094514a2181724e297ed2145bc0792dacaa9ed3772b98733"}, - {file = "cmake-3.29.3-py3-none-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:ab5eb91e7f5bbfc2f0e23c964c3a3e74c6e6a26e9b59b57b87192d249b1b7162"}, - {file = "cmake-3.29.3-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ae9e5dcd77822f89e042ad820ef25a52327bb0d15fd7a492ad4886edb31fae52"}, - {file = "cmake-3.29.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b09d1f0f46a880fdfc50374917fd4c850d9428b244535343bb5411658a36e202"}, - {file = "cmake-3.29.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8d05cf16a6fb370cc344b3552ab321524cba1f067da240876c09cab571bf6ec0"}, - {file = "cmake-3.29.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6c0a23fbb3daeecdc42d233c1a2df233714c2db59e75ab154e2af469c1c308a5"}, - {file = "cmake-3.29.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1037218e135302f396eca444e24ca892d8a440589f1a859313e06484f10c350f"}, - {file = "cmake-3.29.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c84eead2ea6f596fe5ac58beedbfc9bc1f460c410c481348b3783b4794f4b1a2"}, - {file = "cmake-3.29.3-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:e1fd53ca2f24dc0aad54934c2472cb83e273b94b4bad23fcdbd438515881f5a7"}, - {file = "cmake-3.29.3-py3-none-musllinux_1_1_i686.whl", hash = "sha256:00225a2be8422d4b6f2ad2da10d7dfe2ad844748bd1defa94f236bfabb0d2d44"}, - {file = "cmake-3.29.3-py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:28fe371f1865943118a0f669af87344c799751f85a5be084197c006ee6329d89"}, - {file = "cmake-3.29.3-py3-none-musllinux_1_1_s390x.whl", hash = "sha256:ad184528fa9560bf4167279e8e4e7168a5fa1cc87a9f0b4b99ffbc79588b0cf9"}, - {file = "cmake-3.29.3-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:40cd0ec1310e52fa29b4e2b07829d56ae95f01ea0b2479ece359259849269f86"}, - {file = "cmake-3.29.3-py3-none-win32.whl", hash = "sha256:a2c15ab9e4922d71d98a6495a5fd661dd00b3d4ada79a3d183f996fff45db011"}, - {file = "cmake-3.29.3-py3-none-win_amd64.whl", hash = "sha256:dd8aaffe5d8dc2dd41421dc63c39b64df30a7109392e276e2b6d021805b770e9"}, - {file = "cmake-3.29.3-py3-none-win_arm64.whl", hash = "sha256:6672a873855e9a8f954390d0352c1d09b034a36b5f4cc5da012ae292f28623f7"}, - {file = "cmake-3.29.3.tar.gz", hash = "sha256:d04adb1a8b878e92a734742cb0db9c59e3828abcf8ec9c930eb8a01faa00c9df"}, -] - -[package.extras] -test = ["coverage (>=4.2)", "pytest (>=3.0.3)", "pytest-cov (>=2.4.0)"] - [[package]] name = "colorama" version = "0.4.6" @@ -801,6 +803,20 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "intel-openmp" +version = "2021.4.0" +description = "Intel OpenMP* Runtime Library" +optional = false +python-versions = "*" +files = [ + {file = "intel_openmp-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:41c01e266a7fdb631a7609191709322da2bbf24b252ba763f125dd651bcc7675"}, + {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:3b921236a38384e2016f0f3d65af6732cf2c12918087128a9163225451e776f2"}, + {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:e2240ab8d01472fed04f3544a878cda5da16c26232b7ea1b59132dbfb48b186e"}, + {file = "intel_openmp-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:6e863d8fd3d7e8ef389d52cf97a50fe2afe1a19247e8c0d168ce021546f96fc9"}, + {file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"}, +] + [[package]] name = "ipadic" version = "1.0.0" @@ -939,17 +955,6 @@ files = [ {file = "jsonnet_binary-0.17.0-pp37-pypy37_pp73-win32.whl", hash = "sha256:846735c55cf704acb071932dd2c4a22afc7cc77b0a90884080e97f58c7df75a0"}, ] -[[package]] -name = "lit" -version = "18.1.7" -description = "A Software Testing Tool" -optional = false -python-versions = "*" -files = [ - {file = "lit-18.1.7-py3-none-any.whl", hash = "sha256:684629e3af788bd0a61ca253a2dbb46bda8fd40c9022e3925d8ff067b67549f7"}, - {file = "lit-18.1.7.tar.gz", hash = "sha256:2ddd9be26bdcc6da03aea3ec456c6945eb5a09dbde548d3500bff9b8ed4763bb"}, -] - [[package]] name = "loguru" version = "0.7.2" @@ -1048,6 +1053,24 @@ files = [ {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, ] +[[package]] +name = "mkl" +version = "2021.4.0" +description = "Intel® oneAPI Math Kernel Library" +optional = false +python-versions = "*" +files = [ + {file = "mkl-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:67460f5cd7e30e405b54d70d1ed3ca78118370b65f7327d495e9c8847705e2fb"}, + {file = "mkl-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:636d07d90e68ccc9630c654d47ce9fdeb036bb46e2b193b3a9ac8cfea683cce5"}, + {file = "mkl-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:398dbf2b0d12acaf54117a5210e8f191827f373d362d796091d161f610c1ebfb"}, + {file = "mkl-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:439c640b269a5668134e3dcbcea4350459c4a8bc46469669b2d67e07e3d330e8"}, + {file = "mkl-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:ceef3cafce4c009dd25f65d7ad0d833a0fbadc3d8903991ec92351fe5de1e718"}, +] + +[package.dependencies] +intel-openmp = "==2021.*" +tbb = "==2021.*" + [[package]] name = "mpmath" version = "1.3.0" @@ -1310,162 +1333,146 @@ files = [ ] [[package]] -name = "nvidia-cublas-cu11" -version = "11.10.3.66" +name = "nvidia-cublas-cu12" +version = "12.1.3.1" description = "CUBLAS native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl", hash = "sha256:d32e4d75f94ddfb93ea0a5dda08389bcc65d8916a25cb9f37ac89edaeed3bded"}, - {file = "nvidia_cublas_cu11-11.10.3.66-py3-none-win_amd64.whl", hash = "sha256:8ac17ba6ade3ed56ab898a036f9ae0756f1e81052a317bf98f8c6d18dc3ae49e"}, + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"}, + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"}, ] -[package.dependencies] -setuptools = "*" -wheel = "*" - [[package]] -name = "nvidia-cuda-cupti-cu11" -version = "11.7.101" +name = "nvidia-cuda-cupti-cu12" +version = "12.1.105" description = "CUDA profiling tools runtime libs." optional = false python-versions = ">=3" files = [ - {file = "nvidia_cuda_cupti_cu11-11.7.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:e0cfd9854e1f2edaa36ca20d21cd0bdd5dcfca4e3b9e130a082e05b33b6c5895"}, - {file = "nvidia_cuda_cupti_cu11-11.7.101-py3-none-win_amd64.whl", hash = "sha256:7cc5b8f91ae5e1389c3c0ad8866b3b016a175e827ea8f162a672990a402ab2b0"}, + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"}, + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"}, ] -[package.dependencies] -setuptools = "*" -wheel = "*" - [[package]] -name = "nvidia-cuda-nvrtc-cu11" -version = "11.7.99" +name = "nvidia-cuda-nvrtc-cu12" +version = "12.1.105" description = "NVRTC native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:9f1562822ea264b7e34ed5930567e89242d266448e936b85bc97a3370feabb03"}, - {file = "nvidia_cuda_nvrtc_cu11-11.7.99-py3-none-manylinux1_x86_64.whl", hash = "sha256:f7d9610d9b7c331fa0da2d1b2858a4a8315e6d49765091d28711c8946e7425e7"}, - {file = "nvidia_cuda_nvrtc_cu11-11.7.99-py3-none-win_amd64.whl", hash = "sha256:f2effeb1309bdd1b3854fc9b17eaf997808f8b25968ce0c7070945c4265d64a3"}, + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"}, + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"}, ] -[package.dependencies] -setuptools = "*" -wheel = "*" - [[package]] -name = "nvidia-cuda-runtime-cu11" -version = "11.7.99" +name = "nvidia-cuda-runtime-cu12" +version = "12.1.105" description = "CUDA Runtime native Libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl", hash = "sha256:cc768314ae58d2641f07eac350f40f99dcb35719c4faff4bc458a7cd2b119e31"}, - {file = "nvidia_cuda_runtime_cu11-11.7.99-py3-none-win_amd64.whl", hash = "sha256:bc77fa59a7679310df9d5c70ab13c4e34c64ae2124dd1efd7e5474b71be125c7"}, + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"}, + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"}, ] -[package.dependencies] -setuptools = "*" -wheel = "*" - [[package]] -name = "nvidia-cudnn-cu11" -version = "8.5.0.96" +name = "nvidia-cudnn-cu12" +version = "8.9.2.26" description = "cuDNN runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:402f40adfc6f418f9dae9ab402e773cfed9beae52333f6d86ae3107a1b9527e7"}, - {file = "nvidia_cudnn_cu11-8.5.0.96-py3-none-manylinux1_x86_64.whl", hash = "sha256:71f8111eb830879ff2836db3cccf03bbd735df9b0d17cd93761732ac50a8a108"}, + {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"}, ] [package.dependencies] -setuptools = "*" -wheel = "*" +nvidia-cublas-cu12 = "*" [[package]] -name = "nvidia-cufft-cu11" -version = "10.9.0.58" +name = "nvidia-cufft-cu12" +version = "11.0.2.54" description = "CUFFT native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cufft_cu11-10.9.0.58-py3-none-manylinux1_x86_64.whl", hash = "sha256:222f9da70c80384632fd6035e4c3f16762d64ea7a843829cb278f98b3cb7dd81"}, - {file = "nvidia_cufft_cu11-10.9.0.58-py3-none-win_amd64.whl", hash = "sha256:c4d316f17c745ec9c728e30409612eaf77a8404c3733cdf6c9c1569634d1ca03"}, + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"}, + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"}, ] [[package]] -name = "nvidia-curand-cu11" -version = "10.2.10.91" +name = "nvidia-curand-cu12" +version = "10.3.2.106" description = "CURAND native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_curand_cu11-10.2.10.91-py3-none-manylinux1_x86_64.whl", hash = "sha256:eecb269c970fa599a2660c9232fa46aaccbf90d9170b96c462e13bcb4d129e2c"}, - {file = "nvidia_curand_cu11-10.2.10.91-py3-none-win_amd64.whl", hash = "sha256:f742052af0e1e75523bde18895a9ed016ecf1e5aa0ecddfcc3658fd11a1ff417"}, + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"}, + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"}, ] -[package.dependencies] -setuptools = "*" -wheel = "*" - [[package]] -name = "nvidia-cusolver-cu11" -version = "11.4.0.1" +name = "nvidia-cusolver-cu12" +version = "11.4.5.107" description = "CUDA solver native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cusolver_cu11-11.4.0.1-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:72fa7261d755ed55c0074960df5904b65e2326f7adce364cbe4945063c1be412"}, - {file = "nvidia_cusolver_cu11-11.4.0.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:700b781bfefd57d161443aff9ace1878584b93e0b2cfef3d6e9296d96febbf99"}, - {file = "nvidia_cusolver_cu11-11.4.0.1-py3-none-win_amd64.whl", hash = "sha256:00f70b256add65f8c1eb3b6a65308795a93e7740f6df9e273eccbba770d370c4"}, + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"}, + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"}, ] [package.dependencies] -setuptools = "*" -wheel = "*" +nvidia-cublas-cu12 = "*" +nvidia-cusparse-cu12 = "*" +nvidia-nvjitlink-cu12 = "*" [[package]] -name = "nvidia-cusparse-cu11" -version = "11.7.4.91" +name = "nvidia-cusparse-cu12" +version = "12.1.0.106" description = "CUSPARSE native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cusparse_cu11-11.7.4.91-py3-none-manylinux1_x86_64.whl", hash = "sha256:a3389de714db63321aa11fbec3919271f415ef19fda58aed7f2ede488c32733d"}, - {file = "nvidia_cusparse_cu11-11.7.4.91-py3-none-win_amd64.whl", hash = "sha256:304a01599534f5186a8ed1c3756879282c72c118bc77dd890dc1ff868cad25b9"}, + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"}, + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"}, ] [package.dependencies] -setuptools = "*" -wheel = "*" +nvidia-nvjitlink-cu12 = "*" [[package]] -name = "nvidia-nccl-cu11" -version = "2.14.3" +name = "nvidia-nccl-cu12" +version = "2.20.5" description = "NVIDIA Collective Communication Library (NCCL) Runtime" optional = false python-versions = ">=3" files = [ - {file = "nvidia_nccl_cu11-2.14.3-py3-none-manylinux1_x86_64.whl", hash = "sha256:5e5534257d1284b8e825bc3a182c6f06acd6eb405e9f89d49340e98cd8f136eb"}, + {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01"}, + {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56"}, ] [[package]] -name = "nvidia-nvtx-cu11" -version = "11.7.91" -description = "NVIDIA Tools Extension" +name = "nvidia-nvjitlink-cu12" +version = "12.5.82" +description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" files = [ - {file = "nvidia_nvtx_cu11-11.7.91-py3-none-manylinux1_x86_64.whl", hash = "sha256:b22c64eee426a62fc00952b507d6d29cf62b4c9df7a480fcc417e540e05fd5ac"}, - {file = "nvidia_nvtx_cu11-11.7.91-py3-none-win_amd64.whl", hash = "sha256:dfd7fcb2a91742513027d63a26b757f38dd8b07fecac282c4d132a9d373ff064"}, + {file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212"}, + {file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-win_amd64.whl", hash = "sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697"}, ] -[package.dependencies] -setuptools = "*" -wheel = "*" +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.1.105" +description = "NVIDIA Tools Extension" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"}, + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, +] [[package]] name = "openai" @@ -1733,6 +1740,35 @@ files = [ {file = "protobuf-5.27.1.tar.gz", hash = "sha256:df5e5b8e39b7d1c25b186ffdf9f44f40f810bbcc9d2b71d9d3156fee5a9adf15"}, ] +[[package]] +name = "psutil" +version = "6.0.0" +description = "Cross-platform lib for process and system monitoring in Python." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +files = [ + {file = "psutil-6.0.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:a021da3e881cd935e64a3d0a20983bda0bb4cf80e4f74fa9bfcb1bc5785360c6"}, + {file = "psutil-6.0.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:1287c2b95f1c0a364d23bc6f2ea2365a8d4d9b726a3be7294296ff7ba97c17f0"}, + {file = "psutil-6.0.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:a9a3dbfb4de4f18174528d87cc352d1f788b7496991cca33c6996f40c9e3c92c"}, + {file = "psutil-6.0.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:6ec7588fb3ddaec7344a825afe298db83fe01bfaaab39155fa84cf1c0d6b13c3"}, + {file = "psutil-6.0.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:1e7c870afcb7d91fdea2b37c24aeb08f98b6d67257a5cb0a8bc3ac68d0f1a68c"}, + {file = "psutil-6.0.0-cp27-none-win32.whl", hash = "sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35"}, + {file = "psutil-6.0.0-cp27-none-win_amd64.whl", hash = "sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1"}, + {file = "psutil-6.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132"}, + {file = "psutil-6.0.0-cp36-cp36m-win32.whl", hash = "sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14"}, + {file = "psutil-6.0.0-cp36-cp36m-win_amd64.whl", hash = "sha256:34859b8d8f423b86e4385ff3665d3f4d94be3cdf48221fbe476e883514fdb71c"}, + {file = "psutil-6.0.0-cp37-abi3-win32.whl", hash = "sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d"}, + {file = "psutil-6.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3"}, + {file = "psutil-6.0.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0"}, + {file = "psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2"}, +] + +[package.extras] +test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] + [[package]] name = "py" version = "1.11.0" @@ -2478,21 +2514,6 @@ files = [ {file = "sentencepiece-0.2.0.tar.gz", hash = "sha256:a52c19171daaf2e697dc6cbe67684e0fa341b1248966f6aebb541de654d15843"}, ] -[[package]] -name = "setuptools" -version = "70.0.0" -description = "Easily download, build, install, upgrade, and uninstall Python packages" -optional = false -python-versions = ">=3.8" -files = [ - {file = "setuptools-70.0.0-py3-none-any.whl", hash = "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4"}, - {file = "setuptools-70.0.0.tar.gz", hash = "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0"}, -] - -[package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -testing = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] - [[package]] name = "six" version = "1.16.0" @@ -2599,6 +2620,33 @@ files = [ [package.dependencies] mpmath = ">=1.1.0,<1.4.0" +[[package]] +name = "tabulate" +version = "0.9.0" +description = "Pretty-print tabular data" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"}, + {file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"}, +] + +[package.extras] +widechars = ["wcwidth"] + +[[package]] +name = "tbb" +version = "2021.13.0" +description = "Intel® oneAPI Threading Building Blocks (oneTBB)" +optional = false +python-versions = "*" +files = [ + {file = "tbb-2021.13.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:a2567725329639519d46d92a2634cf61e76601dac2f777a05686fea546c4fe4f"}, + {file = "tbb-2021.13.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:aaf667e92849adb012b8874d6393282afc318aca4407fc62f912ee30a22da46a"}, + {file = "tbb-2021.13.0-py3-none-win32.whl", hash = "sha256:6669d26703e9943f6164c6407bd4a237a45007e79b8d3832fe6999576eaaa9ef"}, + {file = "tbb-2021.13.0-py3-none-win_amd64.whl", hash = "sha256:3528a53e4bbe64b07a6112b4c5a00ff3c61924ee46c9c68e004a1ac7ad1f09c3"}, +] + [[package]] name = "threadpoolctl" version = "3.5.0" @@ -2792,58 +2840,57 @@ files = [ [[package]] name = "torch" -version = "2.0.0" +version = "2.3.1" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" files = [ - {file = "torch-2.0.0-1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:c9090bda7d2eeeecd74f51b721420dbeb44f838d4536cc1b284e879417e3064a"}, - {file = "torch-2.0.0-1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:bd42db2a48a20574d2c33489e120e9f32789c4dc13c514b0c44272972d14a2d7"}, - {file = "torch-2.0.0-1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:8969aa8375bcbc0c2993e7ede0a7f889df9515f18b9b548433f412affed478d9"}, - {file = "torch-2.0.0-1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:ab2da16567cb55b67ae39e32d520d68ec736191d88ac79526ca5874754c32203"}, - {file = "torch-2.0.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:7a9319a67294ef02459a19738bbfa8727bb5307b822dadd708bc2ccf6c901aca"}, - {file = "torch-2.0.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:9f01fe1f6263f31bd04e1757946fd63ad531ae37f28bb2dbf66f5c826ee089f4"}, - {file = "torch-2.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:527f4ae68df7b8301ee6b1158ca56350282ea633686537b30dbb5d7b4a52622a"}, - {file = "torch-2.0.0-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:ce9b5a49bd513dff7950a5a07d6e26594dd51989cee05ba388b03e8e366fd5d5"}, - {file = "torch-2.0.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:53e1c33c6896583cdb9a583693e22e99266444c4a43392dddc562640d39e542b"}, - {file = "torch-2.0.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:09651bff72e439d004c991f15add0c397c66f98ab36fe60d5514b44e4da722e8"}, - {file = "torch-2.0.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:d439aec349c98f12819e8564b8c54008e4613dd4428582af0e6e14c24ca85870"}, - {file = "torch-2.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:2802f84f021907deee7e9470ed10c0e78af7457ac9a08a6cd7d55adef835fede"}, - {file = "torch-2.0.0-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:01858620f25f25e7a9ec4b547ff38e5e27c92d38ec4ccba9cfbfb31d7071ed9c"}, - {file = "torch-2.0.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:9a2e53b5783ef5896a6af338b36d782f28e83c8ddfc2ac44b67b066d9d76f498"}, - {file = "torch-2.0.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:ec5fff2447663e369682838ff0f82187b4d846057ef4d119a8dea7772a0b17dd"}, - {file = "torch-2.0.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:11b0384fe3c18c01b8fc5992e70fc519cde65e44c51cc87be1838c1803daf42f"}, - {file = "torch-2.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:e54846aa63855298cfb1195487f032e413e7ac9cbfa978fda32354cc39551475"}, - {file = "torch-2.0.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:cc788cbbbbc6eb4c90e52c550efd067586c2693092cf367c135b34893a64ae78"}, - {file = "torch-2.0.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:d292640f0fd72b7a31b2a6e3b635eb5065fcbedd4478f9cad1a1e7a9ec861d35"}, - {file = "torch-2.0.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:6befaad784004b7af357e3d87fa0863c1f642866291f12a4c2af2de435e8ac5c"}, - {file = "torch-2.0.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:a83b26bd6ae36fbf5fee3d56973d9816e2002e8a3b7d9205531167c28aaa38a7"}, - {file = "torch-2.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:c7e67195e1c3e33da53954b026e89a8e1ff3bc1aeb9eb32b677172d4a9b5dcbf"}, - {file = "torch-2.0.0-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6e0b97beb037a165669c312591f242382e9109a240e20054d5a5782d9236cad0"}, - {file = "torch-2.0.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:297a4919aff1c0f98a58ebe969200f71350a1d4d4f986dbfd60c02ffce780e99"}, + {file = "torch-2.3.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:605a25b23944be5ab7c3467e843580e1d888b8066e5aaf17ff7bf9cc30001cc3"}, + {file = "torch-2.3.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:f2357eb0965583a0954d6f9ad005bba0091f956aef879822274b1bcdb11bd308"}, + {file = "torch-2.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:32b05fe0d1ada7f69c9f86c14ff69b0ef1957a5a54199bacba63d22d8fab720b"}, + {file = "torch-2.3.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:7c09a94362778428484bcf995f6004b04952106aee0ef45ff0b4bab484f5498d"}, + {file = "torch-2.3.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:b2ec81b61bb094ea4a9dee1cd3f7b76a44555375719ad29f05c0ca8ef596ad39"}, + {file = "torch-2.3.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:490cc3d917d1fe0bd027057dfe9941dc1d6d8e3cae76140f5dd9a7e5bc7130ab"}, + {file = "torch-2.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:5802530783bd465fe66c2df99123c9a54be06da118fbd785a25ab0a88123758a"}, + {file = "torch-2.3.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:a7dd4ed388ad1f3d502bf09453d5fe596c7b121de7e0cfaca1e2017782e9bbac"}, + {file = "torch-2.3.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:a486c0b1976a118805fc7c9641d02df7afbb0c21e6b555d3bb985c9f9601b61a"}, + {file = "torch-2.3.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:224259821fe3e4c6f7edf1528e4fe4ac779c77addaa74215eb0b63a5c474d66c"}, + {file = "torch-2.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:e5fdccbf6f1334b2203a61a0e03821d5845f1421defe311dabeae2fc8fbeac2d"}, + {file = "torch-2.3.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:3c333dc2ebc189561514eda06e81df22bf8fb64e2384746b2cb9f04f96d1d4c8"}, + {file = "torch-2.3.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:07e9ba746832b8d069cacb45f312cadd8ad02b81ea527ec9766c0e7404bb3feb"}, + {file = "torch-2.3.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:462d1c07dbf6bb5d9d2f3316fee73a24f3d12cd8dacf681ad46ef6418f7f6626"}, + {file = "torch-2.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:ff60bf7ce3de1d43ad3f6969983f321a31f0a45df3690921720bcad6a8596cc4"}, + {file = "torch-2.3.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:bee0bd33dc58aa8fc8a7527876e9b9a0e812ad08122054a5bff2ce5abf005b10"}, + {file = "torch-2.3.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:aaa872abde9a3d4f91580f6396d54888620f4a0b92e3976a6034759df4b961ad"}, + {file = "torch-2.3.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:3d7a7f7ef21a7520510553dc3938b0c57c116a7daee20736a9e25cbc0e832bdc"}, + {file = "torch-2.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:4777f6cefa0c2b5fa87223c213e7b6f417cf254a45e5829be4ccd1b2a4ee1011"}, + {file = "torch-2.3.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:2bb5af780c55be68fe100feb0528d2edebace1d55cb2e351de735809ba7391eb"}, ] [package.dependencies] filelock = "*" +fsspec = "*" jinja2 = "*" +mkl = {version = ">=2021.1.1,<=2021.4.0", markers = "platform_system == \"Windows\""} networkx = "*" -nvidia-cublas-cu11 = {version = "11.10.3.66", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-cupti-cu11 = {version = "11.7.101", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-nvrtc-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-runtime-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cudnn-cu11 = {version = "8.5.0.96", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cufft-cu11 = {version = "10.9.0.58", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-curand-cu11 = {version = "10.2.10.91", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cusolver-cu11 = {version = "11.4.0.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cusparse-cu11 = {version = "11.7.4.91", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nccl-cu11 = {version = "2.14.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nvtx-cu11 = {version = "11.7.91", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cublas-cu12 = {version = "12.1.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-cupti-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-nvrtc-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-runtime-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cudnn-cu12 = {version = "8.9.2.26", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nccl-cu12 = {version = "2.20.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} sympy = "*" -triton = {version = "2.0.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -typing-extensions = "*" +triton = {version = "2.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""} +typing-extensions = ">=4.8.0" [package.extras] opt-einsum = ["opt-einsum (>=3.3)"] +optree = ["optree (>=0.9.1)"] [[package]] name = "tqdm" @@ -2942,40 +2989,26 @@ vision = ["Pillow (>=10.0.1,<=15.0)"] [[package]] name = "triton" -version = "2.0.0" +version = "2.3.1" description = "A language and compiler for custom Deep Learning operations" optional = false python-versions = "*" files = [ - {file = "triton-2.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:38806ee9663f4b0f7cd64790e96c579374089e58f49aac4a6608121aa55e2505"}, - {file = "triton-2.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:226941c7b8595219ddef59a1fdb821e8c744289a132415ddd584facedeb475b1"}, - {file = "triton-2.0.0-1-cp36-cp36m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4c9fc8c89874bc48eb7e7b2107a9b8d2c0bf139778637be5bfccb09191685cfd"}, - {file = "triton-2.0.0-1-cp37-cp37m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d2684b6a60b9f174f447f36f933e9a45f31db96cb723723ecd2dcfd1c57b778b"}, - {file = "triton-2.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9d4978298b74fcf59a75fe71e535c092b023088933b2f1df933ec32615e4beef"}, - {file = "triton-2.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:74f118c12b437fb2ca25e1a04759173b517582fcf4c7be11913316c764213656"}, - {file = "triton-2.0.0-1-pp37-pypy37_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9618815a8da1d9157514f08f855d9e9ff92e329cd81c0305003eb9ec25cc5add"}, - {file = "triton-2.0.0-1-pp38-pypy38_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1aca3303629cd3136375b82cb9921727f804e47ebee27b2677fef23005c3851a"}, - {file = "triton-2.0.0-1-pp39-pypy39_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e3e13aa8b527c9b642e3a9defcc0fbd8ffbe1c80d8ac8c15a01692478dc64d8a"}, - {file = "triton-2.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f05a7e64e4ca0565535e3d5d3405d7e49f9d308505bb7773d21fb26a4c008c2"}, - {file = "triton-2.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb4b99ca3c6844066e516658541d876c28a5f6e3a852286bbc97ad57134827fd"}, - {file = "triton-2.0.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47b4d70dc92fb40af553b4460492c31dc7d3a114a979ffb7a5cdedb7eb546c08"}, - {file = "triton-2.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fedce6a381901b1547e0e7e1f2546e4f65dca6d91e2d8a7305a2d1f5551895be"}, - {file = "triton-2.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75834f27926eab6c7f00ce73aaf1ab5bfb9bec6eb57ab7c0bfc0a23fac803b4c"}, - {file = "triton-2.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0117722f8c2b579cd429e0bee80f7731ae05f63fe8e9414acd9a679885fcbf42"}, - {file = "triton-2.0.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcd9be5d0c2e45d2b7e6ddc6da20112b6862d69741576f9c3dbaf941d745ecae"}, - {file = "triton-2.0.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42a0d2c3fc2eab4ba71384f2e785fbfd47aa41ae05fa58bf12cb31dcbd0aeceb"}, - {file = "triton-2.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:52c47b72c72693198163ece9d90a721299e4fb3b8e24fd13141e384ad952724f"}, + {file = "triton-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c84595cbe5e546b1b290d2a58b1494df5a2ef066dd890655e5b8a8a92205c33"}, + {file = "triton-2.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9d64ae33bcb3a7a18081e3a746e8cf87ca8623ca13d2c362413ce7a486f893e"}, + {file = "triton-2.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaf80e8761a9e3498aa92e7bf83a085b31959c61f5e8ac14eedd018df6fccd10"}, + {file = "triton-2.3.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b13bf35a2b659af7159bf78e92798dc62d877aa991de723937329e2d382f1991"}, + {file = "triton-2.3.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63381e35ded3304704ea867ffde3b7cfc42c16a55b3062d41e017ef510433d66"}, + {file = "triton-2.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d968264523c7a07911c8fb51b4e0d1b920204dae71491b1fe7b01b62a31e124"}, ] [package.dependencies] -cmake = "*" filelock = "*" -lit = "*" -torch = "*" [package.extras] -tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"] -tutorials = ["matplotlib", "pandas", "tabulate"] +build = ["cmake (>=3.20)", "lit"] +tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)", "torch"] +tutorials = ["matplotlib", "pandas", "tabulate", "torch"] [[package]] name = "typing-extensions" @@ -3053,20 +3086,6 @@ files = [ {file = "wasabi-0.10.1.tar.gz", hash = "sha256:c8e372781be19272942382b14d99314d175518d7822057cb7a97010c4259d249"}, ] -[[package]] -name = "wheel" -version = "0.43.0" -description = "A built-package format for Python" -optional = false -python-versions = ">=3.8" -files = [ - {file = "wheel-0.43.0-py3-none-any.whl", hash = "sha256:55c570405f142630c6b9f72fe09d9b67cf1477fcf543ae5b8dcb1f5b7377da81"}, - {file = "wheel-0.43.0.tar.gz", hash = "sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85"}, -] - -[package.extras] -test = ["pytest (>=6.0.0)", "setuptools (>=65)"] - [[package]] name = "win32-setctime" version = "1.1.0" @@ -3383,4 +3402,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "e40f842f52270eeceaf8810710c4709262b2ca68b3c38ef6117f0e50d81de8ff" +content-hash = "a2c9ed2cef63429fda1482752acb674fe3b39b94498bbe2c177d0b8ac9558c44" diff --git a/pyproject.toml b/pyproject.toml index 2cc9b97..a2eb2b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ description = "The evaluation scripts for JMTEB (Japanese Massive Text Embedding name = "JMTEB" packages = [{from = "src", include = "jmteb"}] readme = "README.md" -version = "1.2.0" +version = "1.3.0" [tool.poetry.dependencies] python = ">=3.10,<4.0" @@ -23,7 +23,7 @@ transformers = {extras = ["ja", "sentencepiece"], version = "^4.38.1"} datasets = ">=2.17" sentence-transformers = "^3.0.0" pytest = "7.1.3" -torch = "2.0.0" # this version is needed to avoid "libcu*" related errors +torch = "^2.3" pydantic = "^2.6.3" eval-type-backport = "^0.1.3" smart-open = "^7.0.1" @@ -31,15 +31,18 @@ openai = "^1.16.2" pytest-mock = "^3.14.0" tiktoken = "^0.6.0" numpy = "^1.26" +accelerate = "^0.31.0" +tabulate = "^0.9.0" [tool.poetry.group.dev.dependencies] black = "^23.11.0" isort = "^5.12.0" mypy = "^1.7.1" flake8 = "^7.0.0" +tabulate = "^0.9.0" [tool.black] line-length = 119 [tool.isort] -profile = "black" +profile = "black" \ No newline at end of file diff --git a/src/jmteb/__main__.py b/src/jmteb/__main__.py index bb9af7f..55e830a 100644 --- a/src/jmteb/__main__.py +++ b/src/jmteb/__main__.py @@ -37,6 +37,10 @@ def main( dataset_name=eval_name, task_name=evaluator.__class__.__name__.replace("Evaluator", ""), ) + if getattr(evaluator, "log_predictions", False): + score_recorder.record_predictions( + metrics, eval_name, evaluator.__class__.__name__.replace("Evaluator", "") + ) logger.info(f"Results for {eval_name}\n{json.dumps(metrics.as_dict(), indent=4, ensure_ascii=False)}") @@ -60,6 +64,9 @@ def main( parser.add_argument("--overwrite_cache", type=bool, default=False, help="Overwrite the save_dir if it exists") parser.add_argument("--eval_include", type=list[str], default=None, help="Evaluators to include.") parser.add_argument("--eval_exclude", type=list[str], default=None, help="Evaluators to exclude.") + parser.add_argument( + "--log_predictions", type=bool, default=False, help="Whether to log predictions for all evaulators." + ) args = parser.parse_args() @@ -99,6 +106,11 @@ def main( f"Please check {args.evaluators}" ) + if args.log_predictions: + for k, v in args.evaluators.items(): + if hasattr(v, "log_predictions"): + args.evaluators[k].log_predictions = True + main( text_embedder=args.embedder, evaluators=args.evaluators, diff --git a/src/jmteb/configs/tasks/mrtydi.jsonnet b/src/jmteb/configs/tasks/mrtydi.jsonnet index be7cc5b..db2bf9e 100644 --- a/src/jmteb/configs/tasks/mrtydi.jsonnet +++ b/src/jmteb/configs/tasks/mrtydi.jsonnet @@ -26,6 +26,7 @@ name: 'mrtydi-corpus', }, }, + "doc_chunk_size":10000 }, }, } diff --git a/src/jmteb/embedders/__init__.py b/src/jmteb/embedders/__init__.py index 07399c5..f28f038 100644 --- a/src/jmteb/embedders/__init__.py +++ b/src/jmteb/embedders/__init__.py @@ -1,3 +1,7 @@ from jmteb.embedders.base import TextEmbedder +from jmteb.embedders.data_parallel_sbert_embedder import ( + DataParallelSentenceBertEmbedder, +) from jmteb.embedders.openai_embedder import OpenAIEmbedder from jmteb.embedders.sbert_embedder import SentenceBertEmbedder +from jmteb.embedders.transformers_embedder import TransformersEmbedder diff --git a/src/jmteb/embedders/base.py b/src/jmteb/embedders/base.py index c74fddf..afefec1 100644 --- a/src/jmteb/embedders/base.py +++ b/src/jmteb/embedders/base.py @@ -5,6 +5,7 @@ from pathlib import Path import numpy as np +import torch import tqdm from loguru import logger @@ -14,7 +15,11 @@ class TextEmbedder(ABC): The base class of text embedder. """ - def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray: + convert_to_tensor: bool + convert_to_numpy: bool + _chunk_size: int = 262144 # 2^18 + + def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray | torch.Tensor: """Convert a text string or a list of texts to embedding. Args: @@ -36,9 +41,9 @@ def _batch_encode_and_save_on_disk( text_list: list[str], save_path: str | PathLike[str], prefix: str | None = None, - batch_size: int = 64, + batch_size: int = 262144, dtype: str = "float32", - ) -> np.memmap: + ) -> np.memmap | torch.Tensor: """ Encode a list of texts and save the embeddings on disk using memmap. @@ -52,18 +57,24 @@ def _batch_encode_and_save_on_disk( num_samples = len(text_list) output_dim = self.get_output_dim() - embeddings = np.memmap(save_path, dtype=dtype, mode="w+", shape=(num_samples, output_dim)) + if self.convert_to_numpy: + embeddings = np.memmap(save_path, dtype=dtype, mode="w+", shape=(num_samples, output_dim)) + else: + embeddings = torch.empty((num_samples, output_dim), dtype=self._torch_dtype_parser(dtype)) with tqdm.tqdm(total=num_samples, desc="Encoding") as pbar: for i in range(0, num_samples, batch_size): batch = text_list[i : i + batch_size] - batch_embeddings = self.encode(batch, prefix=prefix) - batch_embeddings = np.asarray(batch_embeddings, dtype=dtype) + batch_embeddings: np.ndarray | torch.Tensor = self.encode(batch, prefix=prefix) embeddings[i : i + batch_size] = batch_embeddings pbar.update(len(batch)) - embeddings.flush() - return np.memmap(save_path, dtype=dtype, mode="r", shape=(num_samples, output_dim)) + if self.convert_to_numpy: + embeddings.flush() + return np.memmap(save_path, dtype=dtype, mode="r", shape=(num_samples, output_dim)) + else: + torch.save(embeddings, save_path) + return embeddings def batch_encode_with_cache( self, @@ -71,9 +82,8 @@ def batch_encode_with_cache( prefix: str | None = None, cache_path: str | PathLike[str] | None = None, overwrite_cache: bool = False, - batch_size: int = 64, dtype: str = "float32", - ) -> np.ndarray: + ) -> np.ndarray | torch.Tensor: """ Encode a list of texts and save the embeddings on disk using memmap if cache_path is provided. @@ -82,13 +92,12 @@ def batch_encode_with_cache( prefix (str, optional): the prefix to use for encoding. Default to None. cache_path (str, optional): path to save the embeddings. Defaults to None. overwrite_cache (bool, optional): whether to overwrite the cache. Defaults to False. - batch_size (int): batch size. Defaults to 64. dtype (str, optional): data type. Defaults to "float32". """ if cache_path is None: logger.info("Encoding embeddings") - return self.encode(text_list, prefix=prefix).astype(dtype) + return self.encode(text_list, prefix=prefix) if Path(cache_path).exists() and not overwrite_cache: logger.info(f"Loading embeddings from {cache_path}") @@ -96,6 +105,38 @@ def batch_encode_with_cache( logger.info(f"Encoding and saving embeddings to {cache_path}") embeddings = self._batch_encode_and_save_on_disk( - text_list, cache_path, prefix=prefix, batch_size=batch_size, dtype=dtype + text_list, cache_path, prefix=prefix, batch_size=self._chunk_size, dtype=dtype ) return embeddings + + @staticmethod + def _torch_dtype_parser(dtype: str | torch.dtype) -> torch.dtype | str: + if dtype == "auto": + return dtype + elif isinstance(dtype, str): + dtype = dtype.replace("torch.", "") + if hasattr(torch, dtype): + dtype = getattr(torch, dtype) + if isinstance(dtype, torch.dtype): + return dtype + raise ValueError(f"Invalid torch dtype: {dtype}") + elif isinstance(dtype, torch.dtype): + return dtype + else: + raise ValueError(f"Expected `dtype` as `str` or `torch.dtype`, but got {type(dtype)}!") + + def _model_kwargs_parser(self, model_kwargs: dict | None) -> dict: + if not model_kwargs: + return {} + + if "torch_dtype" in model_kwargs: + model_kwargs["torch_dtype"] = self._torch_dtype_parser(model_kwargs["torch_dtype"]) + return model_kwargs + + def set_output_tensor(self): + self.convert_to_numpy = False + self.convert_to_tensor = True + + def set_output_numpy(self): + self.convert_to_numpy = True + self.convert_to_tensor = False diff --git a/src/jmteb/embedders/data_parallel_sbert_embedder.py b/src/jmteb/embedders/data_parallel_sbert_embedder.py new file mode 100644 index 0000000..6fb7e87 --- /dev/null +++ b/src/jmteb/embedders/data_parallel_sbert_embedder.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +import sys +from typing import Literal + +import numpy as np +import torch +from accelerate.utils import find_executable_batch_size +from loguru import logger +from sentence_transformers import SentenceTransformer +from sentence_transformers.quantization import quantize_embeddings +from sentence_transformers.util import truncate_embeddings +from torch import Tensor +from tqdm.autonotebook import trange + +from jmteb.embedders.base import TextEmbedder + + +class DPSentenceTransformer(SentenceTransformer): + """SentenceBERT with pytorch torch.nn.DataParallel""" + + def __init__(self, sbert_model: SentenceTransformer): + super(DPSentenceTransformer, self).__init__() + self.dp_model = torch.nn.DataParallel(sbert_model) + self.sbert = self.dp_model.module + + def forward(self, *args, **kargs): + return self.dp_model.forward(*args, **kargs) + + def encode( + self, + sentences: str | list[str], + prompt_name: str | None = None, + prompt: str | None = None, + batch_size: int = 64, + show_progress_bar: bool | None = None, + output_value: Literal["sentence_embedding", "token_embeddings"] | None = "sentence_embedding", + precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32", + convert_to_numpy: bool = True, + convert_to_tensor: bool = False, + device: str = None, + normalize_embeddings: bool = False, + ) -> list[Tensor] | np.ndarray | Tensor: + self.eval() + if show_progress_bar is None: + logger.remove() + logger.add(sys.stderr, level="INFO") + + if convert_to_tensor: + convert_to_numpy = False + + if output_value != "sentence_embedding": + convert_to_tensor = False + convert_to_numpy = False + + input_was_string = False + if isinstance(sentences, str) or not hasattr( + sentences, "__len__" + ): # Cast an individual sentence to a list with length 1 + sentences = [sentences] + input_was_string = True + + if prompt is None: + if prompt_name is not None: + try: + prompt = self.sbert.prompts[prompt_name] + except KeyError: + raise ValueError( + f"Prompt name '{prompt_name}' not found in the configured " + f"prompts dictionary with keys {list(self.sbert.prompts.keys())!r}." + ) + elif self.default_prompt_name is not None: + prompt = self.sbert.prompts.get(self.sbert.default_prompt_name, None) + else: + if prompt_name is not None: + logger.warning( + "Encode with either a `prompt`, a `prompt_name`, or neither, but not both. " + "Ignoring the `prompt_name` in favor of `prompt`." + ) + + extra_features = {} + if prompt is not None: + sentences = [prompt + sentence for sentence in sentences] + + # Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling + # Tracking the prompt length allow us to remove the prompt during pooling + tokenized_prompt = self.sbert.tokenize([prompt]) + if "input_ids" in tokenized_prompt: + extra_features["prompt_length"] = tokenized_prompt["input_ids"].shape[-1] - 1 + + all_embeddings = [] + length_sorted_idx = np.argsort([-self.sbert._text_length(sen) for sen in sentences]) + sentences_sorted = [sentences[idx] for idx in length_sorted_idx] + + for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): + sentences_batch = sentences_sorted[start_index : start_index + batch_size] + features = self.sbert.tokenize(sentences_batch) + features.update(extra_features) + + with torch.no_grad(): + out_features = self.forward(features) + + out_features["sentence_embedding"] = truncate_embeddings( + out_features["sentence_embedding"], self.sbert.truncate_dim + ) + + if output_value == "token_embeddings": + embeddings = [] + for token_emb, attention in zip(out_features[output_value], out_features["attention_mask"]): + last_mask_id = len(attention) - 1 + while last_mask_id > 0 and attention[last_mask_id].item() == 0: + last_mask_id -= 1 + + embeddings.append(token_emb[0 : last_mask_id + 1]) + elif output_value is None: # Return all outputs + embeddings = [] + for sent_idx in range(len(out_features["sentence_embedding"])): + row = {name: out_features[name][sent_idx] for name in out_features} + embeddings.append(row) + else: # Sentence embeddings + embeddings = out_features[output_value] + embeddings = embeddings.detach() + if normalize_embeddings: + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + + # fixes for #522 and #487 to avoid oom problems on gpu with large datasets + if convert_to_numpy: + embeddings = embeddings.cpu() + + all_embeddings.extend(embeddings) + + all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] + + if precision and precision != "float32": + all_embeddings = quantize_embeddings(all_embeddings, precision=precision) + + if convert_to_tensor: + if len(all_embeddings): + if isinstance(all_embeddings, np.ndarray): + all_embeddings = torch.from_numpy(all_embeddings) + else: + all_embeddings = torch.stack(all_embeddings) + else: + all_embeddings = torch.Tensor() + elif convert_to_numpy: + if not isinstance(all_embeddings, np.ndarray): + if all_embeddings and all_embeddings[0].dtype == torch.bfloat16: + all_embeddings = np.asarray([emb.float().numpy() for emb in all_embeddings]) + else: + all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) + elif isinstance(all_embeddings, np.ndarray): + all_embeddings = [torch.from_numpy(embedding) for embedding in all_embeddings] + + if input_was_string: + all_embeddings = all_embeddings[0] + + return all_embeddings + + +class DataParallelSentenceBertEmbedder(TextEmbedder): + """SentenceBERT embedder with pytorch data parallel""" + + def __init__( + self, + model_name_or_path: str, + batch_size: int = 64, + normalize_embeddings: bool = False, + max_seq_length: int | None = None, + add_eos: bool = False, + truncate_dim: int | None = None, + model_kwargs: dict | None = None, + tokenizer_kwargs: dict | None = None, + auto_find_batch_size: bool = True, + ) -> None: + model_kwargs = self._model_kwargs_parser(model_kwargs) + model = SentenceTransformer( + model_name_or_path, + trust_remote_code=True, + truncate_dim=truncate_dim, + model_kwargs=model_kwargs, # https://github.com/UKPLab/sentence-transformers/blob/84f69fee6dcde023f46a8807e89bc99a7700ba82/sentence_transformers/SentenceTransformer.py#L81-L105 # noqa: E501 + tokenizer_kwargs=tokenizer_kwargs, + ) + self.dp_model = DPSentenceTransformer(sbert_model=model) + self.model = self.dp_model.sbert + if max_seq_length: + self.model.max_seq_length = max_seq_length + self.initital_batch_size = batch_size + self.batch_size = int(self.initital_batch_size) + self.normalize_embeddings = normalize_embeddings + self.max_seq_length = getattr(self.model, "max_seq_length", None) + self.add_eos = add_eos + self.auto_find_batch_size = auto_find_batch_size + + if "torch_dtype" in model_kwargs: + self.set_output_tensor() + else: + self.set_output_numpy() + + def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray: + if self.add_eos: + text = self._add_eos_func(text) + if self.auto_find_batch_size: + # wrap function + @find_executable_batch_size(starting_batch_size=self.batch_size) + def _encode_with_auto_batch_size(batch_size, self, text, prefix): + out = self.dp_model.encode( + text, + prompt=prefix, + convert_to_numpy=self.convert_to_numpy, + convert_to_tensor=self.convert_to_tensor, + batch_size=batch_size, + normalize_embeddings=self.normalize_embeddings, + ) + + self.batch_size = batch_size + return out + + return _encode_with_auto_batch_size(self, text, prefix) + else: + return self.dp_model.encode( + text, + prompt=prefix, + convert_to_numpy=self.convert_to_numpy, + convert_to_tensor=self.convert_to_tensor, + batch_size=self.batch_size, + normalize_embeddings=self.normalize_embeddings, + ) + + def _add_eos_func(self, text: str | list[str]) -> str | list[str]: + try: + eos_token = getattr(self.model.tokenizer, "eos_token") + except AttributeError: + return text + + if isinstance(text, str): + return text + eos_token + elif isinstance(text, list): + return [t + eos_token for t in text] + + def get_output_dim(self) -> int: + return self.model.get_sentence_embedding_dimension() diff --git a/src/jmteb/embedders/openai_embedder.py b/src/jmteb/embedders/openai_embedder.py index 0108c83..6ea8b8f 100644 --- a/src/jmteb/embedders/openai_embedder.py +++ b/src/jmteb/embedders/openai_embedder.py @@ -60,6 +60,9 @@ def __init__(self, model: str = "text-embedding-3-small", dim: int | None = None else: self.dim = dim + self.convert_to_tensor = False + self.convert_to_numpy = True + def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray: kwargs = {"dimensions": self.dim} if self.model != "text-embedding-ada-002" else {} # specifying `dimensions` is not allowed for "text-embedding-ada-002" diff --git a/src/jmteb/embedders/sbert_embedder.py b/src/jmteb/embedders/sbert_embedder.py index bab3da7..0188e7d 100644 --- a/src/jmteb/embedders/sbert_embedder.py +++ b/src/jmteb/embedders/sbert_embedder.py @@ -17,9 +17,18 @@ def __init__( normalize_embeddings: bool = False, max_seq_length: int | None = None, add_eos: bool = False, + truncate_dim: int | None = None, + model_kwargs: dict | None = None, tokenizer_kwargs: dict | None = None, ) -> None: - self.model = SentenceTransformer(model_name_or_path, trust_remote_code=True, tokenizer_kwargs=tokenizer_kwargs) + model_kwargs = self._model_kwargs_parser(model_kwargs) + self.model = SentenceTransformer( + model_name_or_path, + trust_remote_code=True, + truncate_dim=truncate_dim, + model_kwargs=model_kwargs, # https://github.com/UKPLab/sentence-transformers/blob/84f69fee6dcde023f46a8807e89bc99a7700ba82/sentence_transformers/SentenceTransformer.py#L81-L105 # noqa: E501 + tokenizer_kwargs=tokenizer_kwargs, + ) if max_seq_length: self.model.max_seq_length = max_seq_length @@ -29,13 +38,19 @@ def __init__( self.max_seq_length = getattr(self.model, "max_seq_length", None) self.add_eos = add_eos + if "torch_dtype" in model_kwargs: + self.set_output_tensor() + else: + self.set_output_numpy() + def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray: if self.add_eos: text = self._add_eos_func(text) return self.model.encode( text, prompt=prefix, - convert_to_numpy=True, + convert_to_numpy=self.convert_to_numpy, + convert_to_tensor=self.convert_to_tensor, batch_size=self.batch_size, device=self.device, normalize_embeddings=self.normalize_embeddings, diff --git a/src/jmteb/embedders/transformers_embedder.py b/src/jmteb/embedders/transformers_embedder.py new file mode 100644 index 0000000..0592061 --- /dev/null +++ b/src/jmteb/embedders/transformers_embedder.py @@ -0,0 +1,185 @@ +import json +import os +from pathlib import Path + +import numpy as np +import torch +from accelerate import PartialState +from accelerate.utils import gather_object +from loguru import logger +from sentence_transformers.models import Pooling +from tqdm.autonotebook import trange +from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer + +from jmteb.embedders.base import TextEmbedder + + +class TransformersEmbedder(TextEmbedder): + def __init__( + self, + model_name_or_path: str, + batch_size: int = 32, + device: str | None = None, + normalize_embeddings: bool = False, + max_seq_length: int | None = None, + add_eos: bool = False, + truncate_dim: int | None = None, + pooling_config: str | None = "1_Pooling/config.json", + pooling_mode: str | None = None, + model_kwargs: dict = {}, + tokenizer_kwargs: dict = {}, + ) -> None: + model_kwargs = self._model_kwargs_parser(model_kwargs) + self.model: PreTrainedModel = AutoModel.from_pretrained( + model_name_or_path, trust_remote_code=True, **model_kwargs + ) + self.batch_size = batch_size + if not device and torch.cuda.is_available(): + self.device = "cuda" + else: + self.device = device + self.normalize_embeddings = normalize_embeddings + + self.distributed_state = PartialState() if torch.cuda.device_count() > 1 and self.device == "cuda" else None + if self.distributed_state: + self.model.to(self.distributed_state.device) + else: + self.model.to(self.device) + logger.info(f"{self.model.device=}, {torch.cuda.device_count()=}") + self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **tokenizer_kwargs) + + self.max_seq_length = getattr(self.model, "max_seq_length", None) + if max_seq_length: + self.max_seq_length = max_seq_length + self.add_eos = add_eos + self.truncate_dim = truncate_dim + + if pooling_mode: + pooling_config: dict = { + "word_embedding_dimension": getattr(self.model.config, "hidden_size"), + "pooling_mode": pooling_mode, + } + else: + pooling_config: dict = self._load_pooling_config(os.path.join(model_name_or_path, pooling_config)) + + self.pooling = Pooling( + word_embedding_dimension=pooling_config.get("word_embedding_dimension"), + pooling_mode=pooling_config.get("pooling_mode", None), + pooling_mode_cls_token=pooling_config.get("pooling_mode_cls_token", False), + pooling_mode_max_tokens=pooling_config.get("pooling_mode_max_tokens", False), + pooling_mode_mean_tokens=pooling_config.get("pooling_mode_mean_tokens", False), + pooling_mode_mean_sqrt_len_tokens=pooling_config.get("pooling_mode_mean_sqrt_len_tokens", False), + pooling_mode_weightedmean_tokens=pooling_config.get("pooling_mode_weightedmean_tokens", False), + pooling_mode_lasttoken=pooling_config.get("pooling_mode_lasttoken", False), + include_prompt=pooling_config.get("include_prompt", True), + ) + + if self.truncate_dim: + self.output_dim = min(self.pooling.get_sentence_embedding_dimension(), self.truncate_dim) + else: + self.output_dim = self.pooling.get_sentence_embedding_dimension() + + if "torch_dtype" in model_kwargs: + self.set_output_tensor() + else: + self.set_output_numpy() + + def get_output_dim(self) -> int: + return self.output_dim + + def encode( + self, + text: str | list[str], + prefix: str | None = None, + show_progress_bar: bool = True, + ): + if isinstance(text, str): + text = [text] + text_was_str = True + else: + text_was_str = False + + all_embeddings = [] + length_sorted_idx = np.argsort([-len(t) for t in text]) + text_sorted = [text[idx] for idx in length_sorted_idx] + + for start_index in trange(0, len(text), self.batch_size, desc="Batches", disable=not show_progress_bar): + text_batch = text_sorted[start_index : start_index + self.batch_size] + if self.distributed_state: + batch_embeddings = self._encode_batch_distributed(text_batch, prefix) + else: + batch_embeddings = self._encode_batch(text_batch, prefix) + all_embeddings.extend(batch_embeddings) + + all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] + + if len(all_embeddings): + all_embeddings = torch.stack(all_embeddings) + else: + all_embeddings = torch.Tensor() + + if text_was_str: + res = all_embeddings.view(-1) + else: + res = all_embeddings + + if self.convert_to_numpy: + return res.numpy() + else: + return res + + def _encode_batch(self, text: list[str], prefix: str | None = None) -> torch.Tensor: + if prefix: + text = [prefix + t for t in text] + + if self.add_eos: + text = self._add_eos_func(text) + + encoded_input = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt").to(self.model.device) + model_output = self.model(**encoded_input) + last_hidden_states = model_output["last_hidden_state"] + features = { + "input_ids": encoded_input["input_ids"], + "attention_mask": encoded_input["attention_mask"], + "token_embeddings": last_hidden_states, + } + if "token_type_ids" in encoded_input: + features["token_type_ids"] = encoded_input["token_type_ids"] + + if prefix: + features["prompt_length"] = self.tokenizer([prefix], return_tensors="pt")["input_ids"].shape[-1] - 1 + + # TODO: feature["token_weights_sum"] + + with torch.no_grad(): + sentence_embeddings = self.pooling.forward(features)["sentence_embedding"] + if self.truncate_dim: + sentence_embeddings = sentence_embeddings[..., : self.truncate_dim] + if self.normalize_embeddings: + sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) + return sentence_embeddings + + def _encode_batch_distributed(self, text: list[str], prefix: str | None = None) -> torch.Tensor: + batch_gather = [] + with self.distributed_state.split_between_processes(text) as t: + sentence_embeddings = self._encode_batch(t, prefix) + batch_gather.extend(sentence_embeddings.to("cpu")) + + batch_embeddings = gather_object(batch_gather) + return torch.stack(batch_embeddings) + + def _add_eos_func(self, text: list[str]) -> list[str]: + try: + eos_token = getattr(self.tokenizer, "eos_token") + except AttributeError: + return text + + return [t + eos_token for t in text] + + def _load_pooling_config(self, config) -> dict: + if Path(config).is_file(): + with open(Path(config), "r") as fin: + return json.load(fin) + else: + logger.warning("No pooling config found, create a mean pooling!") + return {"word_embedding_dimension": getattr(self.model.config, "hidden_size"), "pooling_mode": "mean"} diff --git a/src/jmteb/evaluators/base.py b/src/jmteb/evaluators/base.py index 7b47379..a94c96c 100644 --- a/src/jmteb/evaluators/base.py +++ b/src/jmteb/evaluators/base.py @@ -19,11 +19,13 @@ class EvaluationResults: metric_value (float): Value of the main metric. details (dict[str, Any]): Details of the evaluation. This included some additional metrics or values that are used to derive the main metric. + predictions (list[Any]): Predictions (such as, (text, y_true, y_pred)) """ metric_name: str metric_value: float details: dict[str, Any] + predictions: list[Any] | None = None def as_dict(self) -> dict[str, Any]: return { diff --git a/src/jmteb/evaluators/classification/__init__.py b/src/jmteb/evaluators/classification/__init__.py index 6c85424..4c1bfbb 100644 --- a/src/jmteb/evaluators/classification/__init__.py +++ b/src/jmteb/evaluators/classification/__init__.py @@ -1,3 +1,7 @@ from .classifiers import Classifier, KnnClassifier, LogRegClassifier -from .data import ClassificationDataset, ClassificationInstance +from .data import ( + ClassificationDataset, + ClassificationInstance, + ClassificationPrediction, +) from .evaluator import ClassificationEvaluator diff --git a/src/jmteb/evaluators/classification/data.py b/src/jmteb/evaluators/classification/data.py index 5885471..ba5eb8d 100644 --- a/src/jmteb/evaluators/classification/data.py +++ b/src/jmteb/evaluators/classification/data.py @@ -13,6 +13,13 @@ class ClassificationInstance: label: int +@dataclass +class ClassificationPrediction: + text: str + label: int + prediction: int + + class ClassificationDataset(ABC): @abstractmethod def __len__(self): diff --git a/src/jmteb/evaluators/classification/evaluator.py b/src/jmteb/evaluators/classification/evaluator.py index dbe2d8e..457d949 100644 --- a/src/jmteb/evaluators/classification/evaluator.py +++ b/src/jmteb/evaluators/classification/evaluator.py @@ -11,7 +11,7 @@ from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults from .classifiers import Classifier, KnnClassifier, LogRegClassifier -from .data import ClassificationDataset +from .data import ClassificationDataset, ClassificationPrediction class ClassificationEvaluator(EmbeddingEvaluator): @@ -28,6 +28,7 @@ class ClassificationEvaluator(EmbeddingEvaluator): The first one is specified as the main index. classifiers (dict[str, Classifier]): classifiers to be evaluated. prefix (str | None): prefix for sentences. Defaults to None. + log_predictions (bool): whether to log predictions of each datapoint. """ def __init__( @@ -38,6 +39,7 @@ def __init__( average: str = "macro", classifiers: dict[str, Classifier] | None = None, prefix: str | None = None, + log_predictions: bool = False, ) -> None: self.train_dataset = train_dataset self.val_dataset = val_dataset @@ -52,6 +54,7 @@ def __init__( if average_name.strip().lower() in ("micro", "macro", "samples", "weighted", "binary") ] or ["macro"] self.prefix = prefix + self.log_predictions = log_predictions self.main_metric = f"{self.average[0]}_f1" def __call__( @@ -119,6 +122,7 @@ def __call__( "val_scores": val_results, "test_scores": test_results, }, + predictions=self._format_predictions(self.test_dataset, y_pred) if self.log_predictions else None, ) @staticmethod @@ -128,3 +132,14 @@ def _compute_metrics(y_pred: np.ndarray, y_true: list[int], average: list[float] for average_method in average: classifier_results[f"{average_method}_f1"] = f1_score(y_true, y_pred, average=average_method) return classifier_results + + @staticmethod + def _format_predictions(dataset: ClassificationDataset, y_pred: np.ndarray) -> list[ClassificationPrediction]: + texts = [item.text for item in dataset] + y_true = [item.label for item in dataset] + y_pred = y_pred.tolist() + assert len(texts) == len(y_true) == len(y_pred) + return [ + ClassificationPrediction(text=text, label=label, prediction=pred) + for text, label, pred in zip(texts, y_true, y_pred) + ] diff --git a/src/jmteb/evaluators/clustering/__init__.py b/src/jmteb/evaluators/clustering/__init__.py index 5164b12..e22bd0d 100644 --- a/src/jmteb/evaluators/clustering/__init__.py +++ b/src/jmteb/evaluators/clustering/__init__.py @@ -1,2 +1,2 @@ -from .data import ClusteringDataset, ClusteringInstance +from .data import ClusteringDataset, ClusteringInstance, ClusteringPrediction from .evaluator import ClusteringEvaluator diff --git a/src/jmteb/evaluators/clustering/data.py b/src/jmteb/evaluators/clustering/data.py index ee2ec4f..64d9608 100644 --- a/src/jmteb/evaluators/clustering/data.py +++ b/src/jmteb/evaluators/clustering/data.py @@ -13,6 +13,13 @@ class ClusteringInstance: label: int +@dataclass +class ClusteringPrediction: + text: str + label: int + prediction: int + + class ClusteringDataset(ABC): @abstractmethod def __len__(self): diff --git a/src/jmteb/evaluators/clustering/evaluator.py b/src/jmteb/evaluators/clustering/evaluator.py index d8ef443..4f3cd3c 100644 --- a/src/jmteb/evaluators/clustering/evaluator.py +++ b/src/jmteb/evaluators/clustering/evaluator.py @@ -18,7 +18,7 @@ from jmteb.embedders.base import TextEmbedder from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults -from .data import ClusteringDataset +from .data import ClusteringDataset, ClusteringPrediction class ClusteringEvaluator(EmbeddingEvaluator): @@ -32,11 +32,13 @@ def __init__( test_dataset: ClusteringDataset, prefix: str | None = None, random_seed: int | None = None, + log_predictions: bool = False, ) -> None: self.val_dataset = val_dataset self.test_dataset = test_dataset self.prefix = prefix self.random_seed = random_seed + self.log_predictions = log_predictions self.main_metric = "v_measure_score" def __call__( @@ -80,20 +82,21 @@ def __call__( logger.info("Fitting clustering model...") val_results = {} for model_name, model_constructor in model_constructors.items(): - val_results[model_name] = self._evaluate_clustering_model(val_embeddings, val_labels, model_constructor()) + val_results[model_name], _ = self._evaluate_clustering_model( + val_embeddings, val_labels, model_constructor() + ) optimal_clustering_model_name = sorted( val_results.items(), key=lambda res: res[1][self.main_metric], reverse=True, )[0][0] - test_results = { - optimal_clustering_model_name: self._evaluate_clustering_model( - test_embeddings, - test_labels, - model_constructors[optimal_clustering_model_name](), - ) - } + test_scores, test_predictions = self._evaluate_clustering_model( + test_embeddings, + test_labels, + model_constructors[optimal_clustering_model_name](), + ) + test_results = {optimal_clustering_model_name: test_scores} return EvaluationResults( metric_name=self.main_metric, @@ -103,12 +106,15 @@ def __call__( "val_scores": val_results, "test_scores": test_results, }, + predictions=( + self._format_predictions(self.test_dataset, test_predictions) if self.log_predictions else None + ), ) @staticmethod def _evaluate_clustering_model( embeddings: np.ndarray, y_true: list[int], clustering_model: ClusterMixin - ) -> dict[str, float]: + ) -> tuple[dict[str, float], list[int]]: y_pred = clustering_model.fit_predict(embeddings) h_score, c_score, v_score = homogeneity_completeness_v_measure( labels_pred=y_pred, labels_true=np.array(y_true) @@ -118,4 +124,10 @@ def _evaluate_clustering_model( "v_measure_score": v_score, "homogeneity_score": h_score, "completeness_score": c_score, - } + }, y_pred.tolist() + + @staticmethod + def _format_predictions(dataset: ClusteringDataset, predictions: list[int]) -> list[ClusteringPrediction]: + return [ + ClusteringPrediction(item.text, item.label, prediction) for item, prediction in zip(dataset, predictions) + ] diff --git a/src/jmteb/evaluators/pair_classification/evaluator.py b/src/jmteb/evaluators/pair_classification/evaluator.py index 6ec30d0..280bbfb 100644 --- a/src/jmteb/evaluators/pair_classification/evaluator.py +++ b/src/jmteb/evaluators/pair_classification/evaluator.py @@ -22,6 +22,8 @@ class PairClassificationEvaluator(EmbeddingEvaluator): test_dataset (PairClassificationDataset): test dataset sentence1_prefix (str | None): prefix for sentence1. Defaults to None. sentence2_prefix (str | None): prefix for sentence2. Defaults to None. + + # NOTE: Don't log predictions, as predictions by different metrics could be different. """ def __init__( diff --git a/src/jmteb/evaluators/reranking/__init__.py b/src/jmteb/evaluators/reranking/__init__.py index 9931fcb..023120b 100644 --- a/src/jmteb/evaluators/reranking/__init__.py +++ b/src/jmteb/evaluators/reranking/__init__.py @@ -1,6 +1,7 @@ from .data import ( RerankingDoc, RerankingDocDataset, + RerankingPrediction, RerankingQuery, RerankingQueryDataset, ) diff --git a/src/jmteb/evaluators/reranking/data.py b/src/jmteb/evaluators/reranking/data.py index 4875729..e4a5228 100644 --- a/src/jmteb/evaluators/reranking/data.py +++ b/src/jmteb/evaluators/reranking/data.py @@ -22,6 +22,13 @@ class RerankingDoc: text: str +@dataclass +class RerankingPrediction: + query: str + relevant_docs: list[RerankingDoc] + reranked_relevant_docs: list[RerankingDoc] + + class RerankingQueryDataset(ABC): @abstractmethod def __len__(self): @@ -47,6 +54,23 @@ def __getitem__(self, idx) -> RerankingDoc: def __eq__(self, __value: object) -> bool: return False + def _build_idx_docid_mapping(self, dataset_attr_name: str = "dataset") -> None: + self.idx_to_docid: dict = {} + self.docid_to_idx: dict = {} + id_key: str = getattr(self, "id_key", None) + dataset = getattr(self, dataset_attr_name) + if id_key: + for idx, doc_dict in enumerate(dataset): + self.idx_to_docid[idx] = doc_dict[id_key] + self.docid_to_idx[doc_dict[id_key]] = idx + elif isinstance(dataset[0], RerankingDoc): + for idx, doc in enumerate(dataset): + doc: RerankingDoc + self.idx_to_docid[idx] = doc.id + self.docid_to_idx[doc.id] = idx + else: + raise ValueError(f"Invalid dataset type: list[{type(dataset[0])}]") + class HfRerankingQueryDataset(RerankingQueryDataset): def __init__( @@ -131,6 +155,7 @@ def __init__(self, path: str, split: str, name: str | None = None, id_key: str = self.dataset = datasets.load_dataset(path, split=split, name=name, trust_remote_code=True) self.id_key = id_key self.text_key = text_key + self._build_idx_docid_mapping() def __len__(self): return len(self.dataset) @@ -157,6 +182,7 @@ def __init__(self, filename: str, id_key: str = "docid", text_key: str = "text") self.dataset = corpus self.id_key = id_key self.text_key = text_key + self._build_idx_docid_mapping() def __len__(self): return len(self.dataset) diff --git a/src/jmteb/evaluators/reranking/evaluator.py b/src/jmteb/evaluators/reranking/evaluator.py index 1029aab..5c4ba34 100644 --- a/src/jmteb/evaluators/reranking/evaluator.py +++ b/src/jmteb/evaluators/reranking/evaluator.py @@ -10,11 +10,18 @@ import tqdm from loguru import logger from torch import Tensor +from torch import distributed as dist from jmteb.embedders.base import TextEmbedder from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults -from .data import RerankingDocDataset, RerankingQueryDataset +from .data import ( + RerankingDoc, + RerankingDocDataset, + RerankingPrediction, + RerankingQuery, + RerankingQueryDataset, +) T = TypeVar("T") @@ -30,6 +37,8 @@ class RerankingEvaluator(EmbeddingEvaluator): ndcg_at_k (list[int] | None): top k documents to consider in NDCG (Normalized Documented Cumulative Gain). query_prefix (str | None): prefix for queries. Defaults to None. doc_prefix (str | None): prefix for documents. Defaults to None. + log_predictions (bool): whether to log predictions of each datapoint. Defaults to False. + top_n_docs_to_log (int): log only top n documents. Defaults to 5. """ def __init__( @@ -40,6 +49,8 @@ def __init__( ndcg_at_k: list[int] | None = None, query_prefix: str | None = None, doc_prefix: str | None = None, + log_predictions: bool = False, + top_n_docs_to_log: int = 5, ) -> None: self.test_query_dataset = test_query_dataset self.val_query_dataset = val_query_dataset @@ -48,6 +59,8 @@ def __init__( self.main_metric = f"ndcg@{self.ndcg_at_k[0]}" self.query_prefix = query_prefix self.doc_prefix = doc_prefix + self.log_predictions = log_predictions + self.top_n_docs_to_log = top_n_docs_to_log def __call__( self, @@ -55,6 +68,7 @@ def __call__( cache_dir: str | PathLike[str] | None = None, overwrite_cache: bool = False, ) -> EvaluationResults: + model.set_output_tensor() if cache_dir is not None: Path(cache_dir).mkdir(parents=True, exist_ok=True) @@ -90,7 +104,7 @@ def __call__( val_results = {} for dist_name, dist_func in dist_functions.items(): - val_results[dist_name] = self._compute_metrics( + val_results[dist_name], _ = self._compute_metrics( query_dataset=self.val_query_dataset, query_embeddings=val_query_embeddings, doc_embeddings=doc_embeddings, @@ -99,14 +113,13 @@ def __call__( sorted_val_results = sorted(val_results.items(), key=lambda res: res[1][self.main_metric], reverse=True) optimal_dist_name = sorted_val_results[0][0] - test_results = { - optimal_dist_name: self._compute_metrics( - query_dataset=self.test_query_dataset, - query_embeddings=test_query_embeddings, - doc_embeddings=doc_embeddings, - dist_func=dist_functions[optimal_dist_name], - ) - } + scores, reranked_docs_list = self._compute_metrics( + query_dataset=self.test_query_dataset, + query_embeddings=test_query_embeddings, + doc_embeddings=doc_embeddings, + dist_func=dist_functions[optimal_dist_name], + ) + test_results = {optimal_dist_name: scores} return EvaluationResults( metric_name=self.main_metric, @@ -116,27 +129,42 @@ def __call__( "val_scores": val_results, "test_scores": test_results, }, + predictions=( + self._format_predictions( + self.test_query_dataset, self.doc_dataset, reranked_docs_list, self.top_n_docs_to_log + ) + if self.log_predictions + else None + ), ) def _compute_metrics( self, query_dataset: RerankingQueryDataset, - query_embeddings: np.ndarray, - doc_embeddings: np.ndarray, + query_embeddings: np.ndarray | Tensor, + doc_embeddings: np.ndarray | Tensor, dist_func: Callable[[Tensor, Tensor], Tensor], - ) -> dict[str, float]: + ) -> tuple[dict[str, float], list[list[str | int]]]: doc_indices = {item.id: i for i, item in enumerate(self.doc_dataset)} results: dict[str, float] = {} with tqdm.tqdm(total=len(query_dataset), desc="Reranking docs") as pbar: - device = "cuda" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available(): + if dist.is_torchelastic_launched(): + device = f"cuda:{dist.get_rank()}" + else: + device = "cuda" + else: + device = "cpu" reranked_docs_list = [] for i, item in enumerate(query_dataset): - query_embedding = convert_to_tensor(query_embeddings[i], device=device) - doc_embedding = convert_to_tensor( - np.array([doc_embeddings[doc_indices[retrieved_doc]] for retrieved_doc in item.retrieved_docs]), - device=device, + query_embedding = to_tensor(query_embeddings[i], device=device) + doc_embedding = torch.stack( + [ + Tensor(doc_embeddings[doc_indices[retrieved_doc]]).to(device=device) + for retrieved_doc in item.retrieved_docs + ] ) similarity = dist_func(query_embedding, doc_embedding) @@ -155,7 +183,34 @@ def _compute_metrics( for k in self.ndcg_at_k: results[f"ndcg@{k}"] = ndcg_at_k(retrieved_docs_list, relevance_scores_list, reranked_docs_list, k) - return results + return results, reranked_docs_list + + @staticmethod + def _format_predictions( + query_dataset: RerankingQueryDataset, + doc_dataset: RerankingDocDataset, + reranked_docs_list: list[list], + top_n_to_log: int, + ) -> list[RerankingPrediction]: + predictions = [] + for q, pred_docids in zip(query_dataset, reranked_docs_list): + q: RerankingQuery + golden_docs: list[RerankingDoc] = [ + doc_dataset[doc_dataset.docid_to_idx[docid]] for docid in q.retrieved_docs + ] + pred_docids = pred_docids[:top_n_to_log] + pred_docs: list[RerankingDoc] = [ + doc_dataset[doc_dataset.docid_to_idx[pred_docid]] for pred_docid in pred_docids + ] + logger.info(f"{golden_docs=}") + logger.info(f"{pred_docs=}") + prediction = RerankingPrediction( + query=q.query, + relevant_docs=golden_docs, + reranked_relevant_docs=pred_docs, + ) + predictions.append(prediction) + return predictions def ndcg_at_k( @@ -179,7 +234,7 @@ def ndcg_at_k( return total_ndcg_scores / len(retrieved_docs_list) -def convert_to_tensor(embeddings: np.ndarray | Tensor, device: str) -> Tensor: +def to_tensor(embeddings: np.ndarray | Tensor, device: str) -> Tensor: if not isinstance(embeddings, Tensor): embeddings = torch.tensor(embeddings) if len(embeddings.shape) == 1: diff --git a/src/jmteb/evaluators/retrieval/__init__.py b/src/jmteb/evaluators/retrieval/__init__.py index c63354c..73d4e33 100644 --- a/src/jmteb/evaluators/retrieval/__init__.py +++ b/src/jmteb/evaluators/retrieval/__init__.py @@ -1,6 +1,7 @@ from .data import ( RetrievalDoc, RetrievalDocDataset, + RetrievalPrediction, RetrievalQuery, RetrievalQueryDataset, ) diff --git a/src/jmteb/evaluators/retrieval/data.py b/src/jmteb/evaluators/retrieval/data.py index 70c69a4..4c8c30b 100644 --- a/src/jmteb/evaluators/retrieval/data.py +++ b/src/jmteb/evaluators/retrieval/data.py @@ -21,6 +21,13 @@ class RetrievalDoc: text: str +@dataclass +class RetrievalPrediction: + query: str + relevant_docs: list[RetrievalDoc] + predicted_relevant_docs: list[RetrievalDoc] + + class RetrievalQueryDataset(ABC): @abstractmethod def __len__(self): @@ -46,6 +53,23 @@ def __getitem__(self, idx) -> RetrievalDoc: def __eq__(self, __value: object) -> bool: return False + def _build_idx_docid_mapping(self, dataset_attr_name: str = "dataset") -> None: + self.idx_to_docid: dict = {} + self.docid_to_idx: dict = {} + id_key: str = getattr(self, "id_key", None) + dataset = getattr(self, dataset_attr_name) + if id_key: + for idx, doc_dict in enumerate(dataset): + self.idx_to_docid[idx] = doc_dict[id_key] + self.docid_to_idx[doc_dict[id_key]] = idx + elif isinstance(dataset[0], RetrievalDoc): + for idx, doc in enumerate(dataset): + doc: RetrievalDoc + self.idx_to_docid[idx] = doc.id + self.docid_to_idx[doc.id] = idx + else: + raise ValueError(f"Invalid dataset type: list[{type(dataset[0])}]") + class HfRetrievalQueryDataset(RetrievalQueryDataset): def __init__( @@ -124,6 +148,7 @@ def __init__(self, path: str, split: str, name: str | None = None, id_key: str = self.dataset = datasets.load_dataset(path, split=split, name=name, trust_remote_code=True) self.id_key = id_key self.text_key = text_key + self._build_idx_docid_mapping() def __len__(self): return len(self.dataset) @@ -150,6 +175,7 @@ def __init__(self, filename: str, id_key: str = "docid", text_key: str = "text") self.dataset = corpus self.id_key = id_key self.text_key = text_key + self._build_idx_docid_mapping() def __len__(self): return len(self.dataset) diff --git a/src/jmteb/evaluators/retrieval/evaluator.py b/src/jmteb/evaluators/retrieval/evaluator.py index 64d48d9..73c0981 100644 --- a/src/jmteb/evaluators/retrieval/evaluator.py +++ b/src/jmteb/evaluators/retrieval/evaluator.py @@ -11,11 +11,18 @@ import tqdm from loguru import logger from torch import Tensor +from torch import distributed as dist from jmteb.embedders.base import TextEmbedder from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults -from .data import RetrievalDocDataset, RetrievalQueryDataset +from .data import ( + RetrievalDoc, + RetrievalDocDataset, + RetrievalPrediction, + RetrievalQuery, + RetrievalQueryDataset, +) T = TypeVar("T") @@ -33,6 +40,8 @@ class RetrievalEvaluator(EmbeddingEvaluator): accuracy_at_k (list[int] | None): accuracy in top k hits. query_prefix (str | None): prefix for queries. Defaults to None. doc_prefix (str | None): prefix for documents. Defaults to None. + log_predictions (bool): whether to log predictions of each datapoint. Defaults to False. + top_n_docs_to_log (int): log only top n documents that are predicted as relevant. Defaults to 5. """ def __init__( @@ -45,6 +54,8 @@ def __init__( ndcg_at_k: list[int] | None = None, query_prefix: str | None = None, doc_prefix: str | None = None, + log_predictions: bool = False, + top_n_docs_to_log: int = 5, ) -> None: self.val_query_dataset = val_query_dataset self.test_query_dataset = test_query_dataset @@ -59,6 +70,8 @@ def __init__( self.query_prefix = query_prefix self.doc_prefix = doc_prefix + self.log_predictions = log_predictions + self.top_n_docs_to_log = top_n_docs_to_log def __call__( self, @@ -66,6 +79,7 @@ def __call__( cache_dir: str | PathLike[str] | None = None, overwrite_cache: bool = False, ) -> EvaluationResults: + model.set_output_tensor() if cache_dir is not None: Path(cache_dir).mkdir(parents=True, exist_ok=True) @@ -102,7 +116,7 @@ def __call__( val_results = {} for dist_name, dist_func in dist_functions.items(): - val_results[dist_name] = self._compute_metrics( + val_results[dist_name], _ = self._compute_metrics( query_dataset=self.val_query_dataset, query_embeddings=val_query_embeddings, doc_embeddings=doc_embeddings, @@ -111,14 +125,13 @@ def __call__( sorted_val_results = sorted(val_results.items(), key=lambda res: res[1][self.main_metric], reverse=True) optimal_dist_name = sorted_val_results[0][0] - test_results = { - optimal_dist_name: self._compute_metrics( - query_dataset=self.test_query_dataset, - query_embeddings=test_query_embeddings, - doc_embeddings=doc_embeddings, - dist_func=dist_functions[optimal_dist_name], - ) - } + test_scores, test_predictions = self._compute_metrics( + query_dataset=self.test_query_dataset, + query_embeddings=test_query_embeddings, + doc_embeddings=doc_embeddings, + dist_func=dist_functions[optimal_dist_name], + ) + test_results = {optimal_dist_name: test_scores} return EvaluationResults( metric_name=self.main_metric, @@ -128,26 +141,34 @@ def __call__( "val_scores": val_results, "test_scores": test_results, }, + predictions=test_predictions, ) def _compute_metrics( self, query_dataset: RetrievalQueryDataset, - query_embeddings: np.ndarray, - doc_embeddings: np.ndarray, + query_embeddings: np.ndarray | Tensor, + doc_embeddings: np.ndarray | Tensor, dist_func: Callable[[Tensor, Tensor], Tensor], - ) -> dict[str, dict[str, float]]: + ) -> tuple[dict[str, dict[str, float]], list[RetrievalPrediction]]: results: dict[str, float] = {} - + predictions: list[RetrievalPrediction] = [] if self.log_predictions else None with tqdm.tqdm(total=len(doc_embeddings), desc="Retrieval doc chunks") as pbar: top_k_indices_chunks: list[np.ndarray] = [] top_k_scores_chunks: list[np.ndarray] = [] for offset in range(0, len(doc_embeddings), self.doc_chunk_size): doc_embeddings_chunk = doc_embeddings[offset : offset + self.doc_chunk_size] - device = "cuda" if torch.cuda.is_available() else "cpu" - query_embeddings = convert_to_tensor(query_embeddings, device=device) - doc_embeddings_chunk = convert_to_tensor(doc_embeddings_chunk, device=device) + if torch.cuda.is_available(): + if dist.is_torchelastic_launched(): + device = f"cuda:{dist.get_rank()}" + else: + device = "cuda" + else: + device = "cpu" + + query_embeddings = to_tensor(query_embeddings, device=device) + doc_embeddings_chunk = to_tensor(doc_embeddings_chunk, device=device) similarity = dist_func(query_embeddings, doc_embeddings_chunk) top_k = min(self.max_top_k, similarity.shape[1]) # in case the corpus is smaller than max_top_k @@ -172,13 +193,44 @@ def _compute_metrics( golden_doc_ids = [item.relevant_docs for item in query_dataset] retrieved_doc_ids = [[self.doc_dataset[i].id for i in indices] for indices in sorted_top_k_indices] + predictions = ( + self._format_predictions(query_dataset, self.doc_dataset, retrieved_doc_ids, self.top_n_docs_to_log) + if self.log_predictions + else None + ) + for k in self.accuracy_at_k: results[f"accuracy@{k}"] = accuracy_at_k(golden_doc_ids, retrieved_doc_ids, k) for k in self.ndcg_at_k: results[f"ndcg@{k}"] = ndcg_at_k(golden_doc_ids, retrieved_doc_ids, k) results[f"mrr@{self.max_top_k}"] = mrr_at_k(golden_doc_ids, retrieved_doc_ids, self.max_top_k) - return results + return results, predictions + + @staticmethod + def _format_predictions( + query_dataset: RetrievalQueryDataset, + doc_dataset: RetrievalDocDataset, + retrieved_doc_ids: list[list], + top_n_to_log: int, + ) -> list[RetrievalPrediction]: + predictions = [] + for q, pred_docids in zip(query_dataset, retrieved_doc_ids): + q: RetrievalQuery + golden_docs: list[RetrievalDoc] = [ + doc_dataset[doc_dataset.docid_to_idx[docid]] for docid in q.relevant_docs + ] + pred_docids = pred_docids[:top_n_to_log] + pred_docs: list[RetrievalDoc] = [ + doc_dataset[doc_dataset.docid_to_idx[pred_docid]] for pred_docid in pred_docids + ] + prediction = RetrievalPrediction( + query=q.query, + relevant_docs=golden_docs, + predicted_relevant_docs=pred_docs, + ) + predictions.append(prediction) + return predictions def accuracy_at_k(relevant_docs: list[list[T]], top_hits: list[list[T]], k: int) -> float: @@ -228,7 +280,7 @@ def ndcg_at_k(relevant_docs: list[list[T]], top_hits: list[list[T]], k: int) -> return total_ndcg_scores / len(relevant_docs) -def convert_to_tensor(embeddings: np.ndarray | Tensor, device: str) -> Tensor: +def to_tensor(embeddings: np.ndarray | Tensor, device: str) -> Tensor: if not isinstance(embeddings, Tensor): embeddings = torch.tensor(embeddings) if len(embeddings.shape) == 1: diff --git a/src/jmteb/evaluators/sts/__init__.py b/src/jmteb/evaluators/sts/__init__.py index 502fdc8..665402c 100644 --- a/src/jmteb/evaluators/sts/__init__.py +++ b/src/jmteb/evaluators/sts/__init__.py @@ -1,2 +1,2 @@ -from .data import STSDataset, STSInstance +from .data import STSDataset, STSInstance, STSPrediction from .evaluator import STSEvaluator diff --git a/src/jmteb/evaluators/sts/data.py b/src/jmteb/evaluators/sts/data.py index 02504e7..a2166a5 100644 --- a/src/jmteb/evaluators/sts/data.py +++ b/src/jmteb/evaluators/sts/data.py @@ -14,6 +14,15 @@ class STSInstance: score: float +@dataclass +class STSPrediction: + sentence1: str + sentence2: str + true_score: float + predicted_score: float + similarity_function_name: str + + class STSDataset(ABC): @abstractmethod def __len__(self): diff --git a/src/jmteb/evaluators/sts/evaluator.py b/src/jmteb/evaluators/sts/evaluator.py index 8a20b3d..b7b8eb8 100644 --- a/src/jmteb/evaluators/sts/evaluator.py +++ b/src/jmteb/evaluators/sts/evaluator.py @@ -1,5 +1,6 @@ from __future__ import annotations +import math from dataclasses import dataclass from os import PathLike from pathlib import Path @@ -13,7 +14,7 @@ from jmteb.embedders.base import TextEmbedder from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults -from .data import STSDataset +from .data import STSDataset, STSInstance, STSPrediction class STSEvaluator(EmbeddingEvaluator): @@ -33,12 +34,14 @@ def __init__( test_dataset: STSDataset, sentence1_prefix: str | None = None, sentence2_prefix: str | None = None, + log_predictions: bool = False, ) -> None: self.val_dataset = val_dataset self.test_dataset = test_dataset self.sentence1_prefix = sentence1_prefix self.sentence2_prefix = sentence2_prefix self.main_metric = "spearman" + self.log_predictions = log_predictions def __call__( self, model: TextEmbedder, cache_dir: str | PathLike[str] | None = None, overwrite_cache: bool = False @@ -68,7 +71,7 @@ def __call__( val_results = {} for sim_name, sim_func in similarity_functions.items(): - val_results[sim_name] = self._compute_similarity( + val_results[sim_name], _ = self._compute_similarity( val_embeddings1, val_embeddings2, val_golden_scores, sim_func ) @@ -79,34 +82,59 @@ def __call__( )[ 0 ][0] - test_results = { - optimal_similarity_name: self._compute_similarity( - test_embeddings1, - test_embeddings2, - test_golden_scores, - similarity_functions[optimal_similarity_name], - ) - } + test_eval_scores, test_sim_scores = self._compute_similarity( + test_embeddings1, + test_embeddings2, + test_golden_scores, + similarity_functions[optimal_similarity_name], + ) return EvaluationResults( metric_name=self.main_metric, - metric_value=test_results[optimal_similarity_name][self.main_metric], + metric_value=test_eval_scores[self.main_metric], details={ "optimal_similarity_metric": optimal_similarity_name, "val_scores": val_results, - "test_scores": test_results, + "test_scores": {optimal_similarity_name: test_eval_scores}, }, + predictions=( + self._format_predictions(self.test_dataset, test_sim_scores, optimal_similarity_name) + if self.log_predictions + else None + ), ) @staticmethod def _compute_similarity( embeddings1: Tensor, embeddings2: Tensor, golden_scores: list, similarity_func: Callable - ) -> dict[str, float]: - test_sim_score = similarity_func(embeddings1, embeddings2).cpu() + ) -> tuple[dict[str, float], list[float]]: + sim_scores = similarity_func(embeddings1, embeddings2).cpu() + if isinstance(sim_scores, Tensor) and sim_scores.dtype is torch.bfloat16: + sim_scores = sim_scores.float() + pearson = pearsonr(golden_scores, sim_scores)[0] + spearman = spearmanr(golden_scores, sim_scores)[0] return { - "pearson": pearsonr(golden_scores, test_sim_score)[0], - "spearman": spearmanr(golden_scores, test_sim_score)[0], - } + "pearson": pearson if not math.isnan(pearson) else 0.0, + "spearman": spearman if not math.isnan(spearman) else 0.0, + }, sim_scores.tolist() + + @staticmethod + def _format_predictions( + dataset: STSDataset, sim_scores: list[float], similarity_function_name: str + ) -> list[STSPrediction]: + predictions = [] + for item, sim_score in zip(dataset, sim_scores): + item: STSInstance + predictions.append( + STSPrediction( + sentence1=item.sentence1, + sentence2=item.sentence2, + true_score=item.score, + predicted_score=sim_score, + similarity_function_name=similarity_function_name, + ) + ) + return predictions def _convert_to_embeddings( self, diff --git a/src/jmteb/utils/score_recorder.py b/src/jmteb/utils/score_recorder.py index e63b00d..afbf22c 100644 --- a/src/jmteb/utils/score_recorder.py +++ b/src/jmteb/utils/score_recorder.py @@ -3,6 +3,7 @@ import json from abc import ABC, abstractmethod from collections import defaultdict +from dataclasses import asdict from os import PathLike from pathlib import Path from typing import Any @@ -30,6 +31,12 @@ def save_to_json(scores: EvaluationResults | dict[Any, Any], filename: str | Pat with open(filename, "w") as fout: json.dump(scores, fout, indent=4, ensure_ascii=False) + @staticmethod + def save_prediction_to_jsonl(predictions: list[Any], filename: str | PathLike[str]) -> None: + with open(filename, "w") as fout: + for prediction in predictions: + fout.write(json.dumps(asdict(prediction), ensure_ascii=False) + "\n") + def record_task_scores(self, scores: EvaluationResults, dataset_name: str, task_name: str) -> None: if not self.save_dir: return @@ -39,6 +46,13 @@ def record_task_scores(self, scores: EvaluationResults, dataset_name: str, task_ self.scores[task_name][dataset_name] = scores self.save_to_json(self.scores[task_name][dataset_name].as_dict(), save_filename) + def record_predictions(self, results: EvaluationResults, dataset_name: str, task_name: str) -> None: + if not self.save_dir: + return + save_filename = Path(self.save_dir) / task_name / f"predictions_{dataset_name}.jsonl" + save_filename.parent.mkdir(parents=True, exist_ok=True) + self.save_prediction_to_jsonl(results.predictions, save_filename) + def record_summary(self): if not self.save_dir: return diff --git a/tests/conftest.py b/tests/conftest.py index b6d95ff..504284d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,6 +25,11 @@ def pytest_collection_modifyitems(config: pytest.Config, items: pytest.Parser): class DummyTextEmbedder(TextEmbedder): + def __init__(self, model_kwargs: dict | None = None) -> None: + self.model_kwargs = self._model_kwargs_parser(model_kwargs) + self.convert_to_tensor = self.model_kwargs.get("torch_dtype", None) is None + self.convert_to_numpy = not self.convert_to_tensor + def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray: if isinstance(text, str): batch_size = 1 diff --git a/tests/embedders/test_dp_sbert.py b/tests/embedders/test_dp_sbert.py new file mode 100644 index 0000000..028e240 --- /dev/null +++ b/tests/embedders/test_dp_sbert.py @@ -0,0 +1,38 @@ +import numpy as np +import torch + +from jmteb.embedders.data_parallel_sbert_embedder import ( + DataParallelSentenceBertEmbedder, +) + +MODEL_NAME_OR_PATH = "prajjwal1/bert-tiny" +OUTPUT_DIM = 128 + + +class TestDPSentenceBertEmbedder: + def setup_class(cls): + cls.model = DataParallelSentenceBertEmbedder(MODEL_NAME_OR_PATH) + + def test_encode(self): + embeddings = self.model.encode("任意のテキスト") + assert isinstance(embeddings, np.ndarray) + assert embeddings.shape == (OUTPUT_DIM,) + + def test_get_output_dim(self): + assert self.model.get_output_dim() == OUTPUT_DIM + + def test_tokenizer_kwargs(self): + assert self.model.model.tokenizer.sep_token == "[SEP]" + model = DataParallelSentenceBertEmbedder(MODEL_NAME_OR_PATH, tokenizer_kwargs={"sep_token": ""}) + assert model.model.tokenizer.sep_token == "" + + def test_model_kwargs(self): + model = DataParallelSentenceBertEmbedder(MODEL_NAME_OR_PATH, model_kwargs={"torch_dtype": torch.float16}) + assert model.convert_to_tensor + assert model.encode("任意のテキスト").dtype is torch.float16 + + def test_bf16(self): + # As numpy doesn't support native bfloat16, add a test case for bf16 + model = DataParallelSentenceBertEmbedder(MODEL_NAME_OR_PATH, model_kwargs={"torch_dtype": torch.bfloat16}) + assert model.convert_to_tensor + assert model.encode("任意のテキスト").dtype is torch.bfloat16 diff --git a/tests/embedders/test_sbert.py b/tests/embedders/test_sbert.py index 77f0585..aa184f9 100644 --- a/tests/embedders/test_sbert.py +++ b/tests/embedders/test_sbert.py @@ -1,4 +1,5 @@ import numpy as np +import torch from jmteb.embedders.sbert_embedder import SentenceBertEmbedder @@ -22,3 +23,14 @@ def test_tokenizer_kwargs(self): assert self.model.model.tokenizer.sep_token == "[SEP]" model = SentenceBertEmbedder(MODEL_NAME_OR_PATH, tokenizer_kwargs={"sep_token": ""}) assert model.model.tokenizer.sep_token == "" + + def test_model_kwargs(self): + model = SentenceBertEmbedder(MODEL_NAME_OR_PATH, model_kwargs={"torch_dtype": torch.float16}) + assert model.convert_to_tensor + assert model.encode("任意のテキスト").dtype is torch.float16 + + def test_bf16(self): + # As numpy doesn't support native bfloat16, add a test case for bf16 + model = SentenceBertEmbedder(MODEL_NAME_OR_PATH, model_kwargs={"torch_dtype": torch.bfloat16}) + assert model.convert_to_tensor + assert model.encode("任意のテキスト").dtype is torch.bfloat16 diff --git a/tests/embedders/test_transformers.py b/tests/embedders/test_transformers.py new file mode 100644 index 0000000..0ab4943 --- /dev/null +++ b/tests/embedders/test_transformers.py @@ -0,0 +1,41 @@ +import numpy as np +import torch + +from jmteb.embedders.transformers_embedder import TransformersEmbedder + +MODEL_NAME_OR_PATH = "prajjwal1/bert-tiny" +OUTPUT_DIM = 128 + + +class TestTransformersEmbedder: + def setup_class(cls): + cls.model = TransformersEmbedder(MODEL_NAME_OR_PATH) + + def test_encode(self): + embeddings = self.model.encode("任意のテキスト") + assert isinstance(embeddings, np.ndarray) + assert embeddings.shape == (OUTPUT_DIM,) + + def test_encode_list(self): + embeddings = self.model.encode(["任意のテキスト", "hello world", "埋め込み"]) + assert isinstance(embeddings, np.ndarray) + assert embeddings.shape == (3, OUTPUT_DIM) + + def test_get_output_dim(self): + assert self.model.get_output_dim() == OUTPUT_DIM + + def test_tokenizer_kwargs(self): + assert self.model.tokenizer.sep_token == "[SEP]" + model = TransformersEmbedder(MODEL_NAME_OR_PATH, tokenizer_kwargs={"sep_token": ""}) + assert model.tokenizer.sep_token == "" + + def test_model_kwargs(self): + model = TransformersEmbedder(MODEL_NAME_OR_PATH, model_kwargs={"torch_dtype": torch.float16}) + assert model.convert_to_tensor + assert model.encode("任意のテキスト").dtype is torch.float16 + + def test_bf16(self): + # As numpy doesn't support native bfloat16, add a test case for bf16 + model = TransformersEmbedder(MODEL_NAME_OR_PATH, model_kwargs={"torch_dtype": torch.bfloat16}) + assert model.convert_to_tensor + assert model.encode("任意のテキスト").dtype is torch.bfloat16 diff --git a/tests/evaluator/test_classification_evaluator.py b/tests/evaluator/test_classification_evaluator.py index bce9964..77cc542 100644 --- a/tests/evaluator/test_classification_evaluator.py +++ b/tests/evaluator/test_classification_evaluator.py @@ -2,6 +2,7 @@ ClassificationDataset, ClassificationEvaluator, ClassificationInstance, + ClassificationPrediction, KnnClassifier, LogRegClassifier, ) @@ -44,6 +45,21 @@ def test_classification_evaluator(embedder): assert set(value.keys()) == expected_metrics +def test_classification_evaluator_with_predictions(embedder): + evaluator = ClassificationEvaluator( + train_dataset=DummyClassificationDataset(), + val_dataset=DummyClassificationDataset(), + test_dataset=DummyClassificationDataset(), + classifiers={ + "logreg": LogRegClassifier(), + "knn": KnnClassifier(k=2, distance_metric="cosine"), + }, + log_predictions=True, + ) + results = evaluator(model=embedder) + assert all([isinstance(result, ClassificationPrediction) for result in results.predictions]) + + def test_classification_evaluator_with_prefix(embedder): evaluator_with_prefix = ClassificationEvaluator( train_dataset=DummyClassificationDataset(), @@ -90,3 +106,22 @@ def test_classification_jsonl_dataset_equal(): assert dummy_jsonl_dataset_1 == dummy_jsonl_dataset_2 dummy_jsonl_dataset_2.label_key = "LABEL" assert dummy_jsonl_dataset_1 != dummy_jsonl_dataset_2 + + +def test_classification_prediction_logging(embedder): + dataset = DummyClassificationDataset() + evaluator = ClassificationEvaluator( + train_dataset=dataset, + val_dataset=dataset, + test_dataset=dataset, + classifiers={ + "logreg": LogRegClassifier(), + "knn": KnnClassifier(k=2, distance_metric="cosine"), + }, + log_predictions=True, + ) + results = evaluator(model=embedder) + assert isinstance(results.predictions, list) + assert [p.text for p in results.predictions] == [d.text for d in dataset] + assert [p.label for p in results.predictions] == [d.label for d in dataset] + assert all([isinstance(p.prediction, int) for p in results.predictions]) diff --git a/tests/evaluator/test_clustering_evaluator.py b/tests/evaluator/test_clustering_evaluator.py index 217d850..50880bd 100644 --- a/tests/evaluator/test_clustering_evaluator.py +++ b/tests/evaluator/test_clustering_evaluator.py @@ -2,6 +2,7 @@ ClusteringDataset, ClusteringEvaluator, ClusteringInstance, + ClusteringPrediction, ) from jmteb.evaluators.clustering.data import JsonlClusteringDataset @@ -39,6 +40,14 @@ def test_kmeans_clustering(embedder): assert set(results.details[score_splitname][clustering_model].keys()) == expected_metrics +def test_clustering_with_predictions(embedder): + evaluator = ClusteringEvaluator( + val_dataset=DummyClusteringDataset(), test_dataset=DummyClusteringDataset(), log_predictions=True + ) + results = evaluator(model=embedder) + assert all([isinstance(p, ClusteringPrediction) for p in results.predictions]) + + def test_clustering_with_prefix(embedder): evaluator_with_prefix = ClusteringEvaluator( val_dataset=DummyClusteringDataset(), diff --git a/tests/evaluator/test_reranking_evaluator.py b/tests/evaluator/test_reranking_evaluator.py index ef847a9..e9ea894 100644 --- a/tests/evaluator/test_reranking_evaluator.py +++ b/tests/evaluator/test_reranking_evaluator.py @@ -1,3 +1,5 @@ +from loguru import logger + from jmteb.evaluators.reranking import ( RerankingDoc, RerankingDocDataset, @@ -8,17 +10,20 @@ from jmteb.evaluators.reranking.data import ( JsonlRerankingDocDataset, JsonlRerankingQueryDataset, + RerankingPrediction, ) EXPECTED_OUTPUT_DICT_KEYS = {"val_scores", "test_scores", "optimal_distance_metric"} EXPECTED_DIST_FUNC_NAMES = {"cosine_similarity", "euclidean_distance", "dot_score"} QUERY_PREFIX = "クエリ: " DOC_PREFIX = "ドキュメント: " +TOP_N_DOCS_TO_LOG = 4 class DummyDocDataset(RerankingDocDataset): def __init__(self, prefix: str = ""): self._items = [RerankingDoc(id=str(i), text=f"{prefix}dummy document {i}") for i in range(30)] + self._build_idx_docid_mapping("_items") def __len__(self): return len(self._items) @@ -60,6 +65,22 @@ def test_reranking_evaluator(embedder): assert any(score.startswith(metric) for metric in ["ndcg"]) +def test_reranking_evaluator_with_predictions(embedder): + evaluator = RerankingEvaluator( + val_query_dataset=DummyQueryDataset(), + test_query_dataset=DummyQueryDataset(), + doc_dataset=DummyDocDataset(), + log_predictions=True, + top_n_docs_to_log=TOP_N_DOCS_TO_LOG, + ) + results = evaluator(model=embedder) + logger.info(f"{results.predictions=}") + for p in results.predictions: + assert isinstance(p, RerankingPrediction) + assert len(p.reranked_relevant_docs) <= TOP_N_DOCS_TO_LOG + assert all([isinstance(doc, RerankingDoc) for doc in p.reranked_relevant_docs]) + + def test_reranking_evaluator_with_prefix(embedder): evaluator_with_prefix = RerankingEvaluator( val_query_dataset=DummyQueryDataset(), diff --git a/tests/evaluator/test_retrieval_evaluator.py b/tests/evaluator/test_retrieval_evaluator.py index fa52c52..2bc5e32 100644 --- a/tests/evaluator/test_retrieval_evaluator.py +++ b/tests/evaluator/test_retrieval_evaluator.py @@ -2,6 +2,7 @@ RetrievalDoc, RetrievalDocDataset, RetrievalEvaluator, + RetrievalPrediction, RetrievalQuery, RetrievalQueryDataset, ) @@ -14,11 +15,13 @@ EXPECTED_DIST_FUNC_NAMES = {"cosine_similarity", "euclidean_distance", "dot_score"} QUERY_PREFIX = "クエリ: " DOC_PREFIX = "ドキュメント: " +TOP_N_DOCS_TO_LOG = 4 class DummyDocDataset(RetrievalDocDataset): def __init__(self, prefix: str = ""): self._items = [RetrievalDoc(id=str(i), text=f"{prefix}dummy document {i}") for i in range(30)] + self._build_idx_docid_mapping("_items") def __len__(self): return len(self._items) @@ -60,6 +63,28 @@ def test_retrieval_evaluator(embedder): assert any(score.startswith(metric) for metric in ["accuracy", "mrr", "ndcg"]) +def test_retrieval_evaluator_with_predictions(embedder): + dummy_query_dataset = DummyQueryDataset() + dummy_doc_dataset = DummyDocDataset() + evaluator = RetrievalEvaluator( + val_query_dataset=dummy_query_dataset, + test_query_dataset=dummy_query_dataset, + doc_dataset=dummy_doc_dataset, + accuracy_at_k=[1, 3, 5, 10], + ndcg_at_k=[1, 3, 5], + doc_chunk_size=3, + log_predictions=True, + top_n_docs_to_log=TOP_N_DOCS_TO_LOG, + ) + results = evaluator(model=embedder) + assert [p.query for p in results.predictions] == [q.query for q in dummy_query_dataset] + assert all([isinstance(p, RetrievalPrediction) for p in results.predictions]) + for p in results.predictions: + assert isinstance(p, RetrievalPrediction) + assert len(p.predicted_relevant_docs) == TOP_N_DOCS_TO_LOG + assert all([isinstance(doc, RetrievalDoc) for doc in p.predicted_relevant_docs]) + + def test_retrieval_evaluator_with_prefix(embedder): evaluator_with_prefix = RetrievalEvaluator( val_query_dataset=DummyQueryDataset(), diff --git a/tests/evaluator/test_sts_evaluator.py b/tests/evaluator/test_sts_evaluator.py index 69469cc..d7a6d1c 100644 --- a/tests/evaluator/test_sts_evaluator.py +++ b/tests/evaluator/test_sts_evaluator.py @@ -1,4 +1,4 @@ -from jmteb.evaluators.sts import STSDataset, STSEvaluator, STSInstance +from jmteb.evaluators.sts import STSDataset, STSEvaluator, STSInstance, STSPrediction from jmteb.evaluators.sts.data import JsonlSTSDataset EXPECTED_OUTPUT_DICT_KEYS = {"val_scores", "test_scores", "optimal_similarity_metric"} @@ -37,6 +37,12 @@ def test_sts(embedder): assert set(results.details[score_splitname][dist].keys()) == EXPECTED_METRIC_NAMES +def test_sts_with_predictions(embedder): + evaluator = STSEvaluator(val_dataset=DummySTSDataset(), test_dataset=DummySTSDataset(), log_predictions=True) + results = evaluator(model=embedder) + assert all([isinstance(result, STSPrediction) for result in results.predictions]) + + def test_sts_with_prefix(embedder): evaluator_with_prefix = STSEvaluator( val_dataset=DummySTSDataset(), diff --git a/tests/test_main.py b/tests/test_main.py index 4ac552e..ee81fb5 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -23,8 +23,10 @@ def test_main_cli(): command = [ "python", "-m", "jmteb", "--embedder", "tests.conftest.DummyTextEmbedder", + "--embedder.model_kwargs", '{"torch_dtype": "torch.float16"}', "--save_dir", f, "--eval_include", '["jsts"]', + "--log_predictions", "true", ] # fmt: on result = subprocess.run(command)