diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 000000000..59c2ee387 Binary files /dev/null and b/.DS_Store differ diff --git a/LICENSE b/LICENSE index 261eeb9e9..1490794b8 100644 --- a/LICENSE +++ b/LICENSE @@ -178,7 +178,7 @@ APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" + boilerplate notice, with the fields enclosed by brackets "{}" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright [yyyy] [name of copyright owner] + Copyright 2023 Alibaba Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -199,3 +199,202 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + +------------------------------------------------------------------------------- + +Code in data_juicer/ops/common/helper_func.py, data_juicer/ops/deduplicator/document_deduplicator.py, +data_juicer/ops/deduplicator/document_simhash_deduplicator.py, data_juicer/ops/filter/character_repetition_filter.py, +data_juicer/ops/filter/flagged_words_filter.py, data_juicer/ops/filter/perplexity_filter.py, +data_juicer/ops/filter/special_characters_filter.py, data_juicer/ops/filter/stopwords_filter.py, +data_juicer/ops/filter/word_repetition_filter.py, data_juicer/ops/mapper/punctuation_normalization_mapper.py, +data_juicer/ops/mapper/remove_long_words_mapper.py, app.py is adapted from +https://huggingface.co/spaces/huggingface/text-data-filtering or +https://github.com/bigscience-workshop/data-preparation + + Copyright [2021] [Bigscience] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------- + +Code in data_juicer/ops/deduplicator/document_minhash_deduplicator.py is +adapted from +https://github.com/bigcode-project/bigcode-dataset + + Copyright 2022 bigcode authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------- + +Code in data_juicer/ops/mapper/clean_copyright_mapper.py, data_juicer/ops/mapper/clean_html_mapper.py, +data_juicer/ops/mapper/expand_macro_mapper.py, data_juicer/ops/mapper/remove_bibliography_mapper.py, +data_juicer/ops/mapper/remove_comments_mapper.py, data_juicer/ops/mapper/remove_header_mapper.py, +is adapted from +https://github.com/togethercomputer/RedPajama-Data + + Copyright 2023 RedPajama authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------- + +The implementations of gpt_evaluator in tools/evaluator/gpt_eval/gpt_evaluator.py +is adapted from https://github.com/lm-sys/FastChat (Apache License) + +Copyright (c) 2023 The FastChat Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +----------------------------------------------------- + + +The implementations of checkpoint converter in tools/converter/ +convert_gpt_to_transformers.py and tools/converter/modeling_megatron_llama.py +are adapted from https://github.com/huggingface/transformers (Apache License) + +Copyright (c) 2022 EleutherAI and the HuggingFace Inc. team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +----------------------------------------------------- + +Code in thirdparty/Megatron-LM +is adapted from https://github.com/NVIDIA/Megatron-LM + +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +----------------------------------------------------- + +Code in thirdparty/helm +is adapted from https://github.com/stanford-crfm/helm (Apache License) + +Copyright (c) 2023 The helm Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + + +----------------------------------------------------- + +Code in tests/run.py is adapted from https://github +.com/alibaba/FederatedScope/blob/master/tests/run.py (Apache License) + +Copyright (c) 2023 The FederatedScope Team + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + + +----------------------------------------------------- + +Code in utils/logger_utils.py is adapted from https://github.com/MegEngine/ +YOLOX/blob/main/yolox/utils/logger.py (Apache License) + +Copyright 2021 Megvii, Base Detection + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + + diff --git a/README.md b/README.md index 8e498944f..c2af1a99f 100644 --- a/README.md +++ b/README.md @@ -1 +1,214 @@ -# data-juicer \ No newline at end of file +# Data-Juicer: A Data-Centric Text Processing System for Large Language Models + +![Data-Juicer](docs/imgs/data-juicer.png "Data-Juicer") + +![](https://img.shields.io/badge/language-Python-214870.svg) +![](https://img.shields.io/badge/license-Apache--2.0-000000.svg) +[![Contributing](https://img.shields.io/badge/Contribution-welcome-brightgreen.svg)](docs/DeveloperGuide.md) + +[![Document_List](https://img.shields.io/badge/Docs-English-blue?logo=Markdown)](#documentation-|-文档) +[![文档列表](https://img.shields.io/badge/文档-中文-blue?logo=Markdown)](README_ZH.md) +[![API Reference](https://img.shields.io/badge/Docs-API_Reference-blue?logo=Markdown)](https://alibaba.github.io/data-juicer/) +[![ModelScope-10+ Demos](https://img.shields.io/badge/ModelScope-10+_Demos-4e29ff.svg?logo=)](#demos) +[![ModelScope-20+_Refined_Datasets](https://img.shields.io/badge/ModelScope-20+_Refined_Datasets-4e29ff.svg?logo=)](https://modelscope.cn/datasets?organization=Data-Juicer&page=1) + +[![QualityClassifier](https://img.shields.io/badge/Tools-Quality_Classifier-saddlebrown?logo=Markdown)](tools/quality_classifier/README.md) +[![AutoEvaluation](https://img.shields.io/badge/Tools-Auto_Evaluation-saddlebrown?logo=Markdown)](tools/evaluator/README.md) + +Data-Juicer is a data-centric text processing system to make data higher-quality, juicier, and more digestible for LLMs. +This project is being actively updated and maintained, and we will periodically enhance and add more features and data recipes. We welcome you to join us in promoting LLM data development and research! + +---- + +Table of Contents +================= + +* [Data-Juicer: A Data-Centric Text Processing System for Large Language Models](#data-juicer-a-data-centric-text-processing-system-for-large-language-models) +* [Table of Contents](#table-of-contents) + * [Features](#features) + * [Prerequisites](#prerequisites) + * [Installation](#installation) + * [Quick Start](#quick-start) + * [Data Processing](#data-processing) + * [Data Analysis](#data-analysis) + * [Data Visualization](#data-visualization) + * [Build Up Config Files](#build-up-config-files) + * [Preprocess raw data (Optional)](#preprocess-raw-data-optional) + * [Documentation | 文档](#documentation-|-文档) + * [Data Recipes](#data-recipes) + * [Demos](#demos) + * [License](#license) + * [Contributing](#contributing) + * [References](#references) + +## Features + +- **Broad Range of Operators**: Equipped with 50+ core [operators (OPs)](docs/Operators.md), including Formatters, Mappers, Filters, Deduplicators, and beyond. + +- **Specialized Toolkits**: Feature-rich specialized toolkits such as [Text Quality Classifier](tools/quality_classifier/README.md), [Dataset Splitter](tools/preprocess/README.md), [Analysers](#data-analysis), [Evaluators](tools/evaluator/README.md), and more that elevate your dataset handling capabilities. + +- **Systematic & Reusable**: Empowering users with a systematic library of reusable [config recipes](configs) and [OPs](docs/Operators.md), designed to function independently of specific datasets, models, or tasks. + +- **Data-in-the-loop**: Allowing detailed data analyses with an automated report generation feature for a deeper understanding of your dataset. Coupled with real-time multi-dimension automatic evaluation capabilities, it supports a [feedback loop](https://modelscope.cn/studios/Data-Juicer/data_process_loop/summary) at multiple stages in the LLM development process. + +- **Comprehensive Processing Recipes**: Offering tens of [pre-built data processing recipes](configs/refine_recipe/README.md) for pre-training, SFT, en, zh, and more scenarios. + +- **User-Friendly Experience**: Designed for simplicity, with [comprehensive documentation](#documentation-|-文档), [easy start guides](#quick-start) and [demo configs](configs/), and intuitive configuration with simple adding/removing OPs from [existing configs](configs/config_all.yaml). + +- **Flexible & Extensible**: Accommodating most types of data formats (e.g., jsonl, parquet, csv, ...) and allowing flexible combinations of OPs. Feel free to [implement your own OPs](docs/DeveloperGuide.md#build-your-own-ops) for customizable data processing. + +- **Enhanced Efficiency**: Providing a speedy data processing pipeline requiring less memory, optimized for maximum productivity. + +## Prerequisites + +- Recommend Python==3.8 +- gcc >= 5 (at least C++14 support) + +## Installation + +- Run the following commands to install the latest `data_juicer` version in + editable mode: +```shell +cd +pip install -v -e .[all] +``` + +- Or install optional dependencies: +```shell +cd +pip install -v -e . # install a minimal dependencies +pip install -v -e .[tools] # install a subset of tools dependencies +``` + +The dependency options are listed below: + +| Tag | Description | +|----------|------------------------------------------------------------------------| +| . | Install minimal dependencies for basic Data-Juicer. | +| .[all] | Install all optional dependencies (all of the following) | +| .[dev] | Install dependencies for developing the package as contributors | +| .[tools] | Install dependencies for dedicated tools, such as quality classifiers. | + +- Installation check: +```python +import data_juicer as dj +print(dj.__version__) +``` + +## Quick Start + + +### Data Processing + +- Run `process_data.py` tool with your config as the argument to process + your dataset. + +```shell +python tools/process_data.py --config configs/demo/process.yaml +``` + +- **NOTICE**: For some operators that involve third-party models or resources which are not stored locally on your computer, it might be slow for the first running because these ops need to download corresponding resources into a directory first. +The default download cache directory is `~/.cache/data_juicer`. Change the cache location by setting the shell environment variable, `DATA_JUICER_CACHE_HOME` to another directory, and you can also change `DATA_JUICER_MODELS_CACHE` or `DATA_JUICER_ASSETS_CACHE` in the same way: + +```shell +# cache home +export DATA_JUICER_CACHE_HOME="/path/to/another/directory" +# cache models +export DATA_JUICER_MODELS_CACHE="/path/to/another/directory/models" +# cache assets +export DATA_JUICER_ASSETS_CACHE="/path/to/another/directory/assets" +``` + +### Data Analysis +- Run `analyze_data.py` tool with your config as the argument to analyse your dataset. + +```shell +python tools/analyze_data.py --config configs/demo/analyser.yaml +``` + +- **NOTICE**: Analyser only compute stats of Filter ops. So extra Mapper or Deduplicator ops will be ignored in the analysis process. + +### Data Visualization + +- Run `app.py` tool to visualize your dataset in your browser. + +```shell +streamlit run app.py +``` + +### Build Up Config Files + +- Config files specify some global arguments, and an operator list for the + data process. You need to set: + - Global arguments: input/output dataset path, number of workers, etc. + - Operator list: list operators with their arguments used to process the dataset. +- You can build up your own config files by: + - ➖:Modify from our example config file [`config_all.yaml`](configs/config_all.yaml) which includes **all** ops and default + arguments. You just need to **remove** ops that you won't use and refine + some arguments of ops. + - ➕:Build up your own config files **from scratch**. You can refer our + example config file [`config_all.yaml`](configs/config_all.yaml), [op documents](docs/Operators.md), and advanced [Build-Up Guide for developers](docs/DeveloperGuide.md#build-your-own-configs). + - Besides the yaml files, you also have the flexibility to specify just + one (of several) parameters on the command line, which will override + the values in yaml files, e.g., `python xxx.py --config configs/demo/process.yaml --language_id_score_filter.lang=en` +- The basic config format and definition is shown below. + + ![Basic config example of format and definition](docs/config_def.png "Basic config file example") + +### Preprocess Raw Data (Optional) +- Our formatters support some common input dataset formats for now: + - Multi-sample in one file: jsonl/json, parquet, csv/tsv, etc. + - Single-sample in one file: txt, code, docx, pdf, etc. +- However, data from different sources are complicated and diverse. Such as: + - [Raw arxiv data downloaded from S3](https://info.arxiv.org/help/bulk_data_s3.html) include thousands of tar files and even more gzip files in them, and expected tex files are embedded in the gzip files so they are hard to obtain directly. + - Some crawled data include different kinds of files (pdf, html, docx, etc.). And extra information like tables, charts, and so on is hard to extract. +- It's impossible to handle all kinds of data in Data-Juicer, issues/PRs are welcome to contribute to process new data types! +- Thus, we provide some **common preprocessing tools** in [`tools/preprocess`](tools/preprocess/) for you to preprocess these data. + - You are welcome to make your contributions to new preprocessing tools for the community. + - We **highly recommend** that complicated data can be preprocessed to jsonl or parquet files. + +## Documentation | 文档 + +- [Overall](README.md) | [概览](README_ZH.md) +- [Operator Zoo](docs/Operators.md) | [算子库](docs/Operators_ZH.md) +- [Configs](configs/README.md) | [配置系统](configs/README_ZH.md) +- [Developer Guide](docs/DeveloperGuide.md) | [开发者指南](docs/DeveloperGuide_ZH.md) +- Dedicated Toolkits | 专用工具箱 + - [Quality Classifier](tools/quality_classifier/README.md) | [质量分类器](tools/quality_classifier/README_ZH.md) + - [Auto Evaluation](tools/evaluator/README.md) | [自动评测](tools/evaluator/README_ZH.md) + - [Preprocess](tools/preprocess/README.md) | [前处理](tools/preprocess/README_ZH.md) + - [Postprocess](tools/postprocess/README.md) | [后处理](tools/postprocess/README_ZH.md) +- [Third-parties (LLM Ecosystems)](thirdparty/README.md) | [第三方库(大语言模型生态)](thirdparty/README_ZH.md) +- [API references](https://alibaba.github.io/data-juicer/) + +## Data Recipes +- [Recipes for data process in BLOOM](configs/bloom/README.md) +- [Recipes for data process in RedPajama](configs/redpajama/README.md) +- [Refined recipes for pretraining data](configs/refine_recipe/README.md) +- [Refined recipes for SFT data](configs/refine_recipe/README.md#L28) + +## Demos +- Introduction to Data-Juicer [[ModelScope](https://modelscope.cn/studios/Data-Juicer/overview_scan/summary)] +- Data Visualization: + - Basic Statistics [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_visulization_statistics/summary)] + - Lexical Diversity [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_visulization_diversity/summary)] + - Operator Effect [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_visulization_op_effect/summary)] +- Data Processing: + - Scientific Literature (e.g. [ArXiv](https://info.arxiv.org/help/bulk_data_s3.html)) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/process_sci_data/summary)] + - Programming Code (e.g. [TheStack](https://huggingface.co/datasets/bigcode/the-stack)) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/process_code_data/summary)] + - Chinese Instruction Data (e.g. [Alpaca-CoT](https://huggingface.co/QingyiSi/Alpaca-CoT)) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/sft_data_zh/summary)] +- Tool Pool: + - Quality Classifier for CommonCrawl [[ModelScope](https://modelscope.cn/studios/Data-Juicer/tool_quality_classifier/summary)] + - Auto Evaluation on [HELM](https://github.com/stanford-crfm/helm) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/auto_evaluation_helm/summary)] + - Data Sampling and Mixture [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_mixture/summary)] +- Data Process Loop [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_process_loop/summary)] +- Data Process HPO [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_process_hpo/summary)] + +## License +Data-Juicer is released under Apache License 2.0. + +## Contributing +We greatly welcome contributions of new features, bug fixes, and discussions. Please refer to [How-to Guide for Developers](docs/DeveloperGuide.md). + +## References +Our paper is coming soon! diff --git a/README_ZH.md b/README_ZH.md new file mode 100644 index 000000000..f63973b3a --- /dev/null +++ b/README_ZH.md @@ -0,0 +1,214 @@ +# Data-Juicer: 为大语言模型提供更高质量、更丰富、更易“消化”的数据 + +![Data-Juicer](docs/imgs/data-juicer.png "Data-Juicer") + +![](https://img.shields.io/badge/language-Python-214870.svg) +![](https://img.shields.io/badge/license-Apache--2.0-000000.svg) +[![Contributing](https://img.shields.io/badge/Contribution-welcome-brightgreen.svg)](docs/DeveloperGuide_ZH.md) + +[![Document_List](https://img.shields.io/badge/Docs-English-blue?logo=Markdown)](#documentation-|-文档) +[![文档列表](https://img.shields.io/badge/文档-中文-blue?logo=Markdown)](README_ZH.md) +[![API Reference](https://img.shields.io/badge/Docs-API_Reference-blue?logo=Markdown)](https://alibaba.github.io/data-juicer/) +[![ModelScope-10+ Demos](https://img.shields.io/badge/ModelScope-10+_Demos-4e29ff.svg?logo=)](#demos) +[![ModelScope-20+_Refined_Datasets](https://img.shields.io/badge/ModelScope-20+_Refined_Datasets-4e29ff.svg?logo=)](#data-recipes) + +[![QualityClassifier](https://img.shields.io/badge/Tools-Quality_Classifier-saddlebrown?logo=Markdown)](tools/quality_classifier/README_ZH.md) +[![AutoEvaluation](https://img.shields.io/badge/Tools-Auto_Evaluation-saddlebrown?logo=Markdown)](tools/evaluator/README_ZH.md) + +Data-Juicer 是一个以数据为中心的文本处理系统,旨在为大语言模型 (LLM) 提供更高质量、更丰富、更易“消化”的数据。 +本项目在积极更新和维护中,我们将定期强化和新增更多的功能和数据菜谱。欢迎您加入我们推进 LLM 数据的开发和研究工作! + +---- + +目录 +=== + +* [Data-Juicer: 为大语言模型提供更高质量、更丰富、更易“消化”的数据](#data-juicer-为大语言模型提供更高质量、更丰富、更易“消化”的数据) +* [目录](#目录) + * [特点](#特点) + * [前置条件](#前置条件) + * [安装](#安装) + * [快速上手](#快速上手) + * [数据处理](#数据处理) + * [数据分析](#数据分析) + * [数据可视化](#数据可视化) + * [构建配置文件](#构建配置文件) + * [预处理原始数据(可选)](#预处理原始数据(可选)) + * [Documentation | 文档](#documentation-|-文档) + * [数据处理菜谱](#数据处理菜谱) + * [演示样例](#演示样例) + * [开源协议](#开源协议) + * [贡献](#贡献) + * [参考文献](#参考文献) + +## 特点 + +* **丰富的算子**: 内置了 50 多个核心 [算子(OPs)](docs/Operators_ZH.md),包括 Formatters,Mappers,Filters,Deduplicators 等。 + +* **专业的工具库**: 提供功能丰富的专业工具库,例如 [文本质量打分器](tools/quality_classifier/README_ZH.md), [数据分割器](tools/preprocess/README_ZH.md), [分析器](#数据分析), [评估器](tools/evaluator/README_ZH.md) 等,提升您的数据处理能力。 + +* **系统化 & 可复用**: 为用户提供系统化且可复用的[配置菜谱](configs)和[算子库](docs/Operators_ZH.md),旨在让数据处理独立于特定的数据集、模型或任务运行。 + +* **数据反馈回路**: 支持详细的数据分析,并提供自动报告生成功能,使您深入了解您的数据集。结合实时多维度自动评估功能,支持在 LLM 开发过程的多个阶段进行[反馈循环](https://modelscope.cn/studios/Data-Juicer/data_process_loop/summary)。 + +* **全面的处理菜谱**: 为预训练、SFT、中英文等场景提供数十种[预构建的数据处理菜谱](configs/refine_recipe/README_ZH.md)。 + +* **用户友好**: 设计简单易用,提供全面的[文档](#documentation-|-文档)、简易[入门指南](#快速上手)和[演示配置](configs/),并且可以轻松地添加/删除[现有配置](configs/config_all.yaml)中的算子。 + +* **灵活 & 易扩展**: 支持大多数数据格式(如jsonl、parquet、csv等),并允许灵活组合算子。支持[自定义算子](docs/DeveloperGuide_ZH.md#构建自己的算子),以执行定制化的数据处理。 + +* **效率增强**: 提供高效的数据处理流水线,减少内存占用,提高生产力。 + +## 前置条件 + +* 推荐 Python==3.8 +* gcc >= 5 (at least C++14 support) + +## 安装 + +* 运行以下命令以安装 `data_juicer` 可编辑模式的最新版本 + +```shell +cd +pip install -v -e .[all] +``` + +* 或是安装可选的依赖项: + +```shell +cd +pip install -v -e . # 安装最小依赖 +pip install -v -e .[tools] # 安装部分工具库的依赖 +``` + +依赖选项如下表所示: + +| 标签 | 描述 | +|----------|----------------------------------------------| +| . | 安装支持 Data-Juicer 基础功能的最小依赖项 | +| .[all] | 安装所有可选依赖项(即下面所有依赖项) | +| .[dev] | 安装作为贡献者开发 Data-Juicer 所需的依赖项 | +| .[tools] | 安装专用工具库(如质量分类器)所需的依赖项 | + +* 核验安装是否成功: + +```python +import data_juicer as dj +print(dj.__version__) +``` + +## 快速上手 + +### 数据处理 + +* 以配置文件路径作为参数来运行 `process_data.py` 来处理数据集。 + +```shell +python tools/process_data.py --config configs/demo/process.yaml +``` + +* **注意**: 使用未保存在本地的第三方模型或资源的算子第一次运行可能会很慢,因为这些算子需要将相应的资源下载到缓存目录中。默认的下载缓存目录为`~/.cache/data_juicer`。您可通过设置 shell 环境变量 `DATA_JUICER_CACHE_HOME` 更改缓存目录位置,您也可以通过同样的方式更改 `DATA_JUICER_MODELS_CACHE` 或 `DATA_JUICER_ASSETS_CACHE` 来分别修改模型缓存或资源缓存目录: + +```shell +# 缓存主目录 +export DATA_JUICER_CACHE_HOME="/path/to/another/directory" +# 模型缓存目录 +export DATA_JUICER_MODELS_CACHE="/path/to/another/directory/models" +# 资源缓存目录 +export DATA_JUICER_ASSETS_CACHE="/path/to/another/directory/assets" +``` + +### 数据分析 + +- 以配置文件路径为参数运行 `analyze_data.py` 来分析数据集。 + +```shell +python tools/analyze_data.py --config configs/demo/analyser.yaml +``` + +* **注意**: Analyser 只计算 Filter 算子的状态,其他的算子(例如 Mapper 和 Deduplicator)会在分析过程中被忽略。 + +### 数据可视化 + +* 运行 `app.py` 来在浏览器中可视化您的数据集。 + +```shell +streamlit run app.py +``` + +### 构建配置文件 + +* 配置文件包含一系列全局参数和用于数据处理的算子列表。您需要设置: + * 全局参数: 输入/输出 数据集路径,worker 进程数量等。 + * 算子列表:列出用于处理数据集的算子及其参数。 +* 您可以通过如下方式构建自己的配置文件: + * ➖:修改我们的样例配置文件 [`config_all.yaml`](configs/config_all.yaml)。该文件包含了**所有**算子以及算子对应的默认参数。您只需要**移除**不需要的算子并重新设置部分算子的参数即可。 + * ➕:从头开始构建自己的配置文件。您可以参考我们提供的样例配置文件 [`config_all.yaml`](configs/config_all.yaml),[算子文档](docs/Operators_ZH.md),以及 [开发者指南](docs/DeveloperGuide_ZH.md#构建自己的算子). + * 除了使用 yaml 文件外,您还可以在命令行上指定一个或多个参数,这些参数将覆盖 yaml 文件中的值,例如:`python xxx.py --config configs/demo/process.yaml --language_id_score_filter.lang=en` +* 基础的配置项格式及定义如下图所示 + + ![基础配置项格式及定义样例](docs/config_def.png "基础配置文件样例") + +### 预处理原始数据(可选) + +* 我们的 Formatter 目前支持一些常见的输入数据集格式: + * 单个文件中包含多个样本:jsonl/json、parquet、csv/tsv 等。 + * 单个文件中包含单个样本:txt、code、docx、pdf 等。 +* 但来自不同源的数据是复杂和多样化的,例如: + * [从 S3 下载的 arxiv 原始数据](https://info.arxiv.org/help/bulk_data_s3.html) 包括数千个 tar 文件以及更多的 gzip 文件,并且所需的 tex 文件在 gzip 文件中,很难直接获取。 + * 一些爬取的数据包含不同类型的文件(pdf、html、docx 等),并且很难提取额外的信息,例如表格、图表等。 +* Data-Juicer 不可能处理所有类型的数据,欢迎提 Issues/PRs,贡献对新数据类型的处理能力! +* 因此我们在 [`tools/preprocess`](tools/preprocess) 中提供了一些**常见的预处理工具**,用于预处理这些类型各异的数据。 + * 欢迎您为社区贡献新的预处理工具。 + * 我们**强烈建议**将复杂的数据预处理为 jsonl 或 parquet 文件。 + +## Documentation | 文档 + +* [Overall](README.md) | [概览](README_ZH.md) +* [Operator Zoo](docs/Operators.md) | [算子库](docs/Operators_ZH.md) +* [Configs](configs/README.md) | [配置系统](configs/README_ZH.md) +* [Developer Guide](docs/DeveloperGuide.md) | [开发者指南](docs/DeveloperGuide_ZH.md) +* Dedicated Toolkits | 专用工具箱 + * [Quality Classifier](tools/quality_classifier/README.md) | [质量分类器](tools/quality_classifier/README_ZH.md) + * [Auto Evaluation](tools/evaluator/README.md) | [自动评测](tools/evaluator/README_ZH.md) + * [Preprocess](tools/preprocess/README.md) | [前处理](tools/preprocess/README_ZH.md) + * [Postprocess](tools/postprocess/README.md) | [后处理](tools/postprocess/README_ZH.md) +* [Third-parties (LLM Ecosystems)](thirdparty/README.md) | [第三方库(大语言模型生态)](thirdparty/README_ZH.md) +* [API references](https://alibaba.github.io/data-juicer/) + +## 数据处理菜谱 + +* [BLOOM 数据处理菜谱](configs/bloom/README_ZH.md) +* [RedPajama 数据处理菜谱](configs/redpajama/README_ZH.md) +* [预训练数据增强菜谱](configs/refine_recipe/README_ZH.md) +* [SFT数据增强菜谱](configs/refine_recipe/README_ZH.md#L32) + +## 演示样例 + +* Data-Juicer 介绍 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/overview_scan/summary)] +* 数据可视化: + * 基础指标统计 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_visulization_statistics/summary)] + * 词汇多样性 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_visulization_diversity/summary)] + * 算子效果 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_visulization_op_effect/summary)] +* 数据处理: + * 科学文献 (例如 [ArXiv](https://info.arxiv.org/help/bulk_data_s3.html)) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/process_sci_data/summary)] + * 编程代码 (例如 [TheStack](https://huggingface.co/datasets/bigcode/the-stack)) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/process_code_data/summary)] + * 中文指令数据 (例如 [Alpaca-CoT](https://huggingface.co/QingyiSi/Alpaca-CoT)) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/sft_data_zh/summary)] +* 工具池: + * CommonCrawl 质量分类器 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/tool_quality_classifier/summary)] + * 基于 [HELM](https://github.com/stanford-crfm/helm) 的自动评测 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/auto_evaluation_helm/summary)] + * 数据采样及混合 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_mixture/summary)] +* 数据处理回路 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_process_loop/summary)] +* 数据处理 HPO [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_process_hpo/summary)] + +## 开源协议 + +Data-Juicer 在 Apache License 2.0 协议下发布。 + +## 贡献 + +我们非常欢迎贡献新功能、修复漏洞以及讨论。请参考[开发者指南](docs/DeveloperGuide_ZH.md)。 + +## 参考文献 + +我们的论文即将发布! diff --git a/app.py b/app.py new file mode 100644 index 000000000..f7f842e9d --- /dev/null +++ b/app.py @@ -0,0 +1,765 @@ +# Some code here has been modified from: +# https://huggingface.co/spaces/huggingface/text-data-filtering +# -------------------------------------------------------- + +import copy +import math +import os +import sys + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import streamlit as st +import yaml +from loguru import logger + +import plotly.express as px +from data_juicer.analysis.diversity_analysis import (DiversityAnalysis, + get_diversity, + prepare_diversity_model) +from data_juicer.config import init_configs +from data_juicer.core import Analyser, Executor +from data_juicer.ops.base_op import OPERATORS +from data_juicer.utils.logger_utils import get_log_file_path + + +@st.cache_data +def convert_csv(df): + # IMPORTANT: Cache the conversion to prevent computation on every rerun + return df.to_csv().encode('utf-8') + + +@st.cache_data +def convert_jsonl(df): + # IMPORTANT: Cache the conversion to prevent computation on every rerun + return df.to_json(orient='records', lines=True).encode('utf-8') + + +@st.cache_data +def get_diversity_model(lang): + diversity_model = prepare_diversity_model(lang) + return diversity_model + + +@st.cache_data +def postproc_diversity(dataframe, **kwargs): + df = get_diversity(dataframe, **kwargs) + return df + + +def read_log_file(): + log_f_path = get_log_file_path() + if log_f_path is None or not os.path.exists(log_f_path): + return '' + sys.stdout.flush() + with open(log_f_path, 'r') as f: + return f.read() + + +def pretty_out(d): + res = '' + process = '' + op_names = set(OPERATORS.modules.keys()) + for key, value in d.items(): + if key == 'process': + process = yaml.dump(value, + allow_unicode=True, + default_flow_style=False) + elif key == 'config' or key.split('.')[0] in op_names: + continue + else: + res += f'{key}:\n \t {value}\n' + res += 'process:\n' + \ + '\n'.join(['\t' + line for line in process.splitlines()]) + + return res + + +def parse_cfg(): + + cfg_file = st.session_state.input_cfg_file + cfg_cmd = st.session_state.input_cfg_cmd + + cfg_f_name = 'null' + del_cfg_file = False + if cfg_file is not None: + cfg_f_name = cfg_file.name + file_contents = cfg_file.getvalue() + with open(cfg_f_name, 'wb') as f: + f.write(file_contents) + cfg_cmd = f'--config {cfg_f_name}' + del_cfg_file = True + + args_in_cmd = cfg_cmd.split() + + if len(args_in_cmd) >= 2 and args_in_cmd[0] == '--config': + cfg_f_name = args_in_cmd[1] + else: + st.warning('Please specify a config command or upload a config file.') + st.stop() + + if not os.path.exists(cfg_f_name): + st.warning('do not parse' + f'config file does not exist with cfg_f_name={cfg_f_name}') + st.stop() + + with open(cfg_f_name, 'r') as cfg_f: + specified_cfg = yaml.safe_load(cfg_f) + + try: + parsed_cfg = init_configs(args=args_in_cmd) + st.session_state.cfg = parsed_cfg + if del_cfg_file: + os.remove(cfg_f_name) + return pretty_out(parsed_cfg), pretty_out(specified_cfg), parsed_cfg + except Exception as e: + return str(e), pretty_out(specified_cfg), None + + +def analyze_and_show_res(): + images_ori = [] + cfg = st.session_state.get('cfg', parse_cfg()[2]) + if cfg is None: + raise ValueError('you have not specify valid cfg') + # force generating separate figures + cfg['save_stats_in_one_file'] = True + + logger.info('=========Stage 1: analyze original data=========') + analyzer = Analyser(cfg) + dataset = analyzer.run() + + analysis_res_ori = pd.read_csv( + os.path.join(analyzer.analysis_path, 'overall.csv')) + for f_path in os.listdir(analyzer.analysis_path): + if '.png' in f_path and 'all-stats' in f_path: + images_ori.append(os.path.join(analyzer.analysis_path, f_path)) + + st.session_state.dataset = dataset + st.session_state.orginal_overall = analysis_res_ori + st.session_state.original_imgs = images_ori + + +def process_and_show_res(): + images_processed = [] + cfg = st.session_state.get('cfg', parse_cfg()[2]) + if cfg is None: + raise ValueError('you have not specify valid cfg') + # force generating separate figures + cfg['save_stats_in_one_file'] = True + logger.info('=========Stage 2: process original data=========') + executor = Executor(cfg) + dataset = executor.run() + + logger.info('=========Stage 3: analyze the processed data==========') + analysis_res_processed = pd.DataFrame() + try: + cfg_for_processed_data = copy.deepcopy(cfg) + cfg_for_processed_data.dataset_path = cfg.export_path + + cfg_for_processed_data.export_path = os.path.dirname( + cfg.export_path) + '_processed/data.jsonl' + cfg_for_processed_data.text_keys_to_load = [cfg.text_key_to_process] + analyzer = Analyser(cfg_for_processed_data) + analyzer.run() + analysis_res_processed = pd.read_csv( + os.path.join(analyzer.analysis_path, 'overall.csv')) + for f_path in os.listdir(analyzer.analysis_path): + if '.png' in f_path and 'all-stats' in f_path: + images_processed.append( + os.path.join(analyzer.analysis_path, f_path)) + except Exception as e: + st.warning(f'Something error with {str(e)}') + + logger.info('=========Stage 4: Render the analysis results==========') + st.session_state.dataset = dataset + st.session_state.processed_overall = analysis_res_processed + st.session_state.processed_imgs = images_processed + + +def get_min_max_step(data): + max_value = np.max(data) + if max_value > 2.0: + min_value = 0 + max_value = int(max_value + 1) + step = 1 + else: + min_value = 0.0 + max_value = max(1.0, max_value) + step = 0.01 + return min_value, max_value, step + + +op_stats_dict = { + 'alphanumeric_filter': ['alpha_token_ratio', 'alnum_ratio'], + 'average_line_length_filter': ['avg_line_length'], + 'character_repetition_filter': ['char_rep_ratio'], + 'flagged_words_filter': ['flagged_words_ratio'], + 'language_id_score_filter': ['lang', 'lang_score'], + 'maximum_line_length_filter': ['max_line_length'], + 'perplexity_filter': ['perplexity'], + 'special_characters_filter': ['special_char_ratio'], + 'stopwords_filter': ['stopwords_ratio'], + 'text_length_filter': ['text_len'], + 'words_num_filter': ['num_words'], + 'word_repetition_filter': ['word_rep_ratio'], +} + + +class Visualize: + + @staticmethod + def filter_dataset(dataset): + + text = dataset['text'] + if 'stats' not in dataset.features: + stats = pd.DataFrame(dataset['stats.meta']) + else: + stats = pd.DataFrame(dataset['stats']) + stats['text'] = text + + non_num_list = ['lang'] + min_cutoff_list = [ + 'lang_score', + 'stopwords_ratio', + ] + max_cutoff_list = [ + 'flagged_words_ratio', + 'max_ppl', + ] + mask_list = ['text'] + + cfg = st.session_state.get('cfg', None) + if cfg is None: + return + + def set_sliders(total_stats, ordered): + stats = copy.deepcopy(total_stats) + conds = list() + index = 1 + for op_cfg in cfg.process: + op_name = list(op_cfg.keys())[0] + op_stats = op_stats_dict.get(op_name, []) + + cutoff_ratio = None + + with st.sidebar.expander(f'{index} {op_name}'): + + for column_name in op_stats: + if column_name not in stats: + continue + data = stats[column_name] + + if column_name in non_num_list: + options = ['all'] + list(set(data)) + label = f'Which {column_name} would \ + you like to keep?' + + selected = st.selectbox( + label=label, + options=options, + ) + if selected == 'all': + cond = [True] * len(data) + else: + cond = data == selected + Visualize.display_discarded_ratio( + cond, column_name) + + elif column_name in min_cutoff_list: + label = f'If the {column_name} of a document \ + is lower than this number, \ + the document is removed.' + + low, high, step = get_min_max_step(data) + + cutoff_ratio = st.slider(label, + low, + high, + low, + step=step) + cond = data >= cutoff_ratio + Visualize.display_discarded_ratio( + cond, column_name) + + elif column_name in max_cutoff_list: + label = f'If the {column_name} of a document \ + is higher than this number, \ + the document is removed.' + + low, high, step = get_min_max_step(data) + cutoff_ratio = st.slider(label, + low, + high, + high, + step=step) + cond = data <= cutoff_ratio + + Visualize.display_discarded_ratio( + cond, column_name) + elif column_name not in mask_list: + # lower + label = f'If the {column_name} of a document \ + is lower than this number, \ + the document is removed.' + + low, high, step = get_min_max_step(data) + + cutoff_ratio_l = st.slider(label, + low, + high, + low, + step=step) + cond_l = data >= cutoff_ratio_l + + Visualize.display_discarded_ratio( + cond_l, column_name) + + # higher + label = f'If the {column_name} of a document \ + is higher than this number, \ + the document is removed.' + + cutoff_ratio_h = st.slider(label, + low, + high, + high, + step=step) + + cond_h = data <= cutoff_ratio_h + Visualize.display_discarded_ratio( + cond_h, column_name) + cond = [ + low & high + for low, high in zip(cond_l, cond_h) + ] + + cutoff_ratio = (cutoff_ratio_l, cutoff_ratio_h) + + if column_name not in mask_list: + Visualize.draw_hist(data, cutoff_ratio) + conds.append({ + (' '.join([str(index), op_name]), column_name): + cond + }) + + if ordered: + stats = stats.loc[cond] + index += 1 + return conds, stats + + st.sidebar.subheader('Parameters of filter ops') + ordered = st.sidebar.checkbox('Process by op order') + conds, filtered_stats = set_sliders(stats, ordered) + + st.subheader('How many samples do you want to show?') + show_num = st.number_input( + label='How many samples do you want to show?', + value=5, + label_visibility='hidden') + if ordered: + all_conds = [ + True if i in filtered_stats.index else False + for i in range(len(stats)) + ] + else: + all_conds = np.all([list(cond.values())[0] for cond in conds], + axis=0) + ds = pd.DataFrame(dataset) + Visualize.display_dataset(ds, all_conds, show_num, 'Retained sampels', + 'docs') + st.download_button('Download Retained data as JSONL', + data=convert_jsonl(ds.loc[all_conds]), + file_name='retained.jsonl') + Visualize.display_dataset(ds, np.invert(all_conds), show_num, + 'Discarded sampels', 'docs') + st.download_button('Download Discarded data as JSONL', + data=convert_jsonl(ds.loc[np.invert(all_conds)]), + file_name='discarded.jsonl') + display_discarded_details = st.checkbox( + 'Display discarded documents by filter details') + + show_stats = copy.deepcopy(stats) + bar_labels = [] + bar_sizes = [] + for item in conds: + for op_key, cond in item.items(): + op_name, column_name = op_key + if column_name not in mask_list: + sub_stats = show_stats[[column_name, 'text']] + if display_discarded_details: + Visualize.display_dataset( + sub_stats, + np.invert(cond) if len(cond) > 0 else [], + show_num, + # f'Discarded documents for the filter on \ + f'{op_name} {column_name} filtered ', + 'docs', + ) + before_filtered_num = len(show_stats.index) + if ordered: + show_stats = show_stats.loc[cond] + retained = np.sum(1 * cond) + filtered = before_filtered_num - len(show_stats.index) + else: + retained = np.sum(1 * cond) + filtered = before_filtered_num - retained + + bar_sizes.append(retained) + bar_sizes.append(filtered) + bar_labels.append(f'{op_name}\n{column_name}') + + bar_title = 'Effect of Filter OPs' + Visualize.draw_stack_bar(bar_sizes, bar_labels, len(stats.index), + bar_title) + + @staticmethod + def diversity(): + with st.expander('Diversity for sft dataset', expanded=False): + dataset = st.session_state.get('dataset', None) + cfg = st.session_state.get('cfg', parse_cfg()[2]) + if dataset: + + col1, col2, col3, col4 = st.columns(4) + with col1: + label = 'Which language of your dataset' + options = ['en', 'zh'] + lang_select = st.selectbox( + label=label, + options=options, + ) + with col2: + top_k_verbs = st.number_input( + 'Set the top_k nums of verbs', value=20) + with col3: + top_k_nouns = st.number_input( + 'Set the top_k nums of nouns', value=4) + with col4: + threshold = st.slider('Count threshold', + min_value=0, + value=32, + max_value=100, + step=1) + + disversity_btn = st.button('Analyse_diversity', + use_container_width=True) + output_path = os.path.join(os.path.dirname(cfg.export_path), + 'analysis') + raw_df = None + if disversity_btn: + try: + diversity_analysis = DiversityAnalysis( + dataset, output_path) + with st.spinner('Wait for analyze diversity...'): + raw_df = diversity_analysis.compute( + lang_or_model=get_diversity_model(lang_select), + column_name=cfg.text_key_to_process) + + st.session_state[f'diversity{lang_select}'] = raw_df + + except Exception as e: + st.warning(f'Error {str(e)} in {lang_select}') + else: + raw_df = st.session_state.get(f'diversity{lang_select}', + None) + + if raw_df is not None: + df = postproc_diversity(raw_df, + top_k_verbs=top_k_verbs, + top_k_nouns=top_k_nouns) + df = df[df['count'] >= threshold] + Visualize.draw_sunburst(df, + path=['verb', 'noun'], + values='count') + + st.download_button( + label='Download diversity data as CSV', + data=convert_csv(df), + file_name='diversity.csv', + mime='text/csv', + ) + else: + st.warning('Please analyze original data first') + + @staticmethod + def draw_sunburst(df, path, values): + + fig = px.sunburst(df, path=path, values=values) + fig.update_layout(margin=dict(l=0, r=0, t=0, b=0), + font_family='Times New Roman', + font=dict(size=40)) + st.plotly_chart(fig, use_container_width=True) + + @staticmethod + def draw_stack_bar(bar_sizes, bar_labels, total_num, title=''): + filtered_size = [ + k / total_num * 100 for i, k in enumerate(bar_sizes[::-1]) + if i % 2 == 0 + ] + retain_size = [ + k / total_num * 100 for i, k in enumerate(bar_sizes[::-1]) + if i % 2 != 0 + ] + plt.clf() + plt.title(title) + bar_labels = bar_labels[::-1] + # retained + r_bars = plt.barh(bar_labels, + retain_size, + label='Retained', + height=0.5, + color='limegreen') + + # filtered + f_bars = plt.barh(bar_labels, + filtered_size, + label='Filtered', + left=retain_size, + height=0.5, + color='orangered') + + for idx, bar in enumerate(r_bars): + width = bar.get_width() + plt.text(bar.get_x() + width / 2, + bar.get_y() + bar.get_height() / 2, + f'{retain_size[idx]:.2f}%', + ha='center', + va='center') + + for idx, bar in enumerate(f_bars): + width = bar.get_width() + plt.text(bar.get_x() + width / 2, + bar.get_y() + bar.get_height() / 2, + f'{filtered_size[idx]:.2f}%', + ha='center', + va='center') + + plt.legend() + plt.gcf() + st.pyplot(plt, use_container_width=True) + + @staticmethod + def draw_pie(bar_labels, big_sizes, small_labels, bar_sizes): + plt.clf() + + # filter op circle + plt.pie(big_sizes, labels=bar_labels, startangle=90, frame=True) + # retained and filtered circle + plt.pie(bar_sizes, + labels=small_labels, + radius=0.7, + rotatelabels=True, + startangle=90, + labeldistance=0.7) + centre_circle = plt.Circle((0, 0), 0.4, color='white', linewidth=0) + fig = plt.gcf() + fig.gca().add_artist(centre_circle) + + plt.axis('equal') + plt.tight_layout() + st.pyplot(plt, use_container_width=True) + + @staticmethod + def display_discarded_ratio(cond, key): + if len(cond) > 0: + st.caption( + f':red[{(len(cond) - np.sum(1*cond)) / len(cond) * 100:.2f}%] \ + of the total (:red[{len(cond)}]) is discarded with {key}.') + else: + st.caption(f':red[{0:.2f}%] \ + of the total (:red[0]) is discarded with {key}.') + + @staticmethod + def display_dataset(dataframe, cond, show_num, desp, type, all=True): + examples = dataframe.loc[cond] + if all or len(examples) > 0: + st.subheader( + f'{desp}: :red[{len(examples)}] of ' + f'{len(dataframe.index)} {type} ' + f'(:red[{len(examples)/len(dataframe.index) * 100:.2f}%])') + + # st.markdown('Click on a column to sort by it, \ + # place the cursor on the text to display it.') + st.dataframe(examples[:show_num], use_container_width=True) + + @staticmethod + def draw_hist(data, cutoff=None): + + fig, ax = plt.subplots() + data_num = len(data) + if data_num >= 100: + rec_bins = int(math.sqrt(len(data))) + else: + rec_bins = 50 + + if data_num > 0: + ax.hist(data, bins=rec_bins, density=True) + if hasattr(data, 'name'): + ax.set_title(data.name) + + if isinstance(cutoff, (float, int)): + ax.axvline(x=cutoff, color='r', linestyle='dashed') + elif isinstance(cutoff, tuple) and len(cutoff) == 2: + ax.axvline(x=cutoff[0], color='r', linestyle='dashed') + ax.axvline(x=cutoff[1], color='r', linestyle='dashed') + st.pyplot(fig) + + @staticmethod + def setup(): + st.set_page_config( + page_title='Data-Juicer', + page_icon=':smile', + layout='wide', + # initial_sidebar_state="expanded", + ) + + readme_link = 'https://code.alibaba-inc.com/DAIL-LLM/' \ + 'data_juicer/blob/master/README.md' + + st.markdown( + '
Data-Juicer ' + '
', + unsafe_allow_html=True, + ) + st.markdown( + f'
A Dataset Preparation System for Large Models, \ + see more detail in Readme
', + unsafe_allow_html=True, + ) + + @staticmethod + def parser(): + with st.expander('Configuration', expanded=True): + st.markdown('Please specify the cfg via ' + '(i) specifying the cfg file path with commands or ' + '(ii) uploading the cfg file.') + + col1, col2 = st.columns(2) + with col1: + example_cfg_f = os.path.abspath( + os.path.join(os.path.dirname(__file__), + './configs/demo.yaml')) + st.text_area(label='(i) Input Cfg Commands', + key='input_cfg_cmd', + value=f'--config {example_cfg_f}') + example_my_cmd = '--dataset_path ./demo/demo-dataset.jsonl ' \ + '--export_path '\ + './outputs/demo/demo-processed.jsonl' + + st.text_area( + label='cmd example. (the cmd-args will override ' + 'yaml-file-args)', + disabled=True, + value=f'--config {example_cfg_f} {example_my_cmd}') + + with col2: + st.file_uploader(label='(ii) Input Cfg File', + key='input_cfg_file', + type=['yaml']) + + btn_show_cfg = st.button('1. Parse Cfg', use_container_width=True) + if btn_show_cfg: + text1, text2, cfg = parse_cfg() + st.session_state.cfg_text1 = text1 + st.session_state.cfg_text2 = text2 + + else: + text1 = st.session_state.get('cfg_text1', '') + text2 = st.session_state.get('cfg_text2', '') + + col3, col4 = st.columns(2) + with col3: + st.text_area(label='Parsed Cfg (in memory)', value=text1) + with col4: + st.text_area(label='Specified Cfg (in yaml file)', value=text2) + + @staticmethod + def analyze_process(): + start_btn = st.button( + '2. Start to analyze original data (per filter op)', + use_container_width=True) + start_btn_process = st.button('3. Start to process data', + use_container_width=True) + + # with st.expander('Log', expanded=False): + # logs = st.Textbox(show_label=False) + # demo.load(read_log_file, inputs=None, outputs=logs, every=1) + + with st.expander('Data Analysis Results', expanded=False): + + if start_btn: + with st.spinner('Wait for analyze...'): + analyze_and_show_res() + + if start_btn_process: + with st.spinner('Wait for process...'): + process_and_show_res() + + orginal_overall = st.session_state.get('orginal_overall', None) + original_imgs = st.session_state.get('original_imgs', []) + processed_overall = st.session_state.get('processed_overall', None) + processed_imgs = st.session_state.get('processed_imgs', []) + + col1, col2 = st.columns(2) + with col1: + st.caption('Original Data') + st.dataframe(orginal_overall, use_container_width=True) + for img in original_imgs: + st.image(img, output_format='png') + + with col2: + st.caption('Processed Data') + st.dataframe(processed_overall, use_container_width=True) + for img in processed_imgs: + st.image(img, output_format='png') + + @staticmethod + def filter(): + with st.expander('Effect of Filter OPs', expanded=False): + dataset = st.session_state.get('dataset', None) + if dataset: + Visualize.filter_dataset(dataset) + else: + st.warning('Please analyze original data first') + + @staticmethod + def auxiliary(): + st.markdown('[WIP] Auxiliary Models on Processed Data') + col1, col2 = st.columns(2) + with col1: + with st.expander('Quality Scorer', expanded=False): + wiki_socre_btn = st.button('Run Wiki-score classifier', + use_container_width=True) + + if wiki_socre_btn: + st.warning('No support for now') + + wikibook_score_btn = st.button('Run WikiBook-score classifier', + use_container_width=True) + if wikibook_score_btn: + st.warning('No support for now') + + with col2: + with st.expander('[WIP] Proxy LM Models Training', expanded=False): + st.file_uploader(label='LM Training Cfg File', type=['yaml']) + st.button('Train proxy model') + st.markdown('[Training Monitoring](http://' + '8.130.26.137:8083/dail/' + 'llama-re-2nd?workspace=user-dail)') + + @staticmethod + def visualize(): + Visualize.setup() + Visualize.parser() + Visualize.analyze_process() + Visualize.filter() + Visualize.diversity() + Visualize.auxiliary() + + +def main(): + Visualize.visualize() + + +if __name__ == '__main__': + main() diff --git a/configs/.DS_Store b/configs/.DS_Store new file mode 100644 index 000000000..824f46ab1 Binary files /dev/null and b/configs/.DS_Store differ diff --git a/configs/README.md b/configs/README.md new file mode 100644 index 000000000..e800b645b --- /dev/null +++ b/configs/README.md @@ -0,0 +1,32 @@ +# Config Files + +This folder contains some configuration files to allow users to easily understand the configuration methods of various functions and quickly reproduce the processing flow of different datasets. + +## Usage + +```shell +# To process your dataset. +python tools/process_data.py --config xxx.yaml +# To analyse your dataset. +python tools/analyze_data.py --config xxx.yaml +``` + +## Categories + +The current configuration files are classified into the subsequent categories. + +### Demo + +Demo configuration files are used to help users quickly familiarize the basic functions of Data-Juicer. Please refer to the [demo](demo) folder for details. + + +### Redpajama + +We have reproduced the processing flow of some redpajama datasets. Please refer to the [redpajama](redpajama) folder for details. + +### Bloom + +We have reproduced the processing flow of some bloom datasets. please refer to the [bloom](bloom) folder for details. + +### Refine_recipe +We have refined some open source datasets (including SFT datasets) by using Data-Juicer and have provided configuration files for the refine flow. please refer to the [refine_recipe](refine_recipe) folder for details. \ No newline at end of file diff --git a/configs/README_ZH.md b/configs/README_ZH.md new file mode 100644 index 000000000..4132cff5a --- /dev/null +++ b/configs/README_ZH.md @@ -0,0 +1,33 @@ +# 配置文件 + +此文件夹包含一些配置文件,帮助用户轻松理解各种功能的配置方法,并快速复现开源数据集的处理流程。 + +## 用法 + +```shell +#处理数据集 +python tools/process_data.py --config xxx.yaml + +#分析数据集 +python tools/analyze_data.py --config xxx.yaml +``` + +## 分类 + +配置文件分为以下几类。 + +### Demo + +Demo 配置文件用于帮助用户快速熟悉 Data-Juicer 的基本功能,请参阅 [demo](demo) 文件夹以获取详细说明。 + + +### Redpajama + +我们已经复现了部分 Redpajama 数据集的处理流程,请参阅 [redpajama](redpajama) 文件夹以获取详细说明。 + +### Bloom + +我们已经重现了部分 Bloom 数据集的处理流程,请参阅 [bloom](bloom) 文件夹以获取详细说明。 + +### Refine_recipe +我们使用 Data-Juicer 更细致地处理了一些开源数据集(包含 SFT 数据集),并提供了处理流程的配置文件。请参阅 [refine_recipe](refine_recipe) 文件夹以获取详细说明。 \ No newline at end of file diff --git a/configs/bloom/README.md b/configs/bloom/README.md new file mode 100644 index 000000000..f83d1f6cf --- /dev/null +++ b/configs/bloom/README.md @@ -0,0 +1,8 @@ +# Bloom Config Files + +This folder contains example configuration files to easily and quickly reproduce the processing flow of the [ROOTS](https://github.com/bigscience-workshop/data-preparation) dataset, created by the BigScience initiative to train the BLOOM models. + +## Oscar +The raw data files can be downloaded as described in [Bloom/Oscar](https://github.com/bigscience-workshop/data-preparation/tree/main/preprocessing/training/01b_oscar_cleaning_and_filtering). Then use [bloom-oscar.yaml](bloom-oscar.yaml) to perform the whole processing. + +An analysis of our reproduction will be published soon. diff --git a/configs/bloom/README_ZH.md b/configs/bloom/README_ZH.md new file mode 100644 index 000000000..49df48975 --- /dev/null +++ b/configs/bloom/README_ZH.md @@ -0,0 +1,9 @@ +# Bloom 配置文件 + +此文件夹包含的配置文件用于轻松复现 [ROOTS](https://github.com/bigscience-workshop/data-preparation) 的处理流程,该数据集由 BigScience 创建并用于训练 BLOOM 模型。 + +## Oscar + +原始文件可以参照 [Bloom/Oscar](https://github.com/bigscience-workshop/data-preparation/tree/main/preprocessing/training/01b_oscar_cleaning_and_filtering) 下载,然后使用 [bloom-oscar.yaml](bloom-oscar.yaml) 进行完整的处理流程。 + +对我们复现结果的分析将在稍后发布。 diff --git a/configs/bloom/bloom-oscar.yaml b/configs/bloom/bloom-oscar.yaml new file mode 100644 index 000000000..b9b9f0860 --- /dev/null +++ b/configs/bloom/bloom-oscar.yaml @@ -0,0 +1,58 @@ +# Process config example for Oscar used in BLOOM + +# global parameters +project_name: 'bloom_oscar' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +np: 4 # number of subprocess to process your dataset + +export_path: '/path/to/result/dataset.jsonl' + +# process schedule +# a list of several process operators with their arguments +process: + # filter English corpus + - language_id_score_filter: + lang: en + min_score: 0.8 + + # basic process for regular English text + - whitespace_normalization_mapper: + - punctuation_normalization_mapper: + - fix_unicode_mapper: + - remove_words_with_incorrect_substrings_mapper: + - remove_long_words_mapper: + max_len: 25 + + # basic filter rules for regular English text + - words_num_filter: + min_num: 20 + max_num: 100000 + - character_repetition_filter: + rep_len: 10 + min_ratio: 0.0 + max_ratio: 0.106 + - word_repetition_filter: + rep_len: 5 + min_ratio: 0.0 + max_ratio: 0.19 + - special_characters_filter: + min_ratio: 0.0 + max_ratio: 0.4 + - stopwords_filter: + lang: en + min_ratio: 0.3 + - flagged_words_filter: + lang: en + max_ratio: 0.01 + - perplexity_filter: + lang: en + max_ppl: 1500 + + # basic deduplication rules for regular English text + - document_simhash_deduplicator: + tokenization: space + window_size: 6 + lowercase: true + ignore_pattern: '\p{P}' + num_blocks: 6 + hamming_distance: 4 diff --git a/configs/config_all.yaml b/configs/config_all.yaml new file mode 100644 index 000000000..64882b183 --- /dev/null +++ b/configs/config_all.yaml @@ -0,0 +1,148 @@ +# Process config example including: +# - all global arguments +# - all ops and their default arguments + +# global parameters +project_name: 'all' # project name for distinguish your configs +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file with weights(0.0-1.0), 1.0 as default. + # Accepted format: 'weight1(optional) dataset1-path weight2(optional) dataset2-path ' +export_path: '/path/to/result/dataset.jsonl' # path to processed result dataset. Supported suffixes include ['jsonl', 'json', 'parquet'] +export_shard_size: 0 # Shard size of exported dataset in Byte. In default, it's 0, which means export the whole dataset into only one file. If it's set a positive number, the exported dataset will be split into several dataset shards, and the max size of each shard won't larger than the export_shard_size +np: 4 # number of subprocess to process your dataset +text_key_to_process: 'content' # the key name of field where the sample texts to be processed, e.g., `text`, `text.instruction`, `text.output`, ...' + # Note: currently, we support specify only ONE key for each op, for cases requiring multiple keys, users can specify the op multiple times +text_keys_to_load: # the key name of field where the sample texts stored in the original data + - 'text' +suffixes: [] # the suffix of files that will be read. For example: '.txt', 'txt' or ['txt', '.pdf', 'docx'] +use_cache: true # whether to use the cache management of hugging face datasets. It might take up lots of disk space when using cache +ds_cache_dir: '~/.cache/huggingface/datasets' # cache dir for hugging face datasets. In default it's the default cache dir "~/.cache/huggingface/datasets". If this argument is reset by users, it will override the default cache dir +use_checkpoint: false # whether to use the checkpoint management to save the latest version of dataset to work dir when processing. Rerun the same config will reload the checkpoint and skip ops before it. Cache will be disabled when using checkpoint. If args of ops before the checkpoint are changed, all ops will be rerun from the beginning. +temp_dir: null # the path to the temp directory to store intermediate caches when cache is disabled, these cache files will be removed on-the-fly. In default, it's None, so the temp dir will be specified by system. NOTICE: you should be caution when setting this argument because it might cause unexpected program behaviors when this path is set to an unsafe directory. +open_tracer: false # whether to open the tracer to trace the changes during process. It might take more time when opening tracer +op_list_to_trace: [] # only ops in this list will be traced by tracer. If it's empty, all ops will be traced. Only available when tracer is opened. +trace_num: 10 # number of samples to show the differences between datasets before and after each op. Only available when tracer is opened. + +# only for data analysis +save_stats_in_one_file: false # whether to store all stats result into one file + +# process schedule: a list of several process operators with their arguments +process: + # Mapper ops. Most of these ops need no arguments. + - clean_email_mapper: # remove emails from text. + - clean_html_mapper: # remove html formats form text. + - clean_ip_mapper: # remove ip addresses from text. + - clean_links_mapper: # remove web links from text. + - clean_copyright_mapper: # remove copyright comments. + - expand_macro_mapper: # expand macro definitions in Latex text. + - fix_unicode_mapper: # fix unicode errors in text. + - punctuation_normalization_mapper: # normalize unicode punctuations to English punctuations. + - remove_bibliography_mapper: # remove bibliography from Latex text. + - remove_comments_mapper: # remove comments from Latex text, code, etc. + doc_type: tex # comment type you want to remove. Only support 'tex' for now. + inline: true # whether to remove inline comments + multiline: true # whether to remove multiline comments + - remove_header_mapper: # remove header texts from Latex text. + drop_no_head: true # whether to drop sample texts without headers + - remove_long_words_mapper: # remove much too long words from text. + min_len: 1 # the min word length to keep words. + max_len: 128 # the max word length to keep words. + - remove_specific_chars_mapper: # remove characters specified by users + chars_to_remove: '◆●■►▼▲▴∆▻▷❖♡□' # a string or a list including those characters that need to be removed + - remove_table_text_mapper: # remove possible table texts from text. + min_col: 2 # the min num of columns in tables to remove + max_col: 20 # the max num of columns in tables to remove + - remove_words_with_incorrect_substrings_mapper: # remove words with incorrect substrings from text. + lang: en # sample in which language + tokenization: false # whether to use model to tokenize documents + substrings: ['http', 'www', '.com', 'href', '//'] # incorrect substrings to remove + - sentence_split_mapper: # split text to sentences and join them with '\n' + lang: 'en' # split text in what language + - whitespace_normalization_mapper: # normalize different kinds of whitespaces to English whitespace. + + # Filter ops + - alphanumeric_filter: # filter text with alphabet/numeric ratio out of specific range. + tokenization: false # Whether to count the ratio of alphanumeric to the total number of tokens. + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.9 # the max ratio of filter range + - average_line_length_filter: # filter text with the average length of lines out of specific range. + min_len: 10 # the min length of filter range + max_len: 10000 # the max length of filter range + - character_repetition_filter: # filter text with the character repetition ratio out of specific range + rep_len: 10 # repetition length for char-level n-gram + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.5 # the max ratio of filter range + - flagged_words_filter: # filter text with the flagged-word ratio larger than a specific max value + lang: en # consider flagged words in what language + tokenization: false # whether to use model to tokenize documents + max_ratio: 0.0045 # the max ratio to filter text + flagged_words_dir: ./assets # directory to store flagged words dictionaries + use_words_aug: false # whether to augment words, especially for Chinese and Vietnamese + words_aug_group_sizes: [2] # the group size of words to augment + words_aug_join_char: "" # the join char between words to augment + - language_id_score_filter: # filter text in specific language with language scores larger than a specific max value + lang: en # keep text in what language + min_score: 0.8 # the min language scores to filter text + - maximum_line_length_filter: # filter text with the maximum length of lines out of specific range + min_len: 10 # the min length of filter range + max_len: 10000 # the max length of filter range + - perplexity_filter: # filter text with perplexity score out of specific range + lang: en # compute perplexity in what language + max_ppl: 1500 # the max perplexity score to filter text + - special_characters_filter: # filter text with special-char ratio out of specific range + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.25 # the max ratio of filter range + - stopwords_filter: # filter text with stopword ratio smaller than a specific min value + lang: en # consider stopwords in what language + tokenization: false # whether to use model to tokenize documents + min_ratio: 0.3 # the min ratio to filter text + stopwords_dir: ./assets # directory to store stopwords dictionaries + use_words_aug: false # whether to augment words, especially for Chinese and Vietnamese + words_aug_group_sizes: [2] # the group size of words to augment + words_aug_join_char: "" # the join char between words to augment + - text_length_filter: # filter text with length out of specific range + min_len: 10 # the min length of filter range + max_len: 10000 # the max length of filter range + - words_num_filter: # filter text with number of words out of specific range + lang: en # sample in which language + tokenization: false # whether to use model to tokenize documents + min_num: 10 # the min number of filter range + max_num: 10000 # the max number of filter range + - word_repetition_filter: # filter text with the word repetition ratio out of specific range + lang: en # sample in which language + tokenization: false # whether to use model to tokenize documents + rep_len: 10 # repetition length for word-level n-gram + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.5 # the max ratio of filter range + - suffix_filter: # filter to keep samples with specified suffix. + suffixes: [] # the suffix of text that will be keep. For example: '.txt', 'txt' or ['txt', '.pdf', 'docx'] + - specified_field_filter: # filter text with the specified field info out of specific range + text_key: '' # the target key corresponding to multi-level field information need to be separated by '.' + target_value: [] # the range of specified field information corresponding to the samples that need to be retained + - specified_numeric_field_filter: # filter text with the specified numeric field info out of specific range + text_key: '' # the target key corresponding to multi-level field information need to be separated by '.' + min_value: 0 # the min filter value in SpecifiedNumericField op + max_value: 10000 # the max filter value in SpecifiedNumericField op + + # Deduplicator ops + - document_deduplicator: # deduplicate text samples using md5 hashing exact matching method + lowercase: false # whether to convert text to lower case + ignore_non_character: false # whether to ignore non-alphabet characters, including whitespaces, digits, and punctuations + - document_simhash_deduplicator: # deduplicate text samples using SimHash-LSH method + tokenization: space # tokenization method for text. One of [space, punctuation, character] + window_size: 6 # window size of shingling + num_blocks: 6 # number of blocks in SimHash computing + hamming_distance: 4 # the max hamming distance to regard 2 samples as similar enough pair. Should be less than num_blocks always + lowercase: true # whether to convert text to lower case + ignore_pattern: null # whether to ignore sub-strings with specific pattern when computing simhash. + + # Selector ops + - topk_specified_field_selector: # selector to select top samples based on the sorted specified field + text_key: '' # the target keys corresponding to multi-level field information need to be separated by '.' + top_ratio: # ratio of selected top samples + topk: # number of selected top sample + reverse: True # determine the sorting rule, if reverse=True, then sort in descending order + - frequency_specified_field_selector: # selector to select samples based on the sorted frequency of specified field value + text_key: '' # the target keys corresponding to multi-level field information need to be separated by '.' + top_ratio: # ratio of selected top specified field value + topk: # number of selected top specified field value + reverse: True # determine the sorting rule, if reverse=True, then sort in descending order diff --git a/configs/demo/analyser.yaml b/configs/demo/analyser.yaml new file mode 100644 index 000000000..3d1e1e40c --- /dev/null +++ b/configs/demo/analyser.yaml @@ -0,0 +1,17 @@ +# Process config example for dataset + +# global parameters +project_name: 'demo-analyser' +dataset_path: './demos/data/demo-dataset.jsonl' # path to your dataset directory or file +np: 4 # number of subprocess to process your dataset + +export_path: './outputs/demo-analyser/demo-analyser-result.jsonl' + +# process schedule +# a list of several process operators with their arguments +process: + - language_id_score_filter: + lang: 'en' + - perplexity_filter: + lang: 'en' + max_ppl: 1500 diff --git a/configs/demo/dedup.yaml b/configs/demo/dedup.yaml new file mode 100644 index 000000000..991302d9f --- /dev/null +++ b/configs/demo/dedup.yaml @@ -0,0 +1,22 @@ +# Process config example for dataset + +# global parameters +project_name: 'demo-dedup' +dataset_path: './demos/data/demo-dataset-deduplication.jsonl' # path to your dataset directory or file +np: 4 # number of subprocess to process your dataset + +open_tracer: true + +export_path: './outputs/demo-dedup/demo-dedup-processed.jsonl' + +# process schedule +# a list of several process operators with their arguments +process: + - language_id_score_filter: + lang: en + min_score: 0.5 +# - document_deduplicator: +# lowercase: false +# ignore_non_character: false + - document_minhash_deduplicator: + tokenization: 'character' diff --git a/configs/demo/process.yaml b/configs/demo/process.yaml new file mode 100644 index 000000000..93aa95698 --- /dev/null +++ b/configs/demo/process.yaml @@ -0,0 +1,14 @@ +# Process config example for dataset + +# global parameters +project_name: 'demo-process' +dataset_path: './demos/data/demo-dataset.jsonl' # path to your dataset directory or file +np: 4 # number of subprocess to process your dataset + +export_path: './outputs/demo-process/demo-processed.jsonl' + +# process schedule +# a list of several process operators with their arguments +process: + - language_id_score_filter: + lang: 'zh' diff --git a/configs/redpajama/README.md b/configs/redpajama/README.md new file mode 100644 index 000000000..3b317cb9e --- /dev/null +++ b/configs/redpajama/README.md @@ -0,0 +1,96 @@ +# Redpajama Config Files + +This folder contains example configuration files to easily and quickly reproduce the processing flow of [Redpajama](https://github.com/togethercomputer/RedPajama-Data/tree/main/data_prep). + +## Arxiv +The raw data files can be downloaded from the same AWS link as in [Redpajama/Arxiv](https://github.com/togethercomputer/RedPajama-Data/tree/main/data_prep/arxiv). + +Once downloaded, use [raw_arxiv_to_jsonl.py](../../tools/preprocess/raw_arxiv_to_jsonl.py) to convert from the original format to `jsonl` that data-juicer can handle easily: + +```shell +python tools/preprocess/raw_arxiv_to_jsonl.py \ + --arxiv_src_dir \ + --target_dir \ + --temp_dir \ + --num_proc +``` + +After conversion, modify the path configurations in [redpajama-arxiv.yaml](redpajama-arxiv.yaml) and execute the following command to reproduce the processing flow of redpajama: + +```shell +python tools/process_data.py --config configs/redpajama/redpajama-arxiv.yaml +``` + +### Comparison + +| | num_samples | num_tokens | peak_memory | wall_time | +| --- | :---: | :---: | :---: | --- | +| redpajama | 1,724,497 | 30,667,506,934 | 35GB |`total: 11h52min` | +| data-juicer | 2,675,426| 30,338,153,178 | 21GB | preprocess: 5h21min
read+unify: 25min
remove_header_mapper: 5min
remove_comments_mapper: 3min
remove_bibliography_mapper: 4min
expand_macro_mapper: 5min19s
text_length_filter: 4min
export: 43min
`total: 6h53min` | + +## Books + +The raw data files can be downloaded from the same HuggingFace datasets as in [Redpajama/Books](https://github.com/togethercomputer/RedPajama-Data/tree/main/data_prep/book). + +Once downloaded, modify the path configurations in [redpajama-books.yaml](redpajama-books.yaml) and execute the following command to reproduce the processing flow of redpajama. + +```shell +python tools/process_data.py --config configs/redpajama/redpajama-books.yaml +``` + +### Comparison + +| | num_samples | num_tokens | peak_memory | wall_time | +| --- | :---: | :---: | :---: | --- | +| redpajama | 205,183 | 25,962,395,123 | 450GB | split_for_dedup: 5min
dedup: 117min
`total: 122min` | +| data-juicer | 207,902 | 26,108,635,683 | 96GB | read+unify: 20min
compute_hash: 78min
dedup: 3min
export: 3min
`total: 114min` | + +## Code + +The raw data files can be downloaded from Google BigQuery as in [Redpajama/Code](https://github.com/togethercomputer/RedPajama-Data/tree/main/data_prep/github). + +Once downloaded, unzip and delete files whose extensions are not in the following whitelist: + +```text +.asm, .bat, .cmd, .c, .h, .cs, .cpp, .hpp, .c++, .h++, .cc, .hh, .C, .H, .cmake, .css, .dockerfile, .f90, .f, .f03, .f08, .f77, .f95, .for, .fpp, .go, .hs, .html, .java, .js, .jl, .lua, .md, .markdown, .php, .php3, .php4, .php5, .phps, .phpt, .pl, .pm, .pod, .perl, ps1, .psd1, .psm1, .py, .rb, .rs, .sql, .scala, .sh, .bash, .command, .zsh, .ts, .tsx, .tex, .vb, Dockerfile, Makefile, .xml, .rst, .m, .smali +``` + +After preparation, modify the path configurations in [redpajama-code.yaml](redpajama-code.yaml) and execute the following command to reproduce the processing flow of redpajama: + +```shell +python tools/process_data.py --config configs/redpajama/redpajama-code.yaml +``` + +### Comparison + +| | num_samples | num_tokens | peak_memory | wall_time | +| --- | :---: | :---: | :---: | --- | +| redpajama | 73,208,524 | 150,390,270,060| 212GB | local-dedup: 37h
global-dedup: 1h
merge-dedup: 6h
filter: 17h
`total: 61h` | +| data-juicer | 73,169,889| 150,310,903,230| 370GB | preprocess: 5h21min
read+unify: 12h
document_deduplicator: 20h
clean_copyright_mappe: 3h
maximum_line_length_filter: 2.5h
average_line_length_filter: 2h
alphanumeric_filter: 13h
export: 2.5h
`total: 59h` | + +## StackExchange + +The raw data files can be downloaded from the same Archive link as in [Redpajama/Stack_exchange](https://github.com/togethercomputer/RedPajama-Data/tree/main/data_prep/stack_exchange). + +Once downloaded, use [raw_stackexchange_to_jsonl.py](../../tools/preprocess/raw_stackexchange_to_jsonl.py) to convert from the original format to `jsonl` that data-juicer can handle easily: + +```shell +python tools/preprocess/raw_arxiv_stackexchange_to_jsonl.py \ + --src_dir \ + --target_dir \ + --topk \ + --num_proc \ +``` + +After conversion, modify the path configurations in [redpajama-stackexchange.yaml](redpajama-stackexchange.yaml) and execute the following command to reproduce the processing flow of redpajama: + +```shell +python tools/process_data.py --config configs/redpajama/redpajama-stackexchange.yaml +``` + +### Comparison + +| | num_samples | num_tokens | peak_memory | wall_time | +| --- | :---: | :---: | :---: | --- | +| redpajama | 29,825,086 | 20,502,757,123 | >500GB | filter: 170min
postprocess: 90min
`total: 260min` | +| data-juicer | 29,825,086 | 20,628,082,262 | 100GB | preprocess: 210min
read+unify: 86min
clean_html: 15min
language_id_score_filter: 18min
`total: 391min` | diff --git a/configs/redpajama/README_ZH.md b/configs/redpajama/README_ZH.md new file mode 100644 index 000000000..8b942440b --- /dev/null +++ b/configs/redpajama/README_ZH.md @@ -0,0 +1,97 @@ +# Redpajama 配置文件 + +此文件夹包含的配置文件用于轻松复现 [Redpajama](https://github.com/togethercomputer/RedPajama-Data/tree/main/data_prep) 的处理流程。 + +## Arxiv + +原始数据文件从 [Redpajama/Arxiv](https://github.com/togethercomputer/RedPajama-Data/tree/main/data_prep/arxiv) 中相同的 AWS 链接下载。 + +下载完成后,使用 [raw_arxiv_to_jsonl.py](../../tools/preprocess/raw_arxiv_to_jsonl.py) 将原始格式转换为 data-juicer 易于处理的格式: + +```shell +python tools/preprocess/raw_arxiv_to_jsonl.py \ + --arxiv_src_dir \ + --target_dir \ + --temp_dir \ + --num_proc +``` + +预处理完成后,修改 [redpajama-arxiv.yaml](redpajama-arxiv.yaml) 中的数据路径,执行以下命令复现 redpajama 的处理流程: + +```shell +python tools/process_data.py --config configs/redpajama/redpajama-arxiv.yaml +``` + +### 指标对比 + +| | 样本数 | 令牌数 | 峰值内存 | 运行时间 | +| --- | :---: | :---: | :---: | --- | +| redpajama | 1,724,497 | 30,667,506,934 | 35GB |`total: 11h52min` | +| data-juicer | 2,675,426| 30,338,153,178 | 21GB | preprocess: 5h21min
read+unify: 25min
remove_header_mapper: 5min
remove_comments_mapper: 3min
remove_bibliography_mapper: 4min
expand_macro_mapper: 5min19s
text_length_filter: 4min
export: 43min
`total: 6h53min` | + +## Books + +原始数据文件从 [Redpajama/Books](https://github.com/togethercomputer/RedPajama-Data/tree/main/data_prep/book) 中相同的 HuggingFace 链接下载。 + +下载完成后,修改 [redpajama-books.yaml](redpajama-books.yaml) 中的数据路径,执行以下命令复现 redpajama 的处理流程: + +```shell +python tools/process_data.py --config configs/redpajama/redpajama-books.yaml +``` + +### 指标对比 + +| | 样本数 | 令牌数 | 峰值内存 | 运行时间 | +| --- | :---: | :---: | :---: | --- | +| redpajama | 205,183 | 25,962,395,123 | 450GB | split_for_dedup: 5min
dedup: 117min
`total: 122min` | +| data-juicer | 207,902 | 26,108,635,683 | 96GB | read+unify: 20min
compute_hash: 78min
dedup: 3min
export: 3min
`total: 114min` | + +## Code + +原始数据文件从 [Redpajama/Code](https://github.com/togethercomputer/RedPajama-Data/tree/main/data_prep/github) 中相同的 Google BigQuery 获取。 + +下载完成后,解压缩并删除扩展名不在以下白名单中的其他文件: + +```text +.asm, .bat, .cmd, .c, .h, .cs, .cpp, .hpp, .c++, .h++, .cc, .hh, .C, .H, .cmake, .css, .dockerfile, .f90, .f, .f03, .f08, .f77, .f95, .for, .fpp, .go, .hs, .html, .java, .js, .jl, .lua, .md, .markdown, .php, .php3, .php4, .php5, .phps, .phpt, .pl, .pm, .pod, .perl, ps1, .psd1, .psm1, .py, .rb, .rs, .sql, .scala, .sh, .bash, .command, .zsh, .ts, .tsx, .tex, .vb, Dockerfile, Makefile, .xml, .rst, .m, .smali +``` + +修改 [redpajama-code.yaml](redpajama-code.yaml) 中的数据路径,执行以下命令复现 redpajama 的处理流程: + +```shell +python tools/process_data.py --config configs/redpajama/redpajama-code.yaml +``` + +### 指标对比 + +| | 样本数 | 令牌数 | 峰值内存 | 运行时间 | +| --- | :---: | :---: | :---: | --- | +| redpajama | 73,208,524 | 150,390,270,060| 212GB | local-dedup: 37h
global-dedup: 1h
merge-dedup: 6h
filter: 17h
`total: 61h` | +| data-juicer | 73,169,889| 150,310,903,230| 370GB | preprocess: 5h21min
read+unify: 12h
document_deduplicator: 20h
clean_copyright_mappe: 3h
maximum_line_length_filter: 2.5h
average_line_length_filter: 2h
alphanumeric_filter: 13h
export: 2.5h
`total: 59h` | + +## StackExchange + +原始数据文件从 [Redpajama/Stack_exchange](https://github.com/togethercomputer/RedPajama-Data/tree/main/data_prep/stack_exchange) 中相同的 Archive 链接获取。 + +下载完成后,使用 [raw_stackexchange_to_jsonl.py](../../tools/preprocess/raw_stackexchange_to_jsonl.py) 将原始格式转换为 data-juicer 易于处理的格式: + +```shell +python tools/preprocess/raw_arxiv_stackexchange_to_jsonl.py \ + --src_dir \ + --target_dir \ + --topk \ + --num_proc \ +``` + +预处理完成后,修改 [redpajama-stackexchange.yaml](redpajama-stackexchange.yaml) 中的数据路径,执行以下命令复现 redpajama 的处理流程: + +```shell +python tools/process_data.py --config configs/redpajama/redpajama-stackexchange.yaml +``` + +### 指标对比 + +| | 样本数 | 令牌数 | 峰值内存 | 运行时间 | +| --- | :---: | :---: | :---: | --- | +| redpajama | 29,825,086 | 20,502,757,123 | >500GB | filter: 170min
postprocess: 90min
`total: 260min` | +| data-juicer | 29,825,086 | 20,628,082,262 | 100GB | preprocess: 210min
read+unify: 86min
clean_html: 15min
language_id_score_filter: 18min
`total: 391min` | diff --git a/configs/redpajama/redpajama-arxiv.yaml b/configs/redpajama/redpajama-arxiv.yaml new file mode 100644 index 000000000..0885fdfed --- /dev/null +++ b/configs/redpajama/redpajama-arxiv.yaml @@ -0,0 +1,22 @@ +# Process config example for Arxiv dataset + +# global parameters +project_name: 'Arxiv' +dataset_path: '/path/to/your/dataset/dir/or/file' # path to your dataset directory or file +np: 32 # number of subprocess to process your dataset + +export_path: '/path/to/your/exported/dataset/file' + +# process schedule +# a list of several process operators with their arguments +process: + - remove_header_mapper: + drop_no_head: true + - remove_comments_mapper: + doc_type: ['md', 'tex'] + inline: true + multiline: true + - remove_bibliography_mapper: + - expand_macro_mapper: + - text_length_filter: + min_len: 1 diff --git a/configs/redpajama/redpajama-books.yaml b/configs/redpajama/redpajama-books.yaml new file mode 100644 index 000000000..49b82eac4 --- /dev/null +++ b/configs/redpajama/redpajama-books.yaml @@ -0,0 +1,19 @@ +# Process config example for Books used in RedPajam + +# global parameters +project_name: 'RedPajam-books' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +np: 4 # number of subprocess to process your dataset + +export_path: '/path/to/result/dataset.jsonl' + +# process schedule +# a list of several process operators with their arguments +process: + - document_simhash_deduplicator: + tokenization: character + window_size: 6 + lowercase: True + ignore_pattern: '[^\w]+' + num_blocks: 6 + hamming_distance: 5 diff --git a/configs/redpajama/redpajama-code.yaml b/configs/redpajama/redpajama-code.yaml new file mode 100644 index 000000000..4ebabdfd0 --- /dev/null +++ b/configs/redpajama/redpajama-code.yaml @@ -0,0 +1,32 @@ +# Process config example for codes used in RedPajam + +# global parameters +project_name: 'RedPajam-codes' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +np: 4 # number of subprocess to process your dataset + +export_path: '/path/to/result/dataset.jsonl' + +# process schedule +# a list of several process operators with their arguments +process: + - document_deduplicator: + - clean_copyright_mapper: + - maximum_line_length_filter: + min_len: 1 + max_len: 1000 + - average_line_length_filter: + min_len: 1 + max_len: 100 + - alphanumeric_filter: + min_ratio: 0.25 + max_ratio: 1.0 + - alphanumeric_filter: + tokenization: True + min_ratio: 1.5 + - suffix_filter: + suffixes: [".asm", ".bat", ".cmd", ".c", ".h", ".cs", ".cpp", ".hpp", ".c++", ".h++", ".cc", ".hh", ".C", ".H", ".cmake", ".css", + ".dockerfile", ".f90", ".f", ".f03", ".f08", ".f77", ".f95", ".for", ".fpp", ".go", ".hs", ".html", ".java", ".js", + ".jl", ".lua", ".md", ".markdown", ".php", ".php3", ".php4", ".php5", ".phps", ".phpt", ".pl", ".pm", ".pod", ".perl", + ".ps1", ".psd1", ".psm1", ".py", ".rb", ".rs", ".sql", ".scala", ".sh", ".bash", ".command", ".zsh", ".ts", ".tsx", + ".tex", ".vb", "Dockerfile", "Makefile", ".xml", ".rst", ".m", ".smali"] diff --git a/configs/redpajama/redpajama-stackexchange.yaml b/configs/redpajama/redpajama-stackexchange.yaml new file mode 100644 index 000000000..4a2eded8b --- /dev/null +++ b/configs/redpajama/redpajama-stackexchange.yaml @@ -0,0 +1,17 @@ +# Process config example for stackexchange used in RedPajam + +# global parameters +project_name: 'RedPajam-stackexchange' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +np: 4 # number of subprocess to process your dataset +use_cache: False + +export_path: '/path/to/result/dataset.jsonl' + +# process schedule +# a list of several process operators with their arguments +process: + - clean_html_mapper: + - language_id_score_filter: + lang: '' + min_score: 0.0 diff --git a/configs/refine_recipe/.DS_Store b/configs/refine_recipe/.DS_Store new file mode 100644 index 000000000..9f02008f0 Binary files /dev/null and b/configs/refine_recipe/.DS_Store differ diff --git a/configs/refine_recipe/README.md b/configs/refine_recipe/README.md new file mode 100644 index 000000000..21a3d365e --- /dev/null +++ b/configs/refine_recipe/README.md @@ -0,0 +1,37 @@ +# Refined open source dataset by Data-Juicer + +We found that there are still some "bad" samples in existing processed datasets (e.g. RedPajama, The Pile, etc.). So we use our Data-Juicer to refine them and try to feed them to LLMs for better performance. + +We use simple 3-σ rule to set the hyperparameters for ops in each recipe. + +## Before and after refining for Pretraining Dataset + +| subset | #samples before | #samples after | keep ratio | config link | data link | source | +|----------------------|:---------------------------:|:--------------:|:----------:|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------| +| Arxiv | 1,724,497 | 1,655,259 | 95.99% | [redpajama-arxiv-refine.yaml](redpajama-arxiv-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/redpajama-arxiv-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/redpajama-arxiv-refined-by-data-juicer/summary) | Redpajama | +| Books | 205,182 | 195,983 | 95.51% | [redpajama-book-refine.yaml](redpajama-book-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/redpajama-book-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/redpajama-book-refined-by-data-juicer/summary) | Redpajama | +| Wikipedia | 29,834,171 | 26,990,659 | 90.47% | [redpajama-wiki-refine.yaml](redpajama-wiki-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/redpajama-wiki-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/redpajama-wiki-refined-by-data-juicer/summary) | Redpajama | +| C4 | 364,868,892 | 346,217,856 | 94.89% | [redpajama-c4-refine.yaml](redpajama-c4-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/redpajama-c4-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/redpajama-c4-refined-by-data-juicer/summary) | Redpajama | +| Common Crawl 2019-30 | 81,085,420 | 36,557,283 | 45.08% | [redpajama-cc-refine/](redpajama-cc-2019-30-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/redpajama-cc-refine-results/redpajama-cc-2019-30-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/redpajama-cc-2019-30-refined-by-data-juicer/summary) | Redpajama | +| Common Crawl 2020-05 | 90,850,492 | 42,612,596 | 46.90% | [redpajama-cc-refine/](redpajama-cc-2020-05-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/redpajama-cc-refine-results/redpajama-cc-2020-05-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/redpajama-cc-2020-05-refined-by-data-juicer/summary) | Redpajama | +| Common Crawl 2021-04 | 98,878,523 | 44,724,752 | 45.23% | [redpajama-cc-refine/](redpajama-cc-2021-04-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/redpajama-cc-refine-results/redpajama-cc-2021-04-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/redpajama-cc-2021-04-refined-by-data-juicer/summary) | Redpajama | +| Common Crawl 2022-05 | 94,058,868 | 42,648,496 | 45.34% | [redpajama-cc-refine/](redpajama-cc-2022-05-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/redpajama-cc-refine-results/redpajama-cc-2022-05-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/redpajama-cc-2022-05-refined-by-data-juicer/summary) | Redpajama | +| Common Crawl 2023-06 | 111,402,716 | 50,643,699 | 45.46% | [redpajama-cc-refine/](redpajama-cc-2023-06-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/redpajama-cc-refine-results/redpajama-cc-2023-06-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/redpajama-cc-2023-06-refined-by-data-juicer/summary) | Redpajama | +| Github Code | 73,208,524
+ 21,387,703 | 49,279,344 | 52.09% | [redpajama-code-refine.yaml](github_code/redpajama-code-refine.yaml)
[stack-code-refine.yaml](github_code/stack-code-refine.yaml)
[redpajama-stack-code-deduplicate.yaml](github_code/redpajama-stack-code-deduplicate.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/redpajama-stack-code-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/redpajama-stack-code-refined-by-data-juicer/summary) | Redpajama
The Stack | +| StackExchange | 45,447,328 | 26,309,203 | 57.89% | [redpajama-pile-stackexchange-refine.yaml](redpajama-pile-stackexchange-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/redpajama-pile-stackexchange-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/redpajama-pile-stackexchange-refined-by-data-juicer/summary) | Redpajama
The Pile | +| EuroParl | 69,814 | 61,601 | 88.23% | [pile-europarl-refine.yaml](pile-europarl-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/the-pile-europarl-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/the-pile-europarl-refined-by-data-juicer/summary) | The Pile | +| FreeLaw | 3,562,015 | 2,942,612 | 82.61% | [pile-freelaw-refine.yaml](pile-freelaw-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/the-pile-freelaw-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/the-pile-freelaw-refined-by-data-juicer/summary) | The Pile | +| HackerNews | 373,027 | 371,331 | 99.55% | [pile-hackernews-refine.yaml](pile-hackernews-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/the-pile-hackernews-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/the-pile-hackernews-refined-by-data-juicer/summary) | The Pile | +| NIH ExPorter | 939,661 | 858,492 | 91.36% | [pile-nih-refine.yaml](pile-nih-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/the-pile-hin-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/the-pile-nih-refined-by-data-juicer/summary) | The Pile | +| PhilPapers | 32,782 | 29,117 | 88.82% | [pile-philpaper-refine.yaml](pile-philpaper-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/the-pile-philpaper-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/the-pile-philpaper-refined-by-data-juicer/summary) | The Pile | +| PubMed Abstracts | 15,518,009 | 15,009,325 | 96.72% | [pile-pubmed-abstract-refine.yaml](pile-pubmed-abstract-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/the-pile-pubmed-abstract-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/the-pile-pubmed-abstracts-refined-by-data-juicer/summary) | The Pile | +| PubMed Central | 3,098,930 | 2,694,860 | 86.96% | [pile-pubmed-central-refine.yaml](pile-pubmed-central-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/the-pile-pubmed-central-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/the-pile-pubmed-central-refined-by-data-juicer/summary) | The Pile | +| USPTO | 5,883,024 | 4,516,283 | 46.77% | [pile-uspto-refine.yaml](pile-uspto-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/the-pile-uspto-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/the-pile-uspto-refined-by-data-juicer/summary) | The Pile | + + +## Before and after refining for Alpaca-CoT Dataset + +| subset | #samples before | #samples after | keep ratio | config link | data link | source | +|------------------|:-------------------------:|:--------------------------------------:|:----------:|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------| +| Alpaca-Cot EN | 136,219,879 | Non-dedup: 104,573,711
Dedup: TBD | 76.77% | [alpaca-cot-en-refine.yaml](alpaca_cot/alpaca-cot-en-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/SFT/alpaca-cot-en-refine_result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/alpaca-cot-en-refined-by-data-juicer/summary) | [39 Subsets of Alpaca-CoT](alpaca_cot/README.md) | +| Alpaca-Cot ZH | 21,197,246 | 9,873,214 | 46.58% | [alpaca-cot-zh-refine.yaml](alpaca_cot/alpaca-cot-zh-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/SFT/alpaca-cot-zh-refine_result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/alpaca-cot-zh-refined-by-data-juicer/summary) | [28 Subsets of Alpaca-CoT](alpaca_cot/README.md) | diff --git a/configs/refine_recipe/README_ZH.md b/configs/refine_recipe/README_ZH.md new file mode 100644 index 000000000..1670e4a7b --- /dev/null +++ b/configs/refine_recipe/README_ZH.md @@ -0,0 +1,37 @@ +# 使用Data-Juicer完善开源数据集 + +我们发现在现有的已经处理过的数据集(如 Redpajama,The Pile 等)中仍然存在一些“脏”数据样本。所以我们使用我们的 Data-Juicer 来完善这些数据集,并尝试将它们提供给 LLM 以获得更好的性能。 + +我们使用简单的 3-σ 规则来设置每个数据处理菜谱中的算子的超参数。 + +## 完善前后的预训练数据集 + +| 数据子集 | 完善前的样本数目 | 完善后的样本数目 | 样本保留率 | 配置链接 | 数据链接 | 来源 | +|----------------------|:---------------------------:|:--------------:|:---------:|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------| +| Arxiv | 1,724,497 | 1,655,259 | 95.99% | [redpajama-arxiv-refine.yaml](redpajama-arxiv-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/redpajama-arxiv-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/redpajama-arxiv-refined-by-data-juicer/summary) | Redpajama | +| Books | 205,182 | 195,983 | 95.51% | [redpajama-book-refine.yaml](redpajama-book-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/redpajama-book-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/redpajama-book-refined-by-data-juicer/summary) | Redpajama | +| Wikipedia | 29,834,171 | 26,990,659 | 90.47% | [redpajama-wiki-refine.yaml](redpajama-wiki-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/redpajama-wiki-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/redpajama-wiki-refined-by-data-juicer/summary) | Redpajama | +| C4 | 364,868,892 | 346,217,856 | 94.89% | [redpajama-c4-refine.yaml](redpajama-c4-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/redpajama-c4-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/redpajama-c4-refined-by-data-juicer/summary) | Redpajama | +| Common Crawl 2019-30 | 81,085,420 | 36,557,283 | 45.08% | [redpajama-cc-refine/](redpajama-cc-2019-30-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/redpajama-cc-refine-results/redpajama-cc-2019-30-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/redpajama-cc-2019-30-refined-by-data-juicer/summary) | Redpajama | +| Common Crawl 2020-05 | 90,850,492 | 42,612,596 | 46.90% | [redpajama-cc-refine/](redpajama-cc-2020-05-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/redpajama-cc-refine-results/redpajama-cc-2020-05-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/redpajama-cc-2020-05-refined-by-data-juicer/summary) | Redpajama | +| Common Crawl 2021-04 | 98,878,523 | 44,724,752 | 45.23% | [redpajama-cc-refine/](redpajama-cc-2021-04-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/redpajama-cc-refine-results/redpajama-cc-2021-04-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/redpajama-cc-2021-04-refined-by-data-juicer/summary) | Redpajama | +| Common Crawl 2022-05 | 94,058,868 | 42,648,496 | 45.34% | [redpajama-cc-refine/](redpajama-cc-2022-05-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/redpajama-cc-refine-results/redpajama-cc-2022-05-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/redpajama-cc-2022-05-refined-by-data-juicer/summary) | Redpajama | +| Common Crawl 2023-06 | 111,402,716 | 50,643,699 | 45.46% | [redpajama-cc-refine/](redpajama-cc-2023-06-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/redpajama-cc-refine-results/redpajama-cc-2023-06-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/redpajama-cc-2023-06-refined-by-data-juicer/summary) | Redpajama | +| Github Code | 73,208,524
+ 21,387,703 | 49,279,344 | 52.09% | [redpajama-code-refine.yaml](github_code/redpajama-code-refine.yaml)
[stack-code-refine.yaml](github_code/stack-code-refine.yaml)
[redpajama-stack-code-deduplicate.yaml](github_code/redpajama-stack-code-deduplicate.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/redpajama-stack-code-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/redpajama-stack-code-refined-by-data-juicer/summary) | Redpajama
The Stack | +| StackExchange | 45,447,328 | 26,309,203 | 57.89% | [redpajama-pile-stackexchange-refine.yaml](redpajama-pile-stackexchange-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/redpajama-pile-stackexchange-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/redpajama-pile-stackexchange-refined-by-data-juicer/summary) | Redpajama
The Pile | +| EuroParl | 69,814 | 61,601 | 88.23% | [pile-europarl-refine.yaml](pile-europarl-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/the-pile-europarl-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/the-pile-europarl-refined-by-data-juicer/summary) | The Pile | +| FreeLaw | 3,562,015 | 2,942,612 | 82.61% | [pile-freelaw-refine.yaml](pile-freelaw-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/the-pile-freelaw-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/the-pile-freelaw-refined-by-data-juicer/summary) | The Pile | +| HackerNews | 373,027 | 371,331 | 99.55% | [pile-hackernews-refine.yaml](pile-hackernews-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/the-pile-hackernews-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/the-pile-hackernews-refined-by-data-juicer/summary) | The Pile | +| NIH ExPorter | 939,661 | 858,492 | 91.36% | [pile-nih-refine.yaml](pile-nih-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/the-pile-hin-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/the-pile-nih-refined-by-data-juicer/summary) | The Pile | +| PhilPapers | 32,782 | 29,117 | 88.82% | [pile-philpaper-refine.yaml](pile-philpaper-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/the-pile-philpaper-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/the-pile-philpaper-refined-by-data-juicer/summary) | The Pile | +| PubMed Abstracts | 15,518,009 | 15,009,325 | 96.72% | [pile-pubmed-abstract-refine.yaml](pile-pubmed-abstract-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/the-pile-pubmed-abstract-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/the-pile-pubmed-abstracts-refined-by-data-juicer/summary) | The Pile | +| PubMed Central | 3,098,930 | 2,694,860 | 86.96% | [pile-pubmed-central-refine.yaml](pile-pubmed-central-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/the-pile-pubmed-central-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/the-pile-pubmed-central-refined-by-data-juicer/summary) | The Pile | +| USPTO | 5,883,024 | 4,516,283 | 46.77% | [pile-uspto-refine.yaml](pile-uspto-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/pretraining/the-pile-uspto-refine-result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/the-pile-uspto-refined-by-data-juicer/summary) | The Pile | + + +## 完善前后的Alpaca-CoT数据集 + +| 数据子集 | 完善前的样本数目 | 完善后的样本数目 | 样本保留率 | 配置链接 | 数据链接 | 来源 | +|-------------------|:------------------------:|:----------------------------------:|:---------:|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------| +| Alpaca-Cot EN | 136,219,879 | 未去重版本: 104,573,711
去重版本: TBD | 76.77% | [alpaca-cot-en-refine.yaml](alpaca_cot/alpaca-cot-en-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/SFT/alpaca-cot-en-refine_result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/alpaca-cot-en-refined-by-data-juicer/summary) | [来自Alpaca-CoT的39个子集](alpaca_cot/README_ZH.md) | +| Alpaca-Cot ZH | 21,197,246 | 9,873,214 | 46.58% | [alpaca-cot-zh-refine.yaml](alpaca_cot/alpaca-cot-zh-refine.yaml) | [Aliyun](http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/SFT/alpaca-cot-zh-refine_result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/alpaca-cot-zh-refined-by-data-juicer/summary) | [来自Alpaca-CoT的28个子集](alpaca_cot/README_ZH.md) | diff --git a/configs/refine_recipe/alpaca_cot/README.md b/configs/refine_recipe/alpaca_cot/README.md new file mode 100644 index 000000000..67468a2e2 --- /dev/null +++ b/configs/refine_recipe/alpaca_cot/README.md @@ -0,0 +1,112 @@ +# Refine Alpaca-CoT Config Files + +This folder contains some configuration files to allow users to easily and quickly refine [Alpaca-CoT](https://huggingface.co/QingyiSi/Alpaca-CoT). + +## Preprocess +The raw data files can be downloaded from [Alpaca-CoT](https://huggingface.co/QingyiSi/Alpaca-CoT) on HuggingFace. + +### Convert raw Alpaca-CoT data to jsonl +Use [raw_alpaca_cot_merge_add_meta.py](../../tools/preprocess/raw_alpaca_cot_merge_add_meta.py) to select `instruction`, `input` and `output` columns and merge them to `text` field with a space, and add extra [ META ]( #meta_info) info to dataset: + +```shell +python tools/preprocess/raw_alpaca_cot_merge_add_meta.py \ + --src_dir \ + --target_dir \ + --num_proc +``` + +### Split datasets to sub-datasets by language +Use [dataset_split_by_language.py](../../tools/preprocess/dataset_split_by_language.py) to split the dataset to EN and ZH sub-datasets: + +```shell +python tools/preprocess/dataset_split_by_language.py \ + --src_dir \ + --target_dir \ + --suffixes jsonl \ + --num_proc +``` + +## Process +After preprocess, modify the dataset path in [alpaca-cot-en-refine](alpaca-cot-en-refine].yaml) and [alpaca-cot-zh-refine](alpaca-cot-zh-refine.yaml), and then execute the following command to reproduce the processing flow of refined Alpaca-CoT. +```shell +# refine English dataset +python tools/process_data.py --config configs/refine_recipe/alpaca_cot/alpaca-cot-en-refine].yaml + +# refine Chinese dataset +python tools/process_data.py --config configs/refine_recipe/alpaca_cot/alpaca-cot-zh-refine].yaml +``` + +### Meta Info + +Each sample in refined data of Alpaca-CoT contains meta info listed as below: + +#### Alpaca-CoT original meta info +* Language Tags: + - EN: Instruction datasets in English + - CN: Instruction datasets in Chinese + - ML: [Multi-lingual] Instruction datasets in multiple languages +* Task Tags + - MT: [Multi-task] Datasets containing multiple tasks + - TS: [Task-specific] Datasets tailored for specific tasks +* Generation-method: + - HG: [Human Generated Dataset] Datasets created by humans + - SI: [Self-Instruct] Datasets generated using self-instruct methods + - MIX: [Mixed Dataset] Dataset contains both human and machine generated data + - COL: [Collection of Dataset] Dataset made from a collection of other datasets + +#### Data-Juicer Meta info +* Dataset: Dataset in Alpaca-CoT + +* Multi-round Dialog (MRD): Multi-round Dialog datasets + +* IFT: Instruction Fine-Tuning datasets + +* SFT: Supervised Fine-Tuning datasets + +* Preference: Preference datasets + +* origin_path: origin file path in in Alpaca-CoT + + +#### Refined Alpaca-CoT dataset Meta info +| | Task | Gen | Lang | Dataset | MRD | IFT | SFT | Preference | +|:---------------------|:-------|:------|:-------|:---------------------|:----:|:---:|:---:|:---:| +| Chain-of-Thought | MT | HG | EN/CN | Chain-of-Thought | | ✅ | | | +| GPT4all | MT | COL | EN | GPT4all | | ✅ | ✅ | | +| GPTeacher | MT | SI | EN | GPTeacher | | | ✅ | | +| Guanaco | MT | SI | ML | Guanaco | | | ✅ | | +| HC3 | TS | MIX | EN/CN | HC3 | | | ✅ | ✅ | +| alpaca | MT | SI | EN | alpaca | | | ✅ | | +| Natural-Instructions | MT | COL | ML | Natural-Instructions | | ✅ | | | +| belle_cn | TS/MT | SI | CN | belle_cn | | | ✅ | | +| instinwild | MT | SI | EN/CN | instinwild | | | ✅ | | +| prosocial-dialog | TS | MIX | EN | prosocial-dialog | | | ✅ | | +| finance | TS | COL | EN | finance | | | ✅ | | +| xP3 | MT | COL | ML | xP3 | | ✅ | | | +| firefly | MT | COL | CN | firefly | | ✅ | | | +| instruct | MT | COL | EN | instruct | | | ✅ | | +| CodeAlpaca | TS | SI | EN | CodeAlpaca | | ✅ | | | +| alpacaGPT4 | MT | SI | EN/CN | alpacaGPT4 | | | ✅ | ✅ | +| webGPT | TS | MIX | EN | webGPT | | ✅ | | ✅ | +| dolly | TS | HG | EN | dolly | | | ✅ | | +| baize | MT | COL | EN | baize | | | ✅ | | +| hh-rlhf | TS | MIX | EN | hh-rlhf | ✅ | | ✅ | ✅ | +| OIG | MT | COL | EN | OIG | | | ✅ | | +| GAOKAO | MT | COL | CN | GAOKAO | | ✅ | | | +| camel | MT | SI | EN | camel | | ✅ | | | +| FLAN-Muffin | MT | COL | EN | FLAN-Muffin | | ✅ | | | +| COIG | MT | COL | CN | COIG | | | ✅ | | +| gpt4tools | MT | SI | EN | gpt4tools | | ✅ | | | +| ShareGPT | MT | MIX | EN | ShareGPT | ✅ | | ✅ | | +| Auto-CoT | MT | COL | EN | Auto-CoT | | ✅ | | | +| MOSS | TS | SI | EN/CN | MOSS | | | ✅ | | +| ultrachat | TS | SI | EN | ultrachat | | | ✅ | | +| Chinese-medical | TS | COL | CN | Chinese-medical | | | ✅ | | +| CSL | MT | COL | CN | CSL | | ✅ | | | +| pCLUE | MT | COL | CN | pCLUE | | ✅ | | | +| news_commentary | TS | COL | CN | news_commentary | | ✅ | | | +| StackExchange | MT | COL | EN | StackExchange | | | ✅ | ✅ | +| ConvAI2 | TS | HG | EN | ConvAI2 | | | ✅ | | +| FastChat | MT | SI | EN | FastChat | | | ✅ | | +| Tabular-LLM-Data | MT | COL | EN/CN | Tabular-LLM-Data | | ✅ | | | +| ThoughtSource | MT | COL | EN | ThoughtSource | | ✅ | | | \ No newline at end of file diff --git a/configs/refine_recipe/alpaca_cot/README_ZH.md b/configs/refine_recipe/alpaca_cot/README_ZH.md new file mode 100644 index 000000000..3461132d6 --- /dev/null +++ b/configs/refine_recipe/alpaca_cot/README_ZH.md @@ -0,0 +1,114 @@ +# Redpajama Config Files + +该文件夹包含的配置文件能够让用户轻松快速地完善 [Alpaca-CoT](https://huggingface.co/QingyiSi/Alpaca-CoT)。 + +## 预处理 + +原始数据文件在 HuggingFace 中的 [Alpaca-CoT](https://huggingface.co/QingyiSi/Alpaca-CoT) 下载。 + +### 将 Alpaca-CoT 转换为 jsonl 文件 +使用 [raw_alpaca_cot_merge_add_meta.py](../../tools/preprocess/raw_alpaca_cot_merge_add_meta.py) 选择数据集的 `instruction`, `input` 和 `output` 3个字段,并使用空格将它们合并到 `text`,同时在数据集中增加额外的[元信息]( #meta_info) : + +```shell +python tools/preprocess/raw_alpaca_cot_merge_add_meta.py \ + --src_dir \ + --target_dir \ + --num_proc +``` + +### 按照语言将数据集拆分子数据集 +使用 [dataset_split_by_language.py](../../tools/preprocess/dataset_split_by_language.py) 将数据集拆分为中文和英文: + +```shell +python tools/preprocess/dataset_split_by_language.py \ + --src_dir \ + --target_dir \ + --suffixes jsonl \ + --num_proc +``` + +## 处理 +在预处理完成之后,修改 [alpaca-cot-en-refine](alpaca-cot-en-refine].yaml) 和 [alpaca-cot-zh-refine](alpaca-cot-zh-refine.yaml) 中的数据集路径,然后执行以下命令来复现完善过的 Alpaca-CoT 的处理流程。 + +```shell +# refine English dataset +python tools/process_data.py --config configs/refine_recipe/alpaca_cot/alpaca-cot-en-refine].yaml + +# refine Chinese dataset +python tools/process_data.py --config configs/refine_recipe/alpaca_cot/alpaca-cot-zh-refine].yaml +``` + +### 元信息 + +在完善后的 Alpaca-CoT 的数据集中每个样本都包含元信息,标签说明如下: + +#### Alpaca-CoT 元信息 +* Language 标签: + - EN: 英文数据集 + - CN: 中文数据集 + - ML: 多语言数据集 +* Task 标签: + - MT: 多任务数据集 + - TS: 特定任务数据集 +* 产生方法: + - HG: 人工产出数据集 + - SI: 机器产出数据集 + - MIX: 人工和机器混合数据集 + - COL: 从其他数据集合成的数据集 + +#### Data-Juicer 元信息 +* Dataset: Alpaca-CoT 中的数据集 + +* Multi-round Dialog (MRD): 多轮对话数据集 + +* IFT: 指令微调数据集 + +* SFT: 有监督微调数据集 + +* Preference: 偏好数据集 + +* origin_path: Alpaca-CoT 中的原始文件路径 + + +#### 完善的 Alpaca-CoT 数据集元信息 +| | 任务 | 产生方法 | 语言 | 数据集 | 多轮对话 | 指令跟随 | 监督微调 | 偏好 | +|:---------------------|:-------|:------|:-------|:---------------------|:---:|:---:|:----:|:----:| +| Chain-of-Thought | MT | HG | EN/CN | Chain-of-Thought | | ✅ | | | +| GPT4all | MT | COL | EN | GPT4all | | ✅ | ✅ | | +| GPTeacher | MT | SI | EN | GPTeacher | | | ✅ | | +| Guanaco | MT | SI | ML | Guanaco | | | ✅ | | +| HC3 | TS | MIX | EN/CN | HC3 | | | ✅ | ✅ | +| alpaca | MT | SI | EN | alpaca | | | ✅ | | +| Natural-Instructions | MT | COL | ML | Natural-Instructions | | ✅ | | | +| belle_cn | TS/MT | SI | CN | belle_cn | | | ✅ | | +| instinwild | MT | SI | EN/CN | instinwild | | | ✅ | | +| prosocial-dialog | TS | MIX | EN | prosocial-dialog | | | ✅ | | +| finance | TS | COL | EN | finance | | | ✅ | | +| xP3 | MT | COL | ML | xP3 | | ✅ | | | +| firefly | MT | COL | CN | firefly | | ✅ | | | +| instruct | MT | COL | EN | instruct | | | ✅ | | +| CodeAlpaca | TS | SI | EN | CodeAlpaca | | ✅ | | | +| alpacaGPT4 | MT | SI | EN/CN | alpacaGPT4 | | | ✅ | ✅ | +| webGPT | TS | MIX | EN | webGPT | | ✅ | | ✅ | +| dolly | TS | HG | EN | dolly | | | ✅ | | +| baize | MT | COL | EN | baize | | | ✅ | | +| hh-rlhf | TS | MIX | EN | hh-rlhf | ✅ | | ✅ | ✅ | +| OIG | MT | COL | EN | OIG | | | ✅ | | +| GAOKAO | MT | COL | CN | GAOKAO | | ✅ | | | +| camel | MT | SI | EN | camel | | ✅ | | | +| FLAN-Muffin | MT | COL | EN | FLAN-Muffin | | ✅ | | | +| COIG | MT | COL | CN | COIG | | | ✅ | | +| gpt4tools | MT | SI | EN | gpt4tools | | ✅ | | | +| ShareGPT | MT | MIX | EN | ShareGPT | ✅ | | ✅ | | +| Auto-CoT | MT | COL | EN | Auto-CoT | | ✅ | | | +| MOSS | TS | SI | EN/CN | MOSS | | | ✅ | | +| ultrachat | TS | SI | EN | ultrachat | | | ✅ | | +| Chinese-medical | TS | COL | CN | Chinese-medical | | | ✅ | | +| CSL | MT | COL | CN | CSL | | ✅ | | | +| pCLUE | MT | COL | CN | pCLUE | | ✅ | | | +| news_commentary | TS | COL | CN | news_commentary | | ✅ | | | +| StackExchange | MT | COL | EN | StackExchange | | | ✅ | ✅ | +| ConvAI2 | TS | HG | EN | ConvAI2 | | | ✅ | | +| FastChat | MT | SI | EN | FastChat | | | ✅ | | +| Tabular-LLM-Data | MT | COL | EN/CN | Tabular-LLM-Data | | ✅ | | | +| ThoughtSource | MT | COL | EN | ThoughtSource | | ✅ | | | \ No newline at end of file diff --git a/configs/refine_recipe/alpaca_cot/alpaca-cot-en-refine.yaml b/configs/refine_recipe/alpaca_cot/alpaca-cot-en-refine.yaml new file mode 100644 index 000000000..c1a43712e --- /dev/null +++ b/configs/refine_recipe/alpaca_cot/alpaca-cot-en-refine.yaml @@ -0,0 +1,37 @@ +# global parameters +project_name: 'Data-Juicer-recipes-alpaca-cot-en' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' + +np: 50 # number of subprocess to process your dataset +open_tracer: true + +# process schedule +# a list of several process operators with their arguments +process: + - document_deduplicator: # 104636705 + lowercase: true + ignore_non_character: true + + - alphanumeric_filter: # 104636381 + tokenization: false + min_ratio: 0.1 + - character_repetition_filter: # 104630030 + rep_len: 10 + max_ratio: 0.6 + - flagged_words_filter: # 104576967 + lang: en + tokenization: true + max_ratio: 0.017 + - maximum_line_length_filter: # 104575811 + min_len: 20 + - text_length_filter: #92673211 + min_len: 30 + + - document_simhash_deduplicator: # coming soon + tokenization: space + window_size: 3 + lowercase: true + ignore_pattern: '\p{P}' + num_blocks: 9 + hamming_distance: 7 diff --git a/configs/refine_recipe/alpaca_cot/alpaca-cot-zh-refine.yaml b/configs/refine_recipe/alpaca_cot/alpaca-cot-zh-refine.yaml new file mode 100644 index 000000000..a563b203f --- /dev/null +++ b/configs/refine_recipe/alpaca_cot/alpaca-cot-zh-refine.yaml @@ -0,0 +1,36 @@ +# global parameters +project_name: 'Data-Juicer-recipes-alpaca-cot-zh' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' + +np: 50 # number of subprocess to process your dataset +open_tracer: true + +# process schedule +# a list of several process operators with their arguments +process: + - document_deduplicator: # 16957516 + lowercase: true # whether to convert text to lower case + ignore_non_character: true + + - alphanumeric_filter: # 16957388 + tokenization: false + min_ratio: 0.10 + - character_repetition_filter: # 16956845 + rep_len: 10 + max_ratio: 0.6 + - flagged_words_filter: # 16954629 + lang: zh + tokenization: true + use_words_aug: true + max_ratio: 0.017 + - text_length_filter: # 16954317 + min_len: 10 + + - document_simhash_deduplicator: # 9873214 + tokenization: character + window_size: 4 # small window size for short texts + lowercase: true + ignore_pattern: '\p{P}' + num_blocks: 10 + hamming_distance: 8 # larger hamming distance threshold for short texts diff --git a/configs/refine_recipe/github_code/redpajama-code-refine.yaml b/configs/refine_recipe/github_code/redpajama-code-refine.yaml new file mode 100644 index 000000000..afb31e1aa --- /dev/null +++ b/configs/refine_recipe/github_code/redpajama-code-refine.yaml @@ -0,0 +1,55 @@ +# global parameters +project_name: 'Data-Juicer-recipes-code-rp' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' + +np: 50 # number of subprocess to process your dataset +open_tracer: true + +# process schedule +# a list of several process operators with their arguments +process: + - clean_email_mapper: + - clean_links_mapper: + - fix_unicode_mapper: + - punctuation_normalization_mapper: + - whitespace_normalization_mapper: + - clean_copyright_mapper: + + - alphanumeric_filter: + tokenization: False + min_ratio: 0.4 + max_ratio: 0.8 + - alphanumeric_filter: + tokenization: True + min_ratio: 1.5 + max_ratio: 3 + - average_line_length_filter: + min_len: 15 + max_len: 100 + - character_repetition_filter: + rep_len: 10 + min_ratio: 0.05 + max_ratio: 0.3 + - maximum_line_length_filter: + min_len: 50 + max_len: 500 + - text_length_filter: + min_len: 300 + - words_num_filter: + lang: en + tokenization: False + min_num: 30 + max_num: 5000 + - word_repetition_filter: + lang: en + tokenization: False + rep_len: 10 + max_ratio: 0.1 + - document_simhash_deduplicator: + tokenization: space + window_size: 6 + lowercase: true + ignore_pattern: '\p{P}' + num_blocks: 6 + hamming_distance: 4 diff --git a/configs/refine_recipe/github_code/redpajama-stack-code-deduplicate.yaml b/configs/refine_recipe/github_code/redpajama-stack-code-deduplicate.yaml new file mode 100644 index 000000000..df960322a --- /dev/null +++ b/configs/refine_recipe/github_code/redpajama-stack-code-deduplicate.yaml @@ -0,0 +1,18 @@ +project_name: 'Data-Juicer-recipes-code' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' # path to your dataset result file + +np: 50 # number of subprocess to process your dataset +open_tracer: true + +# process schedule +# a list of several process operators with their arguments + +process: + - document_simhash_deduplicator: + tokenization: space + window_size: 6 + lowercase: true + ignore_pattern: '\p{P}' + num_blocks: 6 + hamming_distance: 4 diff --git a/configs/refine_recipe/github_code/stack-code-refine.yaml b/configs/refine_recipe/github_code/stack-code-refine.yaml new file mode 100644 index 000000000..3736bb04f --- /dev/null +++ b/configs/refine_recipe/github_code/stack-code-refine.yaml @@ -0,0 +1,51 @@ +# global parameters +project_name: 'Data-Juicer-recipes-the-stack' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' + +text_key: 'content' + +np: 50 # number of subprocess to process your dataset +open_tracer: true + +# process schedule +# a list of several process operators with their arguments +process: + - clean_email_mapper: + - clean_links_mapper: + - fix_unicode_mapper: + - punctuation_normalization_mapper: + - whitespace_normalization_mapper: + - clean_copyright_mapper: + + - alphanumeric_filter: # 18766 + tokenization: false + min_ratio: 0.2 # < 3sigma (0.3791) + max_ratio: 0.9163 # 3sigma + - alphanumeric_filter: # 146432 + tokenization: true + min_ratio: 0.546 # 3sigma + max_ratio: 3.65 # 3sigma + - average_line_length_filter: # for code + min_len: 10 # > 3sigma (0) -- 48790 + max_len: 150 # < 3sigma (15603) -- 233275 + - character_repetition_filter: + max_ratio: 0.36 # 3sigma -- 346875 + - maximum_line_length_filter: # for code + max_len: 1000 # remove 256670 samples + - text_length_filter: + max_len: 96714 # 3sigma -- 190006 + - words_num_filter: + min_num: 20 # remove 1504958 samples + max_num: 6640 # 3sigma -- remove 179847 samples + - word_repetition_filter: + rep_len: 10 + max_ratio: 0.357 # 3sigma -- 598462 + + - document_simhash_deduplicator: + tokenization: space + window_size: 6 + lowercase: true + ignore_pattern: '\p{P}' + num_blocks: 6 + hamming_distance: 4 diff --git a/configs/refine_recipe/pile-europarl-refine.yaml b/configs/refine_recipe/pile-europarl-refine.yaml new file mode 100644 index 000000000..9f3b60d4d --- /dev/null +++ b/configs/refine_recipe/pile-europarl-refine.yaml @@ -0,0 +1,58 @@ +# global parameters +project_name: 'Data-Juicer-recipes-EuroParl' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' + +np: 50 # number of subprocess to process your dataset +open_tracer: true + +# process schedule +# a list of several process operators with their arguments +process: + - clean_email_mapper: + - clean_links_mapper: + - fix_unicode_mapper: + - punctuation_normalization_mapper: + - whitespace_normalization_mapper: + + - alphanumeric_filter: + tokenization: false + min_ratio: 0.75 # <3sigma (0.779) + max_ratio: 0.90 # >3sigma(0.878) + - average_line_length_filter: # for code + max_len: 588 # 3sigma + - character_repetition_filter: + rep_len: 10 + max_ratio: 0.16 # >3sigma (0.114) + - flagged_words_filter: + lang: en + tokenization: true + max_ratio: 0.0007 # 3sigma + - language_id_score_filter: + min_score: 0.7 + - maximum_line_length_filter: # for code + max_len: 4000 # >3sigma (3104) + - perplexity_filter: + lang: en + max_ppl: 7596 #(3sigma) + - special_characters_filter: + max_ratio: 0.3 # > 3sigma (0.243) + - text_length_filter: + max_len: 2e5 + - words_num_filter: + tokenization: true + min_num: 20 + max_num: 1e5 # 3sigma + - word_repetition_filter: + lang: en + tokenization: true + rep_len: 10 + max_ratio: 0.2 # > 3sigma (0.185) + + - document_simhash_deduplicator: + tokenization: space + window_size: 6 + lowercase: true + ignore_pattern: '\p{P}' + num_blocks: 6 + hamming_distance: 4 diff --git a/configs/refine_recipe/pile-freelaw-refine.yaml b/configs/refine_recipe/pile-freelaw-refine.yaml new file mode 100644 index 000000000..071044f40 --- /dev/null +++ b/configs/refine_recipe/pile-freelaw-refine.yaml @@ -0,0 +1,62 @@ +# global parameters +project_name: 'Data-Juicer-recipes-freelaw' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' + +np: 50 # number of subprocess to process your dataset +open_tracer: true + +# process schedule +# a list of several process operators with their arguments +process: + - clean_email_mapper: + - clean_links_mapper: + - fix_unicode_mapper: + - punctuation_normalization_mapper: + - whitespace_normalization_mapper: + + - alphanumeric_filter: + tokenization: false + min_ratio: 0.3 # <3sigma (0.436) + - average_line_length_filter: # for code + max_len: 697 # 3sigma TBD + - character_repetition_filter: + rep_len: 10 + max_ratio: 0.4 # >3sigma (0.350) + - flagged_words_filter: + lang: en + tokenization: true + max_ratio: 0.0053 # 3sigma + - language_id_score_filter: + min_score: 0.5 # < 3sigma (0.583) + - maximum_line_length_filter: # for code + max_len: 4229 # 3sigma + - perplexity_filter: + lang: en + max_ppl: 5322 # 3sigma + - special_characters_filter: + max_ratio: 0.7 # > 3sigma (0.626) + - stopwords_filter: # not use + lang: en + tokenization: true + min_ratio: 0.1 # > 3sigma (0.07) + - text_length_filter: + max_len: 84026 # 3sigma + - words_num_filter: + lang: en + tokenization: true + min_num: 100 + max_num: 15208 # 3sigma + - word_repetition_filter: + lang: en + tokenization: true + rep_len: 10 + max_ratio: 0.155 # 3sigma + + - document_simhash_deduplicator: + tokenization: space + window_size: 6 + lowercase: true + ignore_pattern: '\p{P}' + num_blocks: 6 + hamming_distance: 4 diff --git a/configs/refine_recipe/pile-hackernews-refine.yaml b/configs/refine_recipe/pile-hackernews-refine.yaml new file mode 100644 index 000000000..59e2f113e --- /dev/null +++ b/configs/refine_recipe/pile-hackernews-refine.yaml @@ -0,0 +1,57 @@ +# global parameters +project_name: 'Data-Juicer-recipes-HackerNews' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' + +np: 48 # number of subprocess to process your dataset +open_tracer: true + +# process schedule +# a list of several process operators with their arguments +process: + - clean_email_mapper: + #- clean_links_mapper: + - fix_unicode_mapper: + - punctuation_normalization_mapper: + - whitespace_normalization_mapper: + + - alphanumeric_filter: + tokenization: false + min_ratio: 0.2 #<3sigma + - average_line_length_filter: + min_len: 15 # >3sigma + - character_repetition_filter: + rep_len: 10 + max_ratio: 0.3 # >3sigma + - flagged_words_filter: + lang: en + tokenization: true + max_ratio: 0.05 # >3sigma + - language_id_score_filter: + min_score: 0.2 # <3sigma + - maximum_line_length_filter: + min_len: 20 # >3sigma + - perplexity_filter: + lang: en + max_ppl: 10000 # >3sigma + - special_characters_filter: + max_ratio: 0.7 # >3sigma + - text_length_filter: + min_len: 100 # > 3sigma + - words_num_filter: + lang: en + tokenization: true + min_num: 30 # > 3sigma + - word_repetition_filter: + lang: en + tokenization: true + rep_len: 10 + max_ratio: 0.8 # > 3sigma + + - document_simhash_deduplicator: + tokenization: space + window_size: 6 + lowercase: true + ignore_pattern: '\p{P}' + num_blocks: 6 + hamming_distance: 4 diff --git a/configs/refine_recipe/pile-nih-refine.yaml b/configs/refine_recipe/pile-nih-refine.yaml new file mode 100644 index 000000000..9088efa46 --- /dev/null +++ b/configs/refine_recipe/pile-nih-refine.yaml @@ -0,0 +1,54 @@ +# global parameters +project_name: 'Data-Juicer-recipes-Hin' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' + +np: 50 # number of subprocess to process your dataset +open_tracer: true + +# process schedule +# a list of several process operators with their arguments +process: + - clean_email_mapper: + - clean_links_mapper: + - fix_unicode_mapper: + - punctuation_normalization_mapper: + - whitespace_normalization_mapper: + + - alphanumeric_filter: + tokenization: false + min_ratio: 0.75 # <3sigma (0.800) + max_ratio: 0.866 + - average_line_length_filter: + max_len: 10000 # >3sigma (5425) + - character_repetition_filter: + rep_len: 10 + max_ratio: 0.2 # >3sigma (0.127) + - flagged_words_filter: + lang: en + tokenization: true + max_ratio: 0.0003 # 3sigma + - language_id_score_filter: + min_score: 0.7 + - perplexity_filter: + lang: en + max_ppl: 1669 #(3sigma) + - special_characters_filter: + max_ratio: 0.3 # > 3sigma (0.218) + - words_num_filter: + tokenization: true + min_num: 20 + max_num: 2000 + - word_repetition_filter: + lang: en + tokenization: true + rep_len: 10 + max_ratio: 0.104 # 3sigma + + - document_simhash_deduplicator: + tokenization: space + window_size: 6 + lowercase: true + ignore_pattern: '\p{P}' + num_blocks: 6 + hamming_distance: 4 diff --git a/configs/refine_recipe/pile-philpaper-refine.yaml b/configs/refine_recipe/pile-philpaper-refine.yaml new file mode 100644 index 000000000..0ef83148c --- /dev/null +++ b/configs/refine_recipe/pile-philpaper-refine.yaml @@ -0,0 +1,56 @@ +# global parameters +project_name: 'Data-Juicer-recipes-Philpaper' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' + +np: 50 # number of subprocess to process your dataset +open_tracer: true + +# process schedule +# a list of several process operators with their arguments +process: + - clean_email_mapper: + - clean_links_mapper: + - fix_unicode_mapper: + - punctuation_normalization_mapper: + - whitespace_normalization_mapper: + + - alphanumeric_filter: + tokenization: false + min_ratio: 0.7 # <3sigma (0.72) + - average_line_length_filter: # for code + max_len: 5e5 # >3sigma (406006) + - character_repetition_filter: + rep_len: 10 + max_ratio: 0.2 # >3sigma (0.145) + - flagged_words_filter: + lang: en + tokenization: true + max_ratio: 0.0007 # 3sigma + - language_id_score_filter: + min_score: 0.6 + - maximum_line_length_filter: # for code + max_len: 1e6 # 3sigma + - perplexity_filter: + lang: en + max_ppl: 5000 + - special_characters_filter: + max_ratio: 0.4 # > 3sigma (0.302) + - words_num_filter: + lang: en + tokenization: true + min_num: 1000 + max_num: 2e5 # 3sigma + - word_repetition_filter: + lang: en + tokenization: true + rep_len: 10 + max_ratio: 0.3 # > 3sigma (0.249) + + - document_simhash_deduplicator: + tokenization: space + window_size: 6 + lowercase: true + ignore_pattern: '\p{P}' + num_blocks: 6 + hamming_distance: 4 diff --git a/configs/refine_recipe/pile-pubmed-abstract-refine.yaml b/configs/refine_recipe/pile-pubmed-abstract-refine.yaml new file mode 100644 index 000000000..2395acbd3 --- /dev/null +++ b/configs/refine_recipe/pile-pubmed-abstract-refine.yaml @@ -0,0 +1,59 @@ +# global parameters +project_name: 'Data-Juicer-recipes-pubmed-abstract' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' + +np: 50 # number of subprocess to process your dataset +open_tracer: true + +# process schedule +# a list of several process operators with their arguments +process: + - clean_email_mapper: + - clean_links_mapper: + - fix_unicode_mapper: + - punctuation_normalization_mapper: + - whitespace_normalization_mapper: + + - alphanumeric_filter: # 4068 + tokenization: false + min_ratio: 0.7 # < 3sigma (0.773) + max_ratio: 0.881 # 3sigma + - average_line_length_filter: # for code + max_len: 2100 # > 3sigma (1471) -- 7410 + - character_repetition_filter: + rep_len: 10 + max_ratio: 0.2 # > 3sigma (0.1458) -- 6060 + - flagged_words_filter: + lang: en + tokenization: true + max_ratio: 0.00232 # 3sigma + - language_id_score_filter: # remove language filter + min_score: 0.5 + - maximum_line_length_filter: # for code + max_len: 4000 # remove 8202 samples + - perplexity_filter: + lang: en + max_ppl: 4000 # remove 10284 samples + - special_characters_filter: + max_ratio: 0.38 # remove 5532 samples + - text_length_filter: + max_len: 4000 # > 3sigma -- 10873 + - words_num_filter: + lang: en + tokenization: true + min_num: 20 # remove 10790 samples + max_num: 700 # remove 22709 samples + - word_repetition_filter: + lang: en + tokenization: true + rep_len: 10 + max_ratio: 0.0887 # 3sigma + + - document_simhash_deduplicator: + tokenization: space + window_size: 3 # small window size for short texts + lowercase: true + ignore_pattern: '\p{P}' + num_blocks: 10 + hamming_distance: 8 # larger hamming distance threshold for short texts diff --git a/configs/refine_recipe/pile-pubmed-central-refine.yaml b/configs/refine_recipe/pile-pubmed-central-refine.yaml new file mode 100644 index 000000000..418403291 --- /dev/null +++ b/configs/refine_recipe/pile-pubmed-central-refine.yaml @@ -0,0 +1,58 @@ +# global parameters +project_name: 'Data-Juicer-recipes-pubmed-central' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' + +np: 50 # number of subprocess to process your dataset +open_tracer: true + +# process schedule +# a list of several process operators with their arguments +process: + - clean_email_mapper: + - clean_links_mapper: + - fix_unicode_mapper: + - punctuation_normalization_mapper: + - whitespace_normalization_mapper: + + - alphanumeric_filter: # 89217 + tokenization: false + min_ratio: 0.2787 # 3sigma + - average_line_length_filter: # for code + max_len: 1200 # < 3sigma (1478) -- 7410 + - character_repetition_filter: + rep_len: 10 + max_ratio: 0.3741 # 3sigma -- 65849 + - flagged_words_filter: + lang: en + tokenization: true + max_ratio: 0.00195 # 3sigma -- 8305 + - language_id_score_filter: # remove language filter + min_score: 0.5 # 272359 + - maximum_line_length_filter: # for code + max_len: 7328 # remove 23808 samples + - perplexity_filter: + lang: en + max_ppl: 8000 # remove 173883 samples + - special_characters_filter: + max_ratio: 0.842 # remove 87661 samples + - text_length_filter: + max_len: 136028 # 3sigma -- 15118 + - words_num_filter: + lang: en + tokenization: true + min_num: 20 # remove 176537 samples + max_num: 23305 # remove 15016 samples + - word_repetition_filter: + lang: en + tokenization: true + rep_len: 10 + max_ratio: 0.5981 # 3sigma -- 93843 + + - document_simhash_deduplicator: + tokenization: space + window_size: 6 + lowercase: true + ignore_pattern: '\p{P}' + num_blocks: 6 + hamming_distance: 4 diff --git a/configs/refine_recipe/pile-uspto-refine.yaml b/configs/refine_recipe/pile-uspto-refine.yaml new file mode 100644 index 000000000..aa42f9354 --- /dev/null +++ b/configs/refine_recipe/pile-uspto-refine.yaml @@ -0,0 +1,58 @@ +# global parameters +project_name: 'Data-Juicer-recipes-uspto' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' # path to your dataset result file + +np: 50 # number of subprocess to process your dataset +open_tracer: true + +# process schedule +# a list of several process operators with their arguments +process: + - clean_email_mapper: + - clean_links_mapper: + - fix_unicode_mapper: + - punctuation_normalization_mapper: + - whitespace_normalization_mapper: + + - alphanumeric_filter: + tokenization: false + min_ratio: 0.7 # <3sigma (0.758) + - average_line_length_filter: # for code + max_len: 2000 # >3sigma (1307) + - character_repetition_filter: + rep_len: 10 + max_ratio: 0.2 # >3sigma (0.189) + - flagged_words_filter: + lang: en + tokenization: true + max_ratio: 0.0016 # 3sigma + - language_id_score_filter: + min_score: 0.6 + - maximum_line_length_filter: # for code + max_len: 3061 # 3sigma + - perplexity_filter: + lang: en + max_ppl: 4000 # 3sigma + - special_characters_filter: + max_ratio: 0.3 # > 3sigma (0.274) + - text_length_filter: + max_len: 21556 # 3sigma + - words_num_filter: + lang: en + tokenization: true + min_num: 100 + max_num: 6000 # 3sigma + - word_repetition_filter: + lang: en + tokenization: true + rep_len: 10 + max_ratio: 0.169 # 3sigma + + - document_simhash_deduplicator: + tokenization: space + window_size: 6 + lowercase: true + ignore_pattern: '\p{P}' + num_blocks: 6 + hamming_distance: 4 diff --git a/configs/refine_recipe/redpajama-arxiv-refine.yaml b/configs/refine_recipe/redpajama-arxiv-refine.yaml new file mode 100644 index 000000000..5ca00cc75 --- /dev/null +++ b/configs/refine_recipe/redpajama-arxiv-refine.yaml @@ -0,0 +1,58 @@ +# global parameters +project_name: 'Data-Juicer-arxivrecipes-arxiv' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' + +np: 50 # number of subprocess to process your dataset +open_tracer: true + +# process schedule +# a list of several process operators with their arguments +process: + - clean_email_mapper: + - clean_links_mapper: + - fix_unicode_mapper: + - punctuation_normalization_mapper: + - whitespace_normalization_mapper: + + - alphanumeric_filter: + tokenization: false + min_ratio: 0.516 # 3sigma + max_ratio: 0.915 # 3sigma + - average_line_length_filter: # for code + max_len: 682 # 3sigma + - character_repetition_filter: + rep_len: 10 + max_ratio: 0.3 + - flagged_words_filter: + lang: en + tokenization: true + max_ratio: 0.00076 # 3sigma + #- language_id_score_filter: # remove language filter + - maximum_line_length_filter: # for code + max_len: 4000 + - perplexity_filter: + lang: en + max_ppl: 8000 + - special_characters_filter: + max_ratio: 0.6 + - text_length_filter: + max_len: 350000 + - words_num_filter: + lang: en + tokenization: true + min_num: 20 + max_num: 100000 + - word_repetition_filter: + lang: en + tokenization: true + rep_len: 10 + max_ratio: 0.574 # 3sigma + + - document_simhash_deduplicator: + tokenization: space + window_size: 6 + lowercase: true + ignore_pattern: '\p{P}' + num_blocks: 6 + hamming_distance: 4 diff --git a/configs/refine_recipe/redpajama-book-refine.yaml b/configs/refine_recipe/redpajama-book-refine.yaml new file mode 100644 index 000000000..719ebed06 --- /dev/null +++ b/configs/refine_recipe/redpajama-book-refine.yaml @@ -0,0 +1,57 @@ +# global parameters +project_name: 'Data-Juicer-recipes-book' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' + +np: 50 # number of subprocess to process your dataset +open_tracer: true + +# process schedule +# a list of several process operators with their arguments +process: + - clean_email_mapper: + - clean_links_mapper: + - fix_unicode_mapper: + - punctuation_normalization_mapper: + - whitespace_normalization_mapper: + + - alphanumeric_filter: + tokenization: false + min_ratio: 0.55 # <3sigma (0.697) + max_ratio: 0.854 # 3sigma + - average_line_length_filter: # for code + max_len: 500 # >3sigma (364) + - character_repetition_filter: + rep_len: 10 + max_ratio: 0.2 # >3sigma (0.12) + - flagged_words_filter: + lang: en + tokenization: true + max_ratio: 0.00047 # 3sigma + - language_id_score_filter: # remove language filter + min_score: 0.2 + - maximum_line_length_filter: # for code + max_len: 13381 # 3sigma + - perplexity_filter: + lang: en + max_ppl: 6000 # <3sigma (16516) + - special_characters_filter: + max_ratio: 0.5 # >3sigma (0.32) + - words_num_filter: + lang: en + tokenization: true + min_num: 1000 + max_num: 539754 # 3sigma + - word_repetition_filter: + lang: en + tokenization: true + rep_len: 10 + max_ratio: 0.194 # 3sigma + + - document_simhash_deduplicator: + tokenization: space + window_size: 6 + lowercase: true + ignore_pattern: '\p{P}' + num_blocks: 6 + hamming_distance: 4 diff --git a/configs/refine_recipe/redpajama-c4-refine.yaml b/configs/refine_recipe/redpajama-c4-refine.yaml new file mode 100644 index 000000000..1e1c9cc8c --- /dev/null +++ b/configs/refine_recipe/redpajama-c4-refine.yaml @@ -0,0 +1,44 @@ +# global parameters +project_name: 'Data-Juicer-recipes-c4' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' # path to your dataset result file + +np: 50 # number of subprocess to process your dataset +open_tracer: True + +# process schedule +# a list of several process operators with their arguments +process: + - clean_email_mapper: + - clean_links_mapper: + - fix_unicode_mapper: + - punctuation_normalization_mapper: + - whitespace_normalization_mapper: + + - alphanumeric_filter: + tokenization: false + min_ratio: 0.65 # <3sigma (0.740) + max_ratio: 0.9 # >3sigma (0.867) + - average_line_length_filter: # for code + max_len: 3000 # >3sigma (1277) + - character_repetition_filter: + rep_len: 10 + max_ratio: 0.3 # >3sigma (0.168) + - language_id_score_filter: + min_score: 0.6 + - maximum_line_length_filter: # for code + max_len: 4000 # >3sigma (2017) + - perplexity_filter: + lang: en + max_ppl: 6000 #(>3sigma 4543) + - special_characters_filter: + max_ratio: 0.4 # > 3sigma (0.303) + - words_num_filter: + tokenization: true + min_num: 20 + max_num: 10000 + - word_repetition_filter: + lang: en + tokenization: true + rep_len: 10 + max_ratio: 0.231 # 3sigma diff --git a/configs/refine_recipe/redpajama-cc-2019-30-refine.yaml b/configs/refine_recipe/redpajama-cc-2019-30-refine.yaml new file mode 100644 index 000000000..b600181f5 --- /dev/null +++ b/configs/refine_recipe/redpajama-cc-2019-30-refine.yaml @@ -0,0 +1,60 @@ +# global parameters +project_name: 'Data-Juicer-recipes-cc-2019-30' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' + +np: 50 # number of subprocess to process your dataset +open_tracer: true + +# process schedule +# a list of several process operators with their arguments +process: + - document_simhash_deduplicator: + tokenization: space + window_size: 6 + lowercase: true + ignore_pattern: '\p{P}' + num_blocks: 6 + hamming_distance: 4 + + - clean_email_mapper: + - clean_links_mapper: + - fix_unicode_mapper: + - punctuation_normalization_mapper: + - whitespace_normalization_mapper: + + - alphanumeric_filter: # 770218 + tokenization: false + min_ratio: 0.7489 # 3sigma + max_ratio: 0.8585 # 3sigma + - average_line_length_filter: # for code + max_len: 1500 # < 3sigma (2689) -- 177520 + - character_repetition_filter: + rep_len: 10 + max_ratio: 0.3 # > 3sigma (0.1491) -- 151703 + - flagged_words_filter: + lang: en + tokenization: true + max_ratio: 0.0025 # 3sigma -- 101540 + - language_id_score_filter: # remove language filter + min_score: 0.788 # 3sigma -- 1622574 + - maximum_line_length_filter: # for code + max_len: 5000 # < 3sigma (8775) -- 485806 + - perplexity_filter: + lang: en + max_ppl: 5000 # < 3sigma (6723) -- 676914 + - special_characters_filter: + min_ratio: 0.15 # > 3sigma (0.104) + max_ratio: 0.35 # > 3sigma (0.322) -- 859797 + - text_length_filter: + max_len: 65589 # 3sigma -- 975142 + - words_num_filter: + lang: en + tokenization: true + min_num: 20 # > 3sigma -- 196 + max_num: 13030 # 3sigma -- 989078 + - word_repetition_filter: + lang: en + tokenization: true + rep_len: 10 + max_ratio: 0.279 # 3sigma -- 1716308 diff --git a/configs/refine_recipe/redpajama-cc-2020-05-refine.yaml b/configs/refine_recipe/redpajama-cc-2020-05-refine.yaml new file mode 100644 index 000000000..3d62e6a59 --- /dev/null +++ b/configs/refine_recipe/redpajama-cc-2020-05-refine.yaml @@ -0,0 +1,60 @@ +# global parameters +project_name: 'Data-Juicer-recipes-cc-2020-05' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' + +np: 50 # number of subprocess to process your dataset +open_tracer: true + +# process schedule +# a list of several process operators with their arguments +process: + - document_simhash_deduplicator: + tokenization: space + window_size: 6 + lowercase: true + ignore_pattern: '\p{P}' + num_blocks: 6 + hamming_distance: 4 + + - clean_email_mapper: + - clean_links_mapper: + - fix_unicode_mapper: + - punctuation_normalization_mapper: + - whitespace_normalization_mapper: + + - alphanumeric_filter: + tokenization: false + min_ratio: 0.7469 # 3sigma + max_ratio: 0.8609 # 3sigma + - average_line_length_filter: # for code + max_len: 1500 # < 3sigma -- 332621 + - character_repetition_filter: + rep_len: 10 + max_ratio: 0.3 # > 3sigma -- 170501 + - flagged_words_filter: + lang: en + tokenization: true + max_ratio: 0.002 # 3sigma -- 167167 + - language_id_score_filter: # remove language filter + min_score: 0.774 # 3sigma -- 1943513 + - maximum_line_length_filter: # for code + max_len: 5000 # < 3sigma -- 845490 + - perplexity_filter: + lang: en + max_ppl: 5000 # < 3sigma -- 909218 + - special_characters_filter: + min_ratio: 0.15 # > 3sigma + max_ratio: 0.35 # > 3sigma -- 1134347 + - text_length_filter: + max_len: 68161 # 3sigma -- 1145902 + - words_num_filter: + lang: en + tokenization: true + min_num: 20 # remove 7913 samples + max_num: 13644 # 3sigma -- 1148810 + - word_repetition_filter: + lang: en + tokenization: true + rep_len: 10 + max_ratio: 0.328 # 3sigma -- 2125070 diff --git a/configs/refine_recipe/redpajama-cc-2021-04-refine.yaml b/configs/refine_recipe/redpajama-cc-2021-04-refine.yaml new file mode 100644 index 000000000..5a48b7b07 --- /dev/null +++ b/configs/refine_recipe/redpajama-cc-2021-04-refine.yaml @@ -0,0 +1,60 @@ +# global parameters +project_name: 'Data-Juicer-recipes-cc-2021-04' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' + +np: 50 # number of subprocess to process your dataset +open_tracer: true + +# process schedule +# a list of several process operators with their arguments +process: + - document_simhash_deduplicator: + tokenization: space + window_size: 6 + lowercase: true + ignore_pattern: '\p{P}' + num_blocks: 6 + hamming_distance: 4 + + - clean_email_mapper: + - clean_links_mapper: + - fix_unicode_mapper: + - punctuation_normalization_mapper: + - whitespace_normalization_mapper: + + - alphanumeric_filter: + tokenization: false + min_ratio: 0.7494 # 3sigma + max_ratio: 0.8595 # 3sigma -- 1001790 + - average_line_length_filter: # for code + max_len: 1500 # < 3sigma (2817) -- 541131 + - character_repetition_filter: + rep_len: 10 + max_ratio: 0.3 # > 3sigma (0.1463) -- 159152 + - flagged_words_filter: + lang: en + tokenization: true + max_ratio: 0.0019 # 3sigma -- 184714 + - language_id_score_filter: # remove language filter + min_score: 0.786 # 3sigma -- 1995115 + - maximum_line_length_filter: # for code + max_len: 5000 # < 3sigma -- 1076085 + - perplexity_filter: + lang: en + max_ppl: 5000 # < 3sigma -- 906649 + - special_characters_filter: + min_ratio: 0.15 # > 3sigma + max_ratio: 0.35 # > 3sigma -- 1046590 + - text_length_filter: + max_len: 61592 # 3sigma -- 1114727 + - words_num_filter: + lang: en + tokenization: true + min_num: 20 # > 3sigma + max_num: 12241 # 3sigma -- 1120334 + - word_repetition_filter: + lang: en + tokenization: true + rep_len: 10 + max_ratio: 0.3105 # 3sigma -- 2234933 diff --git a/configs/refine_recipe/redpajama-cc-2022-05-refine.yaml b/configs/refine_recipe/redpajama-cc-2022-05-refine.yaml new file mode 100644 index 000000000..07495ea7b --- /dev/null +++ b/configs/refine_recipe/redpajama-cc-2022-05-refine.yaml @@ -0,0 +1,60 @@ +# global parameters +project_name: 'Data-Juicer-recipes-cc-2022-05' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' + +np: 50 # number of subprocess to process your dataset +open_tracer: true + +# process schedule +# a list of several process operators with their arguments +process: + - document_simhash_deduplicator: + tokenization: space + window_size: 6 + lowercase: true + ignore_pattern: '\p{P}' + num_blocks: 6 + hamming_distance: 4 + + - clean_email_mapper: + - clean_links_mapper: + - fix_unicode_mapper: + - punctuation_normalization_mapper: + - whitespace_normalization_mapper: + + - alphanumeric_filter: + tokenization: false + min_ratio: 0.7514 # 3sigma + max_ratio: 0.8577 # 3sigmai -- 888003 + - average_line_length_filter: # for code + max_len: 1500 # < 3sigma -- 447069 + - character_repetition_filter: + rep_len: 10 + max_ratio: 0.3 # > 3sigma -- 145890 samples + - flagged_words_filter: + lang: en + tokenization: true + max_ratio: 0.0012 # 3sigma -- 319395 + - language_id_score_filter: # remove language filter + min_score: 0.791 # 3sigma -- 1823528 + - maximum_line_length_filter: # for code + max_len: 5000 # < 3sigma -- 791612 + - perplexity_filter: + lang: en + max_ppl: 5000 # < 3sigma -- 654459 + - special_characters_filter: + min_ratio: 0.15 # > 3sigma + max_ratio: 0.35 # > 3sigma + - text_length_filter: + max_len: 59265 # 3sigma -- 1046590 + - words_num_filter: + lang: en + tokenization: true + min_num: 20 # > 3sigma + max_num: 11860 # 3sigma -- 1036780 + - word_repetition_filter: + lang: en + tokenization: true + rep_len: 10 + max_ratio: 0.3117 # 3sigma -- 2089703 diff --git a/configs/refine_recipe/redpajama-cc-2023-06-refine.yaml b/configs/refine_recipe/redpajama-cc-2023-06-refine.yaml new file mode 100644 index 000000000..f6a9996a8 --- /dev/null +++ b/configs/refine_recipe/redpajama-cc-2023-06-refine.yaml @@ -0,0 +1,60 @@ +# global parameters +project_name: 'Data-Juicer-recipes-cc-2013-06' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' + +np: 50 # number of subprocess to process your dataset +open_tracer: true + +# process schedule +# a list of several process operators with their arguments +process: + - document_simhash_deduplicator: + tokenization: space + window_size: 6 + lowercase: true + ignore_pattern: '\p{P}' + num_blocks: 6 + hamming_distance: 4 + + - clean_email_mapper: + - clean_links_mapper: + - fix_unicode_mapper: + - punctuation_normalization_mapper: + - whitespace_normalization_mapper: + + - alphanumeric_filter: + tokenization: false + min_ratio: 0.7508 # 3sigma + max_ratio: 0.8591 # 3sigma -- 1036821 + - average_line_length_filter: # for code + max_len: 1500 # < 3sigma -- 395868 + - character_repetition_filter: + rep_len: 10 + max_ratio: 0.3 # > 3sigma -- 195026 + - flagged_words_filter: + lang: en + tokenization: true + max_ratio: 0.0015 # 3sigma -- 287896 + - language_id_score_filter: # remove language filter + min_score: 0.793 # 3sigma -- 2173246 + - maximum_line_length_filter: # for code + max_len: 5000 # < 3sigma -- 797111 + - perplexity_filter: + lang: en + max_ppl: 5000 # 3sigma -- 942162 + - special_characters_filter: + min_ratio: 0.15 # > 3sigma + max_ratio: 0.35 # > 3sigma -- 1155090 + - text_length_filter: + max_len: 58187 # 3sigma -- 1165902 + - words_num_filter: + lang: en + tokenization: true + min_num: 20 + max_num: 11529 # 3sigma -- 1185363 + - word_repetition_filter: + lang: en + tokenization: true + rep_len: 10 + max_ratio: 0.2962 # 3sigma -- 2407282 diff --git a/configs/refine_recipe/redpajama-pile-stackexchange-refine.yaml b/configs/refine_recipe/redpajama-pile-stackexchange-refine.yaml new file mode 100644 index 000000000..d6400da8a --- /dev/null +++ b/configs/refine_recipe/redpajama-pile-stackexchange-refine.yaml @@ -0,0 +1,59 @@ +# global parameters +project_name: 'Data-Juicer-stack-exchange' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' + +np: 50 # number of subprocess to process your dataset +open_tracer: true + +# process schedule +# a list of several process operators with their arguments +process: + - clean_email_mapper: + - clean_links_mapper: + - fix_unicode_mapper: + - punctuation_normalization_mapper: + - whitespace_normalization_mapper: + + - alphanumeric_filter: + tokenization: false + min_ratio: 0.35 # <3sigma + max_ratio: 0.943 # 3sigma + - average_line_length_filter: # for code + min_len: 20 # >3sigma + max_len: 400 # >3sigma + - character_repetition_filter: + rep_len: 10 + max_ratio: 0.4 # >3sigma (0.12) + - flagged_words_filter: + lang: en + tokenization: true + max_ratio: 0.01 # >3sigma + - language_id_score_filter: # remove language filter + min_score: 0.1 # <3sigma + - maximum_line_length_filter: # for code + min_len: 80 + - perplexity_filter: + lang: en + max_ppl: 10000 # >3sigma + - special_characters_filter: + min_ratio: 0.232 # 3sigma + max_ratio: 0.7 # >3sigma + - text_length_filter: + min_len: 200 + - words_num_filter: + lang: en + tokenization: true + min_num: 100 + - word_repetition_filter: + lang: en + tokenization: true + rep_len: 10 + max_ratio: 0.8 # >3sigma + - document_simhash_deduplicator: #26309203 left + tokenization: space + window_size: 3 + lowercase: true + ignore_pattern: '\n\n' + num_blocks: 9 + hamming_distance: 7 diff --git a/configs/refine_recipe/redpajama-wiki-refine.yaml b/configs/refine_recipe/redpajama-wiki-refine.yaml new file mode 100644 index 000000000..36a41b88d --- /dev/null +++ b/configs/refine_recipe/redpajama-wiki-refine.yaml @@ -0,0 +1,59 @@ +# global parameters +project_name: 'Data-Juicer-recipes-wiki' +dataset_path: '/path/to/your/dataset' # path to your dataset directory or file +export_path: '/path/to/your/dataset.jsonl' + +np: 50 # number of subprocess to process your dataset +open_tracer: true + +# process schedule +# a list of several process operators with their arguments +process: + - clean_email_mapper: + - clean_links_mapper: + - fix_unicode_mapper: + - punctuation_normalization_mapper: + - whitespace_normalization_mapper: + + - alphanumeric_filter: + tokenization: false + min_ratio: 0.6 # <3sigma (0.735) + max_ratio: 0.884 # 3sigma + - average_line_length_filter: # for code + max_len: 192 # 3sigma + - character_repetition_filter: + rep_len: 10 + max_ratio: 0.4 # >3sigma (0.197) + - flagged_words_filter: + lang: en + tokenization: true + max_ratio: 0.0019 # 3sigma + - language_id_score_filter: + min_score: 0.689 # 3sigma + - maximum_line_length_filter: # for code + max_len: 1630 # 3sigma tbd + - perplexity_filter: + lang: en + max_ppl: 6887 # 3sigma + - special_characters_filter: + max_ratio: 0.5 # >3sigma (0.34) + - text_length_filter: + max_len: 18221 # 3sigma + - words_num_filter: + lang: en + tokenization: true + min_num: 20 + max_num: 6086 # 3sigma + - word_repetition_filter: + lang: en + tokenization: true + rep_len: 10 + max_ratio: 0.3 # 3sigma (0.194) + + - document_simhash_deduplicator: + tokenization: space + window_size: 6 + lowercase: true + ignore_pattern: '\p{P}' + num_blocks: 6 + hamming_distance: 4 diff --git a/data_juicer/.DS_Store b/data_juicer/.DS_Store new file mode 100644 index 000000000..a400607d0 Binary files /dev/null and b/data_juicer/.DS_Store differ diff --git a/data_juicer/__init__.py b/data_juicer/__init__.py new file mode 100644 index 000000000..b794fd409 --- /dev/null +++ b/data_juicer/__init__.py @@ -0,0 +1 @@ +__version__ = '0.1.0' diff --git a/data_juicer/analysis/__init__.py b/data_juicer/analysis/__init__.py new file mode 100644 index 000000000..78db975a1 --- /dev/null +++ b/data_juicer/analysis/__init__.py @@ -0,0 +1,2 @@ +from .column_wise_analysis import ColumnWiseAnalysis +from .overall_analysis import OverallAnalysis diff --git a/data_juicer/analysis/column_wise_analysis.py b/data_juicer/analysis/column_wise_analysis.py new file mode 100644 index 000000000..6d109aeb7 --- /dev/null +++ b/data_juicer/analysis/column_wise_analysis.py @@ -0,0 +1,286 @@ +import math +import os + +import matplotlib.pyplot as plt +import pandas as pd + +from .overall_analysis import OverallAnalysis + + +def get_row_col(total_num, factor=2): + """ + Given the total number of stats figures, get the "best" number of rows and + columns. This function is needed when we need to store all stats figures + into one image. + + :param total_num: Total number of stats figures + :param factor: Number of sub-figure types in each figure. In + default, it's 2, which means there are histogram and box plot + for each stat figure + :return: "best" number of rows and columns, and the grid list + """ + n = total_num * factor # actual number of figures + now_col = factor # search from the minimum number of columns + now_row = total_num + for col in range(factor, n + 1, factor): + row = n * 1.0 / col + if row != int(row): # skip non-integer results + continue + if col > row: + # object: minimum the difference between number of columns and rows + if abs(col - row) > abs(now_col - now_row): + break + else: + now_row = row + now_col = col + break + now_row = row + now_col = col + + # different sub-figures of the same stats should be in the same row + now_col = now_col // factor + + # get grid indexes + grids = [] + for i in range(total_num): + grids.append((i // now_col, i % now_col)) + + return int(now_row), int(now_col), grids + + +class ColumnWiseAnalysis: + """Apply analysis on each column of stats respectively.""" + + def __init__(self, + dataset, + output_path, + overall_result=None, + save_stats_in_one_file=True): + """ + Initialization method + :param dataset: the dataset to be analysed + :param output_path: path to store the analysis results + :param overall_result: optional precomputed overall stats result + :param save_stats_in_one_file: whether save all analysis figures of all + stats into one image file + """ + self.stats = pd.DataFrame(dataset['stats']) + self.output_path = output_path + if not os.path.exists(self.output_path): + os.makedirs(self.output_path) + + # if no overall description provided, analyse it from scratch + if overall_result is None: + oa = OverallAnalysis(dataset, output_path) + overall_result = oa.analyse() + self.overall_result = overall_result + + self.save_stats_in_one_file = save_stats_in_one_file + + def analyse(self, show_percentiles=False, show=False): + """ + Apply analysis and draw the analysis figure for stats. + + :param show_percentiles: whether to show the percentile line in + each sub-figure. If it's true, there will be several red + lines to indicate the quantiles of the stats distributions + :param show: whether to show in a single window after drawing + :return: + """ + # number of sub-figures for each stat. There are histogram and box plot + # for now, so it's 2. + num_subcol = 2 + + # Default width and height unit for each sub-figure + width_unit = 4 + height_unit = 6 + + columns = self.stats.columns + num = len(columns) + + # get the recommended "best" number of columns and rows + rec_row, rec_col, grid_indexes = get_row_col(num, num_subcol) + + if self.save_stats_in_one_file: + # if save_stats_in_one_file is opened, use recommended "best" + # number of columns and rows to initialize the image panel. + rec_width = rec_col * num_subcol * width_unit + rec_height = rec_row * height_unit + fig = plt.figure(figsize=(rec_width, rec_height), + layout='constrained') + subfigs = fig.subfigures(rec_row, rec_col, wspace=0.01) + for i, column_name in enumerate(columns): + data = self.stats[column_name] + grid = grid_indexes[i] + if self.save_stats_in_one_file: + if rec_col == 1: + grid = grid[0] + elif rec_row == 1: + grid = grid[1] + + if rec_col == 1 and rec_row == 1: + subfig = subfigs + else: + subfig = subfigs[grid] + subfig.set_facecolor('0.85') + + # numeric or string via nan. Apply different plot method for them. + if pd.isna(self.overall_result[column_name].get('top')): + # numeric -- draw histogram and box plot for this stat + percentiles = self.overall_result[column_name] \ + if show_percentiles else None + + # get axes for each subplot + if self.save_stats_in_one_file: + axes = subfig.subplots(1, num_subcol) + else: + axes = [None] * num_subcol + + # draw histogram + self.draw_hist(axes[0], + data, + os.path.join(self.output_path, + f'{column_name}-hist.png'), + percentiles=percentiles) + + # draw box + self.draw_box(axes[1], + data, + os.path.join(self.output_path, + f'{column_name}-box.png'), + percentiles=percentiles) + else: + # object (string) -- only draw histogram for this stat + if self.save_stats_in_one_file: + axes = subfig.subplots(1, 1) + else: + axes = None + + self.draw_hist( + axes, data, + os.path.join(self.output_path, f'{column_name}-hist.png')) + + # add a title to the figure of this stat + if self.save_stats_in_one_file: + subfig.suptitle(f'{data.name}', + fontsize='x-large', + fontweight='bold') + + if self.save_stats_in_one_file: + fig = plt.gcf() + fig.savefig(os.path.join(self.output_path, 'all-stats.png')) + if show: + plt.show() + else: + pass + # TODO: (fixme) the saved png sometime are blank + plt.clf() + + def draw_hist(self, ax, data, save_path, percentiles=None, show=False): + """ + Draw the histogram for the data. + + :param ax: the axes to draw + :param data: data to draw + :param save_path: the path to save the histogram figure + :param percentiles: the overall analysis result of the data + including percentile information + :param show: whether to show in a single window after drawing + :return: + """ + # recommended number of bins + data_num = len(data) + if data_num >= 100: + rec_bins = int(math.sqrt(len(data))) + else: + rec_bins = None + + # if ax is None, using plot method in pandas + if ax is None: + ax = data.hist(bins=rec_bins, figsize=(20, 16)) + else: + ax.hist(data, bins=rec_bins) + + # set axes + ax.set_xlabel(data.name) + ax.set_ylabel('Count') + + # draw percentile lines if it's not None + if percentiles is not None: + ymin, ymax = ax.get_ylim() + for percentile in percentiles.keys(): + # skip other information + if percentile in {'count', 'unique', 'top', 'freq', 'std'}: + continue + value = percentiles[percentile] + + ax.vlines(x=value, ymin=ymin, ymax=ymax, colors='r') + ax.text(x=value, y=ymax, s=percentile, rotation=30, color='r') + ax.text(x=value, + y=ymax * 0.97, + s=str(round(value, 3)), + rotation=30, + color='r') + + if not self.save_stats_in_one_file: + # save into file + plt.savefig(save_path) + + if show: + plt.show() + else: + # if no showing, we need to clear this axes to avoid + # accumulated overlapped figures in different draw_xxx function + # calling + ax.clear() + else: + # add a little rotation on labels of x axis to avoid overlapping + ax.tick_params(axis='x', rotation=25) + + def draw_box(self, ax, data, save_path, percentiles=None, show=False): + """ + Draw the box plot for the data. + + :param ax: the axes to draw + :param data: data to draw + :param save_path: the path to save the box figure + :param percentiles: the overall analysis result of the data + including percentile information + :param show: whether to show in a single window after drawing + :return: + """ + # if ax is None, using plot method in pandas + if ax is None: + ax = data.plot.box(figsize=(20, 16)) + else: + ax.boxplot(data) + + # set axes + ax.set_ylabel(data.name) + + # draw percentile lines if it's not None + if percentiles is not None: + xmin, xmax = ax.get_xlim() + for percentile in percentiles.keys(): + # skip other information + if percentile in {'count', 'unique', 'top', 'freq', 'std'}: + continue + value = percentiles[percentile] + + ax.hlines(y=value, xmin=xmin, xmax=xmax, colors='r') + ax.text(y=value, + x=xmin + (xmax - xmin) * 0.6, + s=f'{percentile}: {round(value, 3)}', + color='r') + + if not self.save_stats_in_one_file: + # save into file + plt.savefig(save_path) + + if show: + plt.show() + else: + # if no showing, we need to clear this axes to avoid + # accumulated overlapped figures in different draw_xxx function + # calling + ax.clear() diff --git a/data_juicer/analysis/diversity_analysis.py b/data_juicer/analysis/diversity_analysis.py new file mode 100644 index 000000000..dfddbc2fb --- /dev/null +++ b/data_juicer/analysis/diversity_analysis.py @@ -0,0 +1,187 @@ +import os + +import pandas as pd +import spacy +from loguru import logger + +diversity_models = {} + + +def load_diversity_model(model_name): + """ + Load diversity model according to the model name. + + :param model_name: name of diversity model + :return: loaded diversity model. + """ + nlp = spacy.load(model_name) + return nlp + + +def prepare_diversity_model(lang): + """ + Prepare diversity model for specific language. + + :param lang: language of diversity model. Should be one of ["zh", + "en"] + :return: corresponding diversity model + """ + assert lang in ['zh', 'en'], 'Diversity only support zh and en' + + # return loaded models directly + if lang in diversity_models.keys(): + return diversity_models[lang] + + model_name = lang + '_core_web_md' + diversity_models[lang] = load_diversity_model(model_name) + return diversity_models[lang] + + +# Modify from self_instruct, please refer to +# https://github.com/yizhongw/self-instruct/blob/main/self_instruct/instruction_visualize.ipynb +def find_root_verb_and_its_dobj(tree_root): + """ + Find the verb and its object closest to the root. + + :param tree_root: the root of lexical tree + :return: valid verb and its object. + """ + # first check if the current node and its children satisfy the condition + if tree_root.pos_ == 'VERB': + for child in tree_root.children: + if child.dep_ == 'dobj' and child.pos_ == 'NOUN': + return tree_root.lemma_ if len( + tree_root.lemma_) else tree_root.text, child.lemma_ if len( + child.lemma_) else child.text + return tree_root.lemma_ if len( + tree_root.lemma_) else tree_root.text, None + # if not, check its children + for child in tree_root.children: + return find_root_verb_and_its_dobj(child) + # if no children satisfy the condition, return None + return None, None + + +# Modify from self_instruct, please refer to +# https://github.com/yizhongw/self-instruct/blob/main/self_instruct/instruction_visualize.ipynb +def find_root_verb_and_its_dobj_in_string(nlp, s, first_sent=True): + """ + Find the verb and its object closest to the root of lexical tree of input + string. + + :param nlp: the diversity model to analyse the diversity strings + :param s: the string to be analysed + :param first_sent: whether to analyse the first sentence in the + input string only. If it's true, return the analysis result of + the first sentence no matter it's valid or not. If it's false, + return the first valid result over all sentences + :return: valid verb and its object of this string + """ + doc = nlp(s) + for sent in doc.sents: + verb, noun = find_root_verb_and_its_dobj(sent.root) + if first_sent or (verb is not None and noun is not None): + return verb, noun + return None, None + + +def get_diversity(dataset, top_k_verbs=20, top_k_nouns=4, **kwargs): + """ + Given the lexical tree analysis result, return the diversity results. + + :param dataset: lexical tree analysis result + :param top_k_verbs: only keep the top_k_verbs largest verb groups + :param top_k_nouns: only keep the top_k_nouns largest noun groups + for each verb group + :param kwargs: extra args + :return: the diversity results + """ + phrases = pd.DataFrame(dataset).dropna() + logger.info(f'find valid verb-noun structure \ + {phrases.shape[0]} of {dataset.shape[0]}') + top_verbs = phrases.groupby(['verb' + ]).size().nlargest(top_k_verbs).reset_index() + + df = phrases[phrases['verb'].isin(top_verbs['verb'].tolist())] + df = df.groupby(['verb', 'noun']).size().reset_index().rename(columns={ + 0: 'count' + }).sort_values(by=['count'], ascending=False) + + df = df.groupby('verb').apply(lambda x: x.sort_values( + 'count', ascending=False).head(top_k_nouns)).reset_index(drop=True) + return df + + +class DiversityAnalysis: + """Apply diversity analysis for each sample and get an overall analysis + result.""" + + def __init__(self, dataset, output_path, lang_or_model='en'): + """Initialization method :param dataset: the dataset to be analysed + :param output_path: path to store the analysis results :param + lang_or_model: the diversity model or a specific language used to load + the diversity model.""" + + self.dataset = dataset + self.output_path = output_path + if not os.path.exists(self.output_path): + os.makedirs(self.output_path) + self.lang_or_model = lang_or_model + + def compute(self, lang_or_model=None, column_name='text'): + """ + Apply lexical tree analysis on each sample. + + :param lang_or_model: the diversity model or a specific language + used to load the diversity model + :param column_name: the name of column to be analysed + :return: the analysis result. + """ + # load diversity model + lang_or_model = lang_or_model if lang_or_model else self.lang_or_model + if isinstance(lang_or_model, str): + diversity_model = prepare_diversity_model(lang_or_model) + else: + diversity_model = lang_or_model + + assert isinstance(diversity_model, spacy.Language) + + def find_verb_noun(sample): + try: + verb, noun = find_root_verb_and_its_dobj_in_string( + diversity_model, sample[column_name]) + except Exception as e: + print(str(e)) + verb, noun = None, None + return {'verb': verb, 'noun': noun} + + dataset = self.dataset.map(find_verb_noun) + return pd.DataFrame(dataset) + + def analyse(self, + lang_or_model=None, + column_name='text', + postproc_func=get_diversity, + **postproc_kwarg): + """ + Apply diversity analysis on the whole dataset. + + :param lang_or_model: the diversity model or a specific language + used to load the diversity model + :param column_name: the name of column to be analysed + :param postproc_func: function to analyse diversity. In default, + it's function get_diversity + :param postproc_kwarg: arguments of the postproc_func + :return: + """ + # get the lexical tree analysis result + raw_df = self.compute(lang_or_model=lang_or_model, + column_name=column_name) + # get the result of diversity analysis + df = postproc_func(raw_df, **postproc_kwarg) + + # export to result report file + df.to_csv(os.path.join(self.output_path, 'diversity.csv')) + df.to_markdown(os.path.join(self.output_path, 'diversity.md')) + + return df diff --git a/data_juicer/analysis/overall_analysis.py b/data_juicer/analysis/overall_analysis.py new file mode 100644 index 000000000..26e8073eb --- /dev/null +++ b/data_juicer/analysis/overall_analysis.py @@ -0,0 +1,41 @@ +import os + +import pandas as pd + + +class OverallAnalysis: + """Apply analysis on the overall stats, including mean, std, quantiles, + etc.""" + + def __init__(self, dataset, output_path): + """ + Initialization method. + + :param dataset: the dataset to be analysed + :param output_path: path to store the analysis results. + """ + self.stats = pd.DataFrame(dataset['stats']) + self.output_path = output_path + if not os.path.exists(self.output_path): + os.makedirs(self.output_path) + + # default percentiles to analyse + self.default_percentiles = [0.25, 0.5, 0.75] + + def analyse(self, percentiles=[]): + """ + Apply overall analysis on the whole dataset based on the describe + method of pandas. + + :param percentiles: percentiles to analyse + :return: the overall analysis result. + """ + # merge default and customized percentiles and get overall information + percentiles = list(set(percentiles + self.default_percentiles)) + overall = self.stats.describe(percentiles=percentiles, include='all') + + # export to result report file + overall.to_csv(os.path.join(self.output_path, 'overall.csv')) + overall.to_markdown(os.path.join(self.output_path, 'overall.md')) + + return overall diff --git a/data_juicer/config/__init__.py b/data_juicer/config/__init__.py new file mode 100644 index 000000000..853722dec --- /dev/null +++ b/data_juicer/config/__init__.py @@ -0,0 +1 @@ +from .config import * # noqa: F401,F403 diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py new file mode 100644 index 000000000..c67f13095 --- /dev/null +++ b/data_juicer/config/config.py @@ -0,0 +1,310 @@ +import os +import time +from argparse import ArgumentError +from typing import Dict, List, Optional, Tuple, Union + +from jsonargparse import (ActionConfigFile, ArgumentParser, dict_to_namespace, + namespace_to_dict) +from jsonargparse.typing import NonNegativeInt, PositiveInt +from loguru import logger + +from data_juicer.ops.base_op import OPERATORS +from data_juicer.utils.logger_utils import setup_logger + + +def init_configs(args=None): + """ + initialize the jsonargparse parser and parse configs from one of: + 1. POSIX-style commands line args; + 2. config files in yaml (json and jsonnet supersets); + 3. environment variables + 4. hard-coded defaults + + :param args: list of params, e.g., ['--conifg', 'cfg.yaml'], defaut None. + :return: a global cfg object used by the Executor or Analyser + """ + parser = ArgumentParser(default_env=True, default_config_files=None) + + parser.add_argument('--config', + action=ActionConfigFile, + help='Path to a configuration file.', + required=True) + + # basic global paras with extended type hints + # e.g., files can be mode include flags + # "fr": "path to a file that exists and is readable") + # "fc": "path to a file that can be created if it does not exist") + # "dw": "path to a directory that exists and is writeable") + # "dc": "path to a directory that can be created if it does not exist") + # "drw": "path to a directory that exists and is readable and writeable") + parser.add_argument('--project_name', + type=str, + default='hello_world', + help='Name of your data process project.') + parser.add_argument( + '--dataset_path', + type=str, + help='Path to datasets with optional weights(0.0-1.0), ' + '1.0 as default. Accepted format:' + ' dataset1-path dataset2-path ' + ' dataset3-path ...') + parser.add_argument( + '--export_path', + type=str, + default='./outputs/hello_world.jsonl', + help='Path to export and save the output processed dataset.' + ' The directory to store the processed dataset will be the ' + 'work directory of this process.') + parser.add_argument('--export_shard_size', + type=NonNegativeInt, + default=0, + help='Shard size of exported dataset in Byte. In ' + 'default, it\'s 0, which means export the whole ' + 'dataset into only one file. If it\'s set a ' + 'positive number, the exported dataset will be ' + 'split into several sub-dataset shards, and the max ' + 'size of each shard won\'t larger than the ' + 'export_shard_size') + parser.add_argument('--np', + type=PositiveInt, + default=4, + help='Number of processes to process dataset.') + parser.add_argument('--text_key_to_process', + type=Optional[str], + default='text', + help='Key name of field where the sample ' + 'texts to be processed, e.g., ' + '`text`, `text.instruction`, `text.output`, ...' + 'Note: currently, we support specify only ONE key for ' + 'each op, for cases requiring multiple keys, users can' + ' specify the op multiple times') + parser.add_argument('--text_keys_to_load', + type=Union[List[str], Tuple[str]], + default=['text'], + help='Key name of field where the sample ' + 'texts stored in the original data') + parser.add_argument('--suffixes', + type=Union[str, List[str], Tuple[str]], + default=[], + help='Suffixes of files that will be find and loaded. ' + 'If not set, we will find all suffix files, and select' + 'a suitable formatter with the most files as default.') + parser.add_argument('--use_cache', + type=bool, + default=True, + help='Whether to use the cache management of hugging' + 'face datasets. It might take up lots of disk ' + 'space when using cache') + parser.add_argument('--ds_cache_dir', + type=str, + default='~/.cache/huggingface/datasets', + help='Cache dir for HuggingFace datasets. In default ' + 'it\'s the default cache dir "~/.cache/huggingface/dat' + 'asets". If this argument is reset by users, it will ' + 'override the default cache dir.') + parser.add_argument('--use_checkpoint', + type=bool, + default=False, + help='Whether to use the checkpoint management to ' + 'save the latest version of dataset to work dir when ' + 'processing. Rerun the same config will reload the ' + 'checkpoint and skip ops before it. Cache will be ' + 'disabled when it is true . If args of ops ' + 'before the checkpoint are changed, all ops will be ' + 'rerun from the beginning.') + parser.add_argument('--temp_dir', + type=str, + default=None, + help='Path to the temp directory to store ' + 'intermediate caches when cache is disabled. In ' + 'default it\'s None, so the temp dir will be ' + 'specified by system. NOTICE: you should be caution ' + 'when setting this argument because it might cause ' + 'unexpected program behaviors when this path is set ' + 'to an unsafe directory.') + parser.add_argument('--open_tracer', + type=bool, + default=False, + help='Whether to open the tracer to trace samples' + 'changed during process. It might take more ' + 'time when opening tracer.') + parser.add_argument( + '--op_list_to_trace', + type=List[str], + default=[], + help='Which ops will be traced by tracer. ' + 'If it\'s empty, all ops in cfg.process will be traced. Only ' + 'available when open_tracer is true.') + parser.add_argument('--trace_num', + type=int, + default=10, + help='Number of samples extracted by tracer to show ' + 'the dataset difference before and after a op. ' + 'Only available when open_tracer is true.') + parser.add_argument( + '--process', + type=List[Dict], + help='List of several operators with their ' + 'arguments, these ops will be applied to dataset in order') + parser.add_argument('--save_stats_in_one_file', + type=bool, + default=False, + help='Whether to save all stats to only one file. Only' + ' used in Analysis.') + + # add all parameters of the registered ops class to the parser, + # and these op parameters can be modified through the command line, + ops_sorted_by_types = sort_op_by_types_and_names(OPERATORS.modules.items()) + _collect_config_info_from_class_docs(ops_sorted_by_types, parser) + + try: + cfg = parser.parse_args(args=args) + option_in_commands = [ + ''.join(arg.split('--')[1].split('.')[0]) for arg in parser.args + if '--' in arg and 'config' not in arg + ] + + full_option_in_commands = list( + set([ + ''.join(arg.split('--')[1].split('=')[0]) + for arg in parser.args if '--' in arg and 'config' not in arg + ])) + + if cfg.process is None: + cfg.process = [] + + # check and update every op params in `cfg.process` + # e.g. + # `python demo.py --config demo.yaml + # --language_id_score_filter.lang en` + for i, op_in_process in enumerate(cfg.process): + op_in_process_name = list(op_in_process.keys())[0] + + temp_cfg = cfg + if op_in_process_name not in option_in_commands: + + # update op params to temp cfg if set + if op_in_process[op_in_process_name]: + temp_cfg = parser.merge_config( + dict_to_namespace(op_in_process), cfg) + else: + + # args in the command line override the ones in `cfg.process` + for full_option_in_command in full_option_in_commands: + + key = full_option_in_command.split('.')[1] + if op_in_process[ + op_in_process_name] and key in op_in_process[ + op_in_process_name].keys(): + op_in_process[op_in_process_name].pop(key) + + if op_in_process[op_in_process_name]: + temp_cfg = parser.merge_config( + dict_to_namespace(op_in_process), temp_cfg) + + # update op params of cfg.process + internal_op_para = temp_cfg.get(op_in_process_name) + + cfg.process[i] = { + op_in_process_name: + None if internal_op_para is None else + namespace_to_dict(internal_op_para) + } + + cfg = init_setup_from_cfg(cfg) + return cfg + except ArgumentError: + logger.error('Config initialization failed') + + +def init_setup_from_cfg(cfg): + """ + Do some extra setup tasks after parsing config file or command line. + + 1. create working directory and a log directory + 2. update cache directory + 3. update checkpoint and `temp_dir` of tempfile + + :param cfg: a original cfg + :param cfg: a updated cfg + """ + + export_path = cfg.export_path + cfg.work_dir = os.path.dirname(export_path) + log_dir = os.path.join(cfg.work_dir, 'log') + if not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + logfile_name = time.strftime('%Y%m%d%H%M%S', time.localtime( + time.time())) + '.txt' + setup_logger(save_dir=log_dir, filename=logfile_name) + + # whether or not to use cache management + # disabling the cache or using checkpoint explicitly will turn off the + # cache management. + if not cfg.use_cache or cfg.use_checkpoint: + logger.warning('Cache management of datasets is disabled.') + from datasets import disable_caching + disable_caching() + + # when disabling cache, enable the temp_dir argument + logger.warning(f'Set temp directory to store temp files to ' + f'[{cfg.temp_dir}].') + import tempfile + if cfg.temp_dir is not None and not os.path.exists(cfg.temp_dir): + os.makedirs(cfg.temp_dir, exist_ok=True) + tempfile.tempdir = cfg.temp_dir + + # reset huggingface datasets cache directory + from datasets import config + config.HF_DATASETS_CACHE = cfg.ds_cache_dir + + # if there is suffix_filter op, turn on the add_suffix flag + cfg.add_suffix = False + for op in cfg.process: + op_name, _ = list(op.items())[0] + if op_name == 'suffix_filter': + cfg.add_suffix = True + break + + return cfg + + +def _collect_config_info_from_class_docs(configurable_ops, parser): + """ + Add ops and its params to parser for command line. + + :param configurable_ops: a list of ops to be to added, each item is + a pair of op_name and op_class + :param parser: jsonargparse parser need to update + """ + + for op_name, op_class in configurable_ops: + parser.add_class_arguments( + theclass=op_class, + nested_key=op_name, + fail_untyped=False, + instantiate=False, + ) + + +def sort_op_by_types_and_names(op_name_classes): + """ + Split ops items by op type and sort them to sub-ops by name, then concat + together. + + :param op_name_classes: a list of op modules + :return: sorted op list , each item is a pair of op_name and + op_class + """ + + mapper_ops = [(name, c) for (name, c) in op_name_classes + if 'mapper' in name] + filter_ops = [(name, c) for (name, c) in op_name_classes + if 'filter' in name] + deduplicator_ops = [(name, c) for (name, c) in op_name_classes + if 'deduplicator' in name] + selector_ops = [(name, c) for (name, c) in op_name_classes + if 'selector' in name] + ops_sorted_by_types = sorted(mapper_ops) + sorted(filter_ops) + sorted( + deduplicator_ops) + sorted(selector_ops) + return ops_sorted_by_types diff --git a/data_juicer/core/__init__.py b/data_juicer/core/__init__.py new file mode 100644 index 000000000..cf712d21a --- /dev/null +++ b/data_juicer/core/__init__.py @@ -0,0 +1,5 @@ +from .analyser import Analyser +from .data import NestedDataset +from .executor import Executor +from .exporter import Exporter +from .tracer import Tracer diff --git a/data_juicer/core/analyser.py b/data_juicer/core/analyser.py new file mode 100644 index 000000000..e88be1d72 --- /dev/null +++ b/data_juicer/core/analyser.py @@ -0,0 +1,118 @@ +import os + +from loguru import logger + +from data_juicer.analysis import ColumnWiseAnalysis, OverallAnalysis +from data_juicer.config import init_configs +from data_juicer.format import load_formatter +from data_juicer.ops import Filter, load_ops + +from .exporter import Exporter + + +class Analyser: + """ + This Analyser class is used to analyse a specific dataset. + + It will compute stats for all filter ops in the config file, apply + multiple analysis (e.g. OverallAnalysis, ColumnWiseAnalysis, etc.) + on these stats, and generate the analysis results (stats tables, + distribution figures, etc.) to help users understand the input + dataset better. + """ + + def __init__(self, cfg=None): + """ + Initialization method. + + :param cfg: optional config dict. + """ + self.cfg = init_configs() if cfg is None else cfg + + self.work_dir = self.cfg.work_dir + self.ops = None + + # setup formatter + logger.info('Setting up data formatter...') + self.formatter = load_formatter(self.cfg.dataset_path, + self.cfg.text_keys_to_load, + self.cfg.suffixes, self.cfg.add_suffix) + + # prepare exporter and check export path suffix + # NOTICE: no need to export dataset texts for analyser + # (export_ds=False). Instead, only need to export stats + # (export_stats=True). + logger.info('Preparing exporter...') + self.exporter = Exporter(self.cfg.export_path, + self.cfg.export_shard_size, + self.cfg.np, + export_ds=False, + export_stats=True) + + # parsed_res + self.overall_result = None + self.overall_single_plot_path = None + self.analysis_path = os.path.join(self.cfg.work_dir, 'analysis') + + def run(self, load_data_np=None): + """ + Running the dataset analysis pipeline. + + :param load_data_np: number of workers when loading the dataset. + :return: analysed dataset. + """ + # 1. format data + logger.info('Loading dataset from data formatter...') + if load_data_np is None: + load_data_np = self.cfg.np + dataset = self.formatter.load_dataset(load_data_np, self.cfg) + + # extract processes + logger.info('Preparing process operators...') + self.ops = load_ops(self.cfg.process, self.cfg.text_key_to_process) + + # 2. stats precompute only for filter ops + logger.info('Computing the stats of dataset...') + stats_collected = False + for op_cfg, op in zip(self.cfg.process, self.ops): + op_name = list(op_cfg.keys())[0] + if isinstance(op, Filter): + if 'stats' not in dataset.features: + # TODO: + # this is a temp solution, + # only add stats when calling filter op + dataset = dataset.add_column(name='stats', + column=[{}] * + dataset.num_rows) + dataset = dataset.map(op.compute_stats, + num_proc=self.cfg.np, + desc=op_name + '_compute_stats') + stats_collected = True + if not stats_collected: + logger.warning('No stats collected. Please add some Filter ops to ' + 'the process list in configs.') + return dataset + + # 3. analysis and output result to the export path + # 3.1. Only consider fields in 'stats' + # 3.2. For string fields, only consider its histogram + # 3.3. For numeric fields, consider its histogram and box + # 3.4. Otherwise, DO NOT analyse + + logger.info('Applying overall analysis on stats...') + overall_analysis = OverallAnalysis(dataset, self.analysis_path) + self.overall_result = overall_analysis.analyse() + + logger.info('Applying column-wise analysis on stats...') + column_wise_analysis = ColumnWiseAnalysis( + dataset, + self.analysis_path, + overall_result=self.overall_result, + save_stats_in_one_file=self.cfg.save_stats_in_one_file) + column_wise_analysis.analyse() + + # 4. data export + logger.info('Exporting dataset to disk...') + self.exporter.export(dataset) + + return dataset diff --git a/data_juicer/core/data.py b/data_juicer/core/data.py new file mode 100644 index 000000000..0b91a290c --- /dev/null +++ b/data_juicer/core/data.py @@ -0,0 +1,210 @@ +import copy +from functools import wraps +from typing import Union + +from datasets import Dataset, DatasetDict +from datasets.formatting.formatting import LazyBatch +from loguru import logger + + +def wrap_func_with_nested_access(f): + """ + Before conducting actual function `f`, wrap its args and kargs into nested + ones. + + :param f: function to be wrapped. + :return: wrapped function + """ + + def wrap_nested_structure(*args, **kargs): + wrapped_args = [nested_obj_factory(arg) for arg in args] + wrapped_kargs = { + k: nested_obj_factory(arg) + for k, arg in kargs.items() + } + return wrapped_args, nested_obj_factory(wrapped_kargs) + + @wraps(f) + def wrapped_f(*args, **kargs): + args, kargs = wrap_nested_structure(*args, **kargs) + # to ensure the args passing to the final calling of f can be nested, + # in case of deeper-order wrapper funcs de-wrap this nesting behavior + args = [ + wrap_func_with_nested_access(arg) if callable(arg) else arg + for arg in args + ] + kargs = { + k: (wrap_func_with_nested_access(arg) if callable(arg) else arg) + for (k, arg) in kargs.items() + } + return f(*args, **kargs) + + return wrapped_f + + +def nested_obj_factory(obj): + """ + Use nested classes to wrap the input object. + + :param obj: object to be nested. + :return: nested object + """ + if isinstance(obj, Dataset): + return NestedDataset(obj) + elif isinstance(obj, DatasetDict): + return NestedDatasetDict(obj) + elif isinstance(obj, dict): + return NestedQueryDict(obj) + elif isinstance(obj, LazyBatch): + obj.data = NestedQueryDict(obj.data) + return obj + elif isinstance(obj, list): + return [nested_obj_factory(item) for item in obj] + else: + return obj + + +class NestedQueryDict(dict): + """Enhanced dict for better usability.""" + + def __init__(self, *args, **kargs): + if len(args) == 1 and isinstance(args[0], Dataset): + # init from another DatasetDict instance + self.__dict__ = copy.copy(args[0].__dict__) + else: + # init from scratch + super().__init__(*args, **kargs) + + # batched sample, (k & v) are organized by list manner + for k, v in self.items(): + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], dict): + self[k] = [NestedQueryDict(item) for item in v] + + def __getitem__(self, key): + return nested_query(self, key) + + +class NestedDatasetDict(DatasetDict): + """Enhanced HuggingFace-DatasetDict for better usability and efficiency.""" + + def __init(self, *args, **kargs): + if len(args) == 1 and isinstance(args[0], Dataset): + # init from another DatasetDict instance + self.__dict__ = copy.copy(args[0].__dict__) + else: + # init from scratch + super().__init__(*args, **kargs) + + def __getitem__(self, key): + return nested_query(self, key) + + def map(self, **args): + """Override the map func, which is called by most common operations, + such that the processed samples can be accessed by nested manner.""" + if args['function'] is None: + args['function'] = lambda x: nested_obj_factory(x) + else: + args['function'] = wrap_func_with_nested_access(args['function']) + + return super().map(**args) + + +class NestedDataset(Dataset): + """Enhanced HuggingFace-Dataset for better usability and efficiency.""" + + def __init__(self, *args, **kargs): + if len(args) == 1 and isinstance(args[0], Dataset): + # init from another Dataset instance + self.__dict__ = copy.copy(args[0].__dict__) + else: + # init from scratch + super().__init__(*args, **kargs) + + def __getitem__(self, key): + if isinstance(key, str): + # to index columns by query as string name(s) + res = nested_query(self, key) + else: + # to index rows by query as integer index, slices, + # or iter of indices or bools + res = super().__getitem__(key) + return nested_obj_factory(res) + + def map(self, *args, **kargs): + """Override the map func, which is called by most common operations, + such that the processed samples can be accessed by nested manner.""" + if args: + args = list(args) + # the first positional para is function + if args[0] is None: + args[0] = lambda x: nested_obj_factory(x) + else: + args[0] = wrap_func_with_nested_access(args[0]) + else: + if kargs['function'] is None: + kargs['function'] = lambda x: nested_obj_factory(x) + else: + kargs['function'] = wrap_func_with_nested_access( + kargs['function']) + + return NestedDataset(super().map(*args, **kargs)) + + @classmethod + def from_dict(cls, *args, **kargs): + """Override the from_dict func, which is called by most from_xx + constructors, such that the constructed dataset object is + NestedDataset.""" + return NestedDataset(super().from_dict(*args, **kargs)) + + def select(self, *args, **kargs): + """Override the select fun, such that selected samples can be accessed + by nested manner.""" + return nested_obj_factory(super().select(*args, **kargs)) + + +def nested_query(root_obj: Union[NestedDatasetDict, NestedDataset, + NestedQueryDict], key): + """ + Find item from a given object, by first checking flatten layer, then + checking nested layers. + + :param root_obj: the object + :param key: the stored item to be queried, e.g., "meta" or + "meta.date" + :return: + """ + subkeys = key.split('.') + + tmp = root_obj + for i in range(len(subkeys)): + try: + key_to_query = '.'.join(subkeys[i:len(subkeys)]) + if isinstance(tmp, + (NestedQueryDict, NestedDataset, NestedDatasetDict)): + # access field using base_class's func to avoid endless loop + res = super(type(tmp), tmp).__getitem__(key_to_query) + elif isinstance(tmp, list): + # NestedDataset may return multiple rows as list + res = [nested_query(item, key_to_query) for item in tmp] + else: + # NestedQueryDict may return single row + res = tmp[key_to_query] + if res is not None: + return res + except Exception as outer_get_error: + exist_in_dict = issubclass(type(tmp), dict) and \ + '.'.join(subkeys[i:i + 1]) in tmp + exist_in_dataset = issubclass(type(tmp), Dataset) and '.'.join( + subkeys[i:i + 1]) in tmp.features + if exist_in_dict or exist_in_dataset: + # dive into next level + tmp = nested_obj_factory(tmp['.'.join(subkeys[i:i + 1])]) + else: + logger.debug( + f'cannot find item given key={key} in dataset=' + f'{root_obj}. For the final caught outer-exception,' + f'type is: {type(outer_get_error)}, ' + f'info is: {outer_get_error}') + return None + + return None diff --git a/data_juicer/core/executor.py b/data_juicer/core/executor.py new file mode 100644 index 000000000..b3834d061 --- /dev/null +++ b/data_juicer/core/executor.py @@ -0,0 +1,171 @@ +import os + +from loguru import logger + +from data_juicer.config import init_configs +from data_juicer.format.load import load_formatter +from data_juicer.ops import (OPERATORS, Deduplicator, Filter, Mapper, Selector, + load_ops) +from data_juicer.utils.ckpt_utils import CheckpointManager + +from .exporter import Exporter +from .tracer import Tracer + + +class Executor: + """ + This Executor class is used to process a specific dataset. + + It will load the dataset and unify the format, then apply all the + ops in the config file in order and generate a processed dataset. + """ + + def __init__(self, cfg=None): + """ + Initialization method. + + :param cfg: optional config dict. + """ + self.cfg = init_configs() if cfg is None else cfg + + self.work_dir = self.cfg.work_dir + + self.ops = None + + # setup formatter + logger.info('Setting up data formatter...') + self.formatter = load_formatter(self.cfg.dataset_path, + self.cfg.text_keys_to_load, + self.cfg.suffixes, self.cfg.add_suffix) + + # whether to use checkpoint mechanism. If it's true, Executor will + # check if there are existing checkpoints first and try to load the + # checkpoints. If the checkpoints are loaded successfully, ops that + # have been processed will be skipped. + self.process_list = self.cfg.process + if self.cfg.use_checkpoint: + logger.info('Preparing checkpoint manager...') + self.ckpt_dir = os.path.join(self.work_dir, 'ckpt') + self.ckpt_manager = CheckpointManager(self.ckpt_dir, + self.process_list, + self.cfg.np) + if self.ckpt_manager.ckpt_available: + logger.info('Found existed dataset checkpoint.') + self.process_list = self.ckpt_manager.get_left_process_list() + self.cfg.process = self.process_list + + # prepare exporter and check export path suffix + logger.info('Preparing exporter...') + self.exporter = Exporter(self.cfg.export_path, + self.cfg.export_shard_size, self.cfg.np) + + # setup tracer + self.open_tracer = self.cfg.open_tracer + if self.open_tracer: + logger.info('Preparing tracer...') + self.tracer = Tracer(self.work_dir, show_num=self.cfg.trace_num) + self.op_list_to_trace = self.cfg.op_list_to_trace + if len(self.cfg.op_list_to_trace) == 0: + logger.info('Trace for all ops.') + self.op_list_to_trace = set(OPERATORS.modules.keys()) + + def run(self, load_data_np=None): + """ + Running the dataset process pipeline. + + :param load_data_np: number of workers when loading the dataset. + :return: processed dataset. + """ + # 1. format data + if self.cfg.use_checkpoint and self.ckpt_manager.ckpt_available: + logger.info('Loading dataset from checkpoint...') + dataset = self.ckpt_manager.load_ckpt() + else: + logger.info('Loading dataset from data formatter...') + if load_data_np is None: + load_data_np = self.cfg.np + dataset = self.formatter.load_dataset(load_data_np, self.cfg) + + # 2. extract processes + logger.info('Preparing process operators...') + self.ops = load_ops(self.cfg.process, self.cfg.text_key_to_process) + + # 3. data process + # - If tracer is open, trace each op after it's processed + # - If checkpoint is open, clean the cache files after each process + logger.info('Processing data...') + for op_cfg, op in zip(self.process_list, self.ops): + op_name, op_args = list(op_cfg.items())[0] + prev = dataset # record last dataset + try: + if isinstance(op, Mapper): + tmp = dataset.map(op.process, + num_proc=self.cfg.np, + desc=op_name + '_process') + if self.open_tracer and op_name in self.op_list_to_trace: + self.tracer.trace_mapper(op_name, dataset, tmp) + elif isinstance(op, Filter): + if 'stats' not in dataset.features: + # TODO: + # this is a temp solution, + # only add stats when calling filter op + dataset = dataset.add_column(name='stats', + column=[{}] * + dataset.num_rows) + if self.cfg.use_checkpoint: + dataset.cleanup_cache_files() + prev = dataset + dataset = dataset.map(op.compute_stats, + num_proc=self.cfg.np, + desc=op_name + '_compute_stats') + if self.cfg.use_checkpoint: + dataset.cleanup_cache_files() + prev = dataset + tmp = dataset.filter(op.process, + num_proc=self.cfg.np, + desc=op_name + '_process') + if self.open_tracer and op_name in self.op_list_to_trace: + self.tracer.trace_filter(op_name, dataset, tmp) + elif isinstance(op, Selector): + tmp = op.process(dataset) + if self.open_tracer and op_name in self.op_list_to_trace: + self.tracer.trace_filter(op_name, dataset, tmp) + elif isinstance(op, Deduplicator): + dataset = dataset.map(op.compute_hash, + num_proc=self.cfg.np, + desc=op_name + '_compute_hash') + if self.cfg.use_checkpoint: + dataset.cleanup_cache_files() + prev = dataset + tmp, dup_pairs = op.process( + dataset, self.tracer.show_num if self.open_tracer + and op_name in self.op_list_to_trace else 0) + if self.open_tracer and op_name in self.op_list_to_trace: + self.tracer.trace_deduplicator(op_name, dup_pairs) + else: + raise NotImplementedError + dataset = tmp + except: # noqa: E722 + logger.error(f'An error occurred during Op [{op_name}].') + import traceback + traceback.print_exc() + if self.cfg.use_checkpoint: + logger.info('Writing checkpoint of dataset processed by ' + 'last op...') + prev.cleanup_cache_files() + self.ckpt_manager.save_ckpt(prev) + exit(1) + + # clean up cache files and record processed ops + if self.cfg.use_checkpoint: + dataset.cleanup_cache_files() + self.ckpt_manager.record(op_name, op_args) + + logger.info(f'Op [{op_name}] Done. Left ' + f'{len(dataset)} samples.') + + # 4. data export + logger.info('Exporting dataset to disk...') + self.exporter.export(dataset) + + return dataset diff --git a/data_juicer/core/exporter.py b/data_juicer/core/exporter.py new file mode 100644 index 000000000..d76d6ab78 --- /dev/null +++ b/data_juicer/core/exporter.py @@ -0,0 +1,202 @@ +import os +from multiprocessing import Pool + +from loguru import logger + + +class Exporter: + """The Exporter class is used to export a dataset to files of specific + format.""" + + KiB = 2**10 # 1024 + MiB = 2**20 # 1024*1024 + GiB = 2**30 # 1024*1024*1024 + TiB = 2**40 # 1024*1024*1024*1024 + + def __init__(self, + export_path, + export_shard_size=0, + num_proc=1, + export_ds=True, + export_stats=True): + """ + Initialization method. + + :param export_path: the path to export datasets. + :param export_shard_size: the size of each shard of exported + dataset. In default, it's 0, which means export the dataset + to a single file. + :param num_proc: number of process to export the dataset. + :param export_ds: whether to export the dataset contents. + :param export_stats: whether to export the stats of dataset. + """ + self.export_path = export_path + self.export_shard_size = export_shard_size + self.export_ds = export_ds + self.export_stats = export_stats + self.suffix = self._get_suffix(export_path) + self.num_proc = num_proc + self.max_shard_size_str = '' + + # get the string format of shard size + if self.export_shard_size // Exporter.TiB: + self.max_shard_size_str = '%.2f TiB' % (self.export_shard_size / + Exporter.TiB) + elif self.export_shard_size // Exporter.GiB: + self.max_shard_size_str = '%.2f GiB' % (self.export_shard_size / + Exporter.GiB) + elif self.export_shard_size // Exporter.MiB: + self.max_shard_size_str = '%.2f MiB' % (self.export_shard_size / + Exporter.MiB) + elif self.export_shard_size // Exporter.KiB: + self.max_shard_size_str = '%.2f KiB' % (self.export_shard_size / + Exporter.KiB) + else: + self.max_shard_size_str = '%.2f Bytes' % (self.export_shard_size) + + # we recommend users to set a shard size between MiB and TiB. + if 0 < self.export_shard_size < Exporter.MiB: + logger.warning(f'The export_shard_size [{self.max_shard_size_str}]' + f' is less than 1MiB. If the result dataset is too ' + f'large, there might be too many shard files to ' + f'generate.') + if self.export_shard_size >= Exporter.TiB: + logger.warning(f'The export_shard_size [{self.max_shard_size_str}]' + f' is larger than 1TiB. It might generate large ' + f'single shard file and make loading and exporting ' + f'slower.') + + def _get_suffix(self, export_path): + """ + Get the suffix of export path and check if it's supported. + + We only support ["jsonl", "json", "parquet"] for now. + + :param export_path: the path to export datasets. + :return: the suffix of export_path. + """ + suffix = export_path.split('.')[-1].lower() + support_dict = self._router() + if suffix not in support_dict: + raise NotImplementedError(f'Suffix of export path [' + f'{export_path}] is not supported ' + f'for now. Only support ' + f'{list(support_dict.keys())}.') + return suffix + + def _export_impl(self, dataset, export_path, suffix, export_stats=True): + """ + Export a dataset to specific path. + + :param dataset: the dataset to export. + :param export_path: the path to export the dataset. + :param suffix: suffix of export path. + :param export_stats: whether to export stats of dataset. + :return: + """ + if self.export_ds: + # fetch the corresponding export method according to the suffix + export_method = Exporter._router()[suffix] + if self.export_shard_size <= 0: + # export the whole dataset into one single file. + logger.info('Export dataset into 1 file...') + export_method(dataset, export_path) + else: + # compute the dataset size and number of shards to split + if dataset._indices is not None: + dataset_nbytes = dataset.data.nbytes * len( + dataset._indices) / len(dataset.data) + else: + dataset_nbytes = dataset.data.nbytes + num_shards = int(dataset_nbytes / self.export_shard_size) + 1 + num_shards = min(num_shards, len(dataset)) + + # split the dataset into multiple shards + logger.info(f'Split the dataset to export into {num_shards} ' + f'shards. Size of each shard <= ' + f'{self.max_shard_size_str}') + shards = [ + dataset.shard(num_shards=num_shards, + index=i, + contiguous=True) for i in range(num_shards) + ] + len_num = len(str(num_shards)) + 1 + num_fmt = f'%0{len_num}d' + + # regard the export path as a directory and set file names for + # each shard + dirname = os.path.dirname(os.path.abspath(self.export_path)) + basename = os.path.basename(self.export_path).split('.')[0] + os.makedirs(dirname, exist_ok=True) + filenames = [ + os.path.join( + dirname, f'{basename}-{num_fmt % index}-of-' + f'{num_fmt % num_shards}' + f'.{self.suffix}') for index in range(num_shards) + ] + + # export dataset into multiple shards using multiprocessing + logger.info(f'Start to exporting to {num_shards} shards.') + pool = Pool(self.num_proc) + for i in range(num_shards): + pool.apply_async(export_method, + args=( + shards[i], + filenames[i], + )) + pool.close() + pool.join() + + if 'stats' in dataset.features and export_stats: + # export stats of datasets into a single file. + ds_stats = dataset.select_columns('stats').flatten() + stats_file = export_path.replace('.' + suffix, '_stats.jsonl') + Exporter.to_jsonl(ds_stats, stats_file) + + def export(self, dataset): + """ + Export method for a dataset. + + :param dataset: the dataset to export. + :return: + """ + self._export_impl(dataset, self.export_path, self.suffix, + self.export_stats) + + @staticmethod + def to_jsonl(dataset, export_path, **kwargs): + """ + Export method for json/jsonl target files. + + :param dataset: the dataset to export. + :param export_path: the path to store the exported dataset. + :param kwargs: extra arguments. + :return: + """ + dataset.to_json(export_path, force_ascii=False) + + @staticmethod + def to_parquet(dataset, export_path, **kwargs): + """ + Export method for parquet target files. + + :param dataset: the dataset to export. + :param export_path: the path to store the exported dataset. + :param kwargs: extra arguments. + :return: + """ + dataset.to_parquet(export_path) + + # suffix to export method + @staticmethod + def _router(): + """ + A router from different suffixes to corresponding export methods. + + :return: A dict router. + """ + return { + 'jsonl': Exporter.to_jsonl, + 'json': Exporter.to_jsonl, + 'parquet': Exporter.to_parquet, + } diff --git a/data_juicer/core/tracer.py b/data_juicer/core/tracer.py new file mode 100644 index 000000000..b9ae1f104 --- /dev/null +++ b/data_juicer/core/tracer.py @@ -0,0 +1,183 @@ +import os + +import pandas as pd +from datasets import Dataset +from loguru import logger + + +class Tracer: + """ + The tracer to trace the sample changes before and after an operator + process. + + The comparison results will be stored in the work directory. + """ + + def __init__(self, work_dir, show_num=10): + """ + Initialization method. + + :param work_dir: the work directory to store the comparison + results + :param show_num: the maximum number of samples to show in the + comparison result files. + """ + self.work_dir = os.path.join(work_dir, 'trace') + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + self.show_num = show_num + + def trace_mapper(self, op_name: str, previous_ds: Dataset, + processed_ds: Dataset): + """ + Compare datasets before and after a Mapper. + + This will mainly show the different sample pairs due to the + modification by the Mapper + + :param op_name: the op name of mapper + :param previous_ds: dataset before the mapper process + :param processed_ds: dataset processed by the mapper + :return: + """ + assert len(previous_ds) == len(processed_ds) + dif_dict = [] + num = 0 + + # Find different samples orderly between previous and processed + # datasets until the total number of found sample pairs is enough. + for i in range(len(previous_ds)): + previous_sample = previous_ds[i]['text'] + processed_sample = processed_ds[i]['text'] + if previous_sample != processed_sample: + dif_dict.append({ + 'original text': previous_sample, + 'processed_text': processed_sample, + }) + num += 1 + if num >= self.show_num: + break + + if len(dif_dict) == 0: + logger.warning(f'Datasets before and after op [{op_name}] are all ' + f'the same. Thus no comparison results would be ' + f'generated.') + return + elif len(dif_dict) < self.show_num: + logger.warning(f'There are {len(dif_dict)} different samples ' + f'before and after op [{op_name}] -- less than ' + f'expected {self.show_num} samples.') + + # export the tracer results. + res_name = f'mapper-{op_name}.jsonl' + dif_df = pd.DataFrame(dif_dict) + dif_df.to_json(os.path.join(self.work_dir, res_name), + orient='records', + lines=True, + force_ascii=False) + + def trace_filter(self, op_name: str, previous_ds: Dataset, + processed_ds: Dataset): + """ + Compare datasets before and after a Filter. + + This will mainly show the filtered samples by the Filter + + :param op_name: the op name of filter + :param previous_ds: dataset before the filter process + :param processed_ds: dataset processed by the filter + :return: + """ + if len(previous_ds) == len(processed_ds): + logger.warning(f'Datasets before and after op [{op_name}] are all ' + f'the same. Thus no comparison results would be ' + f'generated.') + return + + # get the number of filtered samples. + total_dif_num = len(previous_ds) - len(processed_ds) + # index of the current sample in the previous dataset + i = 0 + filter_dict = [] + # number of found filtered samples. It's the offset bewteen two + # datasets as well. + num = 0 + while i < len(previous_ds): + if i - num >= len(processed_ds) or \ + previous_ds[i] != processed_ds[i - num]: + # 1. If all samples in processed dataset are checked but there + # still some samples left in the previous dataset, all of these + # left samples are filtered. + # 2. If the corresponding samples in previous and processed + # datasets are different, samples in the previous dataset are + # filtered. + num += 1 + filter_dict.append(previous_ds[i]) + if num >= self.show_num or num >= total_dif_num: + # If the total number of found filtered samples is enough or we + # have found all filtered samples, just stop. + break + i += 1 + if len(filter_dict) == 0: + logger.warning(f'Datasets before and after op [{op_name}] are all ' + f'the same. Thus no comparison results would be ' + f'generated.') + return + elif len(filter_dict) < self.show_num: + logger.warning(f'There are {len(filter_dict)} filtered samples ' + f'before and after op [{op_name}] -- less than ' + f'expected {self.show_num} samples.') + + # export the tracer results. + res_name = f'filter-{op_name}.jsonl' + filter_df = pd.DataFrame(filter_dict) + filter_df.to_json(os.path.join(self.work_dir, res_name), + orient='records', + lines=True, + force_ascii=False) + + def trace_deduplicator(self, op_name: str, dup_pairs: list): + """ + Compare datasets before and after a Deduplicator. + + This will mainly show the near-duplicate sample pairs extracted + by the Deduplicator. Different from the other two trace methods, + the trace process for deduplicator is embedded into the process + method of deduplicator, but the other two trace methods are + independent of the process method of mapper and filter operators + + :param op_name: the op name of deduplicator + :param dup_pairs: duplicate sample pairs obtained from + deduplicator + :return: + """ + if dup_pairs is None: + logger.warning(f'Op [{op_name}] does not generate dup_pairs ' + f'correctly, thus no comparison results can be ' + f'obtained from this op.') + return + if len(dup_pairs) == 0: + logger.warning(f'Datasets before and after op [{op_name}] are all ' + f'the same. Thus no comparison results would be ' + f'generated.') + return + elif len(dup_pairs) < self.show_num: + logger.warning(f'There are {len(dup_pairs)} filtered samples ' + f'before and after op [{op_name}] -- less than ' + f'expected {self.show_num} samples.') + + # reorganize the duplicate pairs + dup_dict = [] + for key in dup_pairs: + dup_dict.append({ + 'dup1': dup_pairs[key][0], + 'dup2': dup_pairs[key][1], + }) + + # export the tracer result. + res_name = f'duplicate-{op_name}.jsonl' + dup_df = pd.DataFrame(dup_dict) + dup_df.to_json(os.path.join(self.work_dir, res_name), + orient='records', + lines=True, + force_ascii=False) diff --git a/data_juicer/format/__init__.py b/data_juicer/format/__init__.py new file mode 100644 index 000000000..cd2e10de0 --- /dev/null +++ b/data_juicer/format/__init__.py @@ -0,0 +1,3 @@ +from . import (csv_formatter, json_formatter, mixture_formatter, + parquet_formatter, text_formatter, tsv_formatter) +from .load import load_formatter diff --git a/data_juicer/format/csv_formatter.py b/data_juicer/format/csv_formatter.py new file mode 100644 index 000000000..24df156e0 --- /dev/null +++ b/data_juicer/format/csv_formatter.py @@ -0,0 +1,26 @@ +from .formatter import FORMATTERS, LocalFormatter + + +@FORMATTERS.register_module() +class CsvFormatter(LocalFormatter): + """ + The class is used to load and format csv-type files. + + Default suffixes is `['.csv']` + """ + SUFFIXES = ['.csv'] + + def __init__(self, dataset_path, suffixes=None, **kwargs): + """ + Initialization method. + + :param dataset_path: a dataset file or a dataset directory + :param suffixes: files with specified suffixes to be processed + :param kwargs: extra args + """ + super().__init__( + dataset_path=dataset_path, + suffixes=suffixes if suffixes else self.SUFFIXES, + type='csv', + **kwargs, + ) diff --git a/data_juicer/format/formatter.py b/data_juicer/format/formatter.py new file mode 100644 index 000000000..68dd5036c --- /dev/null +++ b/data_juicer/format/formatter.py @@ -0,0 +1,376 @@ +import os +from typing import List, Tuple, Union + +from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset +from jsonargparse import Namespace +from loguru import logger + +from data_juicer.utils.file_utils import (find_files_with_suffix, + is_absolute_path) +from data_juicer.utils.registry import Registry + +FORMATTERS = Registry('Formatters') + + +class BaseFormatter: + """Base class to load dataset.""" + + def load_dataset(self, *args) -> Dataset: + raise NotImplementedError + + +class LocalFormatter(BaseFormatter): + """The class is used to load a dataset from local files or local + directory.""" + + def __init__( + self, + dataset_path: str, + type: str, + suffixes: Union[str, List[str], Tuple[str]] = None, + text_keys_to_load: List[str] = None, + add_suffix=False, + **kwargs, + ): + """ + Initialization method. + + :param dataset_path: path to a dataset file or a dataset + directory + :param type: a packaged dataset module type (json, csv, etc.) + :param suffixes: files with specified suffixes to be processed + :param text_keys_to_load: key names of field that stores sample + text. + :param add_suffix: whether to add the file suffix to dataset + meta info + :param kwargs: extra args + """ + if text_keys_to_load is None: + text_keys_to_load = ['text'] + self.type = type + self.kwargs = kwargs + self.text_keys_to_load = text_keys_to_load + self.data_files = find_files_with_suffix(dataset_path, suffixes) + self.add_suffix = add_suffix + + def load_dataset(self, + num_proc: int = 1, global_cfg: Namespace = None) -> \ + Dataset: + """ + Load a dataset from dataset file or dataset directory, and unify its + format. + + :param num_proc: number of processes when loading the dataset + :param global_cfg: global cfg used in consequent processes, + :return: formatted dataset + """ + datasets = load_dataset(self.type, + data_files={ + key.strip('.'): self.data_files[key] + for key in self.data_files + }, + num_proc=num_proc, + **self.kwargs) + if self.add_suffix: + logger.info('Add suffix info into dataset...') + datasets = add_suffixes(datasets) + else: + datasets = concatenate_datasets([ds for _, ds in datasets.items()]) + ds = unify_format(datasets, + text_keys_to_load=self.text_keys_to_load, + num_proc=num_proc, + global_cfg=global_cfg) + return ds + + +class RemoteFormatter(BaseFormatter): + """The class is used to load a dataset from repository of huggingface + hub.""" + + def __init__(self, + dataset_path: str, + text_keys_to_load: List[str] = None, + **kwargs): + """ + Initialization method. + + :param dataset_path: a dataset file or a dataset directory + :param text_keys_to_load: key names of field that stores sample + text. + :param kwargs: extra args + """ + self.path = dataset_path + self.text_keys_to_load = text_keys_to_load + self.kwargs = kwargs + + def load_dataset(self, + num_proc: int = 1, + global_cfg: Namespace = None) -> Dataset: + """ + Load a dataset from HuggingFace, and unify its format. + + :param num_proc: number of processes when loading the dataset + :param global_cfg: the global cfg used in consequent processes, + :return: formatted dataset + """ + ds = load_dataset(self.path, + split='train', + num_proc=num_proc, + **self.kwargs) + ds = unify_format(ds, + text_keys_to_load=self.text_keys_to_load, + num_proc=num_proc, + global_cfg=global_cfg) + return ds + + +def add_suffixes(datasets: DatasetDict) -> Dataset: + """ + Add suffix filed to datasets. + + :param datasets: a DatasetDict object + :return: datasets with suffix features. + """ + logger.info('Add suffix column for dataset') + for key, ds in datasets.items(): + if 'suffix' in ds.features: + ds = ds.rename_column('suffix', '__original__suffix') + datasets[key] = ds.add_column(name='suffix', + column=['.' + key] * ds.num_rows) + datasets = concatenate_datasets([ds for _, ds in datasets.items()]) + return datasets + + +def rename_ops_args_text_key(cfg, original_text_key, target_text_key): + """ + Rename some ops args in cfg from a source text key to target text key. + + :param cfg: the global cfg used in consequent processes, + :param original_text_key: source text key in op args :param modified + cfg object + """ + if not hasattr(cfg, 'process'): + return + process_list = cfg.process + for i in range(len(process_list)): + for op_name in process_list[i]: + if process_list[i][op_name] and \ + process_list[i][op_name]['text_key'] == original_text_key: + process_list[i][op_name]['text_key'] = target_text_key + + +def unify_format( + dataset: Dataset, + text_keys_to_load: List[str] = None, + num_proc: int = 1, + global_cfg: Namespace = None, +) -> Dataset: + """ + Get an unified internal format, conduct the following modifications. + + 1. based on the given keys, unifying the key name of sample text to + + 1.1. 'text' (for single column case) + + 1.2. 'text.keys[0]', 'text.keys[i]', ... (for multiple column case) + + 2. filter out those samples with empty or None text + + 3. combining all remaining fields except 'stats' into meta related fields, + users can access them by ds['meta'], ds['meta']['xx'], or ds['meta.xx'] + + 4. add 'stats' field into dataset + + As a result, the dataset will being with the unified format such as: + + >>> { + >>> 'text': 'hello-world', + >>> 'text.instruction': "Let's think step by step.", + >>> "meta": {"date": 2012} + >>> 'meta.src": "customized", + >>> "meta.version": "0.1", + >>> 'stats': { + >>> "lang": "en", + >>> "lang_score": 0.965 + >>> } + >>> } + + :param dataset: input dataset + :param text_keys_to_load: original text key(s) of dataset + :param num_proc: number of processes for mapping + :param global_cfg: the global cfg used in consequent processes, + since cfg.text_key_to_process may need to be modified after unifying + + :return: unified_format_dataset + """ + if isinstance(dataset, DatasetDict): + datasets = list(dataset.values()) + assert len(datasets) == 1, 'Please make sure the passed datasets ' \ + 'contains only 1 dataset' + dataset = datasets[0] + assert isinstance(dataset, Dataset), 'Currently we only support ' \ + 'processing data with ' \ + "'huggingface-Dataset format'" + + if text_keys_to_load is None: + text_keys_to_load = ['text'] + logger.info('Unifying the input dataset formats...') + final_text_related_keys = set() + + from data_juicer.core.data import NestedDataset + dataset = NestedDataset(dataset) + + # 1. unify text related keys + for key in text_keys_to_load: + if key not in dataset.features: + err_msg = f'There is no key [{key}] in dataset. You might set ' \ + f'wrong text_key in the config file for your dataset. ' \ + f'Please check and retry!' + logger.error(err_msg) + raise ValueError(err_msg) + # 1.1 (single-column case) + if len(text_keys_to_load) == 1: + # rename the specified key into 'text' + if 'text' not in dataset.features: + dataset = dataset.rename_column(key, 'text') + rename_ops_args_text_key(global_cfg, key, 'text') + logger.info(f'The field `{key}` has been renamed into `text`') + if global_cfg and key == global_cfg.text_key_to_process: + global_cfg.text_key_to_process = 'text' + elif key == 'text': # text' in dataset.features + # There is 'text' field, we regard it as the real text field. + # DO NOT need to unify + pass + else: # if 'text' in dataset.features and keys[0] != 'text' + # There is 'text' field, but we need another field as the + # real text field. + # We need to put the original 'text' field into meta and + # rename this key field to 'text' field + dataset = dataset.rename_column('text', 'text.original') + rename_ops_args_text_key(global_cfg, 'text', 'text.original') + logger.info('The field `text` has been renamed into ' + '`text.original`') + if global_cfg and 'text' == global_cfg.text_key_to_process: + global_cfg.text_key_to_process = 'text.original' + dataset = dataset.rename_column(key, 'text') + rename_ops_args_text_key(global_cfg, key, 'text') + logger.info(f'The field `{key}` has been renamed into `text`') + if global_cfg and key == global_cfg.text_key_to_process: + global_cfg.text_key_to_process = 'text' + final_text_related_keys.add('text.original') + # Finally, the dataset contains a column named 'text' + final_text_related_keys.add('text') + else: + # 1.2 (multiple-column case) + dataset = dataset.rename_column(key, f'text.{key}') + rename_ops_args_text_key(global_cfg, key, f'text.{key}') + logger.info(f'The field `{key}` has been renamed into ' + f'`text.{key}`') + if global_cfg and key == global_cfg.text_key_to_process: + global_cfg.text_key_to_process = f'text.{key}' + final_text_related_keys.add(f'text.{key}') + + # 2. filter out those samples with empty or None text + # TODO: optimize the filtering operation for better efficiency + logger.info(f'There are {len(dataset)} sample(s) in the original dataset.') + + dataset.cleanup_cache_files() + + def non_empty_text(sample, target_keys): + for target_key in target_keys: + # TODO: case for SFT, in which the len(sample[target_key]) == 0 + if sample[target_key] is None: + # we filter out the samples contains at least None column + # since the op can not handle it now + return False + return True + + dataset = dataset.filter( + non_empty_text, + num_proc=num_proc, + fn_kwargs={'target_keys': list(final_text_related_keys)}) + logger.info(f'{len(dataset)} samples left after filtering empty text.') + dataset.cleanup_cache_files() + + # 3. combine other fields with 'meta' prefix + # 3.1 the original 'meta' field will be remained, + # 3.2 the 'stats' field will be reserved as a dict + remain_root_keys = set(dataset.features.keys()) - set( + final_text_related_keys) - {'stats'} - {'meta'} + for key in remain_root_keys: + dataset = dataset.rename_column(key, f'meta.{key}') + logger.info(f'The field `{key}` has been renamed into `meta.{key}`') + dataset.cleanup_cache_files() + + if 'stats' in set(dataset.features.keys()) - set( + final_text_related_keys) and \ + not isinstance(dataset.features['stats'], dict): + # put the original non-dict field into meta + dataset = dataset.rename_column('stats', 'meta.stats') + + dataset.cleanup_cache_files() + # 4. add 'stats' field + # TODO: + # this is a temp solution, + # it will occur errors when only call mapper ops + # dataset = dataset.add_column( \ + # name='stats', column=[{}] * dataset.num_rows) + + return dataset + + +def load_formatter(dataset_path, + keys_to_load=None, + suffixes=None, + add_suffix=False, + **kwargs) -> BaseFormatter: + """ + Load the appropriate formatter for different types of data formats. + + :param dataset_path: Path to dataset file or dataset directory + :param keys_to_load: key names of field that stores sample text. + Default: ['text'] + :param suffixes: the suffix of files that will be read. Default: + None + :return: a dataset formatter. + """ + if keys_to_load is None: + keys_to_load = ['text'] + if suffixes is None: + suffixes = [] + ext_num = {} + if os.path.isdir(dataset_path) or os.path.isfile(dataset_path): + file_dict = find_files_with_suffix(dataset_path, suffixes) + if not file_dict: + raise IOError( + 'Unable to find files matching the suffix from {}'.format( + dataset_path)) + for ext in file_dict: + ext_num[ext] = len(file_dict[ext]) + + # local dataset + if ext_num: + formatter_num = {} + for name, formatter in FORMATTERS.modules.items(): + formatter_num[name] = 0 + for ext in ext_num: + if ext in formatter.SUFFIXES: + formatter_num[name] += ext_num[ext] + formatter = max(formatter_num, key=lambda x: formatter_num[x]) + target_suffixes = set(ext_num.keys()).intersection( + set(FORMATTERS.modules[formatter].SUFFIXES)) + return FORMATTERS.modules[formatter](dataset_path, + text_keys_to_load=keys_to_load, + suffixes=target_suffixes, + add_suffix=add_suffix, + **kwargs) + + # try huggingface dataset hub + elif not is_absolute_path(dataset_path) and dataset_path.count('/') <= 1: + return RemoteFormatter(dataset_path, + text_keys_to_load=keys_to_load, + **kwargs) + + # no data + else: + raise NotImplementedError diff --git a/data_juicer/format/json_formatter.py b/data_juicer/format/json_formatter.py new file mode 100644 index 000000000..8caa3390c --- /dev/null +++ b/data_juicer/format/json_formatter.py @@ -0,0 +1,26 @@ +from .formatter import FORMATTERS, LocalFormatter + + +@FORMATTERS.register_module() +class JsonFormatter(LocalFormatter): + """ + The class is used to load and format json-type files. + + Default suffixes is `['.json', '.jsonl', '.jsonl.zst']` + """ + SUFFIXES = ['.json', '.jsonl', '.jsonl.zst'] + + def __init__(self, dataset_path, suffixes=None, **kwargs): + """ + Initialization method. + + :param dataset_path: a dataset file or a dataset directory + :param suffixes: files with specified suffixes to be processed + :param kwargs: extra args + """ + super().__init__( + dataset_path=dataset_path, + suffixes=suffixes if suffixes else self.SUFFIXES, + type='json', + **kwargs, + ) diff --git a/data_juicer/format/load.py b/data_juicer/format/load.py new file mode 100644 index 000000000..4dc68a512 --- /dev/null +++ b/data_juicer/format/load.py @@ -0,0 +1,29 @@ +from .formatter import BaseFormatter +from .mixture_formatter import MixtureFormatter + + +def load_formatter(dataset_path, + keys_to_load=None, + suffixes=[], + add_suffix=False, + **kwargs) -> BaseFormatter: + """ + Load mixture formatter for multiple different data formats with an optional + weight(default 1.0) according to their formats. + + :param dataset_path: path to a dataset file or a dataset directory + :param keys_to_load: key names of field that stores sample text. + Default: ['text'] + :param suffixes: files with specified suffixes to be processed. + :param add_suffix: whether to add the file suffix to dataset meta + info + :return: a dataset formatter. + """ + if keys_to_load is None: + keys_to_load = ['text'] + formatter = MixtureFormatter(dataset_path=dataset_path, + keys_to_load=keys_to_load, + suffixes=suffixes, + add_suffix=add_suffix, + **kwargs) + return formatter diff --git a/data_juicer/format/mixture_formatter.py b/data_juicer/format/mixture_formatter.py new file mode 100644 index 000000000..953eb5a7a --- /dev/null +++ b/data_juicer/format/mixture_formatter.py @@ -0,0 +1,107 @@ +from typing import List, Tuple, Union + +import numpy as np +from datasets import Dataset, concatenate_datasets +from jsonargparse import Namespace +from loguru import logger + +from .formatter import BaseFormatter, load_formatter + + +class MixtureFormatter(BaseFormatter): + """The class mixes multiple datasets by randomly selecting samples from + every dataset and merging them, and then exports the merged datasset as a + new mixed dataset.""" + + def __init__(self, + dataset_path: str, + suffixes: Union[str, List[str], Tuple[str]] = None, + keys_to_load=None, + add_suffix=False, + **kwargs): + """ + Initialization method. + + :param dataset_path: a dataset file or a dataset dir or a list + of them, optional weights, default 1.0 e.g. ` ds.jsonl + ds_dir ds_file.json` + :param suffixes: files with specified suffixes to be processed + :param keys_to_load: key names of field that stores sample text. + :param add_suffix: whether to add the file suffix to dataset + meta info + :param kwargs: extra args + """ + if keys_to_load is None: + keys_to_load = ['text'] + data_prefixes, weights = self._get_weight(data_prefix=dataset_path) + self.weights = weights + self.formatters = [ + load_formatter(dataset_path=data_prefix, + suffixes=suffixes, + keys_to_load=keys_to_load, + add_suffix=add_suffix, + **kwargs) for data_prefix in data_prefixes + ] + + def _get_weight(self, data_prefix): + """ + Split every dataset path and its weight. + + :param data_prefix: a dataset file or a dataset dir or a list of + them, e.g. ` ds1.jsonl ds2_dir ds3_file.json` + :return: list of dataset path and list of weights + """ + data_prefix = data_prefix.split() + weights = [] + prefixes = [] + + for i in range(len(data_prefix)): + try: + value = float(data_prefix[i]) + weights.append(value) + except: # noqa: E722 + value = data_prefix[i].strip() + + # if not set weight, use 1.0 as default + if i == 0 or len(weights) == len(prefixes): + weights.append(1.0) + prefixes.append(value) + return prefixes, weights + + def _random_sample(self, dataset, weight=1.0, seed=None): + """ + Randomly sample a subset from a dataset with weight. + :param dataset: a HuggingFace dataset + :param weight: sample ratio of dataset + :param seed: random sample seed, if None, 42 as default + :return: a subset of dataset + """ + if seed is None: + seed = 42 + num_samples = min(int(np.ceil(dataset.num_rows * weight)), + dataset.num_rows) + if num_samples == dataset.num_rows: + return dataset + return dataset.shuffle(seed=seed).select(range(num_samples)) + + def load_dataset(self, + num_proc: int = 1, + global_cfg: Namespace = None) -> Dataset: + """ + Load a mixed dataset. + + :param num_proc: number of processes when loading the dataset + :param global_cfg: the global cfg used in consequent processes, + :return: mixed dataset + """ + dataset_list = [] + for weight, formatter in zip(self.weights, self.formatters): + dataset = formatter.load_dataset(num_proc, global_cfg) + sampled = self._random_sample(dataset, weight) + logger.info(f'sampled {len(sampled)} from ' + f'{len(dataset)} with weight {weight}') + dataset_list.append(sampled) + + mixed_dataset = concatenate_datasets(dataset_list) + logger.info(f'There are {len(mixed_dataset)} in final dataset') + return mixed_dataset diff --git a/data_juicer/format/parquet_formatter.py b/data_juicer/format/parquet_formatter.py new file mode 100644 index 000000000..0d44a2cee --- /dev/null +++ b/data_juicer/format/parquet_formatter.py @@ -0,0 +1,26 @@ +from .formatter import FORMATTERS, LocalFormatter + + +@FORMATTERS.register_module() +class ParquetFormatter(LocalFormatter): + """ + The class is used to load and format parquet-type files. + + Default suffixes is `['.parquet']` + """ + SUFFIXES = ['.parquet'] + + def __init__(self, dataset_path, suffixes=None, **kwargs): + """ + Initialization method. + + :param dataset_path: a dataset file or a dataset directory + :param suffixes: files with specified suffixes to be processed + :param kwargs: extra args + """ + super().__init__( + dataset_path=dataset_path, + suffixes=suffixes if suffixes else self.SUFFIXES, + type='parquet', + **kwargs, + ) diff --git a/data_juicer/format/text_formatter.py b/data_juicer/format/text_formatter.py new file mode 100644 index 000000000..8fd2e45c8 --- /dev/null +++ b/data_juicer/format/text_formatter.py @@ -0,0 +1,157 @@ +import os +from multiprocessing import Pool + +import pdfplumber +from datasets import Dataset, concatenate_datasets, load_dataset +from docx import Document +from loguru import logger + +from data_juicer.utils.cache_utils import DATA_JUICER_CACHE_HOME +from data_juicer.utils.file_utils import find_files_with_suffix + +from .formatter import FORMATTERS, LocalFormatter, add_suffixes, unify_format + + +def extract_txt_from_docx(fn, tgt_path): + """ + Extract text from a docx file and save to target path. + + :param fn: path to input pdf file + :param tgt_path: path to save text file. + """ + doc = Document(fn) + text = [para.text for para in doc.paragraphs if para.text.strip()] + base_fn = os.path.basename(fn).lower().replace('.docx', '.txt') + with open(os.path.join(tgt_path, base_fn), 'w') as f: + f.write('\n'.join(text)) + + +def extract_txt_from_pdf(fn, tgt_path): + """ + Extract text from a pdf file and save to target path. + + :param fn: path to input pdf file + :param tgt_path: path to save text file. + """ + with pdfplumber.open(fn) as pdf: + text = [] + for page in pdf.pages: + # remove tables from each page extracted by pdfplumber + tables = page.find_tables() + for table in tables: + page = page.outside_bbox(table.bbox) + # remove page number from the end of each page + page_text = page.extract_text() + page_num = str(page.page_number) + if page_text.rstrip().endswith(page_num): + page_text = page_text.rstrip()[:-len(page_num)] + if page_text.strip(): + text.append(page_text) + base_fn = os.path.basename(fn).lower().replace('.pdf', '.txt') + with open(os.path.join(tgt_path, base_fn), 'w') as f: + f.write('\n'.join(text)) + + +@FORMATTERS.register_module() +class TextFormatter(LocalFormatter): + """ + The class is used to load and format text-type files. + + e.g. `['.txt', '.pdf', '.cpp', '.docx']` + """ + + SUFFIXES = [ + '.docx', '.pdf', '.txt', '.md', '.tex', '.asm', '.bat', '.cmd', '.c', + '.h', '.cs', '.cpp', '.hpp', '.c++', '.h++', '.cc', '.hh', '.C', '.H', + '.cmake', '.css', '.dockerfile', '.f90', '.f', '.f03', '.f08', '.f77', + '.f95', '.for', '.fpp', '.go', '.hs', '.html', '.java', '.js', '.jl', + '.lua', '.markdown', '.php', '.php3', '.php4', '.php5', '.phps', + '.phpt', '.pl', '.pm', '.pod', '.perl', '.ps1', '.psd1', '.psm1', + '.py', '.rb', '.rs', '.sql', '.scala', '.sh', '.bash', '.command', + '.zsh', '.ts', '.tsx', '.vb', 'Dockerfile', 'Makefile', '.xml', '.rst', + '.m', '.smali' + ] + + def __init__(self, + dataset_path, + suffixes=None, + add_suffix=False, + **kwargs): + """ + Initialization method. + + :param dataset_path: a dataset file or a dataset directory + :param suffixes: files with specified suffixes to be processed + :param add_suffix: Whether to add file suffix to datase meta + info + :param kwargs: extra args + """ + super().__init__( + dataset_path=dataset_path, + suffixes=suffixes if suffixes else self.SUFFIXES, + type='text', + add_suffix=add_suffix, + **kwargs, + ) + self.dataset_path = dataset_path + self.add_suffix = add_suffix + + def load_dataset(self, num_proc: int = 1) -> Dataset: + """ + Load a dataset from local text-type files. + + :param num_proc: number of processes when loading the dataset + :return: unified_format_dataset. + """ + # extract text to cache directory + extracted_dataset_path = os.path.join( + DATA_JUICER_CACHE_HOME, + os.path.basename(os.path.abspath(self.dataset_path))) + + for file_type in self.data_files: + + # extract text from docx or pdf files, and save as txt type + if file_type == '.docx' or file_type == '.pdf': + extracted_filetype_path = os.path.join(extracted_dataset_path, + file_type.strip('.')) + if not os.path.exists(extracted_filetype_path): + os.makedirs(extracted_filetype_path) + logger.info('Extracting text from {} files...'.format( + file_type.strip('.'))) + + extract_func = extract_txt_from_docx \ + if file_type == '.docx' else extract_txt_from_pdf + pool = Pool(num_proc) + for data_file in self.data_files[file_type]: + pool.apply_async(func=extract_func, + args=( + data_file, + extracted_filetype_path, + )) + pool.close() + pool.join() + logger.info(f'Extracted text files are stored in directory ' + f'{extracted_filetype_path}') + + # look for extracted txt files + self.data_files[file_type] = find_files_with_suffix( + extracted_filetype_path, '.txt')['.txt'] + + # load text dataset, one text file as one sample + datasets = load_dataset('text', + data_files={ + key.strip('.'): self.data_files[key] + for key in self.data_files + }, + sample_by='document', + num_proc=num_proc, + **self.kwargs) + # whether to add file suffix to datase meta info + if self.add_suffix: + logger.info('Add suffix info into dataset...') + datasets = add_suffixes(datasets) + else: + datasets = concatenate_datasets([ds for _, ds in datasets.items()]) + return unify_format(datasets, + text_keys_to_load=self.text_keys_to_load, + num_proc=num_proc) diff --git a/data_juicer/format/tsv_formatter.py b/data_juicer/format/tsv_formatter.py new file mode 100644 index 000000000..eb681f11a --- /dev/null +++ b/data_juicer/format/tsv_formatter.py @@ -0,0 +1,27 @@ +from .formatter import FORMATTERS, LocalFormatter + + +@FORMATTERS.register_module() +class TsvFormatter(LocalFormatter): + """ + The class is used to load and format tsv-type files. + + Default suffixes is `['.tsv']` + """ + SUFFIXES = ['.tsv'] + + def __init__(self, dataset_path, suffixes=None, **kwargs): + """ + Initialization method. + + :param dataset_path: a dataset file or a dataset directory + :param suffixes: files with specified suffixes to be processed + :param kwargs: extra args, e.g. `delimiter = ','` + """ + super().__init__( + dataset_path=dataset_path, + suffixes=suffixes if suffixes else self.SUFFIXES, + type='csv', + delimiter='\t', + **kwargs, + ) diff --git a/data_juicer/ops/.DS_Store b/data_juicer/ops/.DS_Store new file mode 100644 index 000000000..a6bca63bd Binary files /dev/null and b/data_juicer/ops/.DS_Store differ diff --git a/data_juicer/ops/__init__.py b/data_juicer/ops/__init__.py new file mode 100644 index 000000000..c35fc22bb --- /dev/null +++ b/data_juicer/ops/__init__.py @@ -0,0 +1,3 @@ +from . import deduplicator, filter, mapper, selector +from .base_op import OPERATORS, Deduplicator, Filter, Mapper, Selector +from .load import load_ops diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py new file mode 100644 index 000000000..391a18d5c --- /dev/null +++ b/data_juicer/ops/base_op.py @@ -0,0 +1,126 @@ +from data_juicer.utils.registry import Registry + +OPERATORS = Registry('Operators') + + +class Mapper: + + def __init__(self, text_key: str = None): + """ + Base class that conducts text editing. + + :param text_key: the key name of field that stores sample texts + to be processed. + """ + if text_key is None: + text_key = 'text' + self.text_key = text_key + from data_juicer.core.data import wrap_func_with_nested_access + self.process = wrap_func_with_nested_access(self.process) + + def process(self, sample): + """ + For sample level, sample --> sample or sample --> [samples, ...] + + :param sample: sample to process + :return: processed sample + """ + raise NotImplementedError + + +class Filter: + + def __init__(self, text_key: str = None): + """ + Base class that removes specific info. + + :param text_key: the key name of field that stores sample texts + to be processed + """ + if text_key is None: + text_key = 'text' + self.text_key = text_key + from data_juicer.core.data import wrap_func_with_nested_access + self.process = wrap_func_with_nested_access(self.process) + self.compute_stats = wrap_func_with_nested_access(self.compute_stats) + + def compute_stats(self, sample): + """ + Compute stats for the sample which is used as a metric to decide + whether to filter this sample. + + :param sample: input sample + :return: sample with computed stats + """ + raise NotImplementedError + + def process(self, sample): + """ + For sample level, sample --> Boolean. + + :param sample: sample to decide whether to filter + :return: true for keeping and false for filtering + """ + raise NotImplementedError + + +class Deduplicator: + + def __init__(self, text_key: str = None): + """ + Base class that conducts deduplication. + + :param text_key: the key name of field that stores sample texts + to be processed + """ + if text_key is None: + text_key = 'text' + self.text_key = text_key + from data_juicer.core.data import wrap_func_with_nested_access + self.process = wrap_func_with_nested_access(self.process) + self.compute_hash = wrap_func_with_nested_access(self.compute_hash) + + def compute_hash(self, sample): + """ + Compute hash values for the sample. + + :param sample: input sample + :return: sample with computed hash value. + """ + raise NotImplementedError + + def process(self, dataset, show_num=0): + """ + For doc-level, dataset --> dataset. + + :param dataset: input dataset + :param show_num: number of traced samples used when tracer is + open. + :return: deduplicated dataset and the sampled duplicate pairs. + """ + raise NotImplementedError + + +class Selector: + + def __init__(self, text_key: str = None): + """ + Base class that conducts selection in dataset-level. + + :param text_key: the key name of field that stores sample texts + to be processed + """ + if text_key is None: + text_key = 'text' + self.text_key = text_key + from data_juicer.core.data import wrap_func_with_nested_access + self.process = wrap_func_with_nested_access(self.process) + + def process(self, dataset): + """ + Dataset --> dataset. + + :param dataset: input dataset + :return: selected dataset. + """ + raise NotImplementedError diff --git a/data_juicer/ops/common/__init__.py b/data_juicer/ops/common/__init__.py new file mode 100644 index 000000000..59d1ab315 --- /dev/null +++ b/data_juicer/ops/common/__init__.py @@ -0,0 +1,5 @@ +from .helper_func import (get_sentences_from_document, get_words_from_document, + merge_on_whitespace_tab_newline, + split_on_newline_tab_whitespace, split_on_whitespace, + strip, words_augmentation) +from .special_characters import SPECIAL_CHARACTERS diff --git a/data_juicer/ops/common/helper_func.py b/data_juicer/ops/common/helper_func.py new file mode 100644 index 000000000..a317eb2d6 --- /dev/null +++ b/data_juicer/ops/common/helper_func.py @@ -0,0 +1,188 @@ +# Some code here has been modified from: +# https://huggingface.co/spaces/huggingface/text-data-filtering +# -------------------------------------------------------- +from typing import Dict + +import regex as re + + +class UnionFind: + + def __init__(self): + """Initialization method.""" + self.parent: Dict[int, int] = {} + + def find(self, x): + if x not in self.parent: + self.parent[x] = x + if self.parent[x] != x: + self.parent[x] = self.find(self.parent[x]) + return self.parent[x] + + def union(self, x, y): + px = self.find(x) + py = self.find(y) + self.parent[px] = self.parent[py] = min(px, py) + + +def strip(document, strip_characters): + """ + Way faster than document.strip(strip_characters) since strip_characters is + now a set instead of a str, and it contains a lot of elements (all the + emojis). + + :param document: document to be processed + :param strip_characters: characters uesd for stripping document + :return: stripped document + """ + if not document: + return document + beg_ind = 0 + end_ind = len(document) + for i in range(len(document)): + if document[i] in strip_characters: + beg_ind += 1 + else: + break + for i in range(1, len(document) + 1): + if document[-i] in strip_characters: + end_ind -= 1 + else: + break + document_stripped = document[beg_ind:end_ind] + return document_stripped + + +def split_on_whitespace(document, new_line=False, tab=False): + """ + This method also removes concatenated spaces. + + :param document: document to be splited + :param new_line: whether to split document with '\\\\n' + :param tag: whether to split document with '\\\\t' + :return: word list obtained after splitting document + """ + sep = [' '] + new_line * ['\n'] + tab * ['\t'] + sep = '|'.join(sep) + split_document = re.split(sep, document) + split_document = [word for word in split_document if word] + return split_document + + +def split_on_newline_tab_whitespace(document): + """ + This method is used to split the document into different levels of sub- + sentences. + + First split on "\\\\n", then on "\\\\t", then on " ". + :param document: document to be splited + :return: setence list obtained after splitting document + """ + sentences = document.split('\n') + sentences = [sentence.split('\t') for sentence in sentences] + sentences = [[ + split_on_whitespace(subsentence) for subsentence in sentence + ] for sentence in sentences] + return sentences + + +def merge_on_whitespace_tab_newline(sentences): + """ + This method is used to merge different levels of sub-sentences into one + document. Invert the method split_on_newline_tab_whitespace. Removes + concatenated separators. + + :param sentences: sentence list to be merged + :return: document obtained after merging sub-sentences + """ + sentences = [[ + ' '.join(subsentence) for subsentence in sentence if subsentence + ] for sentence in sentences] + sentences = ['\t'.join(sentence) for sentence in sentences if sentence] + if not sentences: + return '' + document = '\n'.join(sentences) + return document + + +def words_augmentation(words, group_size, join_char): + """ + Augment words, especially for Chinese (without a space between words) and + Vietnamese (with a space between syllables). + + :param word: word list to be augmented + :param group_size: the size of word groups that need to be merged + :param join_char: characters to be added between word group + :return: word list after augment + """ + augmentation = [ + join_char.join(words[i:i + group_size]) + for i in range(len(words) - group_size + 1) + ] + return augmentation + + +def get_words_from_document(document, + token_func=None, + lower_case=True, + new_line=True, + tab=True, + strip_chars=None, + use_words_aug=False, + words_aug_group_sizes=[2], + words_aug_join_char=''): + """ + Get words from a document. Non reversible since the document is split on + multiple characters, words are stripped of special characters and + characters are converted to lower case. Useful to compute ratios, like the + stopwords ratio. + + :param document: document that need to split words + :param token_func: function of tokenizer, if specified, the function + will be used for split document into different tokens. + :param lower_case: whether to convert word to lowercase + :param new_line: whether to use `\\\\n' to split words + :param tab: whether to use '\\\\t' to split words + :param strip_chars: chars that need to be stripped in words + :param use_words_aug: whether to use word augmentation + :param words_aug_group_sizes: the size of word groups that need to + be merged + :param words_aug_join_char: characters to be added between word + group + :return: word list obtained from document + """ + if token_func: + words = token_func(document) + else: + words = split_on_whitespace(document, new_line, tab) + + if lower_case: + words = [word.lower() for word in words] + if strip_chars: + words = [strip(word, strip_chars) for word in words] + words = [word for word in words if word] + if use_words_aug: + augmentation = [ + words_augmentation(words, group_size, words_aug_join_char) + for group_size in words_aug_group_sizes + ] + augmentation = [word for augm in augmentation for word in augm] + words = words + augmentation + return words + + +def get_sentences_from_document(document, model_func=None): + """ + Get sentences from a document. + + :param document: document that need to split sentences + :param model_func: function of sentence model, if specified, the + function will be used for split document into different + sentences. + :return: document with the sentences separated by '\\\\n' + """ + if model_func: + sentences = model_func(document) + else: + sentences = document.splitlines() + return '\n'.join(sentences) \ No newline at end of file diff --git a/data_juicer/ops/common/special_characters.py b/data_juicer/ops/common/special_characters.py new file mode 100644 index 000000000..19d74fb9c --- /dev/null +++ b/data_juicer/ops/common/special_characters.py @@ -0,0 +1,20 @@ +# Most of the code here has been modified from: +# https://huggingface.co/spaces/huggingface/text-data-filtering +# -------------------------------------------------------- + +import string + +import emoji + +MAIN_SPECIAL_CHARACTERS = string.punctuation + string.digits \ + + string.whitespace +OTHER_SPECIAL_CHARACTERS = ( + "’ “— ™ – •‘œ    ˜ ‚ƒ„’“”–ー一▬…✦�­£​•€«»°·═" + "×士^˘⇓↓↑←→()§″′´¿−±∈¢ø‚„½¼¾¹²³―⁃,ˌ¸‹›ʺˈʻ¦‐⠀‰……‑≤≥‖" + "◆●■►▼▲▴∆▻¡★☆✱ːº。¯˜¥ɪ≈†上ン:∼⁄・♡✓⊕․.⋅÷1‟;،、¨ाাी्े◦˚" + "゜ʼ≖ʼ¤ッツシ℃√!【】‿∞➤~πه۩☛₨➩☻๑٪♥ıॽ《‘©﴿٬?▷Г♫∟™ª₪®「—❖" + "」﴾》" +) +EMOJI = list(emoji.EMOJI_DATA.keys()) +SPECIAL_CHARACTERS = set(MAIN_SPECIAL_CHARACTERS + OTHER_SPECIAL_CHARACTERS) +SPECIAL_CHARACTERS.update(EMOJI) diff --git a/data_juicer/ops/deduplicator/__init__.py b/data_juicer/ops/deduplicator/__init__.py new file mode 100644 index 000000000..b770ee5a6 --- /dev/null +++ b/data_juicer/ops/deduplicator/__init__.py @@ -0,0 +1,2 @@ +from . import (document_deduplicator, document_minhash_deduplicator, + document_simhash_deduplicator) diff --git a/data_juicer/ops/deduplicator/document_deduplicator.py b/data_juicer/ops/deduplicator/document_deduplicator.py new file mode 100644 index 000000000..f90583eb4 --- /dev/null +++ b/data_juicer/ops/deduplicator/document_deduplicator.py @@ -0,0 +1,110 @@ +# Some code here has been modified from: +# https://github.com/bigscience-workshop/data-preparation/blob/main/preprocessing/training/01a_catalogue_cleaning_and_filtering/clean_helpers/deduplication.py +# -------------------------------------------------------- + +import hashlib +import string +from collections import defaultdict +from typing import Dict, Set + +import regex as re + +from ..base_op import OPERATORS, Deduplicator + + +@OPERATORS.register_module('document_deduplicator') +class DocumentDeduplicator(Deduplicator): + """ + Deduplicator to deduplicate samples at document-level using exact matching. + + Using md5 hash to deduplicate samples. + """ + + def __init__(self, + lowercase: bool = False, + ignore_non_character: bool = False, + *args, + **kwargs): + """ + Initialization method. + + :param lowercase: Whether to convert sample text to lower case + :param ignore_non_character: Whether to ignore non-alphabet + characters, including whitespaces, digits, and punctuations + :param args: extra args + :param kwargs: extra args. + """ + super().__init__(*args, **kwargs) + self.lowercase = lowercase + self.remove_non_character_regex = re.compile( + f'\s+|\d+|[{re.escape(string.punctuation)}]' # noqa: W605 + ) if ignore_non_character else None + + def compute_hash(self, sample): + """ + Compute md5 hash values for the sample. + + :param sample: input sample + :return: sample with md5 hash value. + """ + # check if it's computed already + if 'hash' in sample: + return sample + + text = sample[self.text_key] + if self.lowercase: + text = text.lower() + if self.remove_non_character_regex: + text = self.remove_non_character_regex.sub('', text) + + def _get_hash(txt): + return hashlib.md5(txt.strip().encode('utf-8')).hexdigest() + + sample['hash'] = _get_hash(text) + return sample + + def process(self, dataset, show_num=0): + """ + For doc-level, dataset --> dataset. + + :param dataset: input dataset + :param show_num: number of traced samples used when tracer is + open. + :return: deduplicated dataset and the sampled duplicate pairs. + """ + # no need to deduplicate because too few samples + if len(dataset) <= 1: + return dataset, {} + + dup_hashes = None + if show_num > 0: + # sample duplicate pairs + hash2ids: Dict[int, Set[int]] = defaultdict(set) + for sid, hash_val in enumerate(dataset['hash']): + hash2ids[hash_val].add(sid) + dup_samples = sorted(list(hash2ids.items()), + key=lambda x: len(x[1]), + reverse=True) + dup_hashes = set([ + item[0] for item in dup_samples if len(item[1]) > 1 + ][:show_num]) + + def _filter_dup_helper(sample, hashes): + hash = sample['hash'] + if show_num > 0 and hash in dup_hashes \ + and len(dup_pairs[hash]) < 2: + # tracer is open and not enough duplicate sample pairs + dup_pairs[hash].append(sample) + if hash in hashes: + return False + else: + hashes.add(hash) + return True + + hashes = set() + dup_pairs = {hash_v: [] for hash_v in dup_hashes} if dup_hashes else {} + dataset = dataset.filter( + _filter_dup_helper, + fn_kwargs=dict(hashes=hashes), + load_from_cache_file=False if show_num > 0 else True) # num_proc=1 + return dataset, dup_pairs diff --git a/data_juicer/ops/deduplicator/document_minhash_deduplicator.py b/data_juicer/ops/deduplicator/document_minhash_deduplicator.py new file mode 100644 index 000000000..cafa4e717 --- /dev/null +++ b/data_juicer/ops/deduplicator/document_minhash_deduplicator.py @@ -0,0 +1,310 @@ +# Some code here has been modified from: +# https://github.com/bigcode-project/bigcode-dataset/blob/main/near_deduplication/minhash_deduplication.py +# -------------------------------------------------------- + +import hashlib +import struct +from collections import defaultdict + +import numpy as np +import regex +from jsonargparse.typing import ClosedUnitInterval, PositiveInt +from loguru import logger +from scipy.integrate import quad as integrate +from tqdm import tqdm + +from ..base_op import OPERATORS, Deduplicator +from ..common.helper_func import UnionFind, split_on_whitespace + +MERSENNE_PRIME = np.uint64((1 << 61) - 1) +MAX_HASH = np.uint64((1 << 32) - 1) + + +def sha1_hash32(data): + """ + Directly taken from datasketch package to avoid dependency. + + Parameters + ---------- + data : bytes + + Returns + ------- + int + """ + return struct.unpack('= this threshold, they are regarded as + similar samples and this op will only keep one of them after + deduplication + :param num_bands: number of bands in LSH. Default it's None, and + it will be determined by an optimal params computation + algorithm by minimize the weighted sum of probs of False + Positives and False Negatives + :param num_rows_per_band: number of rows in each band in LSH. + Default it's None, and it will be determined by an optimal + params computation algorithm + """ + super().__init__(*args, **kwargs) + # about minhash computation + self.tokenization = tokenization + self.window_size = window_size + self.lowercase = lowercase + self.ignore_pattern = ignore_pattern + if self.ignore_pattern: + self.ignore_pattern = regex.compile(self.ignore_pattern) + + # check parameters + if self.ignore_pattern and self.tokenization == 'punctuation': + logger.warning('Be careful that tokenization with punctuations ' + 'won\'t work if the ignore pattern includes ' + 'punctuations.') + + # about deduplication + self.num_permutation = num_permutations + self.jaccard_threshold = jaccard_threshold + self.num_bands = num_bands + self.num_rows_per_band = num_rows_per_band + + # initialize deduplication parameters + # check number of bands and rows + if self.num_bands is None or self.num_rows_per_band is None: + self.num_bands, self.num_rows_per_band = optimal_param( + self.jaccard_threshold, + self.num_permutation, + ) + + # compute hash ranges and create hash tables + self.hash_ranges = [(i * self.num_rows_per_band, + (i + 1) * self.num_rows_per_band) + for i in range(self.num_bands)] + self.hash_tables = [defaultdict(set) for _ in range(self.num_bands)] + + # generate permutations + gen = np.random.RandomState(seed=42) + self.perm_a, self.perm_b = np.array( + [( + gen.randint(1, MERSENNE_PRIME, dtype=np.uint64), + gen.randint(0, MERSENNE_PRIME, dtype=np.uint64), + ) for _ in range(self.num_permutation)], + dtype=np.uint64, + ).T + + def compute_hash(self, sample): + """ + Compute minhash values for the sample. + + :param sample: input sample + :return: sample with minhash value. + """ + # check if it's computed already + if 'minhash' in sample: + return sample + + text = sample[self.text_key] + + if self.lowercase: + text = text.lower() + if self.ignore_pattern: + text = self.ignore_pattern.sub('', text) + + # get tokens for different tokenization method + tokens = set() + if self.tokenization == 'character': + tokens = { + str.encode(text[i:i + self.window_size]) + for i in range(len(text) - self.window_size) + } + elif self.tokenization == 'punctuation': + tokens = self.punctuation_pattern.split(text) + tokens = { + str.encode(' '.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + elif self.tokenization == 'space': + tokens = split_on_whitespace(text) + tokens = { + str.encode(' '.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + else: + raise NotImplementedError( + f'Unimplemented tokenization method [{self.tokenization}]') + + # compute minhash value + hv = np.array([sha1_hash32(token) for token in tokens], + dtype=np.uint64) + phv = np.bitwise_and( + ((hv * np.tile(self.perm_a, + (len(hv), 1)).T).T + self.perm_b) % MERSENNE_PRIME, + MAX_HASH) + hash_values = np.vstack([ + phv, + np.ones(self.num_permutation, dtype=np.uint64) * MAX_HASH + ]).min(axis=0) + sample['minhash'] = [ + bytes(hash_values[start:end].byteswap().data) + for start, end in self.hash_ranges + ] + return sample + + def process(self, dataset, show_num=0): + """ + For doc-level, dataset --> dataset. + + :param dataset: input dataset + :param show_num: number of traced samples used when tracer is + open. + :return: deduplicated dataset and the sampled duplicate pairs. + """ + # no need to deduplicate because too few samples + if len(dataset) <= 1: + return dataset, {} + + minhashes = dataset['minhash'] + # remove bytes minhash column otherwise unexpected error would occur + # when exporting the processed dataset + dataset = dataset.remove_columns(['minhash']) + + # make clusters -- construct the minhash lookup tables of seg to ids + logger.info(f'Start clustering for {len(dataset)} samples...') + batch_size = 10000 + for i in tqdm(range(0, len(minhashes), batch_size), + dynamic_ncols=True, + desc='Iterating MinHashes of samples...'): + batch = minhashes[i:i + batch_size] + for idx, hs in enumerate(batch): + for h, hashtable in zip(hs, self.hash_tables): + hashtable[h].add(idx + i) + + # using UnionFind set to union samples within the same clusters + union_find = UnionFind() + for table in tqdm(self.hash_tables, + dynamic_ncols=True, + desc='Clustering'): + for cluster in table.values(): + if len(cluster) <= 1: + continue + idx = min(cluster) + for x in cluster: + union_find.union(x, idx) + logger.info(f'There are {len(set(union_find.parent.values()))} ' + f'clusters that includes multiple near-duplicate samples.') + + # record the duplicate sample pairs + dup_pairs = {} + if show_num > 0: + for i in range(len(dataset)): + cluster_idx = union_find.find(i) + if cluster_idx not in dup_pairs and cluster_idx != i: + dup_pairs[cluster_idx] = [ + dataset[cluster_idx], + dataset[i], + ] + if len(dup_pairs) >= show_num: + break + + # filtering -- only keep those samples whose parent index is itself, + # including: + # 1. samples that form a cluster by themselves + # 2. the first sample in a cluster that includes multiple samples + def _filter_minhash_dup_helper(sample, index): + return union_find.find(index) == index + + dataset = dataset.filter( + _filter_minhash_dup_helper, + with_indices=True, + ) + logger.info(f'Keep {len(dataset)} samples after MinHash dedup.') + + return dataset, dup_pairs diff --git a/data_juicer/ops/deduplicator/document_simhash_deduplicator.py b/data_juicer/ops/deduplicator/document_simhash_deduplicator.py new file mode 100644 index 000000000..397fdd613 --- /dev/null +++ b/data_juicer/ops/deduplicator/document_simhash_deduplicator.py @@ -0,0 +1,265 @@ +# Some code here has been modified from: +# https://github.com/bigscience-workshop/data-preparation +# -------------------------------------------------------- + +from collections import Counter, defaultdict, deque +from typing import Dict, Set + +import numpy as np +import regex +import simhash +from jsonargparse.typing import PositiveInt +from loguru import logger + +from ..base_op import OPERATORS, Deduplicator +from ..common.helper_func import split_on_whitespace + + +def local_num_differing_bits(hash_a, hash_b): + """ + Local implementation of calculating the number of different bits between + two integers. + + :param hash_a: integer hash value a + :param hash_b: integer hash value b + :return: number of different bits between input hashes. + """ + cnt = 0 + n = hash_a ^ hash_b + while n != 0: + cnt += 1 + n = n & (n - 1) + return cnt + + +def num_differing_bits_selector(): + """ + Select a num_differing_bits method according to the Python version + installed. + + When Python >= 3.9, the original simhash library cannot be compiled + correctly due to some changes in cython. After fixing this + incompatibility, RecursionError occurs sometimes when calling + simhash.num_differing_bits. So we use our implementation when Python + >= 3.9. Otherwise, we use implementation of simhash. + + :return: an available num_differing_bits function. + """ + import platform + a, b, _ = platform.python_version().split('.') + if a == '3' and int(b) >= 9: + # for >= 3.9, use local implementation + return local_num_differing_bits + else: + # for < 3.9, use simhash version + return simhash.num_differing_bits + + +num_differing_bits = num_differing_bits_selector() + + +@OPERATORS.register_module('document_simhash_deduplicator') +class DocumentSimhashDeduplicator(Deduplicator): + """Deduplicator to deduplicate samples at document-level using SimHash.""" + + def __init__(self, + tokenization: str = 'space', + window_size: PositiveInt = 6, + lowercase: bool = True, + ignore_pattern: str = None, + num_blocks: PositiveInt = 6, + hamming_distance: PositiveInt = 4, + *args, + **kwargs): + """ + Initialization method :param tokenization: tokenization method for + sample texts. + + It should be one of [space, punctuation, character]. For + English-like languages, we recommend to use 'space'. And for + Chinese-like languages, we recommend to use 'character' + + :param window_size: window size of shingling + :param lowercase: whether to convert text to lower case first + :param ignore_pattern: whether to ignore sub-strings with + specific pattern when computing simhash + :param num_blocks: number of blocks in simhash computing + :param hamming_distance: the max hamming distance threshold in + near-duplicate detection. When the hamming distance of two + sample texts is <= this threshold, they are regarded as + similar samples and this op will only keep one of them after + deduplication. This threshold should be always less than + num_blocks + """ + # about simhash computation + super().__init__(*args, **kwargs) + self.tokenization = tokenization + self.window_size = window_size + self.lowercase = lowercase + self.ignore_pattern = ignore_pattern + if self.ignore_pattern: + self.ignore_pattern = regex.compile(self.ignore_pattern) + + # check parameters + if self.ignore_pattern and self.tokenization == 'punctuation': + logger.warning('Be careful that tokenization with punctuations ' + 'won\'t work if the ignore pattern includes ' + 'punctuations.') + + # about deduplication + self.num_blocks = num_blocks + self.hamming_distance = hamming_distance + + def compute_hash(self, sample): + """ + Compute simhash values for the sample. + + :param sample: input sample + :return: sample with simhash value. + """ + # check if it's computed already + if 'simhash' in sample: + return sample + + text = sample[self.text_key] + + if self.lowercase: + text = text.lower() + if self.ignore_pattern: + text = self.ignore_pattern.sub('', text) + + # get tokens for different tokenization method + tokens = [] + if self.tokenization == 'character': + tokens = [ + str.encode(text[i:i + self.window_size]) + for i in range(len(text) - self.window_size) + ] + elif self.tokenization == 'punctuation': + tokens = self.punctuation_pattern.split(text) + tokens = [ + str.encode(' '.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + ] + elif self.tokenization == 'space': + tokens = split_on_whitespace(text) + tokens = [ + str.encode(' '.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + ] + else: + raise NotImplementedError( + f'Unimplemented tokenization method [{self.tokenization}]') + + # compute simhash + sample['simhash'] = np.uint64( + simhash.compute(map(simhash.unsigned_hash, tokens))) + return sample + + def process(self, dataset, show_num=0): + """ + For doc-level, dataset --> dataset. + + :param dataset: input dataset + :param show_num: number of traced samples used when tracer is + open. + :return: deduplicated dataset and the sampled duplicate pairs. + """ + # no need to deduplicate because too few samples + if len(dataset) <= 1: + return dataset, {} + + # find matches + logger.info(f'Start querying {len(dataset)} samples.') + matches = simhash.find_all( + dataset['simhash'], + self.num_blocks, + self.hamming_distance, + ) + logger.info(f'Querying done, found {len(matches)} matches.') + + # compute hash diff distribution + graph = defaultdict(dict) + dist = Counter() + for x, y in matches: + graph[x][y] = graph[y][x] = True + num_diff = num_differing_bits(x, y) + dist[num_diff] += 1 + logger.info(f'Hash diff distribution: {dist}') + + hash2ids: Dict[int, Set[str]] = defaultdict(set) + hashes: Set[int] = set(dataset['simhash']) + hash2cluster: Dict[int, int] = {} + visited: Set[int] = set() + cluster_id: int = 0 + + for sid, hash_val in enumerate(dataset['simhash']): + hash2ids[hash_val].add(str(sid)) + + # clustering + dup_pairs = {} # store duplicate pairs when show_num > 0 + while hashes: + hash_val = hashes.pop() + if hash_val in visited: + continue + + # if this hash value is not in the matches list, it's regarded as a + # single cluster + if hash_val not in graph: + hash2cluster[hash_val] = -1 + continue + + # Otherwise, BFS to find the cluster + q = deque([hash_val]) + visited.add(hash_val) + hash2cluster[hash_val] = cluster_id + if show_num > 0 and len(dup_pairs) < show_num: + dup_pairs[cluster_id] = [] + + while q: + curr = q.popleft() + for neighbor in graph[curr]: + if neighbor in visited: + continue + visited.add(neighbor) + q.append(neighbor) + hash2cluster[neighbor] = cluster_id + + cluster_id += 1 + logger.info(f'Found {cluster_id} clusters and {len(graph)} hashes.') + + # filter duplicated samples + # NOTICE: For now, we only keep the first sample in a cluster. Maybe + # there are some better strategies later. + def _filter_simhash_dup_helper(sample, visited_clusters, + visited_hashes): + sample_hash_val = sample['simhash'] + cluster_num = hash2cluster[sample_hash_val] + if cluster_num == -1: + # single-sample cluster, we need to check hash value still. + if sample_hash_val in visited_hashes: + return False + else: + visited_hashes.add(sample_hash_val) + return True + else: + if show_num > 0 and cluster_num in dup_pairs \ + and len(dup_pairs[cluster_num]) < 2: + dup_pairs[cluster_num].append(sample) + # regular cluster, check cluster number. + if cluster_num in visited_clusters: + return False + else: + visited_clusters.add(cluster_num) + return True + + cluster_record = set() + hash_record = set() + dataset = dataset.filter( + _filter_simhash_dup_helper, + fn_kwargs=dict(visited_clusters=cluster_record, + visited_hashes=hash_record), + load_from_cache_file=False if show_num > 0 else True) + logger.info(f'Keep {len(dataset)} samples after SimHash dedup.') + + return dataset, dup_pairs diff --git a/data_juicer/ops/filter/__init__.py b/data_juicer/ops/filter/__init__.py new file mode 100644 index 000000000..24228326e --- /dev/null +++ b/data_juicer/ops/filter/__init__.py @@ -0,0 +1,7 @@ +from . import (alphanumeric_filter, average_line_length_filter, + character_repetition_filter, flagged_words_filter, + language_id_score_filter, maximum_line_length_filter, + perplexity_filter, special_characters_filter, + specified_field_filter, specified_numeric_field_filter, + stopwords_filter, suffix_filter, text_length_filter, + word_num_filter, word_repetition_filter) diff --git a/data_juicer/ops/filter/alphanumeric_filter.py b/data_juicer/ops/filter/alphanumeric_filter.py new file mode 100644 index 000000000..2f948221b --- /dev/null +++ b/data_juicer/ops/filter/alphanumeric_filter.py @@ -0,0 +1,82 @@ +import sys + +from jsonargparse.typing import PositiveFloat + +from data_juicer.utils.model_utils import MODEL_ZOO, prepare_model + +from ..base_op import OPERATORS, Filter +from ..common import get_words_from_document + + +@OPERATORS.register_module('alphanumeric_filter') +class AlphanumericFilter(Filter): + """Filter to keep samples with alphabet/numeric ratio within a specific + range.""" + + def __init__(self, + tokenization: bool = False, + min_ratio: float = 0.25, + max_ratio: PositiveFloat = sys.maxsize, + *args, + **kwargs): + """ + Initialization method. + + :param tokenization: Whether to count the ratio of alphanumeric + to the total number of tokens. if tokenization=False, it + will count the ratio of alphanumeric to the total number of + characters. + :param min_ratio: The min filter ratio in alphanumeric op, + samples will be filtered if their alphabet/numeric ratio is + below this parameter. + :param max_ratio: The max filter ratio in alphanumeric op, + samples will be filtered if their alphabet/numeric ratio + exceeds this parameter. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.tokenization = tokenization + self.min_ratio = min_ratio + self.max_ratio = max_ratio + self.model_key = None + + if tokenization: + self.model_key = prepare_model( + model_type='huggingface', + model_key='EleutherAI/pythia-6.9b-deduped') + + def compute_stats(self, sample): + if self.tokenization: + if 'alpha_token_ratio' in sample['stats']: + return sample + alpha_count = sum( + map(lambda char: 1 + if char.isalpha() else 0, sample[self.text_key])) + tokenizer = MODEL_ZOO.get(self.model_key, None) + token_count = len( + get_words_from_document( + sample[self.text_key], + token_func=tokenizer.tokenize if tokenizer else None, + lower_case=False)) + sample['stats']['alpha_token_ratio'] = ( + alpha_count / token_count) if token_count != 0 else 0.0 + else: + if 'alnum_ratio' in sample['stats']: + return sample + alnum_count = sum( + map(lambda char: 1 + if char.isalnum() else 0, sample[self.text_key])) + sample['stats']['alnum_ratio'] = ( + alnum_count / len(sample[self.text_key])) if len( + sample[self.text_key]) != 0 else 0.0 + return sample + + def process(self, sample): + ratio = sample['stats'][ + 'alpha_token_ratio'] if self.tokenization else sample['stats'][ + 'alnum_ratio'] + if self.min_ratio <= ratio <= self.max_ratio: + return True + else: + return False \ No newline at end of file diff --git a/data_juicer/ops/filter/average_line_length_filter.py b/data_juicer/ops/filter/average_line_length_filter.py new file mode 100644 index 000000000..893bec785 --- /dev/null +++ b/data_juicer/ops/filter/average_line_length_filter.py @@ -0,0 +1,49 @@ +import sys + +from jsonargparse.typing import PositiveInt + +from ..base_op import OPERATORS, Filter + + +@OPERATORS.register_module('average_line_length_filter') +class AverageLineLengthFilter(Filter): + """Filter to keep samples with average line length within a specific + range.""" + + def __init__(self, + min_len: PositiveInt = 10, + max_len: PositiveInt = sys.maxsize, + *args, + **kwargs): + """ + Initialization method. + + :param min_len: The min filter length in this op, samples will + be filtered if their average line length is below this + parameter. + :param max_len: The max filter length in this op, samples will + be filtered if their average line length exceeds this + parameter. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.min_len = min_len + self.max_len = max_len + + def compute_stats(self, sample): + # check if it's computed already + if 'avg_line_length' in sample['stats']: + return sample + + line_lengths = list(map(len, sample[self.text_key].splitlines())) + sample['stats']['avg_line_length'] = \ + len(sample[self.text_key]) / len(line_lengths) \ + if len(line_lengths) != 0 else 0.0 + return sample + + def process(self, sample): + if self.min_len <= sample['stats']['avg_line_length'] <= self.max_len: + return True + else: + return False \ No newline at end of file diff --git a/data_juicer/ops/filter/character_repetition_filter.py b/data_juicer/ops/filter/character_repetition_filter.py new file mode 100644 index 000000000..dddfb0af3 --- /dev/null +++ b/data_juicer/ops/filter/character_repetition_filter.py @@ -0,0 +1,75 @@ +# Some code here has been modified from: +# https://huggingface.co/spaces/huggingface/text-data-filtering +# -------------------------------------------------------- + +import numpy as np +from jsonargparse.typing import ClosedUnitInterval, PositiveInt + +from ..base_op import OPERATORS, Filter + + +@OPERATORS.register_module('character_repetition_filter') +class CharacterRepetitionFilter(Filter): + """Filter to keep samples with char-level n-gram repetition ratio within a + \ specific range.""" + + def __init__(self, + rep_len: PositiveInt = 10, + min_ratio: ClosedUnitInterval = 0.0, + max_ratio: ClosedUnitInterval = 0.5, + *args, + **kwargs): + """ + Initialization method. + + :param rep_len: Repetition length for char-level n-gram. + :param min_ratio: The min filter ratio in this op, samples will + be filtered if their char-level n-gram repetition ratio is + below this parameter. + :param max_ratio: The max filter ratio in this op, samples will + be filtered if their char-level n-gram repetition ratio + exceeds this parameter. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.n = rep_len + self.min_ratio = min_ratio + self.max_ratio = max_ratio + + def compute_stats(self, sample): + # check if it's computed already + if 'char_rep_ratio' in sample['stats']: + return sample + + char_ngrams = [ + sample[self.text_key][i:i + self.n] + for i in range(len(sample[self.text_key]) - self.n + 1) + ] + freq_char_ngrams = {} + for char_ngram in char_ngrams: + freq_char_ngrams[char_ngram] = ( + freq_char_ngrams.get(char_ngram, 0) + 1) + + if len(freq_char_ngrams) == 0: + sample['stats']['char_rep_ratio'] = 0.0 + return sample + + freq_char_ngrams = sorted(list(freq_char_ngrams.values()), + reverse=True) + rep_more_than_one = len([el for el in freq_char_ngrams if el > 1]) + num_rep_char_ngrams = min( + int(np.sqrt(len(freq_char_ngrams))), + len(freq_char_ngrams) - rep_more_than_one, + ) + sample['stats']['char_rep_ratio'] = (sum( + freq_char_ngrams[:num_rep_char_ngrams]) / sum(freq_char_ngrams)) \ + if sum(freq_char_ngrams) != 0 else 0.0 + return sample + + def process(self, sample): + if self.min_ratio <= sample['stats']['char_rep_ratio'] \ + <= self.max_ratio: + return True + else: + return False \ No newline at end of file diff --git a/data_juicer/ops/filter/flagged_words_filter.py b/data_juicer/ops/filter/flagged_words_filter.py new file mode 100644 index 000000000..1a21ef4db --- /dev/null +++ b/data_juicer/ops/filter/flagged_words_filter.py @@ -0,0 +1,93 @@ +# Some code here has been modified from: +# https://huggingface.co/spaces/huggingface/text-data-filtering +# -------------------------------------------------------- + +from jsonargparse.typing import ClosedUnitInterval, List + +from data_juicer.utils.model_utils import MODEL_ZOO, prepare_model + +from ...utils.asset_utils import ASSET_DIR, load_words_asset +from ..base_op import OPERATORS, Filter +from ..common import SPECIAL_CHARACTERS, get_words_from_document + + +@OPERATORS.register_module('flagged_words_filter') +class FlaggedWordFilter(Filter): + """Filter to keep samples with flagged-word ratio less than a specific max + value.""" + + def __init__(self, + lang: str = 'en', + tokenization: bool = False, + max_ratio: ClosedUnitInterval = 0.045, + flagged_words_dir: str = ASSET_DIR, + use_words_aug: bool = False, + words_aug_group_sizes: List = [2], + words_aug_join_char: str = '', + *args, + **kwargs): + """ + Initialization method. + + :param lang: Consider flagged words in what language. If lang == + "all", we will adopt the one merged from all the available + languages + :param tokenization: Whether to use model to tokenize documents + :param max_ratio: The max filter ratio in this op. + :param flagged_words_dir: The directory storing the + flagged_words file(s) whose name includes "flagged_words" + and in json format + :param use_words_aug: Whether to augment words, especially for + Chinese and Vietnamese + :param words_aug_group_sizes: The group size of words to augment + :param words_aug_join_char: The join char between words to + augment + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.lang = lang + self.max_ratio = max_ratio + self.use_words_aug = use_words_aug + self.words_aug_group_sizes = words_aug_group_sizes + self.words_aug_join_char = words_aug_join_char + self.model_key = None + + self.FLAGGED_WORDS = load_words_asset(words_dir=flagged_words_dir, + words_type='flagged_words') + + if 'all' not in self.FLAGGED_WORDS: + self.FLAGGED_WORDS['all'] = [ + val for vals in self.FLAGGED_WORDS.values() for val in vals + ] + if tokenization: + self.model_key = prepare_model(lang=lang, + model_type='sentencepiece') + + def compute_stats(self, sample): + # check if it's computed already + if 'flagged_words_ratio' in sample['stats']: + return sample + + tokenizer = MODEL_ZOO.get(self.model_key, None) + words = get_words_from_document( + sample[self.text_key], + token_func=tokenizer.encode_as_pieces if tokenizer else None, + strip_chars=SPECIAL_CHARACTERS, + use_words_aug=self.use_words_aug, + words_aug_group_sizes=self.words_aug_group_sizes, + words_aug_join_char=self.words_aug_join_char) + + flagged_words_ratio = (len( + [word + for word in words if word in self.FLAGGED_WORDS[self.lang]]) / + len(words)) if len(words) != 0 else 0.0 + + if flagged_words_ratio > 1.0: + flagged_words_ratio = 1.0 + + sample['stats']['flagged_words_ratio'] = flagged_words_ratio + return sample + + def process(self, sample): + return sample['stats']['flagged_words_ratio'] <= self.max_ratio diff --git a/data_juicer/ops/filter/language_id_score_filter.py b/data_juicer/ops/filter/language_id_score_filter.py new file mode 100644 index 000000000..88ab4b2df --- /dev/null +++ b/data_juicer/ops/filter/language_id_score_filter.py @@ -0,0 +1,58 @@ +from jsonargparse.typing import ClosedUnitInterval +from loguru import logger + +from data_juicer.utils.model_utils import MODEL_ZOO, prepare_model + +from ..base_op import OPERATORS, Filter + + +@OPERATORS.register_module('language_id_score_filter') +class LanguageIDScoreFilter(Filter): + """Filter to keep samples in a specific language with confidence score + larger than a specific min value.""" + + def __init__(self, + lang: str = '', + min_score: ClosedUnitInterval = 0.8, + *args, + **kwargs): + """ + Initialization method. + + :param lang: Samples in which language to keep. + :param min_score: The min language identification confidence + scores of samples to keep. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.lang = lang + self.min_score = min_score + self.model_key = prepare_model(lang=lang, model_type='fasttext') + + def compute_stats(self, sample): + # check if it's computed already + if 'lang' in sample['stats'] and 'lang_score' in sample['stats']: + return sample + + text = sample[self.text_key].lower().replace('\n', ' ') + ft_model = MODEL_ZOO.get(self.model_key, None) + if ft_model is None: + err_msg = 'Model not loaded. Please retry later.' + logger.error(err_msg) + raise ValueError(err_msg) + pred = ft_model.predict(text) + lang_id = pred[0][0].replace('__label__', '') + lang_score = pred[1][0] + + sample['stats']['lang'] = lang_id + sample['stats']['lang_score'] = lang_score + + return sample + + def process(self, sample): + if self.lang: + return sample['stats']['lang'] == self.lang \ + and sample['stats']['lang_score'] >= self.min_score + else: + return sample['stats']['lang_score'] >= self.min_score diff --git a/data_juicer/ops/filter/maximum_line_length_filter.py b/data_juicer/ops/filter/maximum_line_length_filter.py new file mode 100644 index 000000000..695982451 --- /dev/null +++ b/data_juicer/ops/filter/maximum_line_length_filter.py @@ -0,0 +1,48 @@ +import sys + +from jsonargparse.typing import PositiveInt + +from ..base_op import OPERATORS, Filter + + +@OPERATORS.register_module('maximum_line_length_filter') +class MaximumLineLengthFilter(Filter): + """Filter to keep samples with maximum line length within a specific + range.""" + + def __init__(self, + min_len: PositiveInt = 10, + max_len: PositiveInt = sys.maxsize, + *args, + **kwargs): + """ + Initialization method. + + :param min_len: The min filter length in this op, samples will + be filtered if their maximum line length is below this + parameter. + :param max_len: The max filter length in this op, samples will + be filtered if their maximum line length exceeds this + parameter. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.min_len = min_len + self.max_len = max_len + + def compute_stats(self, sample): + # check if it's computed already + if 'max_line_length' in sample['stats']: + return sample + + line_lengths = list(map(len, sample[self.text_key].splitlines())) + sample['stats']['max_line_length'] = max( + line_lengths) if line_lengths else 0.0 + return sample + + def process(self, sample): + if self.min_len <= sample['stats']['max_line_length'] <= self.max_len: + return True + else: + return False diff --git a/data_juicer/ops/filter/perplexity_filter.py b/data_juicer/ops/filter/perplexity_filter.py new file mode 100644 index 000000000..6f5b5ef39 --- /dev/null +++ b/data_juicer/ops/filter/perplexity_filter.py @@ -0,0 +1,62 @@ +# Some code here has been modified from: +# https://huggingface.co/spaces/huggingface/text-data-filtering +# -------------------------------------------------------- + +from jsonargparse.typing import PositiveFloat + +from data_juicer.utils.model_utils import MODEL_ZOO, prepare_model + +from ..base_op import OPERATORS, Filter +from ..common import get_words_from_document + + +@OPERATORS.register_module('perplexity_filter') +class PerplexityFilter(Filter): + """Filter to keep samples with perplexity score less than a specific max + value.""" + + def __init__(self, + lang: str = 'en', + max_ppl: PositiveFloat = 1500, + *args, + **kwargs): + """ + Initialization method. + + :param lang: Compute perplexity for samples in which language. + :param max_ppl: The max filter perplexity in this op, samples + will be filtered if their perplexity exceeds this parameter. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.max_ppl = max_ppl + self.sp_model_key = prepare_model(lang=lang, + model_type='sentencepiece') + self.kl_model_key = prepare_model(lang=lang, model_type='kenlm') + + def compute_stats(self, sample): + # check if it's computed already + if 'perplexity' in sample['stats']: + return sample + + # tokenization + tokenizer = MODEL_ZOO.get(self.sp_model_key, None) + tokens = get_words_from_document( + sample[self.text_key], + token_func=tokenizer.encode_as_pieces if tokenizer else None, + lower_case=False) + text = ' '.join(tokens) + # compute perplexity + logits, length = 0, 0 + kenlm_model = MODEL_ZOO.get(self.kl_model_key, None) + for line in text.splitlines(): + logits += kenlm_model.score(line) + length += (len(line.split()) + 1) + ppl = (10.0**(-logits / length)) if length != 0 else 0.0 + sample['stats']['perplexity'] = round(ppl, 1) + + return sample + + def process(self, sample): + return sample['stats']['perplexity'] <= self.max_ppl diff --git a/data_juicer/ops/filter/special_characters_filter.py b/data_juicer/ops/filter/special_characters_filter.py new file mode 100644 index 000000000..2efaa5955 --- /dev/null +++ b/data_juicer/ops/filter/special_characters_filter.py @@ -0,0 +1,55 @@ +# Some code here has been modified from: +# https://huggingface.co/spaces/huggingface/text-data-filtering +# -------------------------------------------------------- + +from jsonargparse.typing import ClosedUnitInterval + +from ..base_op import OPERATORS, Filter +from ..common import SPECIAL_CHARACTERS + + +@OPERATORS.register_module('special_characters_filter') +class SpecialCharactersFilter(Filter): + """Filter to keep samples with special-char ratio within a specific + range.""" + + def __init__(self, + min_ratio: ClosedUnitInterval = 0.0, + max_ratio: ClosedUnitInterval = 0.25, + *args, + **kwargs): + """ + Initialization method. + + :param min_ratio: The min filter ratio in this op, samples will + be filtered if their special-char ratio is below this + parameter. + :param max_ratio: The max filter ratio in this op, samples will + be filtered if their special-char ratio exceeds this + parameter. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.min_ratio = min_ratio + self.max_ratio = max_ratio + + def compute_stats(self, sample): + # check if it's computed already + if 'special_char_ratio' in sample['stats']: + return sample + + # get ratio of special characters + sample['stats']['special_char_ratio'] = ( + len([c + for c in sample[self.text_key] if c in SPECIAL_CHARACTERS]) / + len(sample[self.text_key])) if len( + sample[self.text_key]) != 0 else 0.0 + return sample + + def process(self, sample): + if self.min_ratio <= sample['stats']['special_char_ratio'] \ + <= self.max_ratio: + return True + else: + return False diff --git a/data_juicer/ops/filter/specified_field_filter.py b/data_juicer/ops/filter/specified_field_filter.py new file mode 100644 index 000000000..f8b3887fe --- /dev/null +++ b/data_juicer/ops/filter/specified_field_filter.py @@ -0,0 +1,55 @@ +from typing import List, Tuple, Union + +from ..base_op import OPERATORS, Filter + + +@OPERATORS.register_module('specified_field_filter') +class SpecifiedFieldFilter(Filter): + """ + Filter based on specified field information. + + If the specified field information in the sample is not within the + specified target value, the sample will be filtered. + """ + + def __init__(self, + text_key: str = '', + target_value: Union[List, Tuple] = [], + *args, + **kwargs): + """ + Initialization method. + + :param text_key: Filter based on the specified value + corresponding to the target key. The target key + corresponding to multi-level field information need to be + separated by '.'. + :param target_value: The range of specified field information + corresponding to the samples that need to be retained. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.text_key = text_key + self.target_value = target_value + + def compute_stats(self, sample): + return sample + + def process(self, sample): + if not (self.text_key and self.target_value): + return True + + field_value = sample + for key in self.text_key.split('.'): + assert key in field_value.keys(), "'{}' not in {}".format( + key, field_value.keys()) + field_value = field_value[key] + + if not (isinstance(field_value, list) + or isinstance(field_value, tuple)): + field_value = [field_value] + for value in field_value: + if value not in self.target_value: + return False + return True diff --git a/data_juicer/ops/filter/specified_numeric_field_filter.py b/data_juicer/ops/filter/specified_numeric_field_filter.py new file mode 100644 index 000000000..3e61ca101 --- /dev/null +++ b/data_juicer/ops/filter/specified_numeric_field_filter.py @@ -0,0 +1,69 @@ +import sys + +from ..base_op import OPERATORS, Filter + + +def is_number(s): + if s: + try: + float(s) + return True + except ValueError: + pass + return False + + +@OPERATORS.register_module('specified_numeric_field_filter') +class SpecifiedNumericFieldFilter(Filter): + """ + Filter based on specified numeric field information. + + If the specified numeric information in the sample is not within the + specified range, the sample will be filtered. + """ + + def __init__(self, + text_key: str = '', + min_value: float = -sys.maxsize, + max_value: float = sys.maxsize, + *args, + **kwargs): + """ + Initialization method. + + :param text_key: Filter based on the specified numeric value + corresponding to the target key. The target key + corresponding to multi-level field information need to be + separated by '.'. + :param min_value: The min filter value in SpecifiedNumericField + op, samples will be filtered if their specified numeric + field value is below this parameter. + :param max_value: The max filter value in SpecifiedNumericField + op, samples will be filtered if their specified numeric + field value exceeds this parameter. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.text_key = text_key + self.min_value = min_value + self.max_value = max_value + + def compute_stats(self, sample): + return sample + + def process(self, sample): + if not self.text_key: + return True + + field_value = sample + for key in self.text_key.split('.'): + assert key in field_value.keys(), "'{}' not in {}".format( + key, field_value.keys()) + field_value = field_value[key] + + if is_number(field_value): + field_value = float(field_value) + return self.min_value <= field_value <= self.max_value + else: + return False diff --git a/data_juicer/ops/filter/stopwords_filter.py b/data_juicer/ops/filter/stopwords_filter.py new file mode 100644 index 000000000..1dc3289c3 --- /dev/null +++ b/data_juicer/ops/filter/stopwords_filter.py @@ -0,0 +1,92 @@ +# Some code here has been modified from: +# https://huggingface.co/spaces/huggingface/text-data-filtering +# -------------------------------------------------------- + +from jsonargparse.typing import ClosedUnitInterval, List + +from data_juicer.utils.model_utils import MODEL_ZOO, prepare_model + +from ...utils.asset_utils import ASSET_DIR, load_words_asset +from ..base_op import OPERATORS, Filter +from ..common import SPECIAL_CHARACTERS, get_words_from_document + + +@OPERATORS.register_module('stopwords_filter') +class StopWordsFilter(Filter): + """Filter to keep samples with stopword ratio larger than a specific min + value.""" + + def __init__(self, + lang: str = 'en', + tokenization: bool = False, + min_ratio: ClosedUnitInterval = 0.3, + stopwords_dir: str = ASSET_DIR, + use_words_aug: bool = False, + words_aug_group_sizes: List = [2], + words_aug_join_char: str = '', + *args, + **kwargs): + """ + Initialization method. + + :param lang: Consider stopwords in what language. If lang == + "all", we will adopt the one merged from all the available + languages + :param tokenization: whether to use model to tokenize documents + :param min_ratio: The min filter ratio in this op. + :param stopwords_dir: The directory storing the stopwords + file(s) whose name includes "stopwords" and in json format + :param use_words_aug: Whether to augment words, especially for + Chinese and Vietnamese + :param words_aug_group_sizes: The group size of words to augment + :param words_aug_join_char: The join char between words to + augment + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.lang = lang + self.min_ratio = min_ratio + self.use_words_aug = use_words_aug + self.words_aug_group_sizes = words_aug_group_sizes + self.words_aug_join_char = words_aug_join_char + self.model_key = None + + self.STOPWORDS = load_words_asset(words_dir=stopwords_dir, + words_type='stopwords') + if 'all' not in self.STOPWORDS: + self.STOPWORDS['all'] = [ + val for vals in self.STOPWORDS.values() for val in vals + ] + if tokenization: + self.model_key = prepare_model(lang=lang, + model_type='sentencepiece') + + def compute_stats(self, sample): + # check if it's computed already + if 'stopwords_ratio' in sample['stats']: + return sample + + tokenizer = MODEL_ZOO.get(self.model_key, None) + words = get_words_from_document( + sample[self.text_key], + token_func=tokenizer.encode_as_pieces if tokenizer else None, + strip_chars=SPECIAL_CHARACTERS, + use_words_aug=self.use_words_aug, + words_aug_group_sizes=self.words_aug_group_sizes, + words_aug_join_char=self.words_aug_join_char) + + stopwords_ratio = ( + len([word for word in words + if word in self.STOPWORDS[self.lang]]) + / len(words)) \ + if len(words) != 0 else 0.0 + + if stopwords_ratio > 1.0: + stopwords_ratio = 1.0 + + sample['stats']['stopwords_ratio'] = stopwords_ratio + return sample + + def process(self, sample): + return sample['stats']['stopwords_ratio'] >= self.min_ratio diff --git a/data_juicer/ops/filter/suffix_filter.py b/data_juicer/ops/filter/suffix_filter.py new file mode 100644 index 000000000..d1ce414d6 --- /dev/null +++ b/data_juicer/ops/filter/suffix_filter.py @@ -0,0 +1,40 @@ +from typing import List, Tuple, Union + +from ..base_op import OPERATORS, Filter + + +@OPERATORS.register_module('suffix_filter') +class SuffixFilter(Filter): + """Filter to keep samples with specified suffix.""" + + def __init__(self, + suffixes: Union[str, List[str], Tuple[str]] = [], + *args, + **kwargs): + """ + Initialization method. + + :param suffixes: the suffix of text that will be keep. + For example: '.txt', 'txt' or ['txt', '.pdf', 'docx'] + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + if suffixes is None: + self.suffixes = [] + elif isinstance(suffixes, str): + self.suffixes = [suffixes] + else: + self.suffixes = suffixes + + def compute_stats(self, sample): + return sample + + def process(self, sample): + if self.suffixes: + if sample['meta.suffix'] in self.suffixes: + return True + else: + return False + else: + return True diff --git a/data_juicer/ops/filter/text_length_filter.py b/data_juicer/ops/filter/text_length_filter.py new file mode 100644 index 000000000..96d780473 --- /dev/null +++ b/data_juicer/ops/filter/text_length_filter.py @@ -0,0 +1,46 @@ +import sys + +from jsonargparse.typing import PositiveInt + +from ..base_op import OPERATORS, Filter + + +@OPERATORS.register_module('text_length_filter') +class TextLengthFilter(Filter): + """Filter to keep samples with total text length within a specific + range.""" + + def __init__(self, + min_len: PositiveInt = 10, + max_len: PositiveInt = sys.maxsize, + *args, + **kwargs): + """ + Initialization method. + + :param min_len: The min text length in the filtering. samples + will be filtered if their text length is below this + parameter. + :param max_len: The max text length in the filtering. samples + will be filtered if their text length exceeds this + parameter. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.min_len = min_len + self.max_len = max_len + + def compute_stats(self, sample): + # check if it's computed already + if 'text_len' in sample['stats']: + return sample + + sample['stats']['text_len'] = len(sample[self.text_key]) + return sample + + def process(self, sample): + if self.min_len <= sample['stats']['text_len'] <= self.max_len: + return True + else: + return False diff --git a/data_juicer/ops/filter/word_num_filter.py b/data_juicer/ops/filter/word_num_filter.py new file mode 100644 index 000000000..068291891 --- /dev/null +++ b/data_juicer/ops/filter/word_num_filter.py @@ -0,0 +1,64 @@ +import sys + +from jsonargparse.typing import PositiveInt + +from data_juicer.utils.model_utils import MODEL_ZOO, prepare_model + +from ..base_op import OPERATORS, Filter +from ..common import SPECIAL_CHARACTERS, get_words_from_document + + +@OPERATORS.register_module('words_num_filter') +class WordNumFilter(Filter): + """Filter to keep samples with total words number within a specific + range.""" + + def __init__(self, + lang: str = 'en', + tokenization: bool = False, + min_num: PositiveInt = 10, + max_num: PositiveInt = sys.maxsize, + *args, + **kwargs): + """ + Initialization method. + + :param lang: sample in which language. + :param tokenization: whether to use model to tokenize documents + :param min_num: The min filter word number in this op, samples + will be filtered if their word number is below this + parameter. + :param max_num: The max filter word number in this op, samples + will be filtered if their word number exceeds this + parameter. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.min_num = min_num + self.max_num = max_num + self.model_key = None + + if tokenization: + self.model_key = prepare_model(lang=lang, + model_type='sentencepiece') + + def compute_stats(self, sample): + # check if it's computed already + if 'num_words' in sample['stats']: + return sample + + tokenizer = MODEL_ZOO.get(self.model_key, None) + sample['stats']['num_words'] = len( + get_words_from_document( + sample[self.text_key], + token_func=tokenizer.encode_as_pieces if tokenizer else None, + lower_case=False, + strip_chars=SPECIAL_CHARACTERS)) + return sample + + def process(self, sample): + if self.min_num <= sample['stats']['num_words'] <= self.max_num: + return True + else: + return False diff --git a/data_juicer/ops/filter/word_repetition_filter.py b/data_juicer/ops/filter/word_repetition_filter.py new file mode 100644 index 000000000..3770eb64e --- /dev/null +++ b/data_juicer/ops/filter/word_repetition_filter.py @@ -0,0 +1,86 @@ +# Some code here has been modified from: +# https://huggingface.co/spaces/huggingface/text-data-filtering +# -------------------------------------------------------- + +from jsonargparse.typing import ClosedUnitInterval, PositiveInt + +from data_juicer.utils.model_utils import MODEL_ZOO, prepare_model + +from ..base_op import OPERATORS, Filter +from ..common import SPECIAL_CHARACTERS, get_words_from_document + + +@OPERATORS.register_module('word_repetition_filter') +class WordRepetitionFilter(Filter): + """Filter to keep samples with word-level n-gram repetition ratio within a + \ specific range.""" + + def __init__(self, + lang: str = 'en', + tokenization: bool = False, + rep_len: PositiveInt = 10, + min_ratio: ClosedUnitInterval = 0.0, + max_ratio: ClosedUnitInterval = 0.5, + *args, + **kwargs): + """ + Initialization method. + + :param lang: sample in which language. + :param tokenization: whether to use model to tokenize documents + :param rep_len: Repetition length for word-level n-gram. + :param min_ratio: The min filter ratio in this op, samples will + be filtered if their word-level n-gram repetition ratio is + below this parameter. + :param max_ratio: The max filter ratio in this op, samples will + be filtered if their word-level n-gram repetition ratio + exceeds this parameter. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.n = rep_len + self.min_ratio = min_ratio + self.max_ratio = max_ratio + self.model_key = None + + if tokenization: + self.model_key = prepare_model(lang=lang, + model_type='sentencepiece') + + def compute_stats(self, sample): + # check if it's computed already + if 'word_rep_ratio' in sample['stats']: + return sample + + tokenizer = MODEL_ZOO.get(self.model_key, None) + words = get_words_from_document( + sample[self.text_key], + token_func=tokenizer.encode_as_pieces if tokenizer else None, + strip_chars=SPECIAL_CHARACTERS) + word_ngrams = [ + ' '.join(words[i:i + self.n]) + for i in range(len(words) - self.n + 1) + ] + freq_word_ngrams = {} + for word_ngram in word_ngrams: + freq_word_ngrams[word_ngram] = ( + freq_word_ngrams.get(word_ngram, 0) + 1) + + if len(freq_word_ngrams) == 0: + sample['stats']['word_rep_ratio'] = 0.0 + return sample + + freq_word_ngrams = list(freq_word_ngrams.values()) + rep_more_than_one = [freq for freq in freq_word_ngrams if freq > 1] + sample['stats']['word_rep_ratio'] = ( + sum(rep_more_than_one) / + sum(freq_word_ngrams)) if sum(freq_word_ngrams) != 0 else 0.0 + return sample + + def process(self, sample): + if self.min_ratio <= sample['stats']['word_rep_ratio'] \ + <= self.max_ratio: + return True + else: + return False diff --git a/data_juicer/ops/load.py b/data_juicer/ops/load.py new file mode 100644 index 000000000..8e4e1d216 --- /dev/null +++ b/data_juicer/ops/load.py @@ -0,0 +1,25 @@ +from .base_op import OPERATORS + + +def load_ops(process_list, text_key='text'): + """ + Load op list according to the process list from config file. + + :param process_list: A process list. Each item is an op name and its + arguments. + :param text_key: the key name of field that stores sample texts to + be processed + :return: The op instance list. + """ + ops = [] + for process in process_list: + op_name, args = list(process.items())[0] + + # users can freely specify text_key for different ops + if args is None: + args = {'text_key': text_key} + elif args['text_key'] is None: + args['text_key'] = text_key + ops.append(OPERATORS.modules[op_name](**args)) + + return ops diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py new file mode 100644 index 000000000..6510631a0 --- /dev/null +++ b/data_juicer/ops/mapper/__init__.py @@ -0,0 +1,8 @@ +from . import (clean_copyright_mapper, clean_email_mapper, clean_html_mapper, + clean_ip_mapper, clean_links_mapper, expand_macro_mapper, + fix_unicode_mapper, punctuation_normalization_mapper, + remove_bibliography_mapper, remove_comments_mapper, + remove_header_mapper, remove_long_words_mapper, + remove_specific_chars_mapper, remove_table_text_mapper, + remove_words_with_incorrect_substrings_mapper, + sentence_split_mapper, whitespace_normalization_mapper) diff --git a/data_juicer/ops/mapper/clean_copyright_mapper.py b/data_juicer/ops/mapper/clean_copyright_mapper.py new file mode 100644 index 000000000..c5b046d0e --- /dev/null +++ b/data_juicer/ops/mapper/clean_copyright_mapper.py @@ -0,0 +1,55 @@ +# Some code here has been modified from: +# https://github.com/togethercomputer/RedPajama-Data/ +# -------------------------------------------------------- + +import regex as re + +from ..base_op import OPERATORS, Mapper + + +@OPERATORS.register_module('clean_copyright_mapper') +class CleanCopyrightMapper(Mapper): + """Mapper to clean copyright comments at the beginning of the text + samples.""" + + def __init__(self, *args, **kwargs): + """ + Initialization method. + + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.pat = re.compile('/\\*[^*]*\\*+(?:[^/*][^*]*\\*+)*/') + self.cpat = re.compile('copyright', re.IGNORECASE) + + def process(self, sample): + + r = self.pat.search(sample[self.text_key]) + if r: + # found one, now see if it contains "copyright", if so strip it + span = r.span() + sub = sample[self.text_key][span[0]:span[1]] + if self.cpat.search(sub): + # cut it + sample[self.text_key] = sample[ + self.text_key][:span[0]] + sample[self.text_key][span[1]:] + + return sample + + lines = sample[self.text_key].split('\n') + skip = 0 + + # Greedy replace any file that begins with comment block, most + # are copyright headers + for k in range(len(lines)): + if (lines[k].startswith('//') or lines[k].startswith('#') + or lines[k].startswith('--') or not lines[k]): + skip = skip + 1 + else: + break + + if skip: + # we skipped, consume it + sample[self.text_key] = '\n'.join(lines[skip:]) + return sample diff --git a/data_juicer/ops/mapper/clean_email_mapper.py b/data_juicer/ops/mapper/clean_email_mapper.py new file mode 100644 index 000000000..4bed01a1d --- /dev/null +++ b/data_juicer/ops/mapper/clean_email_mapper.py @@ -0,0 +1,29 @@ +import regex as re + +from ..base_op import OPERATORS, Mapper + + +@OPERATORS.register_module('clean_email_mapper') +class CleanEmailMapper(Mapper): + """Mapper to clean email in text samples.""" + + def __init__(self, *args, **kwargs): + """ + Initialization method. + + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.pattern = r'[A-Za-z0-9.\-+_]+@[a-z0-9.\-+_]+\.[a-z]+' + + def process(self, sample): + + if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL): + return sample + + sample[self.text_key] = re.sub(pattern=self.pattern, + repl=r'', + string=sample[self.text_key], + flags=re.DOTALL) + return sample diff --git a/data_juicer/ops/mapper/clean_html_mapper.py b/data_juicer/ops/mapper/clean_html_mapper.py new file mode 100644 index 000000000..22e092851 --- /dev/null +++ b/data_juicer/ops/mapper/clean_html_mapper.py @@ -0,0 +1,34 @@ +# Some code here has been modified from: +# https://github.com/togethercomputer/RedPajama-Data/ +# -------------------------------------------------------- + +from selectolax.parser import HTMLParser + +from ..base_op import OPERATORS, Mapper + + +@OPERATORS.register_module('clean_html_mapper') +class CleanHtmlMapper(Mapper): + """Mapper to clean html code in text samples.""" + + def __init__(self, *args, **kwargs): + """ + Initialization method. + + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + + def process(self, sample): + + def _clean_html(raw_html): + raw_html = raw_html.replace('
  • ', '\n*') + raw_html = raw_html.replace('
  • ', '') + raw_html = raw_html.replace('
      ', '\n*') + raw_html = raw_html.replace('
    ', '') + parser = HTMLParser(raw_html) + return parser.text() + + sample[self.text_key] = _clean_html(sample[self.text_key]) + return sample diff --git a/data_juicer/ops/mapper/clean_ip_mapper.py b/data_juicer/ops/mapper/clean_ip_mapper.py new file mode 100644 index 000000000..8f1ce9684 --- /dev/null +++ b/data_juicer/ops/mapper/clean_ip_mapper.py @@ -0,0 +1,34 @@ +import regex as re + +from ..base_op import OPERATORS, Mapper + + +@OPERATORS.register_module('clean_ip_mapper') +class CleanIpMapper(Mapper): + """Mapper to clean ipv4 and ipv6 address in text samples.""" + + def __init__(self, *args, **kwargs): + """ + Initialization method. + + :param args: extra args + :param kwargs: extra args + """ + + super().__init__(*args, **kwargs) + self.pattern = r'(?:(?:1[0-9][0-9]\.)|(?:2[0-4][0-9]\.)|' + self.pattern += r'(?:25[0-5]\.)|(?:[1-9][0-9]\.)|(?:[0-9]\.))' + self.pattern += r'{3}(?:(?:1[0-9][0-9])|(?:2[0-4][0-9])|' + self.pattern += r'(?:25[0-5])|(?:[1-9][0-9])|(?:[0-9]))|' + self.pattern += r'([\da-fA-F]{1,4}:){7}[\da-fA-F]{1,4}' # ipv6 + + def process(self, sample): + + if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL): + return sample + + sample[self.text_key] = re.sub(pattern=self.pattern, + repl=r'', + string=sample[self.text_key], + flags=re.DOTALL) + return sample diff --git a/data_juicer/ops/mapper/clean_links_mapper.py b/data_juicer/ops/mapper/clean_links_mapper.py new file mode 100644 index 000000000..b8d4945fe --- /dev/null +++ b/data_juicer/ops/mapper/clean_links_mapper.py @@ -0,0 +1,39 @@ +# Some code here has been modified from: +# https://github.com/kallewesterling/CleanText/ +# -------------------------------------------------------- +import regex as re + +from ..base_op import OPERATORS, Mapper + + +@OPERATORS.register_module('clean_links_mapper') +class CleanLinksMapper(Mapper): + """Mapper to clean links like http/https/ftp in text samples.""" + + def __init__(self, *args, **kwargs): + """ + Initialization method. + + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.pattern = r'(?i)\b(' + self.pattern += r'(?:[a-z][\w-]+:(?:\/{1,3}|' + self.pattern += r'[a-z0-9%])|www\d{0,3}[.]|' + self.pattern += r'[a-z0-9.\-]+[.][a-z]{2,4}\/)' + self.pattern += r'(?:[^\s()<>]+|\(([^\s()<>]+|(\([^\s()<>]+\)))*\))' + self.pattern += r'+(?:\(([^\s()<>]+|(\([^\s()<>]+\)))*\)|' + self.pattern += r'[^\s`!()\[\]{};:\'\".,<>?«»“”‘’])' + self.pattern += r')' + + def process(self, sample): + + if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL): + return sample + + sample[self.text_key] = re.sub(pattern=self.pattern, + repl=r'', + string=sample[self.text_key], + flags=re.DOTALL) + return sample diff --git a/data_juicer/ops/mapper/expand_macro_mapper.py b/data_juicer/ops/mapper/expand_macro_mapper.py new file mode 100644 index 000000000..1792796ca --- /dev/null +++ b/data_juicer/ops/mapper/expand_macro_mapper.py @@ -0,0 +1,80 @@ +# Some code here has been modified from: +# https://github.com/togethercomputer/RedPajama-Data/blob/main/data_prep/arxiv/arxiv_cleaner.py +# -------------------------------------------------------- + +import regex as re + +from ..base_op import OPERATORS, Mapper + + +@OPERATORS.register_module('expand_macro_mapper') +class ExpandMacroMapper(Mapper): + """Mapper to expand macro definitions in the document body of Latex + samples.""" + + def __init__(self, *args, **kwargs): + """ + Initialization method. + + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + + def _build_non_arg_macros_dict(self, file_content): + # regex for extracting \newcommand macros without arguments + non_arg_nc_reg = re.compile( + # this regex matches the following: + # \newcommand{\macro_name}{macro_value} + # \newcommand*{\macro_name}{macro_value} + # where macro_name is only allowed to contain letters and numbers; + # macro_value can contain any character. + pattern=r'\\\bnewcommand\b\*?\{(\\[a-zA-Z0-9]+?)\}\{(.*?)\}$', + flags=re.MULTILINE) + + # regex for extracting \def macros without arguments + non_arg_def_reg = re.compile( + # this regex matches the following: + # \def\macro_name{macro_value} + # where macro_name is only allowed to contain letters and numbers; + # macro_value can contain any character. + pattern=r'\\def\s*(\\[a-zA-Z0-9]+?)\s*\{(.*?)\}$', + flags=re.MULTILINE) + + # Extract all user-defined LaTeX macros from the preamble + macros = {} + for reg in [non_arg_nc_reg, non_arg_def_reg]: + for match in reg.finditer(file_content): + # convert the macro name and value to a raw string that can be + # used in re.sub + macro_name = match.group(1).encode('unicode-escape').decode( + 'utf-8') + macro_val = match.group(2).encode('unicode-escape').decode( + 'utf-8') + + macros[macro_name] = macro_val + return macros + + def process(self, sample): + non_arg_macros = self._build_non_arg_macros_dict(sample[self.text_key]) + + # TODO: macros that take arguments are not supported yet + arg_macros = {} + + # inline-expand all non-arg macros + for macro_name, macro_value in non_arg_macros.items(): + sample[self.text_key] = re.sub( + # make pattern grouped to make sure that the macro is not part + # of a longer alphanumeric word + pattern=r'(' + macro_name + r')' + r'([^a-zA-Z0-9])', + # replace the macro with its value and add back the character + # that was matched after the macro + repl=macro_value + r'\2', + string=sample[self.text_key]) + + # inline-expand all macros that use args + # TODO: inline-expand macros with args + for macro_name, macro_value in arg_macros.items(): + pass + + return sample diff --git a/data_juicer/ops/mapper/fix_unicode_mapper.py b/data_juicer/ops/mapper/fix_unicode_mapper.py new file mode 100644 index 000000000..275fbba28 --- /dev/null +++ b/data_juicer/ops/mapper/fix_unicode_mapper.py @@ -0,0 +1,21 @@ +import ftfy + +from ..base_op import OPERATORS, Mapper + + +@OPERATORS.register_module('fix_unicode_mapper') +class FixUnicodeMapper(Mapper): + """Mapper to fix unicode errors in text samples.""" + + def __init__(self, *args, **kwargs): + """ + Initialization method. + + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + + def process(self, sample): + sample[self.text_key] = ftfy.fix_text(sample[self.text_key]) + return sample diff --git a/data_juicer/ops/mapper/punctuation_normalization_mapper.py b/data_juicer/ops/mapper/punctuation_normalization_mapper.py new file mode 100644 index 000000000..e8cdf3e60 --- /dev/null +++ b/data_juicer/ops/mapper/punctuation_normalization_mapper.py @@ -0,0 +1,62 @@ +# Some code here has been modified from: +# https://github.com/bigscience-workshop/data-preparation +# -------------------------------------------------------- + +from ..base_op import OPERATORS, Mapper + + +@OPERATORS.register_module('punctuation_normalization_mapper') +class PunctuationNormalizationMapper(Mapper): + """Mapper to normalize unicode punctuations to English punctuations in text + \ samples.""" + + def __init__(self, *args, **kwargs): + """ + Initialization method. + + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.punctuation_unicode = { + ',': ',', + '。': '.', + '、': ',', + '„': '"', + '”': '"', + '“': '"', + '«': '"', + '»': '"', + '1': '"', + '」': '"', + '「': '"', + '《': '"', + '》': '"', + '´': "'", + '∶': ':', + ':': ':', + '?': '?', + '!': '!', + '(': '(', + ')': ')', + ';': ';', + '–': '-', + '—': ' - ', + '.': '. ', + '~': '~', + '’': "'", + '…': '...', + '━': '-', + '〈': '<', + '〉': '>', + '【': '[', + '】': ']', + '%': '%', + '►': '-', + } + + def process(self, sample): + sample[self.text_key] = ''.join([ + self.punctuation_unicode.get(c, c) for c in sample[self.text_key] + ]) + return sample diff --git a/data_juicer/ops/mapper/remove_bibliography_mapper.py b/data_juicer/ops/mapper/remove_bibliography_mapper.py new file mode 100644 index 000000000..7a5c815ca --- /dev/null +++ b/data_juicer/ops/mapper/remove_bibliography_mapper.py @@ -0,0 +1,35 @@ +# Some code here has been modified from: +# https://github.com/togethercomputer/RedPajama-Data/ +# -------------------------------------------------------- + +import regex as re + +from ..base_op import OPERATORS, Mapper + + +@OPERATORS.register_module('remove_bibliography_mapper') +class RemoveBibliographyMapper(Mapper): + """Mapper to remove bibliography at the end of documents in Latex + samples.""" + + def __init__(self, *args, **kwargs): + """ + Initialization method. + + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.pattern = r'(\\appendix|' + self.pattern += r'\\begin\{references\}|' + self.pattern += r'\\begin\{REFERENCES\}|' + self.pattern += r'\\begin\{thebibliography\}|' + self.pattern += r'\\bibliography\{.*\}' + self.pattern += r').*$' + + def process(self, sample): + sample[self.text_key] = re.sub(pattern=self.pattern, + repl=r'', + string=sample[self.text_key], + flags=re.DOTALL) + return sample diff --git a/data_juicer/ops/mapper/remove_comments_mapper.py b/data_juicer/ops/mapper/remove_comments_mapper.py new file mode 100644 index 000000000..f49bc9065 --- /dev/null +++ b/data_juicer/ops/mapper/remove_comments_mapper.py @@ -0,0 +1,55 @@ +# Some code here has been modified from: +# https://github.com/togethercomputer/RedPajama-Data/ +# -------------------------------------------------------- + +from typing import List, Union + +import regex as re + +from ..base_op import OPERATORS, Mapper + + +@OPERATORS.register_module('remove_comments_mapper') +class RemoveCommentsMapper(Mapper): + """ + Mapper to remove comments in different kinds of documents. + + Only support 'tex' \ for now. + """ + + def __init__(self, + doc_type: Union[str, List[str]] = 'tex', + inline: bool = True, + multiline: bool = True, + *args, + **kwargs): + """ + Initialization method. + + :param doc_type: Type of document to remove comments. + :param inline: Whether to remove inline comments. + :param multiline: Whether to remove multiline comments. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.doc_type = doc_type + self.inline = inline + self.multiline = multiline + + def process(self, sample): + # TODO: remove different comments by sample type + + if self.inline: + # remove all in comments within a line + sample[self.text_key] = re.sub(pattern=r'[^\\]%.+$', + repl=r'', + string=sample[self.text_key], + flags=re.MULTILINE) + + if self.multiline: + sample[self.text_key] = re.sub(pattern=r'(?m)^%.*\n?', + repl=r'', + string=sample[self.text_key], + flags=re.MULTILINE) + return sample diff --git a/data_juicer/ops/mapper/remove_header_mapper.py b/data_juicer/ops/mapper/remove_header_mapper.py new file mode 100644 index 000000000..4c36bde64 --- /dev/null +++ b/data_juicer/ops/mapper/remove_header_mapper.py @@ -0,0 +1,49 @@ +# Some code here has been modified from: +# https://github.com/togethercomputer/RedPajama-Data/ +# -------------------------------------------------------- + +import regex as re + +from ..base_op import OPERATORS, Mapper + + +# TODO +@OPERATORS.register_module('remove_header_mapper') +class RemoveHeaderMapper(Mapper): + """Mapper to remove headers at the beginning of documents in Latex + samples.""" + + def __init__(self, drop_no_head: bool = True, *args, **kwargs): + """ + Initialization method. + + :param drop_no_head: whether to drop sample texts without + headers. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.pattern = r'^(.*?)(' + self.pattern += r'\\\bchapter\b\*?(?:\[(.*?)\])?\{(.*?)\}|' + self.pattern += r'\\\bpart\b\*?(?:\[(.*?)\])?\{(.*?)\}|' + self.pattern += r'\\\bsection\b\*?(?:\[(.*?)\])?\{(.*?)\}|' + self.pattern += r'\\\bsubsection\b\*?(?:\[(.*?)\])?\{(.*?)\}|' + self.pattern += r'\\\bsubsubsection\b\*?(?:\[(.*?)\])?\{(.*?)\}|' + self.pattern += r'\\\bparagraph\b\*?(?:\[(.*?)\])?\{(.*?)\}' + self.pattern += r'\\\bsubparagraph\b\*?(?:\[(.*?)\])?\{(.*?)\}' + self.pattern += r')' + + self.drop_no_head = drop_no_head + + def process(self, sample): + + if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL): + if self.drop_no_head: + sample[self.text_key] = '' + return sample + + sample[self.text_key] = re.sub(pattern=self.pattern, + repl=r'\2', + string=sample[self.text_key], + flags=re.DOTALL) + return sample diff --git a/data_juicer/ops/mapper/remove_long_words_mapper.py b/data_juicer/ops/mapper/remove_long_words_mapper.py new file mode 100644 index 000000000..92ac8fe2d --- /dev/null +++ b/data_juicer/ops/mapper/remove_long_words_mapper.py @@ -0,0 +1,53 @@ +# Some code here has been modified from: +# https://huggingface.co/spaces/huggingface/text-data-filtering +# -------------------------------------------------------- + +import sys + +from jsonargparse.typing import PositiveInt + +from ..base_op import OPERATORS, Mapper +from ..common import (SPECIAL_CHARACTERS, merge_on_whitespace_tab_newline, + split_on_newline_tab_whitespace, strip) + + +@OPERATORS.register_module('remove_long_words_mapper') +class RemoveLongWordsMapper(Mapper): + """Mapper to remove long words within a specific range.""" + + def __init__(self, + min_len: PositiveInt = 1, + max_len: PositiveInt = sys.maxsize, + *args, + **kwargs): + """ + Initialization method. + + :param min_len: The min mapper word length in this op, words + will be filtered if their length is below this parameter. + :param max_len: The max mapper word length in this op, words + will be filtered if their length exceeds this parameter. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.min_len = min_len + self.max_len = max_len + + def should_keep_long_word(self, word): + if self.min_len <= len(word) <= self.max_len: + return True + elif self.min_len <= len(strip(word, + SPECIAL_CHARACTERS)) <= self.max_len: + return True + else: + return False + + def process(self, sample): + + sentences = split_on_newline_tab_whitespace(sample[self.text_key]) + sentences = [[[ + word for word in subsentence if self.should_keep_long_word(word) + ] for subsentence in sentence] for sentence in sentences] + sample[self.text_key] = merge_on_whitespace_tab_newline(sentences) + return sample diff --git a/data_juicer/ops/mapper/remove_specific_chars_mapper.py b/data_juicer/ops/mapper/remove_specific_chars_mapper.py new file mode 100644 index 000000000..99e15afef --- /dev/null +++ b/data_juicer/ops/mapper/remove_specific_chars_mapper.py @@ -0,0 +1,40 @@ +from typing import List, Union + +import regex as re + +from ..base_op import OPERATORS, Mapper + + +@OPERATORS.register_module('remove_specific_chars_mapper') +class RemoveSpecificCharsMapper(Mapper): + """Mapper to clean specific chars in text samples.""" + + def __init__(self, + chars_to_remove: Union[str, List[str]] = '◆●■►▼▲▴∆▻▷❖♡□', + *args, + **kwargs): + """ + Initialization method. + + :param chars_to_remove: a list or a string including all + characters that need to be removed from text. + :param args: extra args + :param kwargs: extra args + """ + + super().__init__(*args, **kwargs) + if chars_to_remove: + self.pattern = '[' + '|'.join(chars_to_remove) + ']' + else: + self.pattern = None + + def process(self, sample): + + if self.pattern is None: + return sample + + sample[self.text_key] = re.sub(pattern=self.pattern, + repl=r'', + string=sample[self.text_key], + flags=re.DOTALL) + return sample diff --git a/data_juicer/ops/mapper/remove_table_text_mapper.py b/data_juicer/ops/mapper/remove_table_text_mapper.py new file mode 100644 index 000000000..4f6dfb233 --- /dev/null +++ b/data_juicer/ops/mapper/remove_table_text_mapper.py @@ -0,0 +1,45 @@ +import regex as re +from jsonargparse.typing import restricted_number_type + +from ..base_op import OPERATORS, Mapper + +from_2_to_20 = restricted_number_type('from_2_to_20', int, [('>=', 2), + ('<=', 20)]) + + +@OPERATORS.register_module('remove_table_text_mapper') +class RemoveTableTextMapper(Mapper): + """ + Mapper to remove table texts from text samples. + + Regular expression is used to remove tables in the range of column + number of tables. + """ + + def __init__(self, + min_col: from_2_to_20 = 2, + max_col: from_2_to_20 = 20, + *args, + **kwargs): + """ + Initialization method. + + :param min_col: The min number of columns of table to remove. + :param max_col: The max number of columns of table to remove. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.min_col = min_col + self.max_col = max_col + self.pattern = r'(?<=\n)((\S+?)([ |\t](\S+?)){%d}\n+){2,}' + + def process(self, sample): + + text = sample[self.text_key] + for i in range(self.min_col - 1, self.max_col): + pattern = re.compile(self.pattern % i) + text = pattern.sub('', text) + + sample[self.text_key] = text + return sample diff --git a/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py b/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py new file mode 100644 index 000000000..4f92b6f43 --- /dev/null +++ b/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py @@ -0,0 +1,66 @@ +from jsonargparse.typing import List + +from data_juicer.utils.model_utils import MODEL_ZOO, prepare_model + +from ..base_op import OPERATORS, Mapper +from ..common import (SPECIAL_CHARACTERS, get_words_from_document, + merge_on_whitespace_tab_newline, + split_on_newline_tab_whitespace, strip) + + +@OPERATORS.register_module('remove_words_with_incorrect_substrings_mapper') +class RemoveWordsWithIncorrectSubstringsMapper(Mapper): + """Mapper to remove words with incorrect substrings.""" + + def __init__(self, + lang: str = 'en', + tokenization: bool = False, + substrings: List = None, + *args, + **kwargs): + """ + Initialization method. + + :param lang: sample in which language + :param tokenization: whether to use model to tokenize documents + :param substrings: The incorrect substrings in words. + :param args: extra args + :param kwargs: extra args + """ + if substrings is None: + substrings = ['http', 'www', '.com', 'href', '//'] + super().__init__(*args, **kwargs) + self.tokenization = tokenization + self.substrings = substrings + if tokenization: + self.model_key = prepare_model(lang=lang, + model_type='sentencepiece') + + def should_keep_word_with_incorrect_substrings(self, word, substrings): + word = strip(word, SPECIAL_CHARACTERS) + should_keep = all([(i_substr not in word) for i_substr in substrings]) + return should_keep + + def process(self, sample): + if self.tokenization: + tokenizer = MODEL_ZOO.get(self.model_key, None) + sentences = get_words_from_document( + sample[self.text_key], + token_func=tokenizer.encode_as_pieces if tokenizer else None, + lower_case=False) + words = [ + word.replace('▁', '') for word in sentences + if self.should_keep_word_with_incorrect_substrings( + word.replace('▁', ''), self.substrings) + ] + if len(words) != len(sentences): + sample[self.text_key] = ''.join(words) + else: + sentences = split_on_newline_tab_whitespace(sample[self.text_key]) + sentences = [[[ + word for word in subsentence + if self.should_keep_word_with_incorrect_substrings( + word, self.substrings) + ] for subsentence in sentence] for sentence in sentences] + sample[self.text_key] = merge_on_whitespace_tab_newline(sentences) + return sample diff --git a/data_juicer/ops/mapper/sentence_split_mapper.py b/data_juicer/ops/mapper/sentence_split_mapper.py new file mode 100644 index 000000000..368d2ba92 --- /dev/null +++ b/data_juicer/ops/mapper/sentence_split_mapper.py @@ -0,0 +1,28 @@ +from data_juicer.utils.model_utils import MODEL_ZOO, prepare_model + +from ..base_op import OPERATORS, Mapper +from ..common import get_sentences_from_document + + +@OPERATORS.register_module('sentence_split_mapper') +class SentenceSplitMapper(Mapper): + """Mapper to split text samples to sentences.""" + + def __init__(self, lang: str = 'en', *args, **kwargs): + """ + Initialization method. + + :param lang: split sentence of text in which language. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.model_key = prepare_model(lang=lang, model_type='nltk') + + def process(self, sample): + + nltk_model = MODEL_ZOO.get(self.model_key, None) + sample[self.text_key] = get_sentences_from_document( + sample[self.text_key], + model_func=nltk_model.tokenize if nltk_model else None) + return sample diff --git a/data_juicer/ops/mapper/whitespace_normalization_mapper.py b/data_juicer/ops/mapper/whitespace_normalization_mapper.py new file mode 100644 index 000000000..a81d60f0c --- /dev/null +++ b/data_juicer/ops/mapper/whitespace_normalization_mapper.py @@ -0,0 +1,41 @@ +# Most of the code here has been modified from: +# https://github.com/bigscience-workshop/data-preparation +# -------------------------------------------------------- + +from ..base_op import OPERATORS, Mapper + + +@OPERATORS.register_module('whitespace_normalization_mapper') +class WhitespaceNormalizationMapper(Mapper): + """ + Mapper to normalize different kinds of whitespaces to whitespace ' ' (0x20) + in text samples. + + Different kinds of whitespaces can be found here: + https://en.wikipedia.org/wiki/Whitespace_character + """ + + def __init__(self, *args, **kwargs): + """ + Initialization method. + + :param args: extra args + :param kwargs: extra args + """ + # whitespaces in unicode can be found here: + # https://en.wikipedia.org/wiki/Whitespace_character + super().__init__(*args, **kwargs) + self.whitespaces = { + ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', + ' ', ' ', ' ', ' ', '​', '‌', '‍', '⁠', '', '„' + } + + def process(self, sample): + # remove whitespaces before and after the main content + text = sample[self.text_key].strip() + + # replace all kinds of whitespaces with ' ' + sample[self.text_key] = ''.join( + [char if char not in self.whitespaces else ' ' for char in text]) + + return sample diff --git a/data_juicer/ops/selector/__init__.py b/data_juicer/ops/selector/__init__.py new file mode 100644 index 000000000..cf0977321 --- /dev/null +++ b/data_juicer/ops/selector/__init__.py @@ -0,0 +1 @@ +from . import frequency_specified_field_selector, topk_specified_field_selector diff --git a/data_juicer/ops/selector/frequency_specified_field_selector.py b/data_juicer/ops/selector/frequency_specified_field_selector.py new file mode 100644 index 000000000..937642642 --- /dev/null +++ b/data_juicer/ops/selector/frequency_specified_field_selector.py @@ -0,0 +1,87 @@ +import numbers + +from jsonargparse.typing import ClosedUnitInterval, PositiveInt + +from ..base_op import OPERATORS, Selector + + +@OPERATORS.register_module('frequency_specified_field_selector') +class FrequencySpecifiedFieldSelector(Selector): + """Selector to select samples based on the sorted frequency of specified + field.""" + + def __init__(self, + text_key: str = '', + top_ratio: ClosedUnitInterval = None, + topk: PositiveInt = None, + reverse: bool = True, + *args, + **kwargs): + """ + Initialization method. + + :param text_key: Selector based on the specified value + corresponding to the target key. The target key + corresponding to multi-level field information need to be + separated by '.'. + :param top_ratio: Ratio of selected top specified field value, + samples will be selected if their specified field values are + within this parameter. When both topk and top_ratio are set, + the value corresponding to the smaller number of samples + will be applied. + :param topk: Number of selected top specified field value, + samples will be selected if their specified field values are + within this parameter. When both topk and top_ratio are set, + the value corresponding to the smaller number of samples + will be applied. + :param reverse: Determine the sorting rule, if reverse=True, + then sort in descending order. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.text_key = text_key + self.top_ratio = top_ratio + self.topk = topk + self.reverse = reverse + + def process(self, dataset): + if len(dataset) <= 1 or not self.text_key: + return dataset + + text_keys = self.text_key.split('.') + assert text_keys[0] in dataset.features.keys( + ), "'{}' not in {}".format(text_keys[0], dataset.features.keys()) + + field_value_dict = {} + for i, item in enumerate(dataset[text_keys[0]]): + field_value = item + for key in text_keys[1:]: + assert key in field_value.keys(), "'{}' not in {}".format( + key, field_value.keys()) + field_value = field_value[key] + assert field_value is None or isinstance( + field_value, str) or isinstance( + field_value, numbers.Number + ), 'The {} item is not String, Numbers or NoneType'.format(i) + if field_value not in field_value_dict.keys(): + field_value_dict[field_value] = [i] + else: + field_value_dict[field_value].append(i) + + select_num = 0 + if not self.top_ratio: + if not self.topk: + return dataset + else: + select_num = self.topk + else: + select_num = self.top_ratio * len(field_value_dict) + if self.topk and self.topk < select_num: + select_num = self.topk + + select_index = sum( + sorted(field_value_dict.values(), + key=lambda x: len(x), + reverse=self.reverse)[:int(select_num)], []) + return dataset.select(select_index) diff --git a/data_juicer/ops/selector/topk_specified_field_selector.py b/data_juicer/ops/selector/topk_specified_field_selector.py new file mode 100644 index 000000000..cdcd425a3 --- /dev/null +++ b/data_juicer/ops/selector/topk_specified_field_selector.py @@ -0,0 +1,97 @@ +import heapq +import sys + +from jsonargparse.typing import ClosedUnitInterval, PositiveInt + +from ..base_op import OPERATORS, Selector + + +def to_number(s, reverse=True): + try: + return float(s) + except Exception: + if reverse: + return -sys.maxsize + else: + return sys.maxsize + + +@OPERATORS.register_module('topk_specified_field_selector') +class TopkSpecifiedFieldSelector(Selector): + """Selector to select top samples based on the sorted specified field + value.""" + + def __init__(self, + text_key: str = '', + top_ratio: ClosedUnitInterval = None, + topk: PositiveInt = None, + reverse: bool = True, + *args, + **kwargs): + """ + Initialization method. + + :param text_key: Selector based on the specified value + corresponding to the target key. The target key + corresponding to multi-level field information need to be + separated by '.'. + :param top_ratio: Ratio of selected top samples, samples will be + selected if their specified field values are within this + parameter. When both topk and top_ratio are set, the value + corresponding to the smaller number of samples will be + applied. + :param topk: Number of selected top sample, samples will be + selected if their specified field values are within this + parameter. When both topk and top_ratio are set, the value + corresponding to the smaller number of samples will be + applied. + :param reverse: Determine the sorting rule, if reverse=True, + then sort in descending order. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.text_key = text_key + self.top_ratio = top_ratio + self.topk = topk + self.reverse = reverse + + def process(self, dataset): + if len(dataset) <= 1 or not self.text_key: + return dataset + + select_num = 0 + if not self.top_ratio: + if not self.topk: + return dataset + else: + select_num = self.topk + else: + select_num = self.top_ratio * len(dataset) + if self.topk and self.topk < select_num: + select_num = self.topk + + text_keys = self.text_key.split('.') + assert text_keys[0] in dataset.features.keys( + ), "'{}' not in {}".format(text_keys[0], dataset.features.keys()) + + if len(text_keys) == 1: + field_value_list = dataset[text_keys[0]] + else: + field_value_list = [] + for item in dataset[text_keys[0]]: + field_value = item + for key in text_keys[1:]: + assert key in field_value.keys(), "'{}' not in {}".format( + key, field_value.keys()) + field_value = field_value[key] + field_value_list.append(to_number(field_value, self.reverse)) + + if self.reverse: + select_index = heapq.nlargest(int(select_num), range(len(dataset)), + field_value_list.__getitem__) + else: + select_index = heapq.nsmallest(int(select_num), + range(len(dataset)), + field_value_list.__getitem__) + return dataset.select(select_index) diff --git a/data_juicer/utils/__init__.py b/data_juicer/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/data_juicer/utils/asset_utils.py b/data_juicer/utils/asset_utils.py new file mode 100644 index 000000000..5577c1fe7 --- /dev/null +++ b/data_juicer/utils/asset_utils.py @@ -0,0 +1,58 @@ +import json +import os + +import requests +from loguru import logger + +from .cache_utils import DATA_JUICER_ASSETS_CACHE + +# Default directory to store auxiliary resources +ASSET_DIR = DATA_JUICER_ASSETS_CACHE + +# Default cached assets links for downloading +ASSET_LINKS = { + 'flagged_words': + 'https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/' + 'data_juicer/flagged_words.json', + 'stopwords': + 'https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/' + 'data_juicer/stopwords.json', +} + + +def load_words_asset(words_dir: str, words_type: str): + """ + Load words from a asset file named `words_type`, if not find a valid asset + file, then download it from ASSET_LINKS cached by data_juicer team. + + :param words_dir: directory that stores asset file(s) + :param words_type: name of target words assets + :return: a dict that stores words assets, whose keys are language + names, and the values are lists of words + """ + words_dict = {} + os.makedirs(words_dir, exist_ok=True) + + # try to load words from `words_type` file + for filename in os.listdir(words_dir): + if filename.endswith('.json') and words_type in filename: + with open(os.path.join(words_dir, filename), 'r') as file: + loaded_words = json.load(file) + for key in loaded_words: + if key in words_dict: + words_dict[key] += loaded_words[key] + else: + words_dict[key] = loaded_words[key] + # if the asset file is not found, then download it from ASSET_LINKS + if not bool(words_dict): + logger.info(f'Specified {words_dir} does not contain ' + f'any {words_type} files in json format, now ' + 'download the one cached by data_juicer team') + response = requests.get(ASSET_LINKS[words_type]) + words_dict = response.json() + # cache the asset file locally + cache_path = os.path.join(words_dir, f'{words_type}.json') + with open(cache_path, 'w') as file: + json.dump(words_dict, file) + + return words_dict diff --git a/data_juicer/utils/cache_utils.py b/data_juicer/utils/cache_utils.py new file mode 100644 index 000000000..8ee05a624 --- /dev/null +++ b/data_juicer/utils/cache_utils.py @@ -0,0 +1,21 @@ +import os + +# Default cache location +DEFAULT_CACHE_HOME = '~/.cache' +CACHE_HOME = os.getenv('CACHE_HOME', DEFAULT_CACHE_HOME) + +# Default data_juicer cache location +DEFAULT_DATA_JUICER_CACHE_HOME = os.path.join(CACHE_HOME, 'data_juicer') +DATA_JUICER_CACHE_HOME = os.path.expanduser( + os.getenv('DATA_JUICER_CACHE_HOME', DEFAULT_DATA_JUICER_CACHE_HOME)) + +# Default assets cache location +DEFAULT_DATA_JUICER_ASSETS_CACHE = os.path.join(DATA_JUICER_CACHE_HOME, + 'assets') +DATA_JUICER_ASSETS_CACHE = os.getenv('DATA_JUICER_ASSETS_CACHE', + DEFAULT_DATA_JUICER_ASSETS_CACHE) +# Default models cache location +DEFAULT_DATA_JUICER_MODELS_CACHE = os.path.join(DATA_JUICER_CACHE_HOME, + 'models') +DATA_JUICER_MODELS_CACHE = os.getenv('DATA_JUICER_MODELS_CACHE', + DEFAULT_DATA_JUICER_MODELS_CACHE) diff --git a/data_juicer/utils/ckpt_utils.py b/data_juicer/utils/ckpt_utils.py new file mode 100644 index 000000000..b4e0e636f --- /dev/null +++ b/data_juicer/utils/ckpt_utils.py @@ -0,0 +1,128 @@ +import json +import os + +from datasets import Dataset +from loguru import logger + + +class CheckpointManager: + """ + This class is used to save the latest version of dataset to checkpint + directory or load it from checkpint directory, a bit like cache management + Rerun the same config will reload the checkpoint and skip ops before it. + + If any args of operator in process list is changed, all ops will be + rerun from the beginning. + """ + + def __init__(self, ckpt_dir, original_process_list, num_proc=1): + """ + Initialization method. + + :param ckpt_dir: path to save and load checkpoint + :param original_process_list: process list in config + :param num_proc: number of process workers when saving dataset + """ + self.ckpt_dir = ckpt_dir + self.ckpt_ds_dir = os.path.join(self.ckpt_dir, 'latest') + self.ckpt_op_record = os.path.join(self.ckpt_dir, 'ckpt_op.json') + self.process_list = original_process_list + self.num_proc = num_proc + self.op_record = [] + + self.ckpt_available = self.check_ckpt() + + def get_left_process_list(self): + """ + Get left process list of ops for processing dataset, when checkpoint is + available, remove some ops from process list, otherwise keep it + unchanged. + + :return: process list of left ops + """ + return self.process_list + + def check_ckpt(self): + """ + Check if checkpoint is available. + + :return: True when checkpoint is available, else False + """ + if os.path.exists(self.ckpt_ds_dir) \ + and os.path.isdir(self.ckpt_ds_dir) \ + and os.path.exists(self.ckpt_op_record) \ + and os.path.isfile(self.ckpt_op_record) \ + and self.check_ops_to_skip(): + return True + else: + os.makedirs(self.ckpt_dir, exist_ok=True) + return False + + def record(self, op_name, op_args): + """Save op name and args to op record, which is used to compare with + the process list from config to decide if a checkpoint is available.""" + self.op_record.append({op_name: op_args}) + + def check_ops_to_skip(self): + """ + Check which ops need to be skipped in the process list. + + If op record list from checkpoint are the same as the prefix + part of process list, then skip these ops and start processing + from the checkpoint. Otherwise, process the original dataset + from scratch. + + :return: whether to skip somme ops or not + """ + + # load op records + with open(self.ckpt_op_record, 'r') as fin: + self.op_record = json.load(fin) + + # check whether the op records are exactly the same + # with prefix of process list + # 1. same: remove these ops from process list + # 2. different: cleanup op record, and keep process list unchanged + recorded_op_num = len(self.op_record) + prefix_process = self.process_list[:recorded_op_num] + all_the_same = True + dif1, dif2 = None, None + + for record_op, config_op in zip(self.op_record, prefix_process): + if record_op != config_op: + all_the_same = False + dif1, dif2 = record_op, config_op + break + if all_the_same: + for op in self.op_record: + op_name = list(op.keys())[0] + logger.info(f'Skip op [{op_name}].') + self.process_list = self.process_list[recorded_op_num:] + return True + else: + logger.warning(f'Processed ops of checkpoint are different from ' + f'current configs: checkpoint-{dif1} vs. config-' + f'{dif2}. All ops will be processed from the ' + f'beginning') + self.op_record = [] + return False + + def save_ckpt(self, ds): + """ + Save dataset to checkpoint directory and dump processed ops list. + + :param ds: input dataset to save + """ + ds.save_to_disk(self.ckpt_ds_dir, num_proc=self.num_proc) + + with open(self.ckpt_op_record, 'w') as fout: + json.dump(self.op_record, fout) + + def load_ckpt(self): + """ + Load dataset from a checkpoint file. + + :return: a dataset stored in checkpoint file. + """ + ds = Dataset.load_from_disk(self.ckpt_ds_dir) + return ds diff --git a/data_juicer/utils/file_utils.py b/data_juicer/utils/file_utils.py new file mode 100644 index 000000000..a78572a53 --- /dev/null +++ b/data_juicer/utils/file_utils.py @@ -0,0 +1,69 @@ +from pathlib import Path +from typing import List, Tuple, Union + +from datasets.utils.extract import ZstdExtractor as Extractor + + +def find_files_with_suffix( + path: Union[str, Path], + suffixes: Union[str, List[str], Tuple[str]] = None) -> List[str]: + """ + Traverse a path to find all files with the specified suffixes. + + :param path: path (str/Path): source path + :param suffixes: specified file suffixes, '.txt' or ['.txt', '.md'] + etc + :return: list of all files with the specified suffixes + """ + path = Path(path) + file_dict = {} + + if suffixes is None: + suffixes = [] + + if isinstance(suffixes, str): + suffixes = [suffixes] + + suffixes = [ + x.lower() if x.startswith('.') else '.' + x.lower() for x in suffixes + ] + + if path.is_file(): + files = [path] + else: + searched_files = path.rglob('*') + files = [file for file in searched_files if file.is_file()] + + extractor = Extractor + + # only keep the file with the specified suffixes + for file in files: + suffix = file.suffix.lower() + + if extractor.is_extractable(file): + + # TODO + # hard code + # only support zstd-format file now, + # and use the last 2 sub-suffixes as the final suffix + # just like '.jsonl.zst' + file_suffixes = [suffix.lower() for suffix in file.suffixes] + suffix = ''.join(file_suffixes[-2:]) + + if not suffixes or (suffix in suffixes): + if suffix not in file_dict: + file_dict[suffix] = [str(file)] + else: + file_dict[suffix].append(str(file)) + return file_dict + + +def is_absolute_path(path: Union[str, Path]) -> bool: + """ + Check whether input path is a absolute path. + + :param path: input path + :return: True means input path is absolute path, False means input + path is a relative path. + """ + return Path(path).is_absolute() diff --git a/data_juicer/utils/logger_utils.py b/data_juicer/utils/logger_utils.py new file mode 100644 index 000000000..d930c6a21 --- /dev/null +++ b/data_juicer/utils/logger_utils.py @@ -0,0 +1,132 @@ +# Some codes here are adapted from +# https://github.com/MegEngine/YOLOX/blob/main/yolox/utils/logger.py + +# Copyright 2021 Megvii, Base Detection +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +import sys + +from loguru import logger +from loguru._file_sink import FileSink + +LOGGER_SETUP = False + + +def get_caller_name(depth=0): + """ + Get caller name by depth. + + :param depth: depth of caller context, use 0 for caller depth. + :return: module name of the caller + """ + # the following logic is a little bit faster than inspect.stack() logic + frame = inspect.currentframe().f_back + for _ in range(depth): + frame = frame.f_back + + return frame.f_globals['__name__'] + + +class StreamToLoguru: + """Stream object that redirects writes to a logger instance.""" + + def __init__(self, level='INFO', caller_names=('datasets', 'logging')): + """ + Initialization method. + + :param level: log level string of loguru. Default value: "INFO". + :param caller_names: caller names of redirected module. + Default value: (apex, pycocotools). + """ + self.level = level + self.linebuf = '' + self.caller_names = caller_names + + def write(self, buf): + full_name = get_caller_name(depth=1) + module_name = full_name.rsplit('.', maxsplit=-1)[0] + if module_name in self.caller_names: + for line in buf.rstrip().splitlines(): + # use caller level log + logger.opt(depth=2).log(self.level, line.rstrip()) + else: + # sys.__stdout__.write(buf) + logger.opt(raw=True).info(buf) + + def flush(self): + pass + + +def redirect_sys_output(log_level='INFO'): + """ + Redirect stdout/stderr to loguru with log level. + + :param log_level: log level string of loguru. Default value: "INFO". + """ + redirect_logger = StreamToLoguru(log_level) + sys.stderr = redirect_logger + sys.stdout = redirect_logger + + +def get_log_file_path(): + """ + Get the path to the location of the log file. + + :return: a location of log file. + """ + for _, handler in logger._core.handlers.items(): + if isinstance(handler._sink, FileSink): + return handler._sink._file.name + + +def setup_logger(save_dir, distributed_rank=0, filename='log.txt', mode='o'): + """ + Setup logger for training and testing. + + :param save_dir: location to save log file + :param distributed_rank: device rank when multi-gpu environment + :param filename: log file name to save + :param mode: log file write mode, `append` or `override`. default is `o`. + :return: logger instance. + """ + global LOGGER_SETUP + + if LOGGER_SETUP: + return + + loguru_format = ( + '{time:YYYY-MM-DD HH:mm:ss} | ' + '{level: <8} | ' + '{name}:{line} - {message}') + + logger.remove() + save_file = os.path.join(save_dir, filename) + if mode == 'o' and os.path.exists(save_file): + os.remove(save_file) + + # only keep logger in rank0 process + if distributed_rank == 0: + logger.add( + sys.stderr, + format=loguru_format, + level='INFO', + enqueue=True, + ) + logger.add(save_file) + + # redirect stdout/stderr to loguru + redirect_sys_output('INFO') + LOGGER_SETUP = True diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py new file mode 100644 index 000000000..8460010b1 --- /dev/null +++ b/data_juicer/utils/model_utils.py @@ -0,0 +1,204 @@ +import os + +import wget +from loguru import logger + +from .cache_utils import DATA_JUICER_MODELS_CACHE + +# Default directory to store models +MODEL_PATH = DATA_JUICER_MODELS_CACHE + +# Default backup cached models links for downloading +BACKUP_MODEL_LINKS = { + # language identification model from fasttext + 'lid.176.bin': + 'https://dl.fbaipublicfiles.com/fasttext/supervised-models/', + + # tokenizer and language model for English from sentencepiece and KenLM + '%s.sp.model': + 'https://huggingface.co/edugp/kenlm/resolve/main/wikipedia/', + '%s.arpa.bin': + 'https://huggingface.co/edugp/kenlm/resolve/main/wikipedia/', + + # sentence split model from nltk punkt + 'punkt.%s.pickle': + 'https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/' + 'data_juicer/models/' +} + +# Default cached models links for downloading +MODEL_LINKS = 'https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/' \ + 'data_juicer/models/' + +MODEL_ZOO = {} + + +def check_model(model_name, args=(), force=False): + """ + Check whether a model exists in MODEL_PATH. If exists, return its full path + Else, download it from cached models links. + + :param model_name: a specified model name + :param args: optional extra args of model. + :param force: Whether to download model forcefully or not, Sometimes + the model file maybe incomplete for some reason, so need to + download again forcefully. + """ + if not os.path.exists(MODEL_PATH): + os.makedirs(MODEL_PATH) + + # check if the specified model exists. If it does not exist, download it + true_model_name = model_name % args + mdp = os.path.join(MODEL_PATH, true_model_name) + if force: + if os.path.exists(mdp): + os.remove(mdp) + logger.info( + f'Model [{true_model_name}] invalid, force to downloading...') + else: + logger.info( + f'Model [{true_model_name}] not found . Downloading...') + + try: + model_link = os.path.join(MODEL_LINKS, true_model_name) + wget.download(model_link, mdp, bar=None) + except: # noqa: E722 + try: + backup_model_link = os.path.join( + BACKUP_MODEL_LINKS[model_name], true_model_name) + wget.download(backup_model_link, mdp, bar=None) + except: # noqa: E722 + logger.error( + f'Downloading model [{true_model_name}] error. ' + f'Please retry later or download it into {MODEL_PATH} ' + f'manually from {model_link} or {backup_model_link} ') + exit(1) + return mdp + + +def prepare_fasttext_model(model_name): + """ + Prepare and load a fasttext model. + + :param model_name: input model name + :return: model instance. + """ + import fasttext + logger.info('Loading fasttext language identification model...') + try: + ft_model = fasttext.load_model(check_model(model_name)) + except: # noqa: E722 + ft_model = fasttext.load_model(check_model(model_name, force=True)) + return ft_model + + +def prepare_sentencepiece_model(model_name, lang): + """ + Prepare and load a sentencepiece model. + + :param model_name: input model name in formatting syntax + :param lang: language to render model name + :return: model instance. + """ + import sentencepiece + logger.info('Loading sentencepiece model...') + sentencepiece_model = sentencepiece.SentencePieceProcessor() + try: + sentencepiece_model.load(check_model(model_name, lang)) + except: # noqa: E722 + sentencepiece_model.load(check_model(model_name, lang, force=True)) + return sentencepiece_model + + +def prepare_kenlm_model(model_name, lang): + """ + Prepare and load a kenlm model. + + :param model_name: input model name in formatting syntax. + :param lang: language to render model name + :return: model instance. + """ + import kenlm + logger.info('Loading kenlm language model...') + try: + kenlm_model = kenlm.Model(check_model(model_name, lang)) + except: # noqa: E722 + kenlm_model = kenlm.Model(check_model(model_name, lang, force=True)) + return kenlm_model + + +def prepare_nltk_model(model_name, lang): + """ + Prepare and load a nltk punkt model. + + :param model_name: input model name in formatting syntax + :param lang: language to render model name + :return: model instance. + """ + + nltk_to_punkt = { + 'en': 'english', + 'fr': 'french', + 'pt': 'portuguese', + 'es': 'spanish' + } + assert lang in nltk_to_punkt.keys( + ), 'lang must be one of the following: {}'.format( + list(nltk_to_punkt.keys())) + + from nltk.data import load + logger.info('Loading nltk punkt split model...') + try: + nltk_model = load(check_model(model_name, nltk_to_punkt[lang])) + except: # noqa: E722 + nltk_model = load( + check_model(model_name, nltk_to_punkt[lang], force=True)) + return nltk_model + + +def prepare_huggingface_tokenizer(tokenizer_name): + """ + Prepare and load a tokenizer from HuggingFace. + + :param tokenizer_name: input tokenizer name + :return: a tokenizer instance. + """ + from transformers import AutoTokenizer + logger.info('Loading tokenizer from HuggingFace...') + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + return tokenizer + + +def prepare_model(lang='en', model_type='sentencepiece', model_key=None): + """ + Prepare and load a model or a tokenizer from MODEL_ZOO. + + :param lang: which lang model to load + :param model_type: model or tokenizer type + :param model_key: tokenizer name, only used when prepare HuggingFace + tokenizer + :return: a model or tokenizer instance + """ + + type_to_name = { + 'fasttext': ('lid.176.bin', prepare_fasttext_model), + 'sentencepiece': ('%s.sp.model', prepare_sentencepiece_model), + 'kenlm': ('%s.arpa.bin', prepare_kenlm_model), + 'nltk': ('punkt.%s.pickle', prepare_nltk_model), + 'huggingface': ('%s', prepare_huggingface_tokenizer) + } + assert model_type in type_to_name.keys( + ), 'model_type must be one of the following: {}'.format( + list(type_to_name.keys())) + + if model_key is None: + model_key = model_type + '_' + lang + if model_key not in MODEL_ZOO.keys(): + model_name, model_func = type_to_name[model_type] + if model_type == 'fasttext': + MODEL_ZOO[model_key] = model_func(model_name) + elif model_type == 'huggingface': + MODEL_ZOO[model_key] = model_func(model_key) + else: + MODEL_ZOO[model_key] = model_func(model_name, lang) + return model_key diff --git a/data_juicer/utils/registry.py b/data_juicer/utils/registry.py new file mode 100644 index 000000000..8847ae2d4 --- /dev/null +++ b/data_juicer/utils/registry.py @@ -0,0 +1,133 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -------------------------------------------------------- +# Most of the code here has been modified from: +# https://github.com/modelscope/modelscope/blob/master/modelscope/utils/registry.py +# -------------------------------------------------------- + +from loguru import logger + + +class Registry(object): + """This class is used to register some modules to registry by a repo + name.""" + + def __init__(self, name: str): + """ + Initialization method. + + :param name: a registry repo name + """ + self._name = name + self._modules = {} + + @property + def name(self): + """ + Get name of current registry. + + :return: name of current registry. + """ + return self._name + + @property + def modules(self): + """ + Get all modules in current registry. + + :return: a dict storing modules in current registry. + """ + return self._modules + + def list(self): + """Logging the list of module in current registry.""" + for m in self._modules.keys(): + logger.info(f'{self._name}\t{m}') + + def get(self, module_key): + """ + Get module named module_key from in current registry. If not found, + return None. + + :param module_key: specified module name + :return: module named module_key + """ + return self._modules.get(module_key, None) + + def _register_module(self, module_name=None, module_cls=None, force=False): + """ + Register module to registry. + + :param module_name: module name + :param module_cls: module class object + :param force: Whether to override an existing class with the + same name. Default: False. + """ + + if module_name is None: + module_name = module_cls.__name__ + + if module_name in self._modules and not force: + raise KeyError( + f'{module_name} is already registered in {self._name}') + + self._modules[module_name] = module_cls + + def register_module(self, + module_name: str = None, + module_cls: type = None, + force=False): + """ + Register module class object to registry with the specified modulename. + + :param module_name: module name + :param module_cls: module class object + :param force: Whether to override an existing class with + the same name. Default: False. + + Example: + >>> registry = Registry() + >>> @registry.register_module() + >>> class TextFormatter: + >>> pass + + >>> class TextFormatter2: + >>> pass + >>> registry.register_module( module_name='text_formatter2', + module_cls=TextFormatter2) + """ + if not (module_name is None or isinstance(module_name, str)): + raise TypeError(f'module_name must be either of None, str,' + f'got {type(module_name)}') + if module_cls is not None: + self._register_module(module_name=module_name, + module_cls=module_cls, + force=force) + return module_cls + + # if module_cls is None, should return a decorator function + def _register(module_cls): + """ + Register module class object to registry. + + :param module_cls: module class object + :return: module class object. + """ + self._register_module(module_name=module_name, + module_cls=module_cls, + force=force) + return module_cls + + return _register diff --git a/demos/.DS_Store b/demos/.DS_Store new file mode 100644 index 000000000..4075bd989 Binary files /dev/null and b/demos/.DS_Store differ diff --git a/demos/README.md b/demos/README.md new file mode 100644 index 000000000..e91dd42cf --- /dev/null +++ b/demos/README.md @@ -0,0 +1,40 @@ +# Demos + +This folder contains some demos developed with streamlit to allow users to easily experience the basic functions and tools of Data-Juicer. + +## Usage + +```shell +cd xxx +streamlit run xxx/app.py +``` + +## Categories + +### Data + +This folder contains some demo of datasets. + +### Data visualization diversity + +This demo analyze Verb-Noun structures of SFT dataset, and draw its diversity in sunburst format. + +### Data visualization op effect + +This demo analyze the statistics of dataset, and display every Filter op effect by setting different thresholds. + +### Data visualization statistics +This demo analyze the statistics (up to 13 for now) of dataset. + +### Tool quality classifier +This demo supply 3 text quality classifier, and score dataset + +## Coming Soon +- Overview scan +- Auto evaluation helm +- Data process loop +- Data mixture +- SFT data zh +- Process sci data +- Process code data +- Data process hpo diff --git a/demos/README_ZH.md b/demos/README_ZH.md new file mode 100644 index 000000000..097f3436f --- /dev/null +++ b/demos/README_ZH.md @@ -0,0 +1,42 @@ +# 示例文件 + +此文件夹包含一些示例,帮助用户轻松体验 Data-Juicer各种功能和工具。 + +## 用法 + +```shell +cd xxx +streamlit run xxx/app.py +``` + +## 目录 + +### Data + +该文件夹包含一些样例数据集。 + +### Data visualization diversity + +该示例可以用来分析 SFT 数据集的动词-名词结构, 并绘制成sunburst层级环形图表。 + +### Data visualization op effect + +该示例可以分析数据集的统计信息,并根据这些统计信息可以显示出每个 `Filter` 算子的在不用阈值下的效果。 + +### Data visualization statistics + +该示例可以分析数据集,并获得多达13种统计信息。 + +### Tool quality classifier +该示例提供了3种文本质量打分器, 对数据集进行打分评估。 + +## Coming Soon +- Overview scan | 初体验 +- Auto evaluation helm | 自动HELM评测 +- Data process loop | 数据分析处理迭代 +- Data mixture | 数据混合 +- SFT data zh | 中文指令微调数据处理 +- Process sci data | 科学文献数据处理 +- Process code data | 代码数据处理 +- Data process hpo | 数据混合超参自动优化 + diff --git a/demos/data/demo-dataset-content.jsonl b/demos/data/demo-dataset-content.jsonl new file mode 100644 index 000000000..07871df3a --- /dev/null +++ b/demos/data/demo-dataset-content.jsonl @@ -0,0 +1,6 @@ +{"content": "Today is Sunday and it's a happy day!", "src": "Arxiv", "date": "2023-04-27", "version": "1.0"} +{"content": "Do you need a cup of coffee?", "src": "code", "author": "xxx"} +{"content": "你好,请问你是谁", "src": "customized", "author": "xxx"} +{"content": "Sur la plateforme MT4, plusieurs manières d'accéder à ces fonctionnalités sont conçues simultanément.", "src": "Oscar", "version": "2.0"} +{"content": "欢迎来到阿里巴巴!", "src": "customized", "version": "0.1", "author": "xxx"} +{"content": "This paper proposed a novel method on LLM pretraining.", "src": "customized", "author": "xxx"} diff --git a/demos/data/demo-dataset-deduplication.jsonl b/demos/data/demo-dataset-deduplication.jsonl new file mode 100644 index 000000000..d2590cec3 --- /dev/null +++ b/demos/data/demo-dataset-deduplication.jsonl @@ -0,0 +1,14 @@ +{"text": "Today is Sunday and it's a happy day!", "meta": {"src": "Arxiv", "date": "2023-04-27", "version": "1.0"}} +{"text": "Do you need a cup of coffee?", "meta": {"src": "code", "author": "xxx"}} +{"text": "你好,请问你是谁", "meta": {"src": "customized", "author": "xxx"}} +{"text": "Sur la plateforme MT4, plusieurs manières d'accéder à ces fonctionnalités sont conçues simultanément.", "meta": {"src": "Oscar", "version": "2.0"}} +{"text": "欢迎来到阿里巴巴!", "meta": {"src": "customized", "version": "0.1", "author": "xxx"}} +{"text": "Today is sunday and it's really a happy day!", "meta": {"src": "Arxiv", "date": "2023-05-15", "version": "1.1"}} +{"text": "This paper proposed a novel method on LLM pretraining.", "meta": {"src": "customized", "author": "xxx"}} +{"text": "Smithfield employs 3,700 people at its plant in Sioux Falls, South Dakota. The plant slaughters 19,500 pigs a day — 5 percent of U.S. pork. Most of the workers are immigrants from Ethiopia, Mexico, South Sudan, Honduras, Myanmar, Somalia, Guatemala, and other poor countries.\\n\\nInevitably workers must pass within one foot of hundreds of colleagues in the hallways, locker rooms, cafeterias, and cutting lines. The same conditions have spurred Covid-19 outbreaks at meat plants from Minnesota and Wisconsin to Colorado, Nebraska, Missouri, Iowa, Pennsylvania, North Carolina, and Georgia.\\n\\n801 workers at the Sioux Falls plant have tested positive, together with 206 people close to them. The outbreak has killed Agustín Rodríguez Martínez, aged 64, an employee with two decades of experience originally from El Salvador, and Craig Allen Franken, 61, who worked for Smithfield his entire adult life.\\n\\nThe company knew of its first infection on March 24 or earlier. The virus spread exponentially for several weeks. Ahead of Easter Sunday and Monday (April 12-13), Smithfield promised to “completely shutter” to sanitize and put up cardboard and plastic sheet dividers. This would not end transmission, as potentially hundreds of staff were already carrying the virus. But even during this “shutdown,” many cars were seen in the parking lot. The mayor admits that the company lied, and the local AFL-CIO alleges the plant ran 60 percent production. On Easter, with 238 known infections, Smithfield finally agreed to shut down indefinitely after a request from the mayor and the governor. Yet the company insisted on waiting three more days to actually halt production.\\n\\nSmithfield denied contributing to the outbreak, saying it took a “very proactive approach.” Relying on racism, the company blamed workers for getting themselves sick. A spokesperson said the outbreak was so severe because of the plant’s “large immigrant population,” claming “Living circumstances in certain cultures are different than they are with your traditional American family.” They slandered the workers as dirty, ignorant, and untrustworthy with help from governor Kristi Noem, who claimed, “99 percent of what’s going on today wasn’t happening inside the facility. It was more at home, where these employees were going home and spreading some of the virus” by living too close together.\\n\\nOne sick worker, Michael Bul Gayo Gatluak, 22 and originally from South Sudan, says, “With how we work on the line, I would say I got sick because of them not taking safety measures.” His job is “really, really close” to other workers chopping fresh-killed pigs. “The job is so heavy. You have to breathe so hard.”\\n\\nIn early March, union officials requested masks, overcoats, entrance checking for fevers, and less crowding in 500-capacity cafeterias. But Smithfield waited on most safety measures until early April. Only April 6 did they start checking for fevers. Instead of protective masks, they gave out beard nets.\\n\\nSmithfield concealed infections with a policy of informing only employees whose work stations were in the same area as a person who tested positive. The fact that workers are required to move around was willfully ignored. One worker who tested positive said, “I clearly would have gotten it at the factory. This week I have worked on three different floors. I’ve eaten in two different cafeterias … I’ve been walking through the whole place.” Employees from the eighth floor of the plant were quarantined, but everyone else was told to keep working.\\n\\nWhat Is Really Going On?\\n\\nAverage plant wages are around $16 an hour. Smithfield never raised them. Instead, they offered $500 to employees who could go all of April without an unapproved day off. The company says their “Responsibility Bonuses” show their “immense gratefulness” to employees “for their selfless sacrifices.”\\n\\nMeanwhile, the local Argus Leader wrote union members wanted essential-worker hazard pay, which “would be considered hourly compensation about 1.5 or two times their normal pay.” One worker said, “I feel like they’re bribing us with [the bonus] to come to work sick. That’s how you know they don’t care.”\\n\\nBoth Sioux Falls workers killed by Covid-19 were in their sixties. It is unconscionable that they were still working. All meatpackers over 50 should be on paid leave. Agustín Rodríguez, 64, had a rough job sawing the legs off dead pigs. He mopped floors with a fever shortly before he was hospitalized.\\n\\nWhen CEO Kenneth Sullivan closed the plant, he claimed, “We have continued to run our facilities for one reason: to sustain our nation’s food supply.” This is an effort to sweep Smithfield’s abuses under the rug, as if the company were operating for public benefit. This patriotic propaganda that all Americans are in it together is like a drug to keep workers from getting organized.\\n\\nThe major union in the industry, including at Smithfield, is the United Food and Commercial Workers union (UFCW). What union leaders have done is ultimately troubling.\\n\\nCan Workers Fight?\\n\\nLocal AFL-CIO president Kooper Caraway has publicly said management delayed safety action as long as possible for profit. But while some workers were demanding a two-week shutdown, Caraway told the Argus Leader that was unrealistic because the government considers the plant essential. He suggested the union would be happy with minimal safety measures: “Even if 10 people get exposed in a day rather than 11. If you can implement a program where even one or two less people get exposed during a shift, that’s one or two less people.” Of course reducing infections is good, but suggesting workers would be satisfied if the company allowed 90% of the contagion to continue is horrifying.\\n\\nThe response of UFCW leadership was worse. As the disease was exploding, they told the Argus Leader, “We applaud [Smithfield’s] decision to temporarily close the plant [over Easter weekend] to push for an even safer work environment.” What does “even safer” mean in this context?\\n\\nThe union bureaucracy has taken weak action elsewhere. In Pennsylvania, the UFCW negotiated $2 hazard pay for two months with Cargill Meat — the same pandemic premium Amazon gave workers without a union. In Nebraska, the UFCW negotiated $4 hazard pay for one month with meat giant JBS.\\n\\nThe union has said nothing about forcing companies to send older workers home with pay, even though a 70-year-old shop steward and a 78-year-old grandfather working at JBS plants were killed by Covid-19. Smithfield workers were promised only two weeks of shutdown pay. For many, this compensation is half their normal paycheck because they routinely put in 66 hour weeks — overtime that costs exhaustion and chronic pain.\\n\\nUnion officials endeavor to cooperate with the meat companies. An Iowa UFCW president actually suggested it might be impossible for plants to move workers a full six feet apart and told the Des Moines Register, “We can’t stop the plants. If we stop the plants from running, we stop feeding the country. We want to do everything we can to make sure the employees are safe to keep the plant running.”\\n\\nEvery part of this explanation directly overlaps with what the Smithfield CEO said. Unfortunately, it amounts to accepting the company’s excuses.\\n\\nThey claim that workers who do hard physical labor, waking up at 4 a.m. and often working six days a week for years, would be guilty of taking food away from the people and hurting America if they dared to fight for their human needs. But nothing is said about the company raking in profits and even murdering workers to increase them.\\n\\nSmithfield’s parent company W.H. Group, which slaughters around 30 million pigs per year in plants in both the United States and China, saw its profits skyrocket by about one third in 2019 to $1.38 billion. It is disturbing that UFCW officials do not bring up these soaring profits in their response to the outbreaks. Reuters published a report on the corporation’s financial success in late March. The head of W.H. Group had touted to the media that it got through the pandemic in China with very limited impact on production.\\n\\nIt is true that many Smithfield workers are reasonably afraid for their jobs and want to keep working. A 25-year-old employee explained, “I have a lot of bills. My baby’s coming soon — I have to work.” At the same time, he was afraid of infecting his pregnant wife. His spouse, a former employee, said bitterly, “Smithfield— they don’t care about employees. They only care about their money.”\\n\\nWorkers are pressured in these two painful directions. Nonetheless, work can mean solidarity. Before Smithfield even checked temperatures, there was a “sick-out” strike without union support by 800 to 1,000 workers at a JBS meat factory in Colorado. Hundreds of workers also called in sick days at a Nebraska JBS plant.\\n\\nTrade union leaders won’t even whisper the word “strike” when thousands of workers are thinking about it. They are limiting themselves to polite requests. We need a workers’ movement that asks who controls the factory, that threatens to disrupt the bosses’ profits, and that allows workers to use their immense power — this could change the meat industry and the world.", "meta": {"src": "mine", "author": "xx"}} +{"text": "Smithfield employs 3,700 people at its plants in Sioux Falls, South Dakota. The plant slaughters 19,500 pig a day — 5 percent of U.S. pork. Most of the workers are immigrants from Ethiopia, Mexico, South Sudan, Honduras, Myanmar, Somalia, Guatemala, and other poor countries.\\n\\nInevitably workers must pass within one foot of hundreds of colleagues in the hallways, locker rooms, cafeterias, and cutting lines. The same conditions have spurred Covid-19 outbreaks at meat plants from Minnesota and Wisconsin to Colorado, Nebraska, Missouri, Iowa, Pennsylvania, North Carolina, and Georgia.\\n\\n801 workers at the Sioux Falls plant have tested positive, together with 206 people close to them. The outbreak has killed Agustín Rodríguez Martínez, aged 64, an employee with two decades of experience originally from El Salvador, and Craig Allen Franken, 61, who worked for Smithfield his entire adult life.\\n\\nThe company knew of its first infection on March 24 or earlier. The virus spread exponentially for several weeks. Ahead of Easter Sunday and Monday (April 12-13), Smithfield promised to “completely shutter” to sanitize and put up cardboard and plastic sheet dividers. This would not end transmission, as potentially hundreds of staff were already carrying the virus. But even during this “shutdown,” many cars were seen in the parking lot. The mayor admits that the company lied, and the local AFL-CIO alleges the plant ran 60 percent production. On Easter, with 238 known infections, Smithfield finally agreed to shut down indefinitely after a request from the mayor and the governor. Yet the company insisted on waiting three more days to actually halt production.\\n\\nSmithfield denied contributing to the outbreak, saying it took a “very proactive approach.” Relying on racism, the company blamed workers for getting themselves sick. A spokesperson said the outbreak was so severe because of the plant’s “large immigrant population,” claming “Living circumstances in certain cultures are different than they are with your traditional American family.” They slandered the workers as dirty, ignorant, and untrustworthy with help from governor Kristi Noem, who claimed, “99 percent of what’s going on today wasn’t happening inside the facility. It was more at home, where these employees were going home and spreading some of the virus” by living too close together.\\n\\nOne sick worker, Michael Bul Gayo Gatluak, 22 and originally from South Sudan, says, “With how we work on the line, I would say I got sick because of them not taking safety measures.” His job is “really, really close” to other workers chopping fresh-killed pigs. “The job is so heavy. You have to breathe so hard.”\\n\\nIn early March, union officials requested masks, overcoats, entrance checking for fevers, and less crowding in 500-capacity cafeterias. But Smithfield waited on most safety measures until early April. Only April 6 did they start checking for fevers. Instead of protective masks, they gave out beard nets.\\n\\nSmithfield concealed infections with a policy of informing only employees whose work stations were in the same area as a person who tested positive. The fact that workers are required to move around was willfully ignored. One worker who tested positive said, “I clearly would have gotten it at the factory. This week I have worked on three different floors. I’ve eaten in two different cafeterias … I’ve been walking through the whole place.” Employees from the eighth floor of the plant were quarantined, but everyone else was told to keep working.\\n\\nWhat Is Really Going On?\\n\\nAverage plant wages are around $16 an hour. Smithfield never raised them. Instead, they offered $500 to employees who could go all of April without an unapproved day off. The company says their “Responsibility Bonuses” show their “immense gratefulness” to employees “for their selfless sacrifices.”\\n\\nMeanwhile, the local Argus Leader wrote union members wanted essential-worker hazard pay, which “would be considered hourly compensation about 1.5 or two times their normal pay.” One worker said, “I feel like they’re bribing us with [the bonus] to come to work sick. That’s how you know they don’t care.”\\n\\nBoth Sioux Falls workers killed by Covid-19 were in their sixties. It is unconscionable that they were still working. All meatpackers over 50 should be on paid leave. Agustín Rodríguez, 64, had a rough job sawing the legs off dead pigs. He mopped floors with a fever shortly before he was hospitalized.\\n\\nWhen CEO Kenneth Sullivan closed the plant, he claimed, “We have continued to run our facilities for one reason: to sustain our nation’s food supply.” This is an effort to sweep Smithfield’s abuses under the rug, as if the company were operating for public benefit. This patriotic propaganda that all Americans are in it together is like a drug to keep workers from getting organized.\\n\\nThe major union in the industry, including at Smithfield, is the United Food and Commercial Workers union (UFCW). What union leaders have done is ultimately troubling.\\n\\nCan Workers Fight?\\n\\nLocal AFL-CIO president Kooper Caraway has publicly said management delayed safety action as long as possible for profit. But while some workers were demanding a two-week shutdown, Caraway told the Argus Leader that was unrealistic because the government considers the plant essential. He suggested the union would be happy with minimal safety measures: “Even if 10 people get exposed in a day rather than 11. If you can implement a program where even one or two less people get exposed during a shift, that’s one or two less people.” Of course reducing infections is good, but suggesting workers would be satisfied if the company allowed 90% of the contagion to continue is horrifying.\\n\\nThe response of UFCW leadership was worse. As the disease was exploding, they told the Argus Leader, “We applaud [Smithfield’s] decision to temporarily close the plant [over Easter weekend] to push for an even safer work environment.” What does “even safer” mean in this context?\\n\\nThe union bureaucracy has taken weak action elsewhere. In Pennsylvania, the UFCW negotiated $2 hazard pay for two months with Cargill Meat — the same pandemic premium Amazon gave workers without a union. In Nebraska, the UFCW negotiated $4 hazard pay for one month with meat giant JBS.\\n\\nThe union has said nothing about forcing companies to send older workers home with pay, even though a 70-year-old shop steward and a 78-year-old grandfather working at JBS plants were killed by Covid-19. Smithfield workers were promised only two weeks of shutdown pay. For many, this compensation is half their normal paycheck because they routinely put in 66 hour weeks — overtime that costs exhaustion and chronic pain.\\n\\nUnion officials endeavor to cooperate with the meat companies. An Iowa UFCW president actually suggested it might be impossible for plants to move workers a full six feet apart and told the Des Moines Register, “We can’t stop the plants. If we stop the plants from running, we stop feeding the country. We want to do everything we can to make sure the employees are safe to keep the plant running.”\\n\\nEvery part of this explanation directly overlaps with what the Smithfield CEO said. Unfortunately, it amounts to accepting the company’s excuses.\\n\\nThey claim that workers who do hard physical labor, waking up at 4 a.m. and often working six days a week for years, would be guilty of taking food away from the people and hurting America if they dared to fight for their human needs. But nothing is said about the company raking in profits and even murdering workers to increase them.\\n\\nSmithfield’s parent company W.H. Group, which slaughters around 30 million pigs per year in plants in both the United States and China, saw its profits skyrocket by about one third in 2019 to $1.38 billion. It is disturbing that UFCW officials do not bring up these soaring profits in their response to the outbreaks. Reuters published a report on the corporation’s financial success in late March. The head of W.H. Group had touted to the media that it got through the pandemic in China with very limited impact on production.\\n\\nIt is true that many Smithfield workers are reasonably afraid for their jobs and want to keep working. A 25-year-old employee explained, “I have a lot of bills. My baby’s coming soon — I have to work.” At the same time, he was afraid of infecting his pregnant wife. His spouse, a former employee, said bitterly, “Smithfield— they don’t care about employees. They only care about their money.”\\n\\nWorkers are pressured in these two painful directions. Nonetheless, work can mean solidarity. Before Smithfield even checked temperatures, there was a “sick-out” strike without union support by 800 to 1,000 workers at a JBS meat factory in Colorado. Hundreds of workers also called in sick days at a Nebraska JBS plant.\\n\\nTrade union leaders won’t even whisper the word “strike” when thousands of workers are thinking about it. They are limiting themselves to polite requests. We need a workers’ movement that asks who controls the factory, that threatens to disrupt the bosses’ profits, and that allows workers to use their immense power — this could change the meat industry and the world.", "meta": {"src": "customized", "author": "x"}} +{"text": "Smithfield employs 3,700 people at its plant in Sioux Falls, South Dakota. The plant slaughters 19,500 pigs a day — 5 percent of U.S. pork. Most of the workers are immigrants from Ethiopia, Mexico, South Sudan, Honduras, Myanmar, Somalia, Guatemala, and other poor countries.\\n\\nInevitably workers must pass within one foot of hundreds of colleagues in the hallways, locker rooms, cafeterias, and cutting lines. The same conditions have spurred Covid-19 outbreaks at meat plants from Minnesota and Wisconsin to Colorado, Nebraska, Missouri, Iowa, Pennsylvania, North Carolina, and Georgia.\\n\\n801 workers at the Sioux Falls plant have tested positive, together with 206 people close to them. The outbreak has killed Agustín Rodríguez Martínez, aged 64, an employee with two decades of experience originally from El Salvador, and Craig Allen Franken, 61, who worked for Smithfield his entire adult life.\\n\\nThe company knew of its first infection on March 24 or earlier. The virus spread exponentially for several weeks. Ahead of Easter Sunday and Monday (April 12-13), Smithfield promised to “completely shutter” to sanitize and put up cardboard and plastic sheet dividers. This would not end transmission, as potentially hundreds of staff were already carrying the virus. But even during this “shutdown,” many cars were seen in the parking lot. The mayor admits that the company lied, and the local AFL-CIO alleges the plant ran 60 percent production. On Easter, with 238 known infections, Smithfield finally agreed to shut down indefinitely after a request from the mayor and the governor. Yet the company insisted on waiting three more days to actually halt production.\\n\\nSmithfield denied contributing to the outbreak, saying it took a “very proactive approach.” Relying on racism, the company blamed workers for getting themselves sick. A spokesperson said the outbreak was so severe because of the plant’s “large immigrant population,” claming “Living circumstances in certain cultures are different than they are with your traditional American family.” They slandered the workers as dirty, ignorant, and untrustworthy with help from governor Kristi Noem, who claimed, “99 percent of what’s going on today wasn’t happening inside the facility. It was more at home, where these employees were going home and spreading some of the virus” by living too close together.\\n\\nOne sick worker, Michael Bul Gayo Gatluak, 22 and originally from South Sudan, says, “With how we work on the line, I would say I got sick because of them not taking safety measures.” His job is “really, really close” to other workers chopping fresh-killed pigs. “The job is so heavy. You have to breathe so hard.”\\n\\nIn early March, union officials requested masks, overcoats, entrance checking for fevers, and less crowding in 500-capacity cafeterias. But Smithfield waited on most safety measures until early April. Only April 6 did they start checking for fevers. Instead of protective masks, they gave out beard nets.\\n\\nSmithfield concealed infections with a policy of informing only employees whose work stations were in the same area as a person who tested positive. The fact that workers are required to move around was willfully ignored. One worker who tested positive said, “I clearly would have gotten it at the factory. This week I have worked on three different floors. I’ve eaten in two different cafeterias … I’ve been walking through the whole place.” Employees from the eighth floor of the plant were quarantined, but everyone else was told to keep working.\\n\\nWhat Is Really Going On?\\n\\nAverage plant wages are around $16 an hour. Smithfield never raised them. Instead, they offered $500 to employees who could go all of April without an unapproved day off. The company says their “Responsibility Bonuses” show their “immense gratefulness” to employees “for their selfless sacrifices.”\\n\\nMeanwhile, the local Argus Leader wrote union members wanted essential-worker hazard pay, which “would be considered hourly compensation about 1.5 or two times their normal pay.” One worker said, “I feel like they’re bribing us with [the bonus] to come to work sick. That’s how you know they don’t care.”\\n\\nBoth Sioux Falls workers killed by Covid-19 were in their sixties. It is unconscionable that they were still working. All meatpackers over 50 should be on paid leave. Agustín Rodríguez, 64, had a rough job sawing the legs off dead pigs. He mopped floors with a fever shortly before he was hospitalized.\\n\\nWhen CEO Kenneth Sullivan closed the plant, he claimed, “We have continued to run our facilities for one reason: to sustain our nation’s food supply.” This is an effort to sweep Smithfield’s abuses under the rug, as if the company were operating for public benefit. This patriotic propaganda that all Americans are in it together is like a drug to keep workers from getting organized.", "meta": {"src": "mine", "author": "xx"}} +{"text": "Smithfield employs 3,700 people at its plants in Sioux Falls, South Dakota. The plant slaughters 19,500 pig a day — 5 percent of U.S. pork. Most of the workers are immigrants from Ethiopia, Mexico, South Sudan, Honduras, Myanmar, Somalia, Guatemala, and other poor countries.\\n\\nInevitably workers must pass within one foot of hundreds of colleagues in the hallways, locker rooms, cafeterias, and cutting lines. The same conditions have spurred Covid-19 outbreaks at meat plants from Minnesota and Wisconsin to Colorado, Nebraska, Missouri, Iowa, Pennsylvania, North Carolina, and Georgia.\\n\\n801 workers at the Sioux Falls plant have tested positive, together with 206 people close to them. The outbreak has killed Agustín Rodríguez Martínez, aged 64, an employee with two decades of experience originally from El Salvador, and Craig Allen Franken, 61, who worked for Smithfield his entire adult life.\\n\\nThe company knew of its first infection on March 24 or earlier. The virus spread exponentially for several weeks. Ahead of Easter Sunday and Monday (April 12-13), Smithfield promised to “completely shutter” to sanitize and put up cardboard and plastic sheet dividers. This would not end transmission, as potentially hundreds of staff were already carrying the virus. But even during this “shutdown,” many cars were seen in the parking lot. The mayor admits that the company lied, and the local AFL-CIO alleges the plant ran 60 percent production. On Easter, with 238 known infections, Smithfield finally agreed to shut down indefinitely after a request from the mayor and the governor. Yet the company insisted on waiting three more days to actually halt production.\\n\\nSmithfield denied contributing to the outbreak, saying it took a “very proactive approach.” Relying on racism, the company blamed workers for getting themselves sick. A spokesperson said the outbreak was so severe because of the plant’s “large immigrant population,” claming “Living circumstances in certain cultures are different than they are with your traditional American family.” They slandered the workers as dirty, ignorant, and untrustworthy with help from governor Kristi Noem, who claimed, “99 percent of what’s going on today wasn’t happening inside the facility. It was more at home, where these employees were going home and spreading some of the virus” by living too close together.\\n\\nOne sick worker, Michael Bul Gayo Gatluak, 22 and originally from South Sudan, says, “With how we work on the line, I would say I got sick because of them not taking safety measures.” His job is “really, really close” to other workers chopping fresh-killed pigs. “The job is so heavy. You have to breathe so hard.”\\n\\nIn early March, union officials requested masks, overcoats, entrance checking for fevers, and less crowding in 500-capacity cafeterias. But Smithfield waited on most safety measures until early April. Only April 6 did they start checking for fevers. Instead of protective masks, they gave out beard nets.\\n\\nSmithfield concealed infections with a policy of informing only employees whose work stations were in the same area as a person who tested positive. The fact that workers are required to move around was willfully ignored. One worker who tested positive said, “I clearly would have gotten it at the factory. This week I have worked on three different floors. I’ve eaten in two different cafeterias … I’ve been walking through the whole place.” Employees from the eighth floor of the plant were quarantined, but everyone else was told to keep working.\\n\\nWhat Is Really Going On?\\n\\nAverage plant wages are around $16 an hour. Smithfield never raised them. Instead, they offered $500 to employees who could go all of April without an unapproved day off. The company says their “Responsibility Bonuses” show their “immense gratefulness” to employees “for their selfless sacrifices.”\\n\\nMeanwhile, the local Argus Leader wrote union members wanted essential-worker hazard pay, which “would be considered hourly compensation about 1.5 or two times their normal pay.” One worker said, “I feel like they’re bribing us with [the bonus] to come to work sick. That’s how you know they don’t care.”\\n\\nBoth Sioux Falls workers killed by Covid-19 were in their sixties. It is unconscionable that they were still working. All meatpackers over 50 should be on paid leave. Agustín Rodríguez, 64, had a rough job sawing the legs off dead pigs. He mopped floors with a fever shortly before he was hospitalized.\\n\\nWhen CEO Kenneth Sullivan closed the plant, he claimed, “We have continued to run our facilities for one reason: to sustain our nation’s food supply.” This is an effort to sweep Smithfield’s abuses under the rug, as if the company were operating for public benefit. This patriotic propaganda that all Americans are in it together is like a drug to keep workers from getting organized.", "meta": {"src": "customized", "author": "x"}} +{"text": "This paper proposed a novel method on LLM pretraining.", "meta": {"src": "customized", "author": "xxx"}} +{"text": "第九届会议\\n2003年7月28日至8月8日\\n牙买加金斯敦\\n为来自发展中国家的法律和技术委员会以及财务委员会成员\\n参加委员会会议支付费用的方式\\n1. 国际海底管理局大会第八届会议请秘书长采取一项临时措施,设立一个自愿信托基金,以便支付来自发展中国家的法律和技术委员会成员以及来自发展中国家的财务委员会成员参加委员会会议的费用。\\n2. 由于秘书长向会员国发出为该信托基金捐款的请求,已收到三笔捐款,共计10 500美元。 管理局已为基金设立一个单独的账户。\\n3. 管理局第八届会议还决定,由财务委员会审查资助参加这两个委员会会议的方式,包括审查是否可能从管理局行政预算中提供经费。\\n4. 自愿信托基金迄今收到的捐款数额很小。 这两个委员会成员虽然由缔约国提名,但他们以个人身份当选。 因此,必须确保这些机构的成员在任期内能够参加会议并且持续不断地履行职务。 现已注意到,这两个委员会若干成员因旅费和生活津贴费用方面有困难而未能出席会议。 来自发展中国家成员参加会议的费用估计数见附件,其中比较了经济舱和公务舱机票价格以及适用于金斯敦的每日生活津贴费用。 从表中可以看出,根据不同的人数、机舱等级和会议持续时间,每年平均需要捐款120 000美元至215 000美元。\\n5. 为了指导委员会确定提供经费的方式,对某些国际组织的现行办法作了一次简要调查。 为支付参加会议的旅费和生活费而设立信托基金最相关的实例是2000年大会为来自发展中国家的大陆架界限委员会成员设立的自愿信托基金。 目前这一基金正在运作,但现有资源有限。 联合国制定的程序表明,委员会成员的政府应在规定时间内尽可能提前提出请求。 这种请求按照先到先核可的办法处理。 提供的机票将是最直接路线的经济舱机票,每日生活津贴将按照联合国费率提供。 购买机票的所有安排均由联合国秘书处执行。\\n6. 虽然已经设立了临时性的自愿信托基金,但是,对该基金的捐款数额很小,捐款速度很慢。 因此,除了对信托基金提供自愿捐款的办法之外,建议委员会还可以考虑采用下列办法:\\n(a) 从管理局一般行政经费累计利息中拨出一定数额的经费;\\n(b) 每年从上一年预算未动用部分中拨出规定的数额;\\n(c) 从先驱投资者基金利息中拨出规定的数额。\\n7. 委员会还不妨建议由管理局秘书处依照行政规则和程序管理该基金,并向财务委员会提出一份报告。\\n附件\\n资助来自发展中国家的法律和技术委员会以及财务\\n委员会成员出席会议的指示性费用(美元)\\n成员\\n机票\\n机场\\n费用\\n金斯敦每日生活\\n津贴\\n转机途中每日生活\\n7日\\n共计\\n14日\\n经济舱\\n公务舱\\n7天=(8天每日生活\\n津贴)\\n14天= (15天每日生活津贴)\\n商务舱\\n法律和技术委员会\\n印度尼西亚\\n(纽约)\\n黎巴嫩\\n巴基斯坦\\n阿根廷\\n喀麦隆\\n墨西哥\\n巴西\\n塞内加尔\\n莫桑比克\\n埃及(纽约)\\n大韩民国\\n印度\\n斐济\\n智利\\n中国\\n纳米比亚\\n小计\\n财务委员会\\n缅甸\\n乌干达\\n牙买加\\n印度(纽约)\\n尼日利亚\\n总计\\n注:估计费用表表明每年资助每个机构一次会议需要经费120 000美元至215 000美元(四舍五入)。", "meta": {"src": "wiki", "version": "0.1", "author": "xyz"}} +{"text": "第九届会议\\n时间:2003年7月28日至8月8日\\n牙买加金斯敦\\n为来自发展中国家的法律和技术委员会以及财务委员会成员\\n参加委员会会议支付费用的方式\\n1. 国际海底管理局大会第八届会议请秘书长采取一项临时措施,设立一个自愿信托基金,以便支付来自发展中国家的法律和技术委员会成员以及来自发展中国家的财务委员会成员参加委员会会议的费用。\\n2. 由于秘书长向会员国发出为该信托基金捐款的请求,已收到三笔捐款,共计10 500美元。 管理局已为基金设立一个单独的账户。\\n3. 管理局第八届会议还决定,由财务委员会审查资助参加这两个委员会会议的方式,包括审查是否可能从管理局行政预算中提供经费。\\n4. 自愿信托基金迄今收到的捐款数额很小。 这两个委员会成员虽然由缔约国提名,但他们以个人身份当选。 因此,必须确保这些机构的成员在任期内能够参加会议并且持续不断地履行职务。 现已注意到,这两个委员会若干成员因旅费和生活津贴费用方面有困难而未能出席会议。 来自发展中国家成员参加会议的费用估计数见附件,其中比较了经济舱和公务舱机票价格以及适用于金斯敦的每日生活津贴费用。 从表中可以看出,根据不同的人数、机舱等级和会议持续时间,每年平均需要捐款120 000美元至215 000美元。\\n5. 为了指导委员会确定提供经费的方式,对某些国际组织的现行办法作了一次简要调查。 为支付参加会议的旅费和生活费而设立信托基金最相关的实例是2000年大会为来自发展中国家的大陆架界限委员会成员设立的自愿信托基金。 目前这一基金正在运作,但现有资源有限。 联合国制定的程序表明,委员会成员的政府应在规定时间内尽可能提前提出请求。 这种请求按照先到先核可的办法处理。 提供的机票将是最直接路线的经济舱机票,每日生活津贴将按照联合国费率提供。 购买机票的所有安排均由联合国秘书处执行。\\n6. 虽然已经设立了临时性的自愿信托基金,但是,对该基金的捐款数额很小,捐款速度很慢。 因此,除了对信托基金提供自愿捐款的办法之外,建议委员会还可以考虑采用下列办法:\\n(a) 从管理局一般行政经费累计利息中拨出一定数额的经费;\\n(b) 每年从上一年预算未动用部分中拨出规定的数额;\\n(c) 从先驱投资者基金利息中拨出规定的数额。\\n7. 委员会还不妨建议由管理局秘书处依照行政规则和程序管理该基金,并向财务委员会提出一份报告。\\n附件\\n资助来自发展中国家的法律和技术委员会以及财务\\n委员会成员出席会议的指示性费用(美元)\\n成员\\n机票\\n机场\\n费用\\n金斯敦每日生活\\n津贴\\n转机途中每日生活\\n7日\\n共计\\n14日\\n经济舱\\n公务舱\\n7天=(8天每日生活\\n津贴)\\n14天= (15天每日生活津贴)\\n商务舱\\n法律和技术委员会\\n印度尼西亚\\n(纽约)\\n黎巴嫩\\n巴基斯坦\\n阿根廷\\n喀麦隆\\n墨西哥\\n巴西\\n塞内加尔\\n莫桑比克\\n埃及(纽约)\\n大韩民国\\n印度\\n斐济\\n智利\\n中国\\n纳米比亚\\n小计\\n财务委员会\\n缅甸\\n乌干达\\n牙买加\\n印度(纽约)\\n尼日利亚\\n总计\\n注:估计费用表表明每年资助每个机构一次会议需要经费120 000美元至215 000美元(四舍五入)。", "meta": {"src": "wiki", "version": "0.1", "author": "xy"}} diff --git a/demos/data/demo-dataset.jsonl b/demos/data/demo-dataset.jsonl new file mode 100644 index 000000000..4d8cdadfd --- /dev/null +++ b/demos/data/demo-dataset.jsonl @@ -0,0 +1,6 @@ +{"text": "Today is Sunday and it's a happy day!", "meta": {"src": "Arxiv", "date": "2023-04-27", "version": "1.0"}} +{"text": "Do you need a cup of coffee?", "meta": {"src": "code", "author": "xxx"}} +{"text": "你好,请问你是谁", "meta": {"src": "customized", "author": "xxx"}} +{"text": "Sur la plateforme MT4, plusieurs manières d'accéder à ces fonctionnalités sont conçues simultanément.", "meta": {"src": "Oscar", "version": "2.0"}} +{"text": "欢迎来到阿里巴巴!", "meta": {"src": "customized", "version": "0.1", "author": "xxx"}} +{"text": "This paper proposed a novel method on LLM pretraining.", "meta": {"src": "customized", "author": "xxx"}} diff --git a/demos/data_visualization_diversity/.DS_Store b/demos/data_visualization_diversity/.DS_Store new file mode 100644 index 000000000..e4ae7825a Binary files /dev/null and b/demos/data_visualization_diversity/.DS_Store differ diff --git a/demos/data_visualization_diversity/app.py b/demos/data_visualization_diversity/app.py new file mode 100644 index 000000000..acf4596f5 --- /dev/null +++ b/demos/data_visualization_diversity/app.py @@ -0,0 +1,236 @@ +import os + +import plotly.express as px +import streamlit as st +import yaml +from loguru import logger + +from data_juicer.analysis.diversity_analysis import (DiversityAnalysis, + get_diversity, + prepare_diversity_model) +from data_juicer.config import init_configs +from data_juicer.core import Analyser +from data_juicer.ops.base_op import OPERATORS + + +def convert_csv(df): + # IMPORTANT: Cache the conversion to prevent computation on every rerun + return df.to_csv().encode('utf-8') + + +@st.cache_data +def convert_jsonl(df): + # IMPORTANT: Cache the conversion to prevent computation on every rerun + return df.to_json(orient='records', lines=True).encode('utf-8') + + +@st.cache_data +def get_diversity_model(lang): + diversity_model = prepare_diversity_model(lang) + return diversity_model + + +@st.cache_data +def postproc_diversity(dataframe, **kwargs): + df = get_diversity(dataframe, **kwargs) + return df + + +def pretty_out(d): + res = '' + process = '' + op_names = set(OPERATORS.modules.keys()) + for key, value in d.items(): + if key == 'process': + process = yaml.dump(value, + allow_unicode=True, + default_flow_style=False) + elif key == 'config' or key.split('.')[0] in op_names: + continue + else: + res += f'{key}:\n \t {value}\n' + res += 'process:\n' + \ + '\n'.join(['\t' + line for line in process.splitlines()]) + + return res + + +def parse_cfg(): + + cfg_cmd = '--config configs/demo.yaml' + + args_in_cmd = cfg_cmd.split() + + if len(args_in_cmd) >= 2 and args_in_cmd[0] == '--config': + cfg_f_name = args_in_cmd[1] + else: + st.warning('Please specify a config command or upload a config file.') + st.stop() + + if not os.path.exists(cfg_f_name): + st.warning('do not parse' + f'config file does not exist with cfg_f_name={cfg_f_name}') + st.stop() + + with open(cfg_f_name, 'r') as cfg_f: + specified_cfg = yaml.safe_load(cfg_f) + + try: + parsed_cfg = init_configs(args=args_in_cmd) + st.session_state.cfg = parsed_cfg + + return pretty_out(parsed_cfg), pretty_out(specified_cfg), parsed_cfg + except Exception as e: + return str(e), pretty_out(specified_cfg), None + + +def load_dataset(dataset_file): + + cfg = st.session_state.get('cfg', parse_cfg()[2]) + if cfg is None: + raise ValueError('you have not specify valid cfg') + # force generating separate figures + cfg['save_stats_in_one_file'] = True + + del_file = False + if dataset_file is not None: + + file_contents = dataset_file.getvalue() + with open(dataset_file.name, 'wb') as f: + f.write(file_contents) + cfg.dataset_path = dataset_file.name + del_file = True + + logger.info('=========Stage: analyze original data=========') + analyzer = Analyser(cfg) + + dataset = analyzer.formatter.load_dataset() + if del_file: + os.remove(dataset_file.name) + return dataset + + +class Visualize: + + @staticmethod + def setup(): + st.set_page_config( + page_title='Juicer', + page_icon=':smile', + layout='wide', + # initial_sidebar_state="expanded", + ) + + readme_link = 'https://code.alibaba-inc.com/DAIL-DATA/' \ + 'data_juicer/blob/master/README.md' + st.markdown( + '
    Data-Juicer \ +
    ', + unsafe_allow_html=True, + ) + st.markdown( + f'
    ', + unsafe_allow_html=True, + ) + + @staticmethod + def draw_sunburst(df, path, values): + + fig = px.sunburst(df, path=path, values=values) + fig.update_layout(margin=dict(l=0, r=0, t=0, b=0), + font_family='Times New Roman', + font=dict(size=40)) + st.plotly_chart(fig, use_container_width=True) + + @staticmethod + def diversity(): + col1, col2 = st.columns(2) + with col1: + dataset_file = st.file_uploader( + label='Upload you custom dataset(jsonl/csv)', + type=['json', 'jsonl', 'csv']) + + with col2: + st.text_area(label='Default Demo dataset', + disabled=True, + value='data/demo-dataset.jsonl') + + with st.expander('Set diversity params', expanded=True): + + col1, col2, col3, col4 = st.columns(4) + with col1: + label = 'Which language of your dataset' + options = ['en', 'zh'] + lang_select = st.selectbox( + label=label, + options=options, + ) + with col2: + top_k_verbs = st.number_input('Set the top_k of verbs', + value=20) + with col3: + top_k_nouns = st.number_input('Set the top_k of nouns', + value=4) + with col4: + threshold = st.slider('Count threshold', + min_value=0, + value=0, + max_value=100, + step=1) + diversity_btn = st.button('Start to analyse Verb-Noun diversity', + use_container_width=True) + + with st.expander('Diversity Results ', expanded=True): + + cfg = st.session_state.get('cfg', parse_cfg()[2]) + output_path = os.path.join(os.path.dirname(cfg.export_path), + 'analysis') + raw_df = None + if diversity_btn: + try: + with st.spinner('Wait for analyze diversity...'): + dataset = load_dataset(dataset_file) + + diversity_analysis = DiversityAnalysis( + dataset, output_path) + + raw_df = diversity_analysis.compute( + lang_or_model=get_diversity_model(lang_select)) + + st.session_state[f'diversity{lang_select}'] = raw_df + + except Exception as e: + st.warning(f'Error {str(e)} in {lang_select}') + else: + raw_df = st.session_state.get(f'diversity{lang_select}', None) + + if raw_df is not None: + df = postproc_diversity(raw_df, + top_k_verbs=top_k_verbs, + top_k_nouns=top_k_nouns) + df = df[df['count'] >= threshold] + Visualize.draw_sunburst(df, + path=['verb', 'noun'], + values='count') + + st.download_button( + label='Download diversity data as CSV', + data=convert_csv(df), + file_name='diversity.csv', + mime='text/csv', + ) + + @staticmethod + def visualize(): + Visualize.setup() + Visualize.diversity() + + +def main(): + Visualize.visualize() + + +if __name__ == '__main__': + main() diff --git a/demos/data_visualization_diversity/configs/demo.yaml b/demos/data_visualization_diversity/configs/demo.yaml new file mode 100644 index 000000000..d71266901 --- /dev/null +++ b/demos/data_visualization_diversity/configs/demo.yaml @@ -0,0 +1,66 @@ +# Process config example for dataset + +# global parameters +project_name: 'demo' +dataset_path: './data/demo-dataset.jsonl' # path to your dataset directory or file +np: 1 # number of subprocess to process your dataset + +export_path: './outputs/demo/demo-processed.jsonl' + +# process schedule +# a list of several process operators with their arguments +process: + # Filter ops + - alphanumeric_filter: # filter text with alphabet/numeric ratio out of specific range. + tokenization: false # Whether to count the ratio of alphanumeric to the total number of tokens. + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.9 # the max ratio of filter range + - average_line_length_filter: # filter text with the average length of lines out of specific range. + min_len: 10 # the min length of filter range + max_len: 10000 # the max length of filter range + - character_repetition_filter: # filter text with the character repetition ratio out of specific range + rep_len: 10 # repetition length for char-level n-gram + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.5 # the max ratio of filter range + - flagged_words_filter: # filter text with the flagged-word ratio larger than a specific max value + lang: en # consider flagged words in what language + tokenization: false # whether to use model to tokenize documents + max_ratio: 0.0045 # the max ratio to filter text + flagged_words_dir: ./assets # directory to store flagged words dictionaries + use_words_aug: false # whether to augment words, especially for Chinese and Vietnamese + words_aug_group_sizes: [2] # the group size of words to augment + words_aug_join_char: "" # the join char between words to augment + - language_id_score_filter: # filter text in specific language with language scores larger than a specific max value + lang: en # keep text in what language + min_score: 0.8 # the min language scores to filter text + - maximum_line_length_filter: # filter text with the maximum length of lines out of specific range + min_len: 10 # the min length of filter range + max_len: 10000 # the max length of filter range + - perplexity_filter: # filter text with perplexity score out of specific range + lang: en # compute perplexity in what language + max_ppl: 1500 # the max perplexity score to filter text + - special_characters_filter: # filter text with special-char ratio out of specific range + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.25 # the max ratio of filter range + - stopwords_filter: # filter text with stopword ratio smaller than a specific min value + lang: en # consider stopwords in what language + tokenization: false # whether to use model to tokenize documents + min_ratio: 0.3 # the min ratio to filter text + stopwords_dir: ./assets # directory to store stopwords dictionaries + use_words_aug: false # whether to augment words, especially for Chinese and Vietnamese + words_aug_group_sizes: [2] # the group size of words to augment + words_aug_join_char: "" # the join char between words to augment + - text_length_filter: # filter text with length out of specific range + min_len: 10 # the min length of filter range + max_len: 10000 # the max length of filter range + - words_num_filter: # filter text with number of words out of specific range + lang: en # sample in which language + tokenization: false # whether to use model to tokenize documents + min_num: 10 # the min number of filter range + max_num: 10000 # the max number of filter range + - word_repetition_filter: # filter text with the word repetition ratio out of specific range + lang: en # sample in which language + tokenization: false # whether to use model to tokenize documents + rep_len: 10 # repetition length for word-level n-gram + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.5 # the max ratio of filter range diff --git a/demos/data_visualization_diversity/data/demo-dataset.jsonl b/demos/data_visualization_diversity/data/demo-dataset.jsonl new file mode 100644 index 000000000..3f9e7640b --- /dev/null +++ b/demos/data_visualization_diversity/data/demo-dataset.jsonl @@ -0,0 +1,147 @@ +[ + { + "text": "Give three tips for staying healthy.", + "input": "", + "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule." + }, + { + "text": "What are the three primary colors?", + "input": "", + "output": "The three primary colors are red, blue, and yellow." + }, + { + "text": "Describe the structure of an atom.", + "input": "", + "output": "An atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom." + }, + { + "text": "How can we reduce air pollution?", + "input": "", + "output": "There are a number of ways to reduce air pollution, such as shifting to renewable energy sources, encouraging the use of public transportation, prohibiting the burning of fossil fuels, implementing policies to reduce emissions from industrial sources, and implementing vehicle emissions standards. Additionally, individuals can do their part to reduce air pollution by reducing car use, avoiding burning materials such as wood, and changing to energy efficient appliances." + }, + { + "text": "Describe a time when you had to make a difficult decision.", + "input": "", + "output": "I had to make a difficult decision when I was working as a project manager at a construction company. I was in charge of a project that needed to be completed by a certain date in order to meet the client\u2019s expectations. However, due to unexpected delays, we were not able to meet the deadline and so I had to make a difficult decision. I decided to extend the deadline, but I had to stretch the team\u2019s resources even further and increase the budget. Although it was a risky decision, I ultimately decided to go ahead with it to ensure that the project was completed on time and that the client\u2019s expectations were met. The project was eventually successfully completed and this was seen as a testament to my leadership and decision-making abilities." + }, + { + "text": "Identify the odd one out.", + "input": "Twitter, Instagram, Telegram", + "output": "Telegram" + }, + { + "text": "Explain why the following fraction is equivalent to 1/4", + "input": "4/16", + "output": "The fraction 4/16 is equivalent to 1/4 because both numerators and denominators are divisible by 4. Dividing both the top and bottom numbers by 4 yields the fraction 1/4." + }, + { + "text": "Write a short story in third person narration about a protagonist who has to make an important career decision.", + "input": "", + "output": "John was at a crossroads in his life. He had just graduated college and was now facing the big decision of what career to pursue. After much deliberation, he decided that he wanted to be an accountant and help the financially disadvantaged. He had always been good with numbers and enjoyed seeing the tangible results of his work. \n\nJohn enrolled in accounting courses and initially found it quite challenging. He had to learn multiple systems and regulations quickly, but he worked hard and eventually excelled in his studies. After a few years, John started working at an accounting firm in his city. He was eager to put his knowledge of taxes and accounting to use in a real-world setting.\n\nJohn loved his job, as it let him express his creativity in finding strategies to save his clients money. After a few years at the firm, he became a senior accountant and was asked to manage bigger and more challenging cases. He was now a respected figure in the financial industry, but he still remembers when he was just a recent college graduate, unsure of the direction in which his life would take him." + }, + { + "text": "Render a 3D model of a house", + "input": "", + "output": " This type of text cannot be fulfilled by a GPT model." + }, + { + "text": "Evaluate this sentence for spelling and grammar mistakes", + "input": "He finnished his meal and left the resturant", + "output": "He finished his meal and left the restaurant." + }, + { + "text": "How did Julius Caesar die?", + "input": "", + "output": "Julius Caesar was assassinated by a group of up to 60 conspirators, led by Gaius Cassius Longinus and Marcus Junius Brutus, in the Senate House on the Ides of March (15 March) of 44 BC." + }, + { + "text": "What is the capital of France?", + "input": "", + "output": "The capital of France is Paris." + }, + { + "text": "Generate a list of ten items a person might need for a camping trip", + "input": "", + "output": "1. Tent\n2. Sleeping bags\n3. Flashlight\n4. Matches/lighter\n5. Insect repellent\n6. Extra clothes\n7. First aid kit\n8. Camp stove\n9. Cooking utensils\n10. Biodegradable soaps/shampoos" + }, + { + "text": "Discuss the causes of the Great Depression", + "input": "", + "output": "The Great Depression was a severe global economic downturn that lasted from 1929 to 1939. It was caused by several factors, including the stock market crash of 1929, structural weaknesses in the economy, low consumer spending, high unemployment, and a lack of confidence in the banking system. Other contributing factors included overproduction of commodities, reliance on borrowing to finance investments, agricultural overproduction, and a decrease in international trade." + }, + { + "text": "Classify the following into animals, plants, and minerals", + "input": "Oak tree, copper ore, elephant", + "output": "Oak tree: Plant\n Copper ore: Mineral\n Elephant: Animal" + }, + { + "text": "Explain the use of word embeddings in Natural Language Processing", + "input": "", + "output": "Word embeddings are one of the most powerful tools available for Natural Language Processing (NLP). They are mathematical representations of words or phrases in a vector space, allowing similarities between words and the context in which they are used to be measured. Word embeddings are useful for tasks such as sentiment analysis, text classification, predicting the next word in a sequence, and understanding synonyms and analogies. They allow for words to be processed as numerical values, giving machines an easier way to perform NLP tasks." + }, + { + "text": "Describe the function of a computer motherboard", + "input": "", + "output": "A computer motherboard is the main circuit board in a computer. It is the backbone of the system, connecting all other components such as the processor, RAM, storage, graphics card and peripheral ports. It provides communication pathways between all of these components, allowing them to interact and work together. The motherboard is responsible for regulating power, and also houses slots for expansion cards, as well as ports for USB, audio, Ethernet, and other connection types." + }, + { + "text": "Reverse engineer this code to create a new version", + "input": "def factorialize(num):\n factorial = 1\n for i in range(1, num):\n factorial *= i\n \n return factorial", + "output": "def factorialize(num):\n factorial = 1\n for i in range(num, 0, -1):\n factorial *= i\n \n return factorial" + }, + { + "text": "Propose an ethical solution to the problem of data privacy", + "input": "", + "output": "One ethical solution to the problem of data privacy is to implement an opt-in data collection policy. This means that companies cannot collect and use personal data without the explicit permission of the data owners. This approach gives users the power to determine which data companies can use and for what purpose. Companies should also be transparent about how they use the collected data, so that users can make informed decisions and have control over their own data." + }, + { + "text": "Generate three verbs that mean the same as \"to apologize\"", + "input": "", + "output": "Acknowledge, express regret, make amends." + }, + { + "text": "What does DNA stand for?", + "input": "", + "output": "DNA stands for deoxyribonucleic acid." + }, + { + "text": "Compare and contrast the Cuban Missile Crisis and the Vietnam War.", + "input": "", + "output": "The Cuban Missile Crisis and the Vietnam War were two of the largest and deadliest conflicts of the 20th century. The Cuban Missile Crisis involved a standoff between the United States and the Soviet Union over Soviet nuclear missile bases in Cuba, while the Vietnam War was a prolonged military conflict in Indochina between the North, which was backed by the Soviet Union and the South, which was backed by the United States. The Cuban Missile Crisis unfolded over a period of 13 days in October 1962, with the resolution coming in the form of a diplomatic agreement. The Vietnam War lasted for 19 years, beginning in 1954 and ending in 1973, with the signing of the Paris Peace Accords. While both of these conflicts were Cold War related, they were also very different in terms of causes, objectives, strategic objectives and outcomes. The Cuban Missile Crisis did not result in any direct military conflict and was resolved primarily through diplomacy and the decision to withdraw nuclear weapons from Cuba. On the other hand, the Vietnam War was a prolonged series of battles resulting in significant losses for both sides. The Vietnam War also had a significant impact on the United States, as it highlighted the realities of a prolonged conflict and the limits of U.S. military intervention. Additionally, the Vietnam War also drastically changed the way public opinion viewed the military and the use of force." + }, + { + "text": "Generate a list of random words.", + "input": "", + "output": "sneeze, conflict, ancestor, thunder, companion, amulet, forge, serpent, magnify, jester, cascade, universe, surge, nausea, oasis, anomaly, pursue, rapture, pagoda, mosaic." + }, + { + "text": "Transcribe the recording into text.", + "input": "Recording about the history of Waterloo", + "output": "The Battle of Waterloo, fought on June 18, 1815, was the last major battle of the Napoleonic Wars. It marked the end of Napoleon Bonaparte's rule. The battle was fought between Napoleon Bonaparte's French forces and an allied force of British, Dutch, Belgian and Prussian troops, led by the Duke of Wellington. It was fought near the town of Waterloo in the former Austrian Netherlands, now present-day Belgium. The battle was a defeat for Napoleon's forces and ended Napoleon's rule over France. Napoleon himself was exiled to St. Helena, where he would die in 1821. The battle marked the end of French domination of Europe and the beginning of a period of peace in the region. It was also the last major battle of the Napoleonic Wars and is remembered as one of the most important battles in history. The victory of the allied forces at Waterloo helped to bring an end to French imperial ambitions in Europe. It also established the British as one of the most powerful nations in Europe. The battle was fought with a combination of infantry, cavalry, and artillery tactics, and showed the beginning of the changing nature of warfare. Although the French forces greatly outnumbered the allies, their strategic and tactical mistake led to the loss of the battle. This defeat signaled the end of French imperial power in Europe. The legacy of Waterloo still stands today, and it cemented the British Empire's position for the next one hundred years." + }, + { + "text": "Who is the world's most famous painter?", + "input": "", + "output": "The world's most famous painter is generally considered to be the Italian renaissance artist Leonardo da Vinci. His work includes the Mona Lisa and The Last Supper, both of which are considered masterpieces. Da Vinci was also a master of science, inventing the helicopter and the tank, among other inventions. He was also a prolific musician, sculptor, and architect. His work has inspired generations of artists and continues to be revered today." + }, + { + "text": "Explain the concept of cogging torque.", + "input": "", + "output": "Cogging torque is the torque that is applied to a rotating device when the device experiences resistance from the teeth of the gears that are connected to the device. Cogging torque occurs due to the reaction force generated when the teeth of the gear drive mesh. This force pushes against each other, resulting in a drag force that is transmitted from the gears to the shaft. The magnitude of the cogging torque can vary depending on the size and number of teeth, as well as the gear's material and design. Cogging torque is undesirable in most motorized applications, as it can lead to loss of efficiency. To reduce cogging torque, gear designs are often symmetrical or have certain tooth profiles which reduce the force of contact between the teeth when they come into contact with each other." + }, + { + "text": "Look up the boiling point of water.", + "input": "", + "output": "The boiling point of water is 100\u00b0C (212\u00b0F), at standard atmospheric pressure. When heated at higher pressures, water's boiling point increases. For example, water will boil at a temperature of 115\u00b0C (239\u00b0F) at an atmospheric pressure of 1 bar. At higher pressures, such as those experienced at the bottom of the ocean, the boiling point of water can be as high as 374\u00b0C (705\u00b0F)." + }, + { + "text": "Describe the life and reign of King Charles II.", + "input": "", + "output": "King Charles II ascended to the throne of England following the English Civil War and the death of his father, Charles I. He was crowned in 1661 and his reign lasted until his death in 1685. Charles II was also a patron of the arts, working with some of the greatest painters and sculptors of his day. He also encouraged theater and sports and had an interest in sciences such as astronomy and natural history. Charles II was a moderate ruler who respected the civil and religious liberties of his citizens. He was tolerant of other religions and gave Jews and Catholics more rights in the kingdom. Charles II was also a strong supporter of the Royal Navy, which saw England increase its power in the world. He also promoted the development of industry and commerce, which helped to modernize the country. Charles II's reign was an important step in restoring the monarchy of England after the turmoil of the English Civil War." + }, + { + "text": "Find the area of a circle given its radius.", + "input": "Radius = 4", + "output": "The area of a circle with a radius of 4 is equal to 12.5664 square units. This is calculated by using the formula A = \u03c0r2, where A is the area, \u03c0 is roughly equal to 3.1416 and r is the radius of the circle." + } +] \ No newline at end of file diff --git a/demos/data_visualization_op_effect/.DS_Store b/demos/data_visualization_op_effect/.DS_Store new file mode 100644 index 000000000..02a008699 Binary files /dev/null and b/demos/data_visualization_op_effect/.DS_Store differ diff --git a/demos/data_visualization_op_effect/app.py b/demos/data_visualization_op_effect/app.py new file mode 100644 index 000000000..2d075444f --- /dev/null +++ b/demos/data_visualization_op_effect/app.py @@ -0,0 +1,526 @@ +# Some code here has been modified from: +# https://huggingface.co/spaces/huggingface/text-data-filtering +# -------------------------------------------------------- +import copy +import math +import os + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import streamlit as st +import yaml + +from data_juicer.config import init_configs +from data_juicer.core import Analyser +from data_juicer.ops.base_op import OPERATORS + + +@st.cache_data +def convert_csv(df): + # IMPORTANT: Cache the conversion to prevent computation on every rerun + return df.to_csv().encode('utf-8') + + +@st.cache_data +def convert_jsonl(df): + # IMPORTANT: Cache the conversion to prevent computation on every rerun + return df.to_json(orient='records', lines=True).encode('utf-8') + + +def pretty_out(d): + res = '' + process = '' + op_names = set(OPERATORS.modules.keys()) + for key, value in d.items(): + if key == 'process': + process = yaml.dump(value, + allow_unicode=True, + default_flow_style=False) + elif key == 'config' or key.split('.')[0] in op_names: + continue + else: + res += f'{key}:\n \t {value}\n' + res += 'process:\n' + \ + '\n'.join(['\t' + line for line in process.splitlines()]) + + return res + + +def parse_cfg(): + + lang_select = st.session_state.get('lang_select', 'en') + + if lang_select == 'zh': + cfg_cmd = '--config configs/demo_zh.yaml' + else: + cfg_cmd = '--config configs/demo_en.yaml' + + args_in_cmd = cfg_cmd.split() + + if len(args_in_cmd) >= 2 and args_in_cmd[0] == '--config': + cfg_f_name = args_in_cmd[1] + else: + st.warning('Please specify a config command or upload a config file.') + st.stop() + + if not os.path.exists(cfg_f_name): + st.warning('do not parse' + f'config file does not exist with cfg_f_name={cfg_f_name}') + st.stop() + + with open(cfg_f_name, 'r') as cfg_f: + specified_cfg = yaml.safe_load(cfg_f) + + try: + parsed_cfg = init_configs(args=args_in_cmd) + st.session_state.cfg = parsed_cfg + + return pretty_out(parsed_cfg), pretty_out(specified_cfg), parsed_cfg + except Exception as e: + return str(e), pretty_out(specified_cfg), None + + +def analyze_and_show_res(dataset_file): + images_ori = [] + cfg = st.session_state.get('cfg', parse_cfg()[2]) + if cfg is None: + raise ValueError('you have not specify valid cfg') + # force generating separate figures + cfg['save_stats_in_one_file'] = True + + del_file = False + if dataset_file is not None: + + file_contents = dataset_file.getvalue() + with open(dataset_file.name, 'wb') as f: + f.write(file_contents) + cfg.dataset_path = dataset_file.name + del_file = True + + analyzer = Analyser(cfg) + dataset = analyzer.run() + + analysis_res_ori = pd.read_csv( + os.path.join(analyzer.analysis_path, 'overall.csv')) + for f_path in os.listdir(analyzer.analysis_path): + if '.png' in f_path and 'all-stats' in f_path: + images_ori.append(os.path.join(analyzer.analysis_path, f_path)) + + st.session_state.dataset = dataset + st.session_state.original_overall = analysis_res_ori + st.session_state.original_imgs = images_ori + + if del_file: + os.remove(dataset_file.name) + + +def get_min_max_step(data): + max_value = np.max(data) + if max_value > 2.0: + min_value = 0 + max_value = int(max_value + 1) + step = 1 + else: + min_value = 0.0 + max_value = max(1.0, max_value) + step = 0.01 + return min_value, max_value, step + + +op_stats_dict = { + 'alphanumeric_filter': ['alpha_token_ratio', 'alnum_ratio'], + 'average_line_length_filter': ['avg_line_length'], + 'character_repetition_filter': ['char_rep_ratio'], + 'flagged_words_filter': ['flagged_words_ratio'], + 'language_id_score_filter': ['lang', 'lang_score'], + 'maximum_line_length_filter': ['max_line_length'], + 'perplexity_filter': ['perplexity'], + 'special_characters_filter': ['special_char_ratio'], + 'stopwords_filter': ['stopwords_ratio'], + 'text_length_filter': ['text_len'], + 'words_num_filter': ['num_words'], + 'word_repetition_filter': ['word_rep_ratio'], +} + + +class Visualize: + + @staticmethod + def setup(): + st.set_page_config( + page_title='Juicer', + page_icon=':smile', + layout='wide', + # initial_sidebar_state="expanded", + ) + + readme_link = 'https://code.alibaba-inc.com/DAIL-DATA/' \ + 'data_juicer/blob/master/README.md' + st.markdown( + '
    Data-Juicer \ +
    ', + unsafe_allow_html=True, + ) + st.markdown( + f'
    A Data-Centric Text Processing System for \ + Large Language Models, \ + see more details in Document
    ', + unsafe_allow_html=True, + ) + + @staticmethod + def draw_stack_bar(bar_sizes, bar_labels, total_num, title=''): + filtered_size = [ + k / total_num * 100 for i, k in enumerate(bar_sizes[::-1]) + if i % 2 == 0 + ] + retain_size = [ + k / total_num * 100 for i, k in enumerate(bar_sizes[::-1]) + if i % 2 != 0 + ] + plt.clf() + plt.title(title) + bar_labels = bar_labels[::-1] + # retained + r_bars = plt.barh(bar_labels, + retain_size, + label='Retained', + height=0.5, + color='limegreen') + + # filtered + f_bars = plt.barh(bar_labels, + filtered_size, + label='Filtered', + left=retain_size, + height=0.5, + color='orangered') + + for idx, bar in enumerate(r_bars): + width = bar.get_width() + plt.text(bar.get_x() + width / 2, + bar.get_y() + bar.get_height() / 2, + f'{retain_size[idx]:.2f}%', + ha='center', + va='center') + + for idx, bar in enumerate(f_bars): + width = bar.get_width() + plt.text(bar.get_x() + width / 2, + bar.get_y() + bar.get_height() / 2, + f'{filtered_size[idx]:.2f}%', + ha='center', + va='center') + + plt.legend() + plt.gcf() + st.pyplot(plt, use_container_width=True) + + @staticmethod + def display_discarded_ratio(cond, key): + if len(cond) > 0: + st.caption( + f':red[{(len(cond) - np.sum(1*cond)) / len(cond) * 100:.2f}%] \ + of the total (:red[{len(cond)}]) is discarded with {key}.') + else: + st.caption(f':red[{0:.2f}%] \ + of the total (:red[0]) is discarded with {key}.') + + @staticmethod + def display_dataset(dataframe, cond, show_num, desp, type, all=True): + examples = dataframe.loc[cond] + if all or len(examples) > 0: + st.subheader( + f'{desp}: :red[{len(examples)}] of ' + f'{len(dataframe.index)} {type} ' + f'(:red[{len(examples)/len(dataframe.index) * 100:.2f}%])') + + # st.markdown('Click on a column to sort by it, \ + # place the cursor on the text to display it.') + st.dataframe(examples[:show_num], use_container_width=True) + + @staticmethod + def draw_hist(data, cutoff=None): + + fig, ax = plt.subplots() + data_num = len(data) + if data_num >= 100: + rec_bins = int(math.sqrt(len(data))) + else: + rec_bins = 50 + + if data_num > 0: + ax.hist(data, bins=rec_bins, density=True) + if hasattr(data, 'name'): + ax.set_title(data.name) + + if isinstance(cutoff, (float, int)): + ax.axvline(x=cutoff, color='r', linestyle='dashed') + elif isinstance(cutoff, tuple) and len(cutoff) == 2: + ax.axvline(x=cutoff[0], color='r', linestyle='dashed') + ax.axvline(x=cutoff[1], color='r', linestyle='dashed') + st.pyplot(fig) + + @staticmethod + def op_effect_analyze(): + col1, col2, col3 = st.columns(3) + + with col1: + label = 'Which language of your dataset' + options = ['en', 'zh'] + lang_select = st.selectbox( + label=label, + options=options, + ) + st.session_state.lang_select = lang_select + + with col2: + dataset_file = st.file_uploader( + label='Upload you custom dataset(jsonl/csv)', + type=['json', 'jsonl', 'csv']) + + with col3: + st.text_area(label='Default Demo dataset', + disabled=True, + value='data/demo-dataset.jsonl') + + start_btn = st.button('Start to analyze data (per filter op)', + use_container_width=True) + + if start_btn: + with st.spinner('Wait for analyze...'): + analyze_and_show_res(dataset_file) + + with st.expander('Data Analysis Results', expanded=False): + original_overall = st.session_state.get('original_overall', None) + original_imgs = st.session_state.get('original_imgs', []) + + st.dataframe(original_overall, use_container_width=True) + for img in original_imgs: + st.image(img, output_format='png') + with st.expander('Effect of Filter OPs', expanded=True): + dataset = st.session_state.get('dataset', None) + if dataset: + Visualize.filter_dataset(dataset) + else: + st.warning('Please analyze data first') + + @staticmethod + def filter_dataset(dataset): + text = dataset['text'] + if 'stats' not in dataset.features: + stats = pd.DataFrame(dataset['stats.meta']) + else: + stats = pd.DataFrame(dataset['stats']) + stats['text'] = text + + non_num_list = ['lang'] + min_cutoff_list = [ + 'lang_score', + 'stopwords_ratio', + ] + max_cutoff_list = [ + 'flagged_words_ratio', + 'max_ppl', + ] + mask_list = ['text'] + + cfg = st.session_state.get('cfg', None) + if cfg is None: + return + + def set_sliders(total_stats, ordered): + stats = copy.deepcopy(total_stats) + conds = list() + index = 1 + for op_cfg in cfg.process: + op_name = list(op_cfg.keys())[0] + op_stats = op_stats_dict.get(op_name, []) + + cutoff_ratio = None + + with st.sidebar.expander(f'{index} {op_name}'): + + for column_name in op_stats: + if column_name not in stats: + continue + data = stats[column_name] + + if column_name in non_num_list: + options = ['all'] + list(set(data)) + label = f'Which {column_name} would \ + you like to keep?' + + selected = st.selectbox( + label=label, + options=options, + ) + if selected == 'all': + cond = [True] * len(data) + else: + cond = data == selected + Visualize.display_discarded_ratio( + cond, column_name) + + elif column_name in min_cutoff_list: + label = f'If the {column_name} of a document \ + is lower than this number, \ + the document is removed.' + + low, high, step = get_min_max_step(data) + + cutoff_ratio = st.slider(label, + low, + high, + low, + step=step) + cond = data >= cutoff_ratio + Visualize.display_discarded_ratio( + cond, column_name) + + elif column_name in max_cutoff_list: + label = f'If the {column_name} of a document \ + is higher than this number, \ + the document is removed.' + + low, high, step = get_min_max_step(data) + cutoff_ratio = st.slider(label, + low, + high, + high, + step=step) + cond = data <= cutoff_ratio + + Visualize.display_discarded_ratio( + cond, column_name) + elif column_name not in mask_list: + # lower + label = f'If the {column_name} of a document \ + is lower than this number, \ + the document is removed.' + + low, high, step = get_min_max_step(data) + + cutoff_ratio_l = st.slider(label, + low, + high, + low, + step=step) + cond_l = data >= cutoff_ratio_l + + Visualize.display_discarded_ratio( + cond_l, column_name) + + # higher + label = f'If the {column_name} of a document \ + is higher than this number, \ + the document is removed.' + + cutoff_ratio_h = st.slider(label, + low, + high, + high, + step=step) + + cond_h = data <= cutoff_ratio_h + Visualize.display_discarded_ratio( + cond_h, column_name) + cond = [ + low & high + for low, high in zip(cond_l, cond_h) + ] + + cutoff_ratio = (cutoff_ratio_l, cutoff_ratio_h) + + if column_name not in mask_list: + Visualize.draw_hist(data, cutoff_ratio) + conds.append({ + (' '.join([str(index), op_name]), column_name): + cond + }) + + if ordered: + stats = stats.loc[cond] + index += 1 + return conds, stats + + st.subheader('How many samples do you want to show?') + show_num = st.number_input( + label='How many samples do you want to show?', + value=5, + label_visibility='hidden') + + st.sidebar.subheader('Parameters of filter ops') + ordered = st.sidebar.checkbox('Process by op order') + conds, filtered_stats = set_sliders(stats, ordered) + + if ordered: + all_conds = [ + True if i in filtered_stats.index else False + for i in range(len(stats)) + ] + else: + all_conds = np.all([list(cond.values())[0] for cond in conds], + axis=0) + + ds = pd.DataFrame(dataset) + Visualize.display_dataset(ds, all_conds, show_num, 'Retained sampels', + 'docs') + st.download_button('Download Retained data as JSONL', + data=convert_jsonl(ds.loc[all_conds]), + file_name='retained.jsonl') + Visualize.display_dataset(ds, np.invert(all_conds), show_num, + 'Discarded sampels', 'docs') + st.download_button('Download Discarded data as JSONL', + data=convert_jsonl(ds.loc[np.invert(all_conds)]), + file_name='discarded.jsonl') + display_discarded_details = st.checkbox( + 'Display discarded documents by filter details') + + show_stats = copy.deepcopy(stats) + bar_labels = [] + bar_sizes = [] + for item in conds: + for op_key, cond in item.items(): + op_name, column_name = op_key + if column_name not in mask_list: + sub_stats = show_stats[[column_name, 'text']] + if display_discarded_details: + Visualize.display_dataset( + sub_stats, + np.invert(cond) if len(cond) > 0 else [], + show_num, + # f'Discarded documents for the filter on \ + f'{op_name} {column_name} filtered ', + 'docs', + ) + before_filtered_num = len(show_stats.index) + if ordered: + show_stats = show_stats.loc[cond] + retained = np.sum(1 * cond) + filtered = before_filtered_num - len(show_stats.index) + else: + retained = np.sum(1 * cond) + filtered = before_filtered_num - retained + + bar_sizes.append(retained) + bar_sizes.append(filtered) + bar_labels.append(f'{op_name}\n{column_name}') + + bar_title = 'Effect of Filter OPs' + Visualize.draw_stack_bar(bar_sizes, bar_labels, len(stats.index), + bar_title) + + @staticmethod + def visualize(): + Visualize.setup() + Visualize.op_effect_analyze() + + +def main(): + Visualize.visualize() + + +if __name__ == '__main__': + main() diff --git a/demos/data_visualization_op_effect/configs/demo_en.yaml b/demos/data_visualization_op_effect/configs/demo_en.yaml new file mode 100644 index 000000000..d71266901 --- /dev/null +++ b/demos/data_visualization_op_effect/configs/demo_en.yaml @@ -0,0 +1,66 @@ +# Process config example for dataset + +# global parameters +project_name: 'demo' +dataset_path: './data/demo-dataset.jsonl' # path to your dataset directory or file +np: 1 # number of subprocess to process your dataset + +export_path: './outputs/demo/demo-processed.jsonl' + +# process schedule +# a list of several process operators with their arguments +process: + # Filter ops + - alphanumeric_filter: # filter text with alphabet/numeric ratio out of specific range. + tokenization: false # Whether to count the ratio of alphanumeric to the total number of tokens. + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.9 # the max ratio of filter range + - average_line_length_filter: # filter text with the average length of lines out of specific range. + min_len: 10 # the min length of filter range + max_len: 10000 # the max length of filter range + - character_repetition_filter: # filter text with the character repetition ratio out of specific range + rep_len: 10 # repetition length for char-level n-gram + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.5 # the max ratio of filter range + - flagged_words_filter: # filter text with the flagged-word ratio larger than a specific max value + lang: en # consider flagged words in what language + tokenization: false # whether to use model to tokenize documents + max_ratio: 0.0045 # the max ratio to filter text + flagged_words_dir: ./assets # directory to store flagged words dictionaries + use_words_aug: false # whether to augment words, especially for Chinese and Vietnamese + words_aug_group_sizes: [2] # the group size of words to augment + words_aug_join_char: "" # the join char between words to augment + - language_id_score_filter: # filter text in specific language with language scores larger than a specific max value + lang: en # keep text in what language + min_score: 0.8 # the min language scores to filter text + - maximum_line_length_filter: # filter text with the maximum length of lines out of specific range + min_len: 10 # the min length of filter range + max_len: 10000 # the max length of filter range + - perplexity_filter: # filter text with perplexity score out of specific range + lang: en # compute perplexity in what language + max_ppl: 1500 # the max perplexity score to filter text + - special_characters_filter: # filter text with special-char ratio out of specific range + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.25 # the max ratio of filter range + - stopwords_filter: # filter text with stopword ratio smaller than a specific min value + lang: en # consider stopwords in what language + tokenization: false # whether to use model to tokenize documents + min_ratio: 0.3 # the min ratio to filter text + stopwords_dir: ./assets # directory to store stopwords dictionaries + use_words_aug: false # whether to augment words, especially for Chinese and Vietnamese + words_aug_group_sizes: [2] # the group size of words to augment + words_aug_join_char: "" # the join char between words to augment + - text_length_filter: # filter text with length out of specific range + min_len: 10 # the min length of filter range + max_len: 10000 # the max length of filter range + - words_num_filter: # filter text with number of words out of specific range + lang: en # sample in which language + tokenization: false # whether to use model to tokenize documents + min_num: 10 # the min number of filter range + max_num: 10000 # the max number of filter range + - word_repetition_filter: # filter text with the word repetition ratio out of specific range + lang: en # sample in which language + tokenization: false # whether to use model to tokenize documents + rep_len: 10 # repetition length for word-level n-gram + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.5 # the max ratio of filter range diff --git a/demos/data_visualization_op_effect/configs/demo_zh.yaml b/demos/data_visualization_op_effect/configs/demo_zh.yaml new file mode 100644 index 000000000..c07ea1cf9 --- /dev/null +++ b/demos/data_visualization_op_effect/configs/demo_zh.yaml @@ -0,0 +1,66 @@ +# Process config example for dataset + +# global parameters +project_name: 'demo' +dataset_path: './data/demo-dataset.jsonl' # path to your dataset directory or file +np: 1 # number of subprocess to process your dataset + +export_path: './outputs/demo/demo-processed.jsonl' + +# process schedule +# a list of several process operators with their arguments +process: + # Filter ops + - alphanumeric_filter: # filter text with alphabet/numeric ratio out of specific range. + tokenization: false # Whether to count the ratio of alphanumeric to the total number of tokens. + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.9 # the max ratio of filter range + - average_line_length_filter: # filter text with the average length of lines out of specific range. + min_len: 10 # the min length of filter range + max_len: 10000 # the max length of filter range + - character_repetition_filter: # filter text with the character repetition ratio out of specific range + rep_len: 10 # repetition length for char-level n-gram + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.5 # the max ratio of filter range + - flagged_words_filter: # filter text with the flagged-word ratio larger than a specific max value + lang: zh # consider flagged words in what language + tokenization: true # whether to use model to tokenize documents + max_ratio: 0.0045 # the max ratio to filter text + flagged_words_dir: ./assets # directory to store flagged words dictionaries + use_words_aug: true # whether to augment words, especially for Chinese and Vietnamese + words_aug_group_sizes: [2] # the group size of words to augment + words_aug_join_char: "" # the join char between words to augment + - language_id_score_filter: # filter text in specific language with language scores larger than a specific max value + lang: zh # keep text in what language + min_score: 0.8 # the min language scores to filter text + - maximum_line_length_filter: # filter text with the maximum length of lines out of specific range + min_len: 10 # the min length of filter range + max_len: 10000 # the max length of filter range + - perplexity_filter: # filter text with perplexity score out of specific range + lang: zh # compute perplexity in what language + max_ppl: 1500 # the max perplexity score to filter text + - special_characters_filter: # filter text with special-char ratio out of specific range + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.25 # the max ratio of filter range + - stopwords_filter: # filter text with stopword ratio smaller than a specific min value + lang: zh # consider stopwords in what language + tokenization: true # whether to use model to tokenize documents + min_ratio: 0.3 # the min ratio to filter text + stopwords_dir: ./assets # directory to store stopwords dictionaries + use_words_aug: true # whether to augment words, especially for Chinese and Vietnamese + words_aug_group_sizes: [2] # the group size of words to augment + words_aug_join_char: "" # the join char between words to augment + - text_length_filter: # filter text with length out of specific range + min_len: 10 # the min length of filter range + max_len: 10000 # the max length of filter range + - words_num_filter: # filter text with number of words out of specific range + lang: zh # sample in which language + tokenization: true # whether to use model to tokenize documents + min_num: 10 # the min number of filter range + max_num: 10000 # the max number of filter range + - word_repetition_filter: # filter text with the word repetition ratio out of specific range + lang: zh # sample in which language + tokenization: true # whether to use model to tokenize documents + rep_len: 10 # repetition length for word-level n-gram + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.5 # the max ratio of filter range diff --git a/demos/data_visualization_op_effect/data/demo-dataset.jsonl b/demos/data_visualization_op_effect/data/demo-dataset.jsonl new file mode 100644 index 000000000..4d8cdadfd --- /dev/null +++ b/demos/data_visualization_op_effect/data/demo-dataset.jsonl @@ -0,0 +1,6 @@ +{"text": "Today is Sunday and it's a happy day!", "meta": {"src": "Arxiv", "date": "2023-04-27", "version": "1.0"}} +{"text": "Do you need a cup of coffee?", "meta": {"src": "code", "author": "xxx"}} +{"text": "你好,请问你是谁", "meta": {"src": "customized", "author": "xxx"}} +{"text": "Sur la plateforme MT4, plusieurs manières d'accéder à ces fonctionnalités sont conçues simultanément.", "meta": {"src": "Oscar", "version": "2.0"}} +{"text": "欢迎来到阿里巴巴!", "meta": {"src": "customized", "version": "0.1", "author": "xxx"}} +{"text": "This paper proposed a novel method on LLM pretraining.", "meta": {"src": "customized", "author": "xxx"}} diff --git a/demos/data_visualization_statistics/.DS_Store b/demos/data_visualization_statistics/.DS_Store new file mode 100644 index 000000000..6a4643e38 Binary files /dev/null and b/demos/data_visualization_statistics/.DS_Store differ diff --git a/demos/data_visualization_statistics/app.py b/demos/data_visualization_statistics/app.py new file mode 100644 index 000000000..e019b290b --- /dev/null +++ b/demos/data_visualization_statistics/app.py @@ -0,0 +1,176 @@ +import os + +import pandas as pd +import streamlit as st +import yaml +from loguru import logger + +from data_juicer.config import init_configs +from data_juicer.core import Analyser +from data_juicer.ops.base_op import OPERATORS + + +@st.cache_data +def convert_csv(df): + # IMPORTANT: Cache the conversion to prevent computation on every rerun + return df.to_csv().encode('utf-8') + + +@st.cache_data +def convert_jsonl(df): + # IMPORTANT: Cache the conversion to prevent computation on every rerun + return df.to_json(orient='records', lines=True).encode('utf-8') + + +def pretty_out(d): + res = '' + process = '' + op_names = set(OPERATORS.modules.keys()) + for key, value in d.items(): + if key == 'process': + process = yaml.dump(value, + allow_unicode=True, + default_flow_style=False) + elif key == 'config' or key.split('.')[0] in op_names: + continue + else: + res += f'{key}:\n \t {value}\n' + res += 'process:\n' + \ + '\n'.join(['\t' + line for line in process.splitlines()]) + + return res + + +def parse_cfg(): + + cfg_cmd = '--config configs/demo.yaml' + + args_in_cmd = cfg_cmd.split() + + if len(args_in_cmd) >= 2 and args_in_cmd[0] == '--config': + cfg_f_name = args_in_cmd[1] + else: + st.warning('Please specify a config command or upload a config file.') + st.stop() + + if not os.path.exists(cfg_f_name): + st.warning('do not parse' + f'config file does not exist with cfg_f_name={cfg_f_name}') + st.stop() + + with open(cfg_f_name, 'r') as cfg_f: + specified_cfg = yaml.safe_load(cfg_f) + + try: + parsed_cfg = init_configs(args=args_in_cmd) + st.session_state.cfg = parsed_cfg + + return pretty_out(parsed_cfg), pretty_out(specified_cfg), parsed_cfg + except Exception as e: + return str(e), pretty_out(specified_cfg), None + + +def analyze_and_show_res(dataset_file): + + images_ori = [] + cfg = st.session_state.get('cfg', parse_cfg()[2]) + if cfg is None: + raise ValueError('you have not specify valid cfg') + # force generating separate figures + cfg['save_stats_in_one_file'] = True + + del_file = False + logger.info('=========Stage: analyze original data=========') + if dataset_file is not None: + + file_contents = dataset_file.getvalue() + with open(dataset_file.name, 'wb') as f: + f.write(file_contents) + cfg.dataset_path = dataset_file.name + del_file = True + + analyzer = Analyser(cfg) + dataset = analyzer.run() + + analysis_res_ori = pd.read_csv( + os.path.join(analyzer.analysis_path, 'overall.csv')) + for f_path in os.listdir(analyzer.analysis_path): + if '.png' in f_path and 'all-stats' in f_path: + images_ori.append(os.path.join(analyzer.analysis_path, f_path)) + + st.session_state.dataset = dataset + st.session_state.original_overall = analysis_res_ori + st.session_state.original_imgs = images_ori + if del_file: + os.remove(dataset_file.name) + + +class Visualize: + + @staticmethod + def setup(): + st.set_page_config( + page_title='Data-Juicer', + page_icon=':smile', + layout='wide', + # initial_sidebar_state="expanded", + ) + + readme_link = 'https://code.alibaba-inc.com/DAIL-DATA/' \ + 'data_juicer/blob/master/README.md' + st.markdown( + '
    Data-Juicer \ +
    ', + unsafe_allow_html=True, + ) + st.markdown( + f'
    A Data-Centric Text Processing System for \ + Large Language Models, \ + see more details in Document
    ', + unsafe_allow_html=True, + ) + + @staticmethod + def analyze_process(): + col1, col2 = st.columns(2) + with col1: + dataset_file = st.file_uploader( + label='Upload your custom dataset csv or jsonl', + type=['csv', 'json', 'jsonl']) + with col2: + st.text_area(label='Default Demo dataset', + disabled=True, + value='demo/demo-dataset.jsonl') + + start_btn = st.button( + '2. Start to analyze original data (per filter op)', + use_container_width=True) + + with st.expander('Data Analysis Results', expanded=True): + + if start_btn: + with st.spinner('Wait for analyze...'): + analyze_and_show_res(dataset_file) + + original_overall = st.session_state.get('original_overall', None) + original_imgs = st.session_state.get('original_imgs', []) + + st.header('Statistics') + st.dataframe(original_overall, use_container_width=True) + if len(original_imgs) > 0: + st.header('Histograms') + for img in original_imgs: + st.image(img, output_format='png', use_column_width = True) + + @staticmethod + def visualize(): + Visualize.setup() + Visualize.analyze_process() + + +def main(): + Visualize.visualize() + + +if __name__ == '__main__': + main() diff --git a/demos/data_visualization_statistics/configs/demo.yaml b/demos/data_visualization_statistics/configs/demo.yaml new file mode 100644 index 000000000..d71266901 --- /dev/null +++ b/demos/data_visualization_statistics/configs/demo.yaml @@ -0,0 +1,66 @@ +# Process config example for dataset + +# global parameters +project_name: 'demo' +dataset_path: './data/demo-dataset.jsonl' # path to your dataset directory or file +np: 1 # number of subprocess to process your dataset + +export_path: './outputs/demo/demo-processed.jsonl' + +# process schedule +# a list of several process operators with their arguments +process: + # Filter ops + - alphanumeric_filter: # filter text with alphabet/numeric ratio out of specific range. + tokenization: false # Whether to count the ratio of alphanumeric to the total number of tokens. + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.9 # the max ratio of filter range + - average_line_length_filter: # filter text with the average length of lines out of specific range. + min_len: 10 # the min length of filter range + max_len: 10000 # the max length of filter range + - character_repetition_filter: # filter text with the character repetition ratio out of specific range + rep_len: 10 # repetition length for char-level n-gram + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.5 # the max ratio of filter range + - flagged_words_filter: # filter text with the flagged-word ratio larger than a specific max value + lang: en # consider flagged words in what language + tokenization: false # whether to use model to tokenize documents + max_ratio: 0.0045 # the max ratio to filter text + flagged_words_dir: ./assets # directory to store flagged words dictionaries + use_words_aug: false # whether to augment words, especially for Chinese and Vietnamese + words_aug_group_sizes: [2] # the group size of words to augment + words_aug_join_char: "" # the join char between words to augment + - language_id_score_filter: # filter text in specific language with language scores larger than a specific max value + lang: en # keep text in what language + min_score: 0.8 # the min language scores to filter text + - maximum_line_length_filter: # filter text with the maximum length of lines out of specific range + min_len: 10 # the min length of filter range + max_len: 10000 # the max length of filter range + - perplexity_filter: # filter text with perplexity score out of specific range + lang: en # compute perplexity in what language + max_ppl: 1500 # the max perplexity score to filter text + - special_characters_filter: # filter text with special-char ratio out of specific range + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.25 # the max ratio of filter range + - stopwords_filter: # filter text with stopword ratio smaller than a specific min value + lang: en # consider stopwords in what language + tokenization: false # whether to use model to tokenize documents + min_ratio: 0.3 # the min ratio to filter text + stopwords_dir: ./assets # directory to store stopwords dictionaries + use_words_aug: false # whether to augment words, especially for Chinese and Vietnamese + words_aug_group_sizes: [2] # the group size of words to augment + words_aug_join_char: "" # the join char between words to augment + - text_length_filter: # filter text with length out of specific range + min_len: 10 # the min length of filter range + max_len: 10000 # the max length of filter range + - words_num_filter: # filter text with number of words out of specific range + lang: en # sample in which language + tokenization: false # whether to use model to tokenize documents + min_num: 10 # the min number of filter range + max_num: 10000 # the max number of filter range + - word_repetition_filter: # filter text with the word repetition ratio out of specific range + lang: en # sample in which language + tokenization: false # whether to use model to tokenize documents + rep_len: 10 # repetition length for word-level n-gram + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.5 # the max ratio of filter range diff --git a/demos/data_visualization_statistics/data/demo-dataset.jsonl b/demos/data_visualization_statistics/data/demo-dataset.jsonl new file mode 100644 index 000000000..4d8cdadfd --- /dev/null +++ b/demos/data_visualization_statistics/data/demo-dataset.jsonl @@ -0,0 +1,6 @@ +{"text": "Today is Sunday and it's a happy day!", "meta": {"src": "Arxiv", "date": "2023-04-27", "version": "1.0"}} +{"text": "Do you need a cup of coffee?", "meta": {"src": "code", "author": "xxx"}} +{"text": "你好,请问你是谁", "meta": {"src": "customized", "author": "xxx"}} +{"text": "Sur la plateforme MT4, plusieurs manières d'accéder à ces fonctionnalités sont conçues simultanément.", "meta": {"src": "Oscar", "version": "2.0"}} +{"text": "欢迎来到阿里巴巴!", "meta": {"src": "customized", "version": "0.1", "author": "xxx"}} +{"text": "This paper proposed a novel method on LLM pretraining.", "meta": {"src": "customized", "author": "xxx"}} diff --git a/demos/tool_quality_classifier/.DS_Store b/demos/tool_quality_classifier/.DS_Store new file mode 100644 index 000000000..5127a64e1 Binary files /dev/null and b/demos/tool_quality_classifier/.DS_Store differ diff --git a/demos/tool_quality_classifier/app.py b/demos/tool_quality_classifier/app.py new file mode 100644 index 000000000..6282dc133 --- /dev/null +++ b/demos/tool_quality_classifier/app.py @@ -0,0 +1,175 @@ +import os + +import streamlit as st +from loguru import logger + +from quality_classifier.qc_utils import (init_spark, load_dataset, predict, + prepare_model) + + +@st.cache_data +def install_jdk(): + + os.system('apt update') + os.system('apt install -y default-jre') + os.system('apt install -y default-jdk') + os.system('export JAVA_HOME=/usr/lib/jvm/default-java') + + +@st.cache_data +def convert_csv(df): + # IMPORTANT: Cache the conversion to prevent computation on every rerun + return df.to_csv().encode('utf-8') + + +@st.cache_data +def convert_jsonl(df): + # IMPORTANT: Cache the conversion to prevent computation on every rerun + return df.to_json(orient='records', lines=True).encode('utf-8') + + +@st.cache_resource +def st_init_spark(): + return init_spark() + + +@st.cache_resource +def st_prepare_model(model_name): + return prepare_model(model_name) + + +def st_load_dataset(spark, ds_path, text_key='text', only_text=False): + return load_dataset(spark=spark, + ds_path=ds_path, + text_key=text_key, + only_text=only_text) + + +def st_predict(model, ds, tokenizer=None, keep_method='label'): + return predict(model=model, + ds=ds, + tokenizer=tokenizer, + keep_method=keep_method) + +def quality_classifier(dataset_file, model): + + del_file = False + + logger.info('=========Stage: analyze original data=========') + if dataset_file is not None: + file_contents = dataset_file.getvalue() + with open(dataset_file.name, 'wb') as f: + f.write(file_contents) + dataset_path = dataset_file.name + del_file = True + else: + dataset_path = st.session_state.get('default_demo_dataset') + + if model == 'chinese': + tokenizer = 'zh.sp.model' + keep_method = 'label' + if model == 'code': + tokenizer = 'code.sp.model' + keep_method = 'label' + if model == 'gpt3': + tokenizer = None + keep_method = 'gpt3' + + spark = st_init_spark() + model = st_prepare_model(model_name=model) + ds = st_load_dataset(spark, dataset_path) + + pred = st_predict(model, ds, tokenizer=tokenizer, keep_method=keep_method) + overall = pred.select('doc_score').toPandas().describe(include='all') + + st.session_state.dataset = pred + st.session_state.original_overall = overall + if del_file: + os.remove(dataset_file.name) + + +class Visualize: + + @staticmethod + def setup(): + st.set_page_config( + page_title='Data-Juicer', + page_icon=':smile', + layout='wide', + # initial_sidebar_state="expanded", + ) + + install_jdk() + + readme_link = 'https://code.alibaba-inc.com/DAIL-DATA/' \ + 'data_juicer/blob/master/README.md' + st.markdown( + '
    Data-Juicer \ +
    ', + unsafe_allow_html=True, + ) + st.markdown( + f'
    A Data-Centric Text Processing System for \ + Large Language Models, \ + see more details in Document
    ', + unsafe_allow_html=True, + ) + + @staticmethod + def quality(): + col1, col2 = st.columns(2) + with col1: + dataset_file = st.file_uploader( + label='Upload you custom dataset(jsonl/parquet)', + type=['json', 'jsonl', 'parquet']) + + st.text_input(label='Default Demo dataset', + disabled=True, + key = 'default_demo_dataset', + value='data/demo-dataset.jsonl') + with col2: + label = 'Select a quality classifier' + quality_model_map = { + 'Chinese quality classifier': 'chinese', + 'Code quality classifier': 'code', + 'GPT3 quality classifier': 'gpt3' + } + + selected_model = st.selectbox(label=label, + options=list( + quality_model_map.keys())) + model_name = quality_model_map[selected_model] + + start_btn = st.button( + f'2. Start to analyze dataset with {selected_model}', + use_container_width=True) + + with st.expander(f'{selected_model} Results', expanded=True): + + if start_btn: + with st.spinner('Wait for analyze...'): + quality_classifier(dataset_file, model_name) + + col1, col2 = st.columns(2) + with col1: + original_overall = st.session_state.get( + 'original_overall', None) + st.header('Statistics') + st.dataframe(original_overall, use_container_width=True) + with col2: + pred = st.session_state.get('dataset', None) + st.header('Details') + st.dataframe(pred, use_container_width=True) + + @staticmethod + def visualize(): + Visualize.setup() + Visualize.quality() + + +def main(): + Visualize.visualize() + + +if __name__ == '__main__': + main() diff --git a/demos/tool_quality_classifier/data/demo-dataset.jsonl b/demos/tool_quality_classifier/data/demo-dataset.jsonl new file mode 100644 index 000000000..14aa71f9a --- /dev/null +++ b/demos/tool_quality_classifier/data/demo-dataset.jsonl @@ -0,0 +1,11 @@ +{"text":"What’s one thing you wish everyone knew about the brain?\nibble\nWhat’s one thing you wish everyone knew about the brain?\nThe place to have real conversations and understand each other better. Join a community or build and grow your own with groups, threads, and conversations.\nSee this content immediately after install\nGet The App\n"} +{"text":"JavaScript must be enabled to use the system\n"} +{"text":"中国企业又建成一座海外三峡工程!-科技-高清完整正版视频在线观看-优酷\n"} +{"text":"Skip to content\nPOLIDEPORTES\nPeriodismo especialzado en deportes\nPrimary Menu\nPOLIDEPORTES\nPolideportes\n¿Quiénes somos?\nNoticia\nEntrevistas\nReportaje\nEquipos de Época\nOpinión\nEspeciales\nCopa Poli\nBuscar:\nSteven Villegas Ceballos patinador\nShare this...\nFacebook\nTwitter\nLinkedin\nWhatsapp\nEmail\nSeguir leyendo\nAnterior El imparable campeón Steven Villegas\nTe pueden interesar\nDeportes\nNoticia\nPiezas filatélicas llegan al Museo Olímpico Colombiano\nmarzo 17, 2023"} +{"text":"Redirect Notice\nRedirect Notice\nThe previous page is sending you to http:\/\/sieuthikhoavantay.vn\/chi-tiet\/khoa-van-tay-dessmann-s710fp-duc.\nIf you do not want to visit that page, you can return to the previous page.\n"} +{"text": "Do you need a cup of coffee?"} +{"text": ".cv域名是因特网域名管理机构ICANN为佛得角共和国(The Republic of Cape Verde República de Cabo Verde)国家及地区分配的顶级域(ccTLD),作为其国家及地区因特网顶级域名。- 奇典网络\n专业的互联网服务提供商 登录 注册 控制中心 新闻中心 客户支持 交费方式 联系我们\n首页\n手机AI建站\n建站\n推广\n域名\n主机\n安全\n企业服务\n加盟\nICANN与CNNIC双认证顶级注册商 在中国,奇典网络是域名服务提供商\n.cv\n.cv域名是ICANN为佛得角共和国国家及地区分配的顶级域名,注册期限1年到10年不等。\n价格: 845 元\/1年\n注册要求: 无要求\n.cv\/.com.cv注册要求\n更多国别域名\n更多NewG域名\n相关资质\n1.什么是 .cv\/.com.cv域名?有什么优势?\n.cv域名是因特网域名管理机构ICANN为佛得角共和国(The Republic of Cape Verde República de Cabo Verde)国家及地区分配的顶级域(ccTLD),作为其国家及地区因特网顶级域名。\n2.cv\/.com.cv域名长度为多少?有什么注册规则?"} +{"text": "Sur la plateforme MT4, plusieurs manières d'accéder à ces fonctionnalités sont conçues simultanément."} +{"text": "欢迎来到阿里巴巴!"} +{"text": "This paper proposed a novel method on LLM pretraining."} +{"text":"世界十大网投平台_2022年卡塔尔世界杯官网\n177-8228-4819\n网站首页\n关于我们\n产品展示\n广告牌制作 广告灯箱制作 标识牌制作 楼宇亮化工程 门头店招制作 不锈钢金属字制作 LED发光字制作 形象墙Logo墙背景墙制作 LED显示屏制作 装饰装潢工程 铜字铜牌制作 户外广告 亚克力制品 各类广告设计 建筑工地广告制作 楼顶大字制作|楼顶发光字制作 霓虹灯制作 三维扣板|3D扣板|广告扣板 房地产广告制作设计 精神堡垒|立牌|指示牌制作 大型商业喷绘写真 展览展示 印刷服务\n合作伙伴\n新闻资讯\n公司新闻 行业新闻 制作知识 设计知识\n成功案例\n技术园地\n联系方式\n"} \ No newline at end of file diff --git a/demos/tool_quality_classifier/quality_classifier/__init__.py b/demos/tool_quality_classifier/quality_classifier/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/demos/tool_quality_classifier/quality_classifier/eval.py b/demos/tool_quality_classifier/quality_classifier/eval.py new file mode 100644 index 000000000..06eb72069 --- /dev/null +++ b/demos/tool_quality_classifier/quality_classifier/eval.py @@ -0,0 +1,92 @@ +# This tool is used for evaluating a quality classifier on your own datasets +# based on PySpark. +# +# We provide several trained models for you. Please refer to the comments at +# the beginning of predict tool for more details. +# +# This tool needs several arguments: +# - positive_datasets: the paths to the positive datasets. It could be a +# string for a single dataset, e.g. 'pos.parquet', or a list of strings +# for several datasets, e.g. '["pos1.parquet", "pos2.parquet"]'. +# - negative_datasets: the paths to the negative datasets. It could be a +# string for a single dataset, e.g. 'neg.parquet', or a list of strings +# for several datasets, e.g. '["neg1.parquet", "neg2.parquet"]'. +# - model: quality classifier name to apply. It's "gpt3" in default. You can +# use one of ["gpt3", "chinese", "code"] we provided, or you can set it +# to the path to your own model trained using the train.py tool. +# - tokenizer: what tokenizer to use to tokenize texts. It's None in default, +# which means using the standard Tokenizer of PySpark. You can use one of +# ["zh.sp.model", "code.sp.model"] we provided, or you can set it to the +# path to your own sentencepiece model. +# - text_key: the field key name to hold texts to be classified. It's "text" +# in default. + +import fire +from loguru import logger + +from qc_utils import eval, init_spark, load_datasets + + +@logger.catch +def main(positive_datasets=None, + negative_datasets=None, + model='my_quality_model', + tokenizer=None, + text_key='text'): + """ + + :param positive_datasets: the paths to the positive datasets. It could be a + string for a single dataset, e.g. 'pos.parquet', or a list of strings + for several datasets, e.g. '["pos1.parquet", "pos2.parquet"]'. + :param negative_datasets: the paths to the negative datasets. It could be a + string for a single dataset, e.g. 'neg.parquet', or a list of strings + for several datasets, e.g. '["neg1.parquet", "neg2.parquet"]'. + :param model: quality classifier name to apply. It's "gpt3" in default. You + can use one of ["gpt3", "chinese", "code"] we provided, or you can set + it to the path to your own model trained using the train.py tool. + :param tokenizer: what tokenizer to use to tokenize texts. It's None in + default, which means using the standard Tokenizer of PySpark. You can + use one of ["zh.sp.model", "code.sp.model"] we provided, or you can set + it to the path to your own sentencepiece model. + :param text_key: the field key name to hold texts to be classified. It's + "text" in default. + :return: + """ + # convert a single dataset to a dataset list + if positive_datasets is None: + positive_datasets = [] + if negative_datasets is None: + negative_datasets = [] + if isinstance(positive_datasets, str): + positive_datasets = [positive_datasets] + if isinstance(negative_datasets, str): + negative_datasets = [negative_datasets] + + spark = init_spark() + + pos = load_datasets(spark, + positive_datasets, + text_key=text_key, + label=1, + only_text=True) + neg = load_datasets(spark, + negative_datasets, + text_key=text_key, + label=0, + only_text=True) + # merge pos and neg samples + if pos is not None and neg is not None: + ds = pos.unionAll(neg) + elif pos is not None: + ds = pos + elif neg is not None: + ds = neg + else: + logger.error('Empty dataset.') + exit(0) + logger.info(f'Number of samples: {ds.count()}') + eval(model, ds, tokenizer) + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/demos/tool_quality_classifier/quality_classifier/predict.py b/demos/tool_quality_classifier/quality_classifier/predict.py new file mode 100644 index 000000000..ddbb084b7 --- /dev/null +++ b/demos/tool_quality_classifier/quality_classifier/predict.py @@ -0,0 +1,120 @@ +# This tool is used for predicting a document score for text samples using +# quality classifier models we provided, including: +# - gpt3: A GPT3 quality classifier reproduced from scratch by us based on +# PySpark. It's trained over CC as negative samples and Wikipedia-en, +# Books, OpenWebText as positive samples. +# - chinese: A quality classifier for Chinese. It's trained over Chinese +# texts sampled from CC as negative samples and Wudao, Wikipedia-zh as +# positive samples. +# - code: A quality classifier for codes. It's trained over code samples that +# have stars >= 1372 as positive samples and random samples from left +# data as negative samples. Stars count 1372 splits a nearly 700w subset +# with most stars. +# All these 3 classifiers are trained using the same training pipeline as GPT3 +# based on PySpark but with different tokenizers and keeping methods: +# - gpt3: standard Tokenizer from spark & GPT3 keeping method based on pareto +# - chinese: sentencepiece tokenizer for Chinese & label +# - code: sentencepiece tokenizer for code & label +# +# This tool needs several arguments: +# - dataset_path: the path to the dataset you want to predict doc_scores for. +# - result_path: the path to store the predicted result dataset. +# - model: quality classifier name to apply. It's "gpt3" in default. You can +# use one of ["gpt3", "chinese", "code"] we provided, or you can set it +# to the path to your own model trained using the train.py tool. +# - tokenizer: what tokenizer to use to tokenize texts. It's None in default, +# which means using the standard Tokenizer of PySpark. You can use one of +# ["zh.sp.model", "code.sp.model"] we provided, or you can set it to the +# path to your own sentencepiece model. +# - keep_method: the method to label should_keep field for each sample. It's +# "gpt3" in default. Should be one of ["gpt3", "label"]. +# - text_key: the field key name to hold texts to be classified. It's "text" +# in default. +# - overall_statics: whether to output an overall statics report on predicted +# document scores. It's False in default. +# +# Recommended arguments for provided trained models: +# - gpt3: +# - model: gpt3 +# - tokenizer: None +# - keep_method: gpt3 +# - chinese: +# - model: chinese +# - tokenizer: zh.sp.model +# - keep_method: label +# - code: +# - model: code +# - tokenizer: code.sp.model +# - keep_method: label +# +# Notice: +# 1. The configs of SparkSession in function init_spark can be modified to be +# more suitable for your own machine. See function init_spark in +# qc_utils.py. +# 2. Random factors are involved in "gpt3" model. So you might get different +# should_keep label in different running processes. But you should get +# same doc_score predictions in different running processes. + +import os + +import fire +from loguru import logger + +from qc_utils import (export_result, init_spark, load_dataset, predict, + prepare_model) + + +@logger.catch +def main(dataset_path, + result_path, + model='gpt3', + tokenizer=None, + keep_method='gpt3', + text_key='text', + overall_statics=False): + """ + Apply quality classifier for your dataset. + :param dataset_path: the path to the dataset you want to predict for. + :param result_path: the path to store the predicted result dataset. + :param model: quality classifier name to apply. It's "gpt3" in default. You + can use one of ["gpt3", "chinese", "code"] we provided, or you can set + it to the path to your own model trained using the train.py tool. + :param tokenizer: what tokenizer to use to tokenize texts. It's None in + default, which means using the standard Tokenizer of PySpark. You can + use one of ["zh.sp.model", "code.sp.model"] we provided, or you can set + it to the path to your own sentencepiece model. + :param keep_method: the method to label should_keep field for each sample. + It's "gpt3" in default. Should be one of ["gpt3", "label"]. + :param text_key: the field key name to hold texts to be classified. It's + "text" in default. + :param overall_statics: whether to output an overall statics report on + predicted document scores. It's False in default. + :return: + """ + # set default tokenizers for default models + if model == 'chinese': + tokenizer = 'zh.sp.model' + keep_method = 'label' + if model == 'code': + tokenizer = 'code.sp.model' + keep_method = 'label' + if model == 'gpt3': + tokenizer = None + keep_method = 'gpt3' + + spark = init_spark() + model = prepare_model(model_name=model) + ds = load_dataset(spark, dataset_path, text_key=text_key) + pred = predict(model, ds, tokenizer=tokenizer, keep_method=keep_method) + export_result(pred, result_path) + + # generate overall statistics on doc scores + if overall_statics: + overall = pred.select('doc_score').toPandas().describe(include='all') + # export to result report file + overall.to_csv(os.path.join(result_path, 'overall.csv')) + overall.to_markdown(os.path.join(result_path, 'overall.md')) + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/demos/tool_quality_classifier/quality_classifier/qc_utils.py b/demos/tool_quality_classifier/quality_classifier/qc_utils.py new file mode 100644 index 000000000..862e6f1bd --- /dev/null +++ b/demos/tool_quality_classifier/quality_classifier/qc_utils.py @@ -0,0 +1,214 @@ +import os +import zipfile + +import numpy as np +import sentencepiece as spm +import wget +from loguru import logger +from pyspark.ml import Pipeline, PipelineModel +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.feature import HashingTF, Tokenizer +from pyspark.sql import SparkSession +from pyspark.sql.functions import col, rand, udf +from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StringType + +from data_juicer.utils.cache_utils import DATA_JUICER_MODELS_CACHE +from data_juicer.utils.model_utils import (MODEL_LINKS, + prepare_sentencepiece_model) + + +def init_spark(): + """ + Initialize a spark session. You can set parameters such as memory, number + of partitions, timeout and so on here. + :return: A spark session instance. + """ + spark = (SparkSession.builder.config('spark.driver.memory', '64g').config( + 'spark.executor.memory', + '64g').config('spark.sql.shuffle.partitions', '300').config( + 'spark.sql.execution.arrow.pyspark.enabled', + 'true').config('spark.executor.memoryOverhead', '20000').config( + 'spark.network.timeout', + '10000s').config('spark.executor.heartbeatInterval', + '3600s').getOrCreate()) + logger.info('Spark initialization done.') + return spark + + +def prepare_model(model_name, model_path=DATA_JUICER_MODELS_CACHE): + udm = False + if model_name not in ['gpt3', 'chinese', 'code']: + # use user-specific mdoel + real_model_path = model_name + udm = True + else: + # use prepared models we provided + model_name = '%s_quality_model' % model_name + real_model_path = os.path.join(model_path, model_name) + logger.info(f'Preparing scorer model in [{real_model_path}]...') + if os.path.exists(real_model_path) and os.path.isdir(real_model_path): + return PipelineModel.load(real_model_path) + if udm: + logger.error(f'Customized model [{real_model_path}] cannot be loaded.') + exit(0) + # No specific models in local file systems. Download them from remote. + os.makedirs(model_path, exist_ok=True) + wget.download(os.path.join(MODEL_LINKS, f'{model_name}.zip'), + os.path.join(model_path, f'{model_name}.zip'), + bar=None) + with zipfile.ZipFile(os.path.join(model_path, f'{model_name}.zip')) as zip: + zip.extractall(os.path.join(model_path)) + return PipelineModel.load(real_model_path) + + +def load_dataset(spark, ds_path, text_key='text', only_text=False): + logger.info(f'Loading dataset from [{ds_path}]...') + if ds_path.endswith('.json') or ds_path.endswith('.jsonl'): + df = spark.read.json(ds_path) + elif ds_path.endswith('.parquet'): + df = spark.read.parquet(ds_path) + else: + raise NotImplementedError('Dataset type is not supported for now. ' + 'Suffix of dataset file should be one of ' + '[.json, .jsonl, .parquet]') + if text_key != 'text': + df = df.withColumnRenamed(text_key, 'text') + if only_text: + return df.select('text') + else: + return df + + +def load_datasets(spark, + ds_paths, + text_key='text', + label=None, + only_text=True): + if len(ds_paths) == 0: + logger.warning('No dataset path provided.') + return None + base_ds = load_dataset(spark, ds_paths[0], text_key, only_text) + for i in range(1, len(ds_paths)): + base_ds = base_ds.unionAll( + load_dataset(spark, ds_paths[i], text_key, only_text)) + if label is not None: + # add labels for training pipelines + return base_ds.selectExpr('text', '%d as label' % label) + else: + return base_ds + + +def shuffle(df): + temp_df = df.withColumn('rand', rand(seed=42)) + df_rnd = temp_df.orderBy(temp_df.rand) + return df_rnd.drop(df_rnd.rand) + + +def export_result(ds, res_path): + logger.info(f'Exporting predicted result to [{res_path}]') + if res_path.endswith('.json') or res_path.endswith('.jsonl'): + ds.write.mode('overwrite').format('json').save(res_path) + elif res_path.endswith('.parquet'): + ds.write.mode('overwrite').format('parquet').save(res_path) + else: + ds.write.mode('overwrite').save(res_path) + + +def get_keep_method_udf(keep_method): + if keep_method == 'label': + return udf(lambda score: int(score > 0.5), IntegerType()) + elif keep_method == 'gpt3': + pareto = 9 + return udf(lambda score: int(score > 1 - np.random.pareto(pareto)), + IntegerType()) + else: + raise NotImplementedError(f'Keep method [{keep_method}] is not ' + f'implemented for now.') + + +def tokenize_dataset(ds, tokenizer): + if os.path.exists(tokenizer): + # if it's a local model + tkn = spm.SentencePieceProcessor() + tkn.load(tokenizer) + else: + # else, try to load it from our remote model list + tkn = prepare_sentencepiece_model(tokenizer, ()) + tokenizer_udf = udf(lambda text: tkn.encode_as_pieces(text), + ArrayType(StringType())) + logger.info('Tokenize texts using specific tokenizer...') + return ds.withColumn('words', tokenizer_udf(col('text'))) + + +def train(output_model_path, ds, tokenizer=None): + logger.info('Preparing training quality classifier model...') + if tokenizer: + # tokenizer is not standard Tokenizer in PySpark, need to apply it + # explicitly + ds = tokenize_dataset(ds, tokenizer) + + # model + hashingTF = HashingTF(inputCol='words', outputCol='features') + lr = LogisticRegression() + if tokenizer is None: + std_tokenizer = Tokenizer(inputCol='text', outputCol='words') + pipeline = Pipeline(stages=[std_tokenizer, hashingTF, lr]) + else: + pipeline = Pipeline(stages=[hashingTF, lr]) + + logger.info('Start training...') + model = pipeline.fit(ds) + + logger.info('Trained model saving...') + model.write().overwrite().save(output_model_path) + + +def eval(model_path, ds, tokenizer=None): + logger.info('Preparing to evaluate...') + if tokenizer: + # tokenizer is not standard Tokenizer in PySpark, need to apply it + # explicitly + ds = tokenize_dataset(ds, tokenizer) + + logger.info('Start evaluation...') + model = prepare_model(model_path) + pred = model.transform(ds) + P = pred.filter('label = 1') + N = pred.filter('label = 0') + TP = P.filter('prediction = 1').count() + 1 + FP = N.filter('prediction = 1').count() + 1 + TN = N.filter('prediction = 0').count() + 1 + FN = P.filter('prediction = 0').count() + 1 + precision = 1.0 * TP / (TP + FP) + recall = 1.0 * TP / P.count() + F1 = 2.0 * precision * recall / (precision + recall) + logger.info(f'TP: {TP}, FN: {FN}') + logger.info(f'FP: {FP}, TN: {TN}') + logger.info(f'P: {precision}, R: {recall}, F1: {F1}') + +def predict(model, ds, tokenizer=None, keep_method='label'): + logger.info('Start scoring dataset...') + if tokenizer: + # tokenizer is not standard Tokenizer in PySpark, need to apply it + # explicitly + ds = tokenize_dataset(ds, tokenizer) + + prediction = model.transform(ds) + + # A UDF to extract doc scores from probability vectors + def extract_prob(v): + try: + return float(v[1]) + except ValueError: + return None + + extract_prob_udf = udf(extract_prob, DoubleType()) + doc_score = prediction.withColumn('doc_score', + extract_prob_udf(col('probability'))) + + # A UDF to get the bool value indicating whether this sample should be kept + should_keep_label_udf = get_keep_method_udf(keep_method) + should_keep = doc_score.withColumn('should_keep', + should_keep_label_udf(col('doc_score'))) + return should_keep.drop('words', 'features', 'rawPrediction', + 'probability', 'prediction') diff --git a/demos/tool_quality_classifier/quality_classifier/train.py b/demos/tool_quality_classifier/quality_classifier/train.py new file mode 100644 index 000000000..ea4459c69 --- /dev/null +++ b/demos/tool_quality_classifier/quality_classifier/train.py @@ -0,0 +1,113 @@ +# This tool is used for training a quality classifier for your own datasets +# based on PySpark. +# +# After training, this tool will generate a classifier model in a specific +# directory. You can use it to evaluate or predict on other datasets using eval +# and predict tools. +# +# This tool needs several arguments: +# - positive_datasets: the paths to the positive datasets. It could be a +# string for a single dataset, e.g. 'pos.parquet', or a list of strings +# for several datasets, e.g. '["pos1.parquet", "pos2.parquet"]'. +# - negative_datasets: the paths to the negative datasets. It could be a +# string for a single dataset, e.g. 'neg.parquet', or a list of strings +# for several datasets, e.g. '["neg1.parquet", "neg2.parquet"]'. +# - output_model_path: the path to store the trained quality classifier. It's +# "my_quality_model" in default. +# - num_training_samples: number of samples used to train the model. It's 0 +# in default, which means using all samples in datasets to train. +# - train_test_split_ratio: ratio to split train and test set. It's 0.8 in +# default. +# - tokenizer: what tokenizer to use to tokenize texts. It's None in default, +# which means using the standard Tokenizer of PySpark. You can use one of +# ["zh.sp.model", "code.sp.model"] we provided, or you can set it to the +# path to your own sentencepiece model. +# - evaluation: whether to evaluate the model after training using test set. +# It's True in default. +# - text_key: the field key name to hold texts to be classified. It's "text" +# in default. + +import fire +from loguru import logger + +from qc_utils import eval, init_spark, load_datasets, shuffle, train + + +@logger.catch +def main(positive_datasets, + negative_datasets, + output_model_path='my_quality_model', + num_training_samples=0, + train_test_split_ratio=0.8, + tokenizer=None, + evaluation=True, + text_key='text'): + """ + Train a quality classifier using your own pos/neg datasets. + :param positive_datasets: the paths to the positive datasets. It could be a + string for a single dataset, e.g. 'pos.parquet', or a list of strings + for several datasets, e.g. '["pos1.parquet", "pos2.parquet"]'. + :param negative_datasets: the paths to the negative datasets. It could be a + string for a single dataset, e.g. 'neg.parquet', or a list of strings + for several datasets, e.g. '["neg1.parquet", "neg2.parquet"]'. + :param output_model_path: the path to store the trained quality classifier. + It's "my_quality_model" in default. + :param num_training_samples: number of samples used to train the model. + It's 0 in default, which means using all samples in datasets to train. + :param train_test_split_ratio: ratio to split train and test set. It's 0.8 + in default. + :param tokenizer: what tokenizer to use to tokenize texts. It's None in + default, which means using the standard Tokenizer of PySpark. You can + use one of ["zh.sp.model", "code.sp.model"] we provided, or you can set + it to the path to your own sentencepiece model. + :param evaluation: whether to evaluate the model after training using test + set. It's True in default. + :param text_key: the field key name to hold texts to be classified. It's + "text" in default. + :return: + """ + # convert a single dataset to a dataset list + if isinstance(positive_datasets, str): + positive_datasets = [positive_datasets] + if isinstance(negative_datasets, str): + negative_datasets = [negative_datasets] + + spark = init_spark() + + pos = load_datasets(spark, + positive_datasets, + text_key=text_key, + label=1, + only_text=True) + neg = load_datasets(spark, + negative_datasets, + text_key=text_key, + label=0, + only_text=True) + + if pos is None or neg is None: + logger.error('Empty dataset in positive/negative dataset list...') + exit(1) + + if num_training_samples > 0: + logger.info(f'Only use {num_training_samples} pairs samples to train.') + pos = shuffle(pos).limit(num_training_samples) + neg = shuffle(neg).limit(num_training_samples) + + # merge pos and neg samples + ds = pos.unionAll(neg) + train_set, test_set = ds.randomSplit( + [train_test_split_ratio, 1.0 - train_test_split_ratio], seed=42) + + logger.info(f'Number of training samples: {train_set.count()}, ' + f'test samples: {test_set.count()}') + + # ML pipeline + train(output_model_path, train_set, tokenizer) + + if evaluation: + eval(output_model_path, test_set, tokenizer) + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/docs/.DS_Store b/docs/.DS_Store new file mode 100644 index 000000000..f9a28fc79 Binary files /dev/null and b/docs/.DS_Store differ diff --git a/docs/DeveloperGuide.md b/docs/DeveloperGuide.md new file mode 100644 index 000000000..d00441133 --- /dev/null +++ b/docs/DeveloperGuide.md @@ -0,0 +1,205 @@ +# How-to Guide for Developers + +* [How-to Guide for Developers](#how-to-guide-for-developers) + * [Coding Style](#coding-style) + * [Build your own ops](#build-your-own-ops) + * [Build your own configs](#build-your-own-configs) + * [Fruitful config sources & Type hints](#fruitful-config-sources--type-hints) + * [Hierarchical configs and helps](#hierarchical-configs-and-helps) + +## Coding Style + +We define our styles in `.pre-commit-config.yaml`. Before committing, +please install `pre-commit` tool to check and modify accordingly: + +```shell +# ===========install pre-commit tool=========== +pip install pre-commit + +cd +# install pre-commit script for data_juicer +pre-commit install + + +# ===========check all files=========== +git add . +pre-commit run --all-files + +# commit after all checking are passed +git commit -m "xxxx" +``` + +## Build your own ops + +- Data-Juicer allows everybody to build their own ops. +- Before implementing a new op, please refer to [Operators](Operators.md) to avoid unnecessary duplication. +- Assuming we want to add a new Filter operator called "TextLengthFilter" to get corpus of expected text length, we can follow these steps to build it. + +1. Create a new op file `text_length_filter.py` in the corresponding `data_juicer/ops/filter/` directory as follows. + - Because it's a Filter op, so the new op needs to inherit from the basic `Filter` class in the `base_op.py`, and be decorated with `OPERATORS` to register itself automatically. + +```python +import sys + +from jsonargparse.typing import PositiveInt + +from ..base_op import OPERATORS, Filter + + +@OPERATORS.register_module('text_length_filter') +class TextLengthFilter(Filter): + """ +Filter to keep samples with total text length within a specific range. + """ + + def __init__( + self, + min_len: PositiveInt = 10, + max_len: PositiveInt = sys.maxsize, + *args, + **kwargs + ): + """ + Initialization method. + :param min_len: The min text length in the filtering. + :param max_len: The max text length in the filtering. + """ + super().__init__(*args, **kwargs) + self.min_len = min_len + self.max_len = max_len + + def compute_stats(self, sample): + # check if it's computed already + if 'text_len' in sample['stats']: + return sample + + sample['stats']['text_len'] = len(sample['text']) + return sample + + def process(self, sample): + if self.min_len <= sample['stats']['text_len'] <= self.max_len: + return True + else: + return False +``` + +2. After implemention, add it to the op dictionary in the `__init__.py` file in `data_juicer/ops/filter/` directory. + +```python +from . import (..., # other ops + text_length_filter) # import this new op module +``` + +3. Now you can use this new op with custom arguments in your own config files! + +```yaml +# other configs +... + +# process configs +process: + - text_length_filter: # add this op to your process list and set the parameters + min_len: 10 + max_len: 1000 +``` + +4. (Strongly Recommend) It's better to add corresponding tests for your own ops. For `TextLengthFilter` above, you would like to add `test_text_length_filter.py` into `tests/ops/filter/` directory as below. + +```python +import unittest +from data_juicer.ops.filter.text_length_filter import TextLengthFilter + +class TextLengthFilterTest(unittest.TestCase): + + def test_func1(self): + pass + + def test_func2(self): + pass + + def test_func3(self): + pass +``` + +## Build your own configs +- We provide easy configuration based on [jsonargparse](https://github.com/omni-us/jsonargparse/) to reduce cost for boilerplate codes. + +### Fruitful config sources & Type hints +- A global config object can be initialized via +``` +# core.executor.py +self.cfg = init_configs() +``` +- in which function arguments from diverse sources can be specified and mixed +up, including +1. *hard-coded default values* when registering the config into parser or specified in the classes' `__init__` functions +2. default *config files* in json (yaml or jsonnet supersets) +3. *environment variables* +4. *POSIX-style command line arguments*, such as ``--project_name + my_data_demo`` or ``--project_name=my_data_demo`` , including config files + +- The final parsed values are mixed from these sources. And the override order is the same as the numbers above. + +Besides, many argument types and respective validation are supported. +Including python built-in types, types from [Lib/typing](https://docs.python.org/3/library/typing.html) module, and +extended [types](https://jsonargparse.readthedocs.io/en/stable/#type-hints) +from jsonargparse, such as `restricted types` and `Paths` with customized +limitations. + +### Hierarchical configs and helps +- You can use dot notation in the argument names freely to define the +hierarchy, e.g., `maximum_line_length_filter.min`. +More importantly, by default, we automatically register the configs from +the docstrings of implemented operators. That is, the structure of all +configs are always in sync with codes. + +- You can get the hierarchical help information by running a script that calls +our executor such as +``` +$ python tools/process_data.py --help + +usage: process_data.py [-h] [--config CONFIG] [--print_config[=flags]] [--project_name PROJECT_NAME] [--dataset_path DATASET_PATH] [--dataset_dir DATASET_DIR] [--export_path EXPORT_PATH] [--process PROCESS] + [--np NP] [--text_key TEXT_KEY] [--document_deduplicator CONFIG] [--document_deduplicator.hash_method HASH_METHOD] [--document_deduplicator.lowercase LOWERCASE] + [--document_deduplicator.ignore_non_character IGNORE_NON_CHARACTER] [--language_id_score_filter CONFIG] [--language_id_score_filter.lang LANG] [--words_num_filter CONFIG] [--words_num_filter.min MIN] [--words_num_filter.max MAX] + [--alphanumeric_filter CONFIG] [--alphanumeric_filter.min MIN] [--alphanumeric_filter.max MAX] [--average_line_length_filter CONFIG] [--average_line_length_filter.min MIN] [--average_line_length_filter.max MAX] + [--maximum_line_length_filter CONFIG] [--maximum_line_length_filter.min MIN] [--maximum_line_length_filter.max MAX] [--text_length_filter CONFIG] [--text_length_filter.min MIN] [--text_length_filter.max MAX] + [--remove_comments_mapper CONFIG] [--remove_comments_mapper.type TYPE] [--remove_comments_mapper.inline INLINE] [--remove_comments_mapper.multiline MULTILINE] [--remove_header_mapper CONFIG] + [--remove_header_mapper.before_section BEFORE_SECTION] + +optional arguments: + -h, --help Show this help message and exit. + --config CONFIG Path to a configuration file. + --print_config[=flags] + Print the configuration after applying all other arguments and exit. The optional flags customizes the output and are one or more keywords separated by comma. The supported flags are: comments, skip_default, skip_null. + --project_name PROJECT_NAME + name of your data process project. (type: str, default: null) + --dataset_path DATASET_PATH + path to your dataset file, relative with respect to the config file’s location (type: Path_fr, default: null) + --dataset_dir DATASET_DIR + path to your dataset(s) within a directory, relative with respect to the config file’s location (type: Path_drw, default: null) + --export_path EXPORT_PATH + path to the output processed dataset, relative with respect to the config file’s location (type: Path_fc, default: null) + --process PROCESS, --process+ PROCESS + a list of several process operators with their arguments (type: List[Dict], default: null) + --np NP number of subprocess to process your dataset. (type: PositiveInt, default: null) + --text_key TEXT_KEY the key name of field that stores sample texts (type: Optional[str], default: content) + +: + --alphanumeric_filter CONFIG + Path to a configuration file. + --alphanumeric_filter.min MIN + the min filter rate in alphanumeric op. (type: ClosedUnitInterval, default: 0.0) + --alphanumeric_filter.max MAX + the max filter rate in alphanumeric op. (type: ClosedUnitInterval, default: 0.25) + +: + --text_length_filter CONFIG + Path to a configuration file. + --text_length_filter.min MIN + min text length in the filtering (type: int, default: 10) + --text_length_filter.max MAX + max text length in the filtering (type: int, default: 10000) + +...... + +``` diff --git a/docs/DeveloperGuide_ZH.md b/docs/DeveloperGuide_ZH.md new file mode 100644 index 000000000..080af587e --- /dev/null +++ b/docs/DeveloperGuide_ZH.md @@ -0,0 +1,193 @@ +# 开发者指南 + +[TOC] + +## 编码规范 + +我们将编码规范定义在 `.pre-commit-config.yaml` 中。在向仓库贡献代码之前,请使用 `pre-commit` 工具对代码进行规范化。 + +```shell +# ===========install pre-commit tool=========== +pip install pre-commit + +cd +# install pre-commit script for data_juicer +pre-commit install + + +# ===========check all files=========== +git add . +pre-commit run --all-files + +# commit after all checking are passed +git commit -m "xxxx" +``` + +## 构建自己的算子 + +- Data-Juicer 支持每个人定义自己的算子。 +- 在实现新的算子之前,请参考 [Operators](Operators_ZH.md) 以避免不必要的重复。 +- 假设要添加一个名为 “TextLengthFilter” 的运算符以过滤仅包含预期文本长度的样本语料,可以按照以下步骤进行构建。 + +1. 在 `data_juicer/ops/filter/` 目录下创建一个新的算子文件 `text_length_filter.py`,内容如下: + - 因为它是一个 Filter 算子,所以需要继承 `base_op.py` 中的 `Filter` 基类,并用 `OPERATORS` 修饰以实现自动注册。 + +```python +import sys + +from jsonargparse.typing import PositiveInt + +from ..base_op import OPERATORS, Filter + + +@OPERATORS.register_module('text_length_filter') +class TextLengthFilter(Filter): + """ +Filter to keep samples with total text length within a specific range. + """ + + def __init__( + self, + min_len: PositiveInt = 10, + max_len: PositiveInt = sys.maxsize, + ): + """ + Initialization method. + :param min_len: The min text length in the filtering. + :param max_len: The max text length in the filtering. + """ + self.min_len = min_len + self.max_len = max_len + + def compute_stats(self, sample): + # check if it's computed already + if 'text_len' in sample['stats']: + return sample + + sample['stats']['text_len'] = len(sample['text']) + return sample + + def process(self, sample): + if self.min_len <= sample['stats']['text_len'] <= self.max_len: + return True + else: + return False +``` + +2. 实现后,将其添加到 `data_juicer/ops/filter` 目录下 `__init__.py` 文件中的算子字典中: + +```python +from . import (..., # other ops + text_length_filter) # import this new op module + +``` + +3. 全部完成!现在您可以在自己的配置文件中使用新添加的算子: + +```yaml +# other configs +... + +# process configs +process: + - text_length_filter: # add this op to your process list and set the parameters + min_len: 10 + max_len: 1000 +``` + +4. (强烈推荐)最好为新添加的算子进行单元测试。对于上面的 `TextLengthFilter` 算子,建议在 `tests/ops/filter/` 中实现如 `test_text_length_filter.py` 的测试文件: + +```python +import unittest +from data_juicer.ops.filter.text_length_filter import TextLengthFilter + +class TextLengthFilterTest(unittest.TestCase): + + def test_func1(self): + pass + + def test_func2(self): + pass + + def test_func3(self): + pass +``` + +## 构建自己的配置 + +- 我们提供基于 [jsonargparse](https://github.com/omni-us/jsonargparse/) 的简单配置以降低样板代码的成本。 + +### 丰富的配置源和类型提示 + +- 全局配置对象可以通过以下方式初始化 + +```python +# core.executor.py +self.cfg = init_configs() +``` + +- 其中可以指定和混合来自不同来源的函数参数,包括 +1. *硬编码默认值* 将配置注册到解析器中或在类的 `__init__` 函数中指定 +2. json 格式的默认*配置文件*(yaml 或 jsonnet 超集) +3. *环境变量* +4. *POSIX-style 命令行参数*, 例如 `--project_name my_data_demo` 或 `--project_name=my_data_demo`,包含配置文件 + +- 最终解析的值是来自这些来源的混合。 并且覆盖顺序与上面的数字相同。 + +此外,还支持许多参数类型和相应的验证。 +包含 Python内置类型、来自 [Lib/typing](https://docs.python.org/3/library/typing.html) 的类型,以及来自 jsonargparse 的 [扩展类型](https://jsonargparse.readthedocs.io/en/stable/#type-hints),例如具有自定义限制的 `restricted types` 和 `Paths`。 + +### Hierarchical configs and helps + +- 您可以在参数名称中自由使用点符号来定义层次结构, 例如 `maximum_line_length_filter.min`. +更重要的是,默认情况下,我们自动注册已实现的运算符的 docstring。 也就是说,所有的结构配置始终与代码同步。 +- 您可以通过运行脚本来获取层次化的帮助信息,例如: + +``` +$ python tools/process_data.py --help + +usage: process_data.py [-h] [--config CONFIG] [--print_config[=flags]] [--project_name PROJECT_NAME] [--dataset_path DATASET_PATH] [--dataset_dir DATASET_DIR] [--export_path EXPORT_PATH] [--process PROCESS] + [--np NP] [--text_key TEXT_KEY] [--document_deduplicator CONFIG] [--document_deduplicator.hash_method HASH_METHOD] [--document_deduplicator.lowercase LOWERCASE] + [--document_deduplicator.ignore_non_character IGNORE_NON_CHARACTER] [--language_id_score_filter CONFIG] [--language_id_score_filter.lang LANG] [--words_num_filter CONFIG] [--words_num_filter.min MIN] [--words_num_filter.max MAX] + [--alphanumeric_filter CONFIG] [--alphanumeric_filter.min MIN] [--alphanumeric_filter.max MAX] [--average_line_length_filter CONFIG] [--average_line_length_filter.min MIN] [--average_line_length_filter.max MAX] + [--maximum_line_length_filter CONFIG] [--maximum_line_length_filter.min MIN] [--maximum_line_length_filter.max MAX] [--text_length_filter CONFIG] [--text_length_filter.min MIN] [--text_length_filter.max MAX] + [--remove_comments_mapper CONFIG] [--remove_comments_mapper.type TYPE] [--remove_comments_mapper.inline INLINE] [--remove_comments_mapper.multiline MULTILINE] [--remove_header_mapper CONFIG] + [--remove_header_mapper.before_section BEFORE_SECTION] + +optional arguments: + -h, --help Show this help message and exit. + --config CONFIG Path to a configuration file. + --print_config[=flags] + Print the configuration after applying all other arguments and exit. The optional flags customizes the output and are one or more keywords separated by comma. The supported flags are: comments, skip_default, skip_null. + --project_name PROJECT_NAME + name of your data process project. (type: str, default: null) + --dataset_path DATASET_PATH + path to your dataset file, relative with respect to the config file’s location (type: Path_fr, default: null) + --dataset_dir DATASET_DIR + path to your dataset(s) within a directory, relative with respect to the config file’s location (type: Path_drw, default: null) + --export_path EXPORT_PATH + path to the output processed dataset, relative with respect to the config file’s location (type: Path_fc, default: null) + --process PROCESS, --process+ PROCESS + a list of several process operators with their arguments (type: List[Dict], default: null) + --np NP number of subprocess to process your dataset. (type: PositiveInt, default: null) + --text_key TEXT_KEY the key name of field that stores sample texts (type: Optional[str], default: content) + +: + --alphanumeric_filter CONFIG + Path to a configuration file. + --alphanumeric_filter.min MIN + the min filter rate in alphanumeric op. (type: ClosedUnitInterval, default: 0.0) + --alphanumeric_filter.max MAX + the max filter rate in alphanumeric op. (type: ClosedUnitInterval, default: 0.25) + +: + --text_length_filter CONFIG + Path to a configuration file. + --text_length_filter.min MIN + min text length in the filtering (type: int, default: 10) + --text_length_filter.max MAX + max text length in the filtering (type: int, default: 10000) + +...... + +``` diff --git a/docs/Operators.md b/docs/Operators.md new file mode 100644 index 000000000..df181eb2c --- /dev/null +++ b/docs/Operators.md @@ -0,0 +1,105 @@ +# Operator Schemas + +Operators are a collection of basic processes that assist in data modification, cleaning, filtering, deduplication, etc. We support a wide range of data sources and file formats, and allow for flexible extension to custom datasets. + + +## Overview + +The operators in Data-Juicer are categorized into 5 types. + +| Type | Number | Description | +|-----------------------------------|:------:|-------------| +| [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data | +| [ Mapper ]( #mapper ) | 17 | Edits and transforms samples | +| [ Filter ]( #filter ) | 15 | Filters out low-quality samples | +| [ Deduplicator ]( #deduplicator ) | 3 | Detects and removes duplicate samples | +| [ Selector ]( #selector ) | 2 | Selects top samples based on ranking | + + +All the specific operators are listed below, each featured with several capability tags. + +* Domain Tags + - General: general purpose + - LaTeX: specific to LaTeX source files + - Code: specific to programming codes + - Financial: closely related to financial sector +* Language Tags + - en: English + - zh: Chinese + + +## Formatter + +| Operator | Domain | Lang | Description | +|-------------------|---------|--------|--------------------------------------------------------------------| +| remote_formatter | General | en, zh | Prepares datasets from remote (e.g., HuggingFace) | +| csv_formatter | General | en, zh | Prepares local `.csv` files | +| tsv_formatter | General | en, zh | Prepares local `.tsv` files | +| json_formatter | General | en, zh | Prepares local `.json`, `.jsonl`, `.jsonl.zst` files | +| parquet_formatter | General | en, zh | Prepares local `.parquet` files | +| text_formatter | General | en, zh | Prepares other local text files ([complete list](data_juicer/format/text_formatter.py#L46,56)) | +| mixture_formatter | General | en, zh | Handles a mixture of all the supported local file types | + + +## Mapper + +| Operator | Domain | Lang | Description | +|-----------------------------------------------|--------------------|--------|----------------------------------------------------------------------------------------------------------------| +| remove_header_mapper | LaTeX | en, zh | Removes the running headers of TeX documents, e.g., titles, chapter or section numbers/names | +| remove_bibliography_mapper | LaTeX | en, zh | Removes the bibliography of TeX documents | +| expand_macro_mapper | LaTeX | en, zh | Expands macros usually defined at the top of TeX documents | +| whitespace_normalization_mapper | General | en, zh | Normalizes various Unicode whitespaces to the normal ASCII space (U+0020) | +| punctuation_normalization_mapper | General | en, zh | Normalizes various Unicode punctuations to their ASCII equivalents | +| fix_unicode_mapper | General | en, zh | Fixes broken Unicodes (by [ftfy](https://ftfy.readthedocs.io/)) | +| sentence_split_mapper | General | en | Splits and reorganizes sentences according to semantics | +| remove_long_words_mapper | General | en, zh | Removes words with length outside the specified range | +| remove_words_with_incorrect_
    substrings_mapper | General | en, zh | Removes words containing specified substrings | +| clean_email_mapper | General | en, zh | Removes email information | +| clean_ip_mapper | General | en, zh | Removes IP addresses | +| clean_links_mapper | General, Code | en, zh | Removes links, such as those starting with http or ftp | +| clean_html_mapper | General | en, zh | Removes HTML tags and returns plain text of all the nodes | +| remove_table_text_mapper | General, Financial | en | Detects and removes possible table contents (:warning: relies on regular expression matching and thus fragile) | +| clean_copyright_mapper | Code | en, zh | Removes copyright notice at the beginning of code files (:warning: must contain the word *copyright*) | +| remove_specific_chars_mapper | General | en, zh | Removes any user-specified characters or substrings | + + +## Filter
    + +| Operator | Domain | Lang | Description | +|--------------------------------|---------|--------|--------------------------------------------------------------------------------------------| +| word_num_filter | General | en, zh | Keeps samples with word count within the specified range | +| stopwords_filter | General | en, zh | Keeps samples with stopword ratio above the specified threshold | +| flagged_words_filter | General | en, zh | Keeps samples with flagged-word ratio below the specified threshold | +| character_repetition_filter | General | en, zh | Keeps samples with char-level n-gram repetition ratio within the specified range | +| word_repetition_filter | General | en, zh | Keeps samples with word-level n-gram repetition ratio within the specified range | +| special_characters_filter | General | en, zh | Keeps samples with special-char ratio within the specified range | +| language_id_score_filter | General | en, zh | Keeps samples of the specified language, judged by a predicted confidence score | +| perplexity_filter | General | en, zh | Keeps samples with perplexity score below the specified threshold | +| maximum_line_length_filter | Code | en, zh | Keeps samples with maximum line length within the specified range | +| average_line_length_filter | Code | en, zh | Keeps samples with average line length within the specified range | +| alphanumeric_filter | General | en, zh | Keeps samples with alphanumeric ratio within the specified range | +| text_length_filter | General | en, zh | Keeps samples with total text length within the specified range | +| suffix_filter | General | en, zh | Keeps samples with specified suffixes | +| specified_field_filter | General | en, zh | Filters samples based on field, with value lies in the specified targets | +| specified_numeric_field_filter | General | en, zh | Filters samples based on field, with value lies in the specified range (for numeric types) | + + +## Deduplicator + +| Operator | Domain | Lang | Description | +|-------------------------------|---------|--------|-------------------------------------------------------------| +| document_deduplicator | General | en, zh | Deduplicate samples at document-level by comparing MD5 hash | +| document_minhash_deduplicator | General | en, zh | Deduplicate samples at document-level using MinHashLSH | +| document_simhash_deduplicator | General | en, zh | Deduplicate samples at document-level using SimHash | + + +## Selector + +| Operator | Domain | Lang | Description | +|------------------------------------|---------|--------|-----------------------------------------------------------------------| +| topk_specified_field_selector | General | en, zh | Selects top samples by comparing the values of the specified field | +| frequency_specified_field_selector | General | en, zh | Selects top samples by comparing the frequency of the specified field | + + +## Contributing +We welcome contributions of adding new operators. Please refer to [How-to Guide for Developers](DeveloperGuide.md). diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md new file mode 100644 index 000000000..b76bc944d --- /dev/null +++ b/docs/Operators_ZH.md @@ -0,0 +1,98 @@ +# 算子提要 + +算子 (Operator) 是协助数据修改、清理、过滤、去重等基本流程的集合。我们支持广泛的数据来源和文件格式,并支持对自定义数据集的灵活扩展。 + +## 概览 + +Data-Juicer 中的算子分为以下 5 种类型。 + +| 类型 | 数量 | 描述 | +|-----------------------------------|:------:|-------------| +| [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 | +| [ Mapper ]( #mapper ) | 17 | 对数据样本进行编辑和转换 | +| [ Filter ]( #filter ) | 15 | 过滤低质量样本 | +| [ Deduplicator ]( #deduplicator ) | 3 | 识别、删除重复样本 | +| [ Selector ]( #selector ) | 2 | 基于排序选取高质量样本 | + +下面列出所有具体算子,每种算子都通过多个标签来注明其主要功能。 + +* Domain 标签 + - General: 一般用途 + - LaTeX: 专用于 LaTeX 源文件 + - Code: 专用于编程代码 + - Financial: 与金融领域相关 +* Language 标签 + - en: 英文 + - zh: 中文 + + +## Formatter + +| 算子 | 场景 | 语言 | 描述 | +|-------------------|---------|--------|--------------------------------------------------------------------| +| remote_formatter | General | en, zh | 准备远端数据集 (如 HuggingFace) | +| csv_formatter | General | en, zh | 准备本地 `.csv` 文件 | +| tsv_formatter | General | en, zh | 准备本地 `.tsv` 文件 | +| json_formatter | General | en, zh | 准备本地 `.json`, `.jsonl`, `.jsonl.zst` 文件 | +| parquet_formatter | General | en, zh | 准备本地 `.parquet` 文件 | +| text_formatter | General | en, zh | 准备其他本地文本文件([完整的支持列表](data_juicer/format/text_formatter.py#L46,56)) | +| mixture_formatter | General | en, zh | 处理可支持本地文件的混合 | + +## Mapper + +| 算子 | 场景 | 语言 | 描述 | +|-----------------------------------------------|--------------------|--------|-------------------------------------------------------------------| +| remove_header_mapper | LaTeX | en, zh | 删除 TeX 文档头,例如标题、章节数字/名称等 | +| remove_bibliography_mapper | LaTeX | en, zh | 删除 TeX 文档的参考文献 | +| expand_macro_mapper | LaTeX | en, zh | 扩展通常在 TeX 文档顶部定义的宏 | +| whitespace_normalization_mapper | General | en, zh | 将各种 Unicode 空白标准化为常规 ASCII 空格 (U+0020) | +| punctuation_normalization_mapper | General | en, zh | 将各种 Unicode 标点符号标准化为其 ASCII 等效项 | +| fix_unicode_mapper | General | en, zh | 修复损坏的 Unicode(借助 [ftfy](https://ftfy.readthedocs.io/)) | +| sentence_split_mapper | General | en | 根据语义拆分和重组句子 | +| remove_long_words_mapper | General | en, zh | 删除长度超出指定范围的单词 | +| remove_words_with_incorrect_
    substrings_mapper | General | en, zh | 删除包含指定子字符串的单词 | +| clean_email_mapper | General | en, zh | 删除邮箱信息 | +| clean_ip_mapper | General | en, zh | 删除 IP 地址 | +| clean_links_mapper | General, Code | en, zh | 删除链接,例如以 http 或 ftp 开头的 | +| clean_html_mapper | General | en, zh | 删除 HTML 标签并返回所有节点的纯文本 | +| remove_table_text_mapper | General, Financial | en | 检测并删除可能的表格内容(:warning: 依赖正则表达式匹配,因此很脆弱) | +| clean_copyright_mapper | Code | en, zh | 删除代码文件开头的版权声明 (:warning: 必须包含单词 *copyright*) | +| remove_specific_chars_mapper | General | en, zh | 删除任何用户指定的字符或子字符串 | + +## Filter
    + +| 算子 | 场景 | 语言 | 描述 | +|--------------------------------|---------|--------|---------------------------------------------------------| +| word_num_filter | General | en, zh | 保留字数在指定范围内的样本 | +| stopwords_filter | General | en, zh | 保留停用词比率高于指定阈值的样本 | +| flagged_words_filter | General | en, zh | 保留使标记字比率保持在指定阈值以下的样本 | +| character_repetition_filter | General | en, zh | 保留 char-level n-gram 重复比率在指定范围内的样本 | +| word_repetition_filter | General | en, zh | 保留 word-level n-gram 重复比率在指定范围内的样本 | +| special_characters_filter | General | en, zh | 保留 special-char 比率的在指定范围内的样本 | +| language_id_score_filter | General | en, zh | 保留特定语言的样本,通过预测的置信度得分来判断 | +| perplexity_filter | General | en, zh | 保留困惑度低于指定阈值的样本 | +| maximum_line_length_filter | Code | en, zh | 保留最大行长度在指定范围内的样本 | +| average_line_length_filter | Code | en, zh | 保留平均行长度在指定范围内的样本 | +| alphanumeric_filter | General | en, zh | 保留字母数字比例在指定范围内的样本 | +| text_length_filter | General | en, zh | 保留总文本长度在指定范围内的样本 | +| suffix_filter | General | en, zh | 保留包含特定后缀的样本 | +| specified_field_filter | General | en, zh | 根据字段过滤样本,要求字段的值处于指定目标中 | +| specified_numeric_field_filter | General | en, zh | 根据字段过滤样本,要求字段的值处于指定范围(针对数字类型) | + +## Deduplicator + +| 算子 | 场景 | 语言 | 描述 | +|-------------------------------|---------|--------|----------------------------------------------| +| document_deduplicator | General | en, zh | 通过比较 MD5 哈希值在文档级别对样本去重 | +| document_minhash_deduplicator | General | en, zh | 使用 MinHashLSH 在文档级别对样本去重 | +| document_simhash_deduplicator | General | en, zh | 使用 SimHash 在文档级别对样本去重 | + +## Selector + +| 算子 | 场景 | 语言 | 描述 | +|------------------------------------|---------|--------|-----------------------------------------------| +| topk_specified_field_selector | General | en, zh | 通过比较指定字段的值选出前 k 个样本 | +| frequency_specified_field_selector | General | en, zh | 通过比较指定字段的频率选出前 k 个样本 | + +## 贡献 +我们欢迎社区贡献新的算子,具体请参考[开发者指南](DeveloperGuide_ZH.md)。 diff --git a/docs/config_def.png b/docs/config_def.png new file mode 100644 index 000000000..2ff77ef65 Binary files /dev/null and b/docs/config_def.png differ diff --git a/docs/imgs/data-juicer.png b/docs/imgs/data-juicer.png new file mode 100644 index 000000000..ffa59a3db Binary files /dev/null and b/docs/imgs/data-juicer.png differ diff --git a/docs/imgs/eval-01.png b/docs/imgs/eval-01.png new file mode 100644 index 000000000..382bf2743 Binary files /dev/null and b/docs/imgs/eval-01.png differ diff --git a/docs/imgs/eval-02.png b/docs/imgs/eval-02.png new file mode 100644 index 000000000..71ca49bb2 Binary files /dev/null and b/docs/imgs/eval-02.png differ diff --git a/docs/sphinx_doc/Makefile b/docs/sphinx_doc/Makefile new file mode 100644 index 000000000..d0c3cbf10 --- /dev/null +++ b/docs/sphinx_doc/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/sphinx_doc/README.md b/docs/sphinx_doc/README.md new file mode 100644 index 000000000..eb4372201 --- /dev/null +++ b/docs/sphinx_doc/README.md @@ -0,0 +1,32 @@ +# Data-Juicer Documentation + +We build our API documentation with help of Sphinx. +To update the generated +doc, please run the following commands: + +```bash +# $~/data_juicer/docs/sphinx_doc +# 1. install the sphinx requirements and init the sphinx-quickstart +pip install sphinx sphinx-autobuild sphinx_rtd_theme recommonmark +# or pip install -r ../../environments/dev_requires +sphinx-quickstart + +# 2. auto generate the doc files for all sub modules (*.rst) from source codes +sphinx-apidoc -o source ../../data_juicer + +# 3. modify the auto-generated files according to your requirements +vim source/modules.rst + +# 4. finalize the doc, which is stored in the `build/html` directory +make clean +make html +mv build/html position_to_publish +``` + +- For convenience (you don’t have to compile from scratch again), the built + directory (including the html files) can be download as follows: +```bash +# cd docs/sphinx_doc +wget https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/data_juicer/sphinx_API_build_0801.zip +unzip sphinx_API_build_0801.zip +``` diff --git a/docs/sphinx_doc/README_ZH.md b/docs/sphinx_doc/README_ZH.md new file mode 100644 index 000000000..4f57ea167 --- /dev/null +++ b/docs/sphinx_doc/README_ZH.md @@ -0,0 +1,31 @@ +# Data-Juicer 文档 + +Data-Juicer 借助 Sphinx 构建 API 文档。 +如需更新生成的文档,请运行以下命令: + +```bash +# $~/data_juicer/docs/sphinx_doc +# 1.安装 sphinx 的依赖并初始化 sphinx-quickstart +pip install sphinx sphinx-autobuild sphinx_rtd_theme recommonmark +# or pip install -r ../../environments/dev_requires +sphinx-quickstart + +# 2. 从源代码自动生成所有子模块(*.rst)的文档文件 +sphinx-apidoc -o source ../../data_juicer + +# 3. 根据您的要求修改自动生成的文件 +vim source/modules.rst + +# 4. 完成文档的构建,文档存储目录为 `build/html` +make clean +make html +mv build/html position_to_publish +``` + +- 为了方便起见(不必再次从头开始编译),可以按如下方式下载构建的目录(包括 html 文件): + +```bash +# cd docs/sphinx_doc +wget https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/data_juicer/sphinx_API_build_0801.zip +unzip sphinx_API_build_0801.zip +``` diff --git a/docs/sphinx_doc/make.bat b/docs/sphinx_doc/make.bat new file mode 100644 index 000000000..dc1312ab0 --- /dev/null +++ b/docs/sphinx_doc/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/sphinx_doc/source/conf.py b/docs/sphinx_doc/source/conf.py new file mode 100644 index 000000000..37aee1f02 --- /dev/null +++ b/docs/sphinx_doc/source/conf.py @@ -0,0 +1,42 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = 'data_juicer' +copyright = '2023, Data-Juicer Team' +author = 'Data-Juicer Team' + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +import sphinx_rtd_theme + +from data_juicer import __version__ as version + +release = version + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.viewcode', + 'sphinx.ext.napoleon', +] + +templates_path = ['_templates'] +exclude_patterns = ['build'] + +# -- Options for HTML output ------------------------------------------------- + +html_theme = 'sphinx_rtd_theme' +html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] diff --git a/docs/sphinx_doc/source/data_juicer.analysis.rst b/docs/sphinx_doc/source/data_juicer.analysis.rst new file mode 100644 index 000000000..e8a6c97a7 --- /dev/null +++ b/docs/sphinx_doc/source/data_juicer.analysis.rst @@ -0,0 +1,37 @@ +data\_juicer.analysis package +============================= + +Submodules +---------- + +data\_juicer.analysis.column\_wise\_analysis module +--------------------------------------------------- + +.. automodule:: data_juicer.analysis.column_wise_analysis + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.analysis.diversity\_analysis module +------------------------------------------------ + +.. automodule:: data_juicer.analysis.diversity_analysis + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.analysis.overall\_analysis module +---------------------------------------------- + +.. automodule:: data_juicer.analysis.overall_analysis + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: data_juicer.analysis + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/sphinx_doc/source/data_juicer.config.rst b/docs/sphinx_doc/source/data_juicer.config.rst new file mode 100644 index 000000000..9b7293596 --- /dev/null +++ b/docs/sphinx_doc/source/data_juicer.config.rst @@ -0,0 +1,21 @@ +data\_juicer.config package +=========================== + +Submodules +---------- + +data\_juicer.config.config module +--------------------------------- + +.. automodule:: data_juicer.config.config + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: data_juicer.config + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/sphinx_doc/source/data_juicer.core.rst b/docs/sphinx_doc/source/data_juicer.core.rst new file mode 100644 index 000000000..858d271ca --- /dev/null +++ b/docs/sphinx_doc/source/data_juicer.core.rst @@ -0,0 +1,53 @@ +data\_juicer.core package +========================= + +Submodules +---------- + +data\_juicer.core.analyser module +--------------------------------- + +.. automodule:: data_juicer.core.analyser + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.core.data module +----------------------------- + +.. automodule:: data_juicer.core.data + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.core.executor module +--------------------------------- + +.. automodule:: data_juicer.core.executor + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.core.exporter module +--------------------------------- + +.. automodule:: data_juicer.core.exporter + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.core.tracer module +------------------------------- + +.. automodule:: data_juicer.core.tracer + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: data_juicer.core + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/sphinx_doc/source/data_juicer.format.rst b/docs/sphinx_doc/source/data_juicer.format.rst new file mode 100644 index 000000000..575a5b16a --- /dev/null +++ b/docs/sphinx_doc/source/data_juicer.format.rst @@ -0,0 +1,77 @@ +data\_juicer.format package +=========================== + +Submodules +---------- + +data\_juicer.format.csv\_formatter module +----------------------------------------- + +.. automodule:: data_juicer.format.csv_formatter + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.format.formatter module +------------------------------------ + +.. automodule:: data_juicer.format.formatter + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.format.json\_formatter module +------------------------------------------ + +.. automodule:: data_juicer.format.json_formatter + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.format.load module +------------------------------- + +.. automodule:: data_juicer.format.load + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.format.mixture\_formatter module +--------------------------------------------- + +.. automodule:: data_juicer.format.mixture_formatter + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.format.parquet\_formatter module +--------------------------------------------- + +.. automodule:: data_juicer.format.parquet_formatter + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.format.text\_formatter module +------------------------------------------ + +.. automodule:: data_juicer.format.text_formatter + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.format.tsv\_formatter module +----------------------------------------- + +.. automodule:: data_juicer.format.tsv_formatter + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: data_juicer.format + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/sphinx_doc/source/data_juicer.ops.common.rst b/docs/sphinx_doc/source/data_juicer.ops.common.rst new file mode 100644 index 000000000..be34ff5bf --- /dev/null +++ b/docs/sphinx_doc/source/data_juicer.ops.common.rst @@ -0,0 +1,29 @@ +data\_juicer.ops.common package +=============================== + +Submodules +---------- + +data\_juicer.ops.common.helper\_func module +------------------------------------------- + +.. automodule:: data_juicer.ops.common.helper_func + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.common.special\_characters module +-------------------------------------------------- + +.. automodule:: data_juicer.ops.common.special_characters + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: data_juicer.ops.common + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/sphinx_doc/source/data_juicer.ops.deduplicator.rst b/docs/sphinx_doc/source/data_juicer.ops.deduplicator.rst new file mode 100644 index 000000000..d30ce1dad --- /dev/null +++ b/docs/sphinx_doc/source/data_juicer.ops.deduplicator.rst @@ -0,0 +1,37 @@ +data\_juicer.ops.deduplicator package +===================================== + +Submodules +---------- + +data\_juicer.ops.deduplicator.document\_deduplicator module +----------------------------------------------------------- + +.. automodule:: data_juicer.ops.deduplicator.document_deduplicator + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.deduplicator.document\_minhash\_deduplicator module +-------------------------------------------------------------------- + +.. automodule:: data_juicer.ops.deduplicator.document_minhash_deduplicator + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.deduplicator.document\_simhash\_deduplicator module +-------------------------------------------------------------------- + +.. automodule:: data_juicer.ops.deduplicator.document_simhash_deduplicator + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: data_juicer.ops.deduplicator + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/sphinx_doc/source/data_juicer.ops.filter.rst b/docs/sphinx_doc/source/data_juicer.ops.filter.rst new file mode 100644 index 000000000..64e449177 --- /dev/null +++ b/docs/sphinx_doc/source/data_juicer.ops.filter.rst @@ -0,0 +1,133 @@ +data\_juicer.ops.filter package +=============================== + +Submodules +---------- + +data\_juicer.ops.filter.alphanumeric\_filter module +--------------------------------------------------- + +.. automodule:: data_juicer.ops.filter.alphanumeric_filter + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.filter.average\_line\_length\_filter module +------------------------------------------------------------ + +.. automodule:: data_juicer.ops.filter.average_line_length_filter + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.filter.character\_repetition\_filter module +------------------------------------------------------------ + +.. automodule:: data_juicer.ops.filter.character_repetition_filter + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.filter.flagged\_words\_filter module +----------------------------------------------------- + +.. automodule:: data_juicer.ops.filter.flagged_words_filter + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.filter.language\_id\_score\_filter module +---------------------------------------------------------- + +.. automodule:: data_juicer.ops.filter.language_id_score_filter + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.filter.maximum\_line\_length\_filter module +------------------------------------------------------------ + +.. automodule:: data_juicer.ops.filter.maximum_line_length_filter + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.filter.perplexity\_filter module +------------------------------------------------- + +.. automodule:: data_juicer.ops.filter.perplexity_filter + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.filter.special\_characters\_filter module +---------------------------------------------------------- + +.. automodule:: data_juicer.ops.filter.special_characters_filter + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.filter.specified\_field\_filter module +------------------------------------------------------- + +.. automodule:: data_juicer.ops.filter.specified_field_filter + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.filter.specified\_numeric\_field\_filter module +---------------------------------------------------------------- + +.. automodule:: data_juicer.ops.filter.specified_numeric_field_filter + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.filter.stopwords\_filter module +------------------------------------------------ + +.. automodule:: data_juicer.ops.filter.stopwords_filter + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.filter.suffix\_filter module +--------------------------------------------- + +.. automodule:: data_juicer.ops.filter.suffix_filter + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.filter.text\_length\_filter module +--------------------------------------------------- + +.. automodule:: data_juicer.ops.filter.text_length_filter + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.filter.word\_num\_filter module +------------------------------------------------ + +.. automodule:: data_juicer.ops.filter.word_num_filter + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.filter.word\_repetition\_filter module +------------------------------------------------------- + +.. automodule:: data_juicer.ops.filter.word_repetition_filter + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: data_juicer.ops.filter + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/sphinx_doc/source/data_juicer.ops.mapper.rst b/docs/sphinx_doc/source/data_juicer.ops.mapper.rst new file mode 100644 index 000000000..c8688614b --- /dev/null +++ b/docs/sphinx_doc/source/data_juicer.ops.mapper.rst @@ -0,0 +1,149 @@ +data\_juicer.ops.mapper package +=============================== + +Submodules +---------- + +data\_juicer.ops.mapper.clean\_copyright\_mapper module +------------------------------------------------------- + +.. automodule:: data_juicer.ops.mapper.clean_copyright_mapper + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.mapper.clean\_email\_mapper module +--------------------------------------------------- + +.. automodule:: data_juicer.ops.mapper.clean_email_mapper + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.mapper.clean\_html\_mapper module +-------------------------------------------------- + +.. automodule:: data_juicer.ops.mapper.clean_html_mapper + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.mapper.clean\_ip\_mapper module +------------------------------------------------ + +.. automodule:: data_juicer.ops.mapper.clean_ip_mapper + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.mapper.clean\_links\_mapper module +--------------------------------------------------- + +.. automodule:: data_juicer.ops.mapper.clean_links_mapper + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.mapper.expand\_macro\_mapper module +---------------------------------------------------- + +.. automodule:: data_juicer.ops.mapper.expand_macro_mapper + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.mapper.fix\_unicode\_mapper module +--------------------------------------------------- + +.. automodule:: data_juicer.ops.mapper.fix_unicode_mapper + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.mapper.punctuation\_normalization\_mapper module +----------------------------------------------------------------- + +.. automodule:: data_juicer.ops.mapper.punctuation_normalization_mapper + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.mapper.remove\_bibliography\_mapper module +----------------------------------------------------------- + +.. automodule:: data_juicer.ops.mapper.remove_bibliography_mapper + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.mapper.remove\_comments\_mapper module +------------------------------------------------------- + +.. automodule:: data_juicer.ops.mapper.remove_comments_mapper + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.mapper.remove\_header\_mapper module +----------------------------------------------------- + +.. automodule:: data_juicer.ops.mapper.remove_header_mapper + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.mapper.remove\_long\_words\_mapper module +---------------------------------------------------------- + +.. automodule:: data_juicer.ops.mapper.remove_long_words_mapper + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.mapper.remove\_specific\_chars\_mapper module +-------------------------------------------------------------- + +.. automodule:: data_juicer.ops.mapper.remove_specific_chars_mapper + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.mapper.remove\_table\_text\_mapper module +---------------------------------------------------------- + +.. automodule:: data_juicer.ops.mapper.remove_table_text_mapper + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.mapper.remove\_words\_with\_incorrect\_substrings\_mapper module +--------------------------------------------------------------------------------- + +.. automodule:: data_juicer.ops.mapper.remove_words_with_incorrect_substrings_mapper + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.mapper.sentence\_split\_mapper module +------------------------------------------------------ + +.. automodule:: data_juicer.ops.mapper.sentence_split_mapper + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.mapper.whitespace\_normalization\_mapper module +---------------------------------------------------------------- + +.. automodule:: data_juicer.ops.mapper.whitespace_normalization_mapper + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: data_juicer.ops.mapper + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/sphinx_doc/source/data_juicer.ops.rst b/docs/sphinx_doc/source/data_juicer.ops.rst new file mode 100644 index 000000000..f25068b50 --- /dev/null +++ b/docs/sphinx_doc/source/data_juicer.ops.rst @@ -0,0 +1,41 @@ +data\_juicer.ops package +======================== + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + data_juicer.ops.common + data_juicer.ops.deduplicator + data_juicer.ops.filter + data_juicer.ops.mapper + data_juicer.ops.selector + +Submodules +---------- + +data\_juicer.ops.base\_op module +-------------------------------- + +.. automodule:: data_juicer.ops.base_op + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.load module +---------------------------- + +.. automodule:: data_juicer.ops.load + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: data_juicer.ops + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/sphinx_doc/source/data_juicer.ops.selector.rst b/docs/sphinx_doc/source/data_juicer.ops.selector.rst new file mode 100644 index 000000000..266b47408 --- /dev/null +++ b/docs/sphinx_doc/source/data_juicer.ops.selector.rst @@ -0,0 +1,29 @@ +data\_juicer.ops.selector package +================================= + +Submodules +---------- + +data\_juicer.ops.selector.frequency\_specified\_field\_selector module +---------------------------------------------------------------------- + +.. automodule:: data_juicer.ops.selector.frequency_specified_field_selector + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.ops.selector.topk\_specified\_field\_selector module +----------------------------------------------------------------- + +.. automodule:: data_juicer.ops.selector.topk_specified_field_selector + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: data_juicer.ops.selector + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/sphinx_doc/source/data_juicer.rst b/docs/sphinx_doc/source/data_juicer.rst new file mode 100644 index 000000000..c305d1dd0 --- /dev/null +++ b/docs/sphinx_doc/source/data_juicer.rst @@ -0,0 +1,23 @@ +data\_juicer package +==================== + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + data_juicer.analysis + data_juicer.config + data_juicer.core + data_juicer.format + data_juicer.ops + data_juicer.utils + +Module contents +--------------- + +.. automodule:: data_juicer + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/sphinx_doc/source/data_juicer.utils.rst b/docs/sphinx_doc/source/data_juicer.utils.rst new file mode 100644 index 000000000..65b8d1208 --- /dev/null +++ b/docs/sphinx_doc/source/data_juicer.utils.rst @@ -0,0 +1,69 @@ +data\_juicer.utils package +========================== + +Submodules +---------- + +data\_juicer.utils.asset\_utils module +-------------------------------------- + +.. automodule:: data_juicer.utils.asset_utils + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.utils.cache\_utils module +-------------------------------------- + +.. automodule:: data_juicer.utils.cache_utils + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.utils.ckpt\_utils module +------------------------------------- + +.. automodule:: data_juicer.utils.ckpt_utils + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.utils.file\_utils module +------------------------------------- + +.. automodule:: data_juicer.utils.file_utils + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.utils.logger\_utils module +--------------------------------------- + +.. automodule:: data_juicer.utils.logger_utils + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.utils.model\_utils module +-------------------------------------- + +.. automodule:: data_juicer.utils.model_utils + :members: + :undoc-members: + :show-inheritance: + +data\_juicer.utils.registry module +---------------------------------- + +.. automodule:: data_juicer.utils.registry + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: data_juicer.utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/sphinx_doc/source/index.rst b/docs/sphinx_doc/source/index.rst new file mode 100644 index 000000000..9c098d834 --- /dev/null +++ b/docs/sphinx_doc/source/index.rst @@ -0,0 +1,21 @@ +.. data-juicer documentation master file, created by + sphinx-quickstart on Mon May 22 16:16:12 2023. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to data-juicer's documentation! +======================================= + +.. toctree:: + :maxdepth: 2 + :caption: References: + +.. include:: modules.rst + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/sphinx_doc/source/modules.rst b/docs/sphinx_doc/source/modules.rst new file mode 100644 index 000000000..2845759f3 --- /dev/null +++ b/docs/sphinx_doc/source/modules.rst @@ -0,0 +1,7 @@ +data_juicer +=========== + +.. toctree:: + :maxdepth: 4 + + data_juicer diff --git a/environments/dev_requires.txt b/environments/dev_requires.txt new file mode 100644 index 000000000..ff091a304 --- /dev/null +++ b/environments/dev_requires.txt @@ -0,0 +1,5 @@ +pre-commit +sphinx +sphinx-autobuild +sphinx_rtd_theme +recommonmark diff --git a/environments/minimal_requires.txt b/environments/minimal_requires.txt new file mode 100644 index 000000000..f1917aad3 --- /dev/null +++ b/environments/minimal_requires.txt @@ -0,0 +1,17 @@ +datasets==2.11.0 +loguru +tqdm +jsonargparse[signatures] +matplotlib +pandas +requests +wget +zstandard +pdfplumber +python-docx +streamlit +spacy==3.5.0 +en_core_web_md @ https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/data_juicer/models/en_core_web_md-3.5.0-py3-none-any.whl +zh_core_web_md @ https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/data_juicer/models/zh_core_web_md-3.5.0-py3-none-any.whl +multiprocess==0.70.12 +dill==0.3.4 diff --git a/environments/preprocess_requires.txt b/environments/preprocess_requires.txt new file mode 100644 index 000000000..657e1936b --- /dev/null +++ b/environments/preprocess_requires.txt @@ -0,0 +1,2 @@ +fire +jsonlines diff --git a/environments/quality_classifier_requires.txt b/environments/quality_classifier_requires.txt new file mode 100644 index 000000000..e7b76ed45 --- /dev/null +++ b/environments/quality_classifier_requires.txt @@ -0,0 +1,3 @@ +pyspark +fire +wget diff --git a/environments/science_requires.txt b/environments/science_requires.txt new file mode 100644 index 000000000..d0410bd57 --- /dev/null +++ b/environments/science_requires.txt @@ -0,0 +1,13 @@ +fasttext +kenlm @ http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/data_juicer/kenlm-master.zip +sentencepiece +scipy +tabulate +pandas +ftfy +emoji==2.2.0 +regex +simhash-py +selectolax +nltk +transformers diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 000000000..ebc6dcad3 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,3 @@ +[flake8] +per-file-ignores = + */__init__.py: F401 diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..69a2bb3b1 --- /dev/null +++ b/setup.py @@ -0,0 +1,69 @@ +import logging +import os.path +import re + +import setuptools + + +def get_package_dir(): + pkg_dir = { + 'data_juicer.tools': 'tools', + } + return pkg_dir + + +def get_install_requirements(require_f_paths, env_dir='environments'): + reqs = [] + for path in require_f_paths: + target_f = os.path.join(env_dir, path) + if not os.path.exists(target_f): + logging.warning(f'target file does not exist: {target_f}') + else: + with open(target_f, 'r', encoding='utf-8') as fin: + reqs += [x.strip() for x in fin.read().splitlines()] + reqs = [x for x in reqs if not x.startswith('#')] + return reqs + + +# allowing selective installment based on users' needs +# TODO: The specific taxonomy and dependencies will be determined +# after implementing some preliminary operators and detailed discussions +min_requires = get_install_requirements( + ['minimal_requires.txt', 'science_requires.txt']) +extra_requires = { + 'mini': + min_requires, + 'dev': + get_install_requirements(['dev_requires.txt']), + 'tools': + get_install_requirements( + ['preprocess_requires.txt', 'quality_classifier_requires.txt']), +} +extra_requires['all'] = [v for v in extra_requires.values()] + +with open('data_juicer/__init__.py', 'r') as f: + version = re.search(r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', f.read(), + re.MULTILINE).group(1) + +with open('README.md', encoding='utf-8') as f: + readme_md = f.read() + +setuptools.setup( + name='data_juicer', + version=version, + author='SysML team of Alibaba DAMO Academy', + description='A Data-Centric Text Processing System for Large Language ' + 'Models.', + long_description=readme_md, + long_description_content_type='text/markdown', + license='Apache License 2.0', + packages=setuptools.find_packages(), + package_dir=get_package_dir(), + install_requires=min_requires, + extras_require=extra_requires, + classifiers=[ + 'License :: OSI Approved :: Apache Software License', + 'Programming Language :: Python :: 3', + 'Operating System :: OS Independent' + ], +) diff --git a/tests/.DS_Store b/tests/.DS_Store new file mode 100644 index 000000000..35f1b0501 Binary files /dev/null and b/tests/.DS_Store differ diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/analysis/__init__.py b/tests/analysis/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/config/__init__.py b/tests/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/config/demo_4_test.yaml b/tests/config/demo_4_test.yaml new file mode 100644 index 000000000..39d11fd8f --- /dev/null +++ b/tests/config/demo_4_test.yaml @@ -0,0 +1,18 @@ +# Process config example for Arxiv dataset + +# global parameters +project_name: 'test_demo' +dataset_path: './demo/demo-dataset.jsonl' # path to your dataset directory or file +np: 4 # number of subprocess to process your dataset + +export_path: './outputs/demo/demo-processed.parquet' + +# process schedule +# a list of several process operators with their arguments +process: + - whitespace_normalization_mapper: + - language_id_score_filter: + lang: 'zh' + - document_deduplicator: # deduplicate text samples using md5 hashing exact matching method + lowercase: false # whether to convert text to lower case + ignore_non_character: false diff --git a/tests/config/test_config_funcs.py b/tests/config/test_config_funcs.py new file mode 100644 index 000000000..74996cc6a --- /dev/null +++ b/tests/config/test_config_funcs.py @@ -0,0 +1,115 @@ +import os +import unittest +from contextlib import redirect_stdout +from io import StringIO + +from jsonargparse import Namespace + +from data_juicer.config import init_configs +from data_juicer.ops import load_ops + +test_yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), + 'demo_4_test.yaml') + + +class ConfigTest(unittest.TestCase): + + def test_help_info(self): + out = StringIO() + with redirect_stdout(out), self.assertRaises(SystemExit): + _ = init_configs(args=['--help']) + out_str = out.getvalue() + self.assertIn('usage:', out_str, 'lacks message for command beginning') + self.assertIn('--config CONFIG', out_str, + 'lacks message for positional argument') + self.assertIn('[--project_name PROJECT_NAME]', out_str, + 'lacks message for optional argument') + self.assertIn( + 'Number of processes to process dataset. (type:', out_str, + 'the help message of `np` argument does not show as expected') + + def test_yaml_cfg_file(self): + out = StringIO() + with redirect_stdout(out): + cfg = init_configs(args=f'--config {test_yaml_path}'.split()) + self.assertIsInstance(cfg, Namespace) + self.assertEqual(cfg.project_name, 'test_demo') + self.assertDictEqual( + cfg.process[0], + {'whitespace_normalization_mapper': { + 'text_key': None + }}, 'nested dict load fail, for nonparametric op') + self.assertDictEqual( + cfg.process[1], { + 'language_id_score_filter': { + 'lang': 'zh', + 'min_score': 0.8, + 'text_key': None + } + }, 'nested dict load fail, un-expected internal value') + + op_from_cfg = load_ops(cfg.process, cfg.text_key_to_process) + self.assertTrue(len(op_from_cfg) == 3) + + def test_mixture_cfg(self): + out = StringIO() + with redirect_stdout(out): + ori_cfg = init_configs(args=f'--config {test_yaml_path}'.split()) + mixed_cfg_1 = init_configs( + args=f'--config {test_yaml_path} ' + '--language_id_score_filter.lang en'.split()) + mixed_cfg_2 = init_configs( + args=f'--config {test_yaml_path} ' + '--language_id_score_filter.lang=fr'.split()) + mixed_cfg_3 = init_configs( + args=f'--config {test_yaml_path} ' + '--language_id_score_filter.lang zh ' + '--language_id_score_filter.min_score 0.6'.split()) + mixed_cfg_4 = init_configs( + args=f'--config {test_yaml_path} ' + '--language_id_score_filter.lang=en ' + '--language_id_score_filter.min_score=0.5'.split()) + self.assertDictEqual( + ori_cfg.process[1], { + 'language_id_score_filter': { + 'lang': 'zh', + 'min_score': 0.8, + 'text_key': None + } + }) + self.assertDictEqual( + mixed_cfg_1.process[1], { + 'language_id_score_filter': { + 'lang': 'en', + 'min_score': 0.8, + 'text_key': None + } + }) + self.assertDictEqual( + mixed_cfg_2.process[1], { + 'language_id_score_filter': { + 'lang': 'fr', + 'min_score': 0.8, + 'text_key': None + } + }) + self.assertDictEqual( + mixed_cfg_3.process[1], { + 'language_id_score_filter': { + 'lang': 'zh', + 'min_score': 0.6, + 'text_key': None + } + }) + self.assertDictEqual( + mixed_cfg_4.process[1], { + 'language_id_score_filter': { + 'lang': 'en', + 'min_score': 0.5, + 'text_key': None + } + }) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/format/__init__.py b/tests/format/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/format/data/structured/demo-dataset.csv b/tests/format/data/structured/demo-dataset.csv new file mode 100644 index 000000000..3ead57857 --- /dev/null +++ b/tests/format/data/structured/demo-dataset.csv @@ -0,0 +1,7 @@ +text,meta +Today is Sunday and it's a happy day!,"{'src': 'Arxiv', 'date': datetime.datetime(2023, 4, 27, 0, 0), 'version': '1.0', 'author': None}" +Do you need a cup of coffee?,"{'src': 'code', 'date': None, 'version': None, 'author': 'xxx'}" +你好,请问你是谁,"{'src': 'customized', 'date': None, 'version': None, 'author': 'xxx'}" +"Sur la plateforme MT4, plusieurs manières d'accéder à ces fonctionnalités sont conçues simultanément.","{'src': 'Oscar', 'date': None, 'version': '2.0', 'author': None}" +欢迎来到阿里巴巴!,"{'src': 'customized', 'date': None, 'version': '0.1', 'author': 'xxx'}" +This paper proposed a novel method on LLM pretraining.,"{'src': 'customized', 'date': None, 'version': None, 'author': 'xxx'}" diff --git a/tests/format/data/structured/demo-dataset.jsonl b/tests/format/data/structured/demo-dataset.jsonl new file mode 100644 index 000000000..707f802b0 --- /dev/null +++ b/tests/format/data/structured/demo-dataset.jsonl @@ -0,0 +1,2 @@ +{"text": "Today is Sunday and it's a happy day!", "meta": {"src": "Arxiv", "date": "2023-04-27", "version": "1.0"}} +{"text": "Do you need a cup of coffee?", "meta": {"src": "code", "author": "xxx"}} diff --git a/tests/format/data/structured/demo-dataset.parquet b/tests/format/data/structured/demo-dataset.parquet new file mode 100644 index 000000000..57ea0c38e Binary files /dev/null and b/tests/format/data/structured/demo-dataset.parquet differ diff --git a/tests/format/data/structured/demo-dataset.tsv b/tests/format/data/structured/demo-dataset.tsv new file mode 100644 index 000000000..2cc07067a --- /dev/null +++ b/tests/format/data/structured/demo-dataset.tsv @@ -0,0 +1,7 @@ +text meta +Today is Sunday and it's a happy day! {'src': 'Arxiv', 'date': datetime.datetime(2023, 4, 27, 0, 0), 'version': '1.0', 'author': None} +Do you need a cup of coffee? {'src': 'code', 'date': None, 'version': None, 'author': 'xxx'} +你好,请问你是谁 {'src': 'customized', 'date': None, 'version': None, 'author': 'xxx'} +Sur la plateforme MT4, plusieurs manières d'accéder à ces fonctionnalités sont conçues simultanément. {'src': 'Oscar', 'date': None, 'version': '2.0', 'author': None} +欢迎来到阿里巴巴! {'src': 'customized', 'date': None, 'version': '0.1', 'author': 'xxx'} +This paper proposed a novel method on LLM pretraining. {'src': 'customized', 'date': None, 'version': None, 'author': 'xxx'} diff --git a/tests/format/data/text/sample1.txt b/tests/format/data/text/sample1.txt new file mode 100644 index 000000000..698ad7c54 --- /dev/null +++ b/tests/format/data/text/sample1.txt @@ -0,0 +1 @@ +Today is Sunday and it's a happy day! diff --git a/tests/format/data/text/sample2.txt b/tests/format/data/text/sample2.txt new file mode 100644 index 000000000..5d6227b09 --- /dev/null +++ b/tests/format/data/text/sample2.txt @@ -0,0 +1 @@ +Do you need a cup of coffee? diff --git a/tests/format/data/text/sample3.txt b/tests/format/data/text/sample3.txt new file mode 100644 index 000000000..78dc2d5ad --- /dev/null +++ b/tests/format/data/text/sample3.txt @@ -0,0 +1 @@ +你好,请问你是谁 diff --git a/tests/format/data/text/sample4.txt b/tests/format/data/text/sample4.txt new file mode 100644 index 000000000..704306740 --- /dev/null +++ b/tests/format/data/text/sample4.txt @@ -0,0 +1 @@ +Sur la plateforme MT4, plusieurs manières d'accéder à ces fonctionnalités sont conçues simultanément. diff --git a/tests/format/data/text/sample5.txt b/tests/format/data/text/sample5.txt new file mode 100644 index 000000000..0390b9676 --- /dev/null +++ b/tests/format/data/text/sample5.txt @@ -0,0 +1 @@ +欢迎来到阿里巴巴! diff --git a/tests/format/data/text/sample6.txt b/tests/format/data/text/sample6.txt new file mode 100644 index 000000000..ea375cee5 --- /dev/null +++ b/tests/format/data/text/sample6.txt @@ -0,0 +1 @@ +This paper proposed a novel method on LLM pretraining. diff --git a/tests/format/test_csv_formatter.py b/tests/format/test_csv_formatter.py new file mode 100644 index 000000000..7bb99d978 --- /dev/null +++ b/tests/format/test_csv_formatter.py @@ -0,0 +1,28 @@ +import os +import unittest + +from data_juicer.format.csv_formatter import CsvFormatter + + +class CsvFormatterTest(unittest.TestCase): + + def setUp(self): + self._path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'structured') + self._file = os.path.join(self._path, 'demo-dataset.csv') + print(self._file) + + def test_csv_file(self): + formatter = CsvFormatter(self._file) + ds = formatter.load_dataset() + self.assertEqual(len(ds), 6) + self.assertEqual(list(ds.features.keys()), ['text', 'meta']) + + def test_csv_path(self): + formatter = CsvFormatter(self._path) + ds = formatter.load_dataset() + self.assertEqual(len(ds), 6) + self.assertEqual(list(ds.features.keys()), ['text', 'meta']) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/format/test_parquet_formatter.py b/tests/format/test_parquet_formatter.py new file mode 100644 index 000000000..ddd7b80d7 --- /dev/null +++ b/tests/format/test_parquet_formatter.py @@ -0,0 +1,28 @@ +import os +import unittest + +from data_juicer.format.parquet_formatter import ParquetFormatter + + +class CsvFormatterTest(unittest.TestCase): + + def setUp(self): + self._path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'structured') + self._file = os.path.join(self._path, 'demo-dataset.parquet') + print(self._file) + + def test_parquet_file(self): + formatter = ParquetFormatter(self._file) + ds = formatter.load_dataset() + self.assertEqual(len(ds), 6) + self.assertEqual(list(ds.features.keys()), ['text', 'meta']) + + def test_parquet_path(self): + formatter = ParquetFormatter(self._path) + ds = formatter.load_dataset() + self.assertEqual(len(ds), 6) + self.assertEqual(list(ds.features.keys()), ['text', 'meta']) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/format/test_tsv_formatter.py b/tests/format/test_tsv_formatter.py new file mode 100644 index 000000000..5f6da8a78 --- /dev/null +++ b/tests/format/test_tsv_formatter.py @@ -0,0 +1,28 @@ +import os +import unittest + +from data_juicer.format.tsv_formatter import TsvFormatter + + +class TsvFormatterTest(unittest.TestCase): + + def setUp(self): + self._path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'structured') + self._file = os.path.join(self._path, 'demo-dataset.tsv') + print(self._file) + + def test_tsv_file(self): + formatter = TsvFormatter(self._file) + ds = formatter.load_dataset() + self.assertEqual(len(ds), 6) + self.assertEqual(list(ds.features.keys()), ['text', 'meta']) + + def test_tsv_path(self): + formatter = TsvFormatter(self._path) + ds = formatter.load_dataset() + self.assertEqual(len(ds), 6) + self.assertEqual(list(ds.features.keys()), ['text', 'meta']) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/format/test_unify_format.py b/tests/format/test_unify_format.py new file mode 100644 index 000000000..03a55e391 --- /dev/null +++ b/tests/format/test_unify_format.py @@ -0,0 +1,447 @@ +import os +import unittest + +from datasets import Dataset + +from data_juicer.format.formatter import load_dataset, unify_format + + +class UnifyFormatTest(unittest.TestCase): + + def run_test(self, sample, args=None): + if args is None: + args = {} + ds = Dataset.from_list(sample['source']) + ds = unify_format(ds, **args) + self.assertEqual(ds.to_list(), sample['target']) + + def test_text_key(self): + samples = [ + { + 'source': [{ + 'text': 'This is a test text', + 'outer_key': 1, + }], + 'target': [{ + 'text': 'This is a test text', + 'meta.outer_key': 1, + }] + }, + { + 'source': [{ + 'content': 'This is a test text', + 'outer_key': 1, + }], + 'target': [{ + 'text': 'This is a test text', + 'meta.outer_key': 1, + }] + }, + { + 'source': [{ + 'input': 'This is a test text, input part', + 'instruction': 'This is a test text, instruction part', + 'outer_key': 1, + }], + 'target': [{ + 'text.input': 'This is a test text, input part', + 'text.instruction': + 'This is a test text, instruction part', + 'meta.outer_key': 1, + }] + }, + ] + self.run_test(samples[0]) + self.run_test(samples[1], args={'text_keys_to_load': ['content']}) + self.run_test(samples[2], + args={'text_keys_to_load': ['input', 'instruction']}) + + def test_empty_text(self): + # filter out samples containing None field, but '' is OK + samples = [ + { + 'source': [{ + 'text': '', + 'outer_key': 1, + }], + 'target': [{ + 'text': '', + 'meta.outer_key': 1, + }], + }, + { + 'source': [{ + 'text': None, + 'outer_key': 1, + }], + 'target': [], + }, + ] + for sample in samples: + self.run_test(sample) + + def test_no_extra_fields(self): + samples = [{ + 'source': [{ + 'text': 'This is a test text.', + 'stats': { + 'lang': 'en' + }, + }], + 'target': [{ + 'text': 'This is a test text.', + 'stats': { + 'lang': 'en' + }, + }], + }, { + 'source': [{ + 'text': 'This is a test text.', + }], + 'target': [{ + 'text': 'This is a test text.', + }], + }] + for sample in samples: + self.run_test(sample) + + def test_no_extra_fields_except_meta(self): + samples = [{ + 'source': [{ + 'text': 'This is a test text.', + 'meta': { + 'version': 1 + }, + 'stats': { + 'lang': 'en' + }, + }], + 'target': [{ + 'text': 'This is a test text.', + 'meta': { + 'version': 1 + }, + 'stats': { + 'lang': 'en' + }, + }], + }, { + 'source': [{ + 'text': 'This is a test text.', + 'meta': { + 'version': 1 + }, + }], + 'target': [{ + 'text': 'This is a test text.', + 'meta': { + 'version': 1 + }, + }], + }] + for sample in samples: + self.run_test(sample) + + def test_invalid_stats(self): + # non-dict stats will be unified into meta.stats + samples = [{ + 'source': [{ + 'text': 'This is a test text.', + 'stats': 'nice', + }], + 'target': [{ + 'text': 'This is a test text.', + 'meta.stats': 'nice' + }], + }, { + 'source': [{ + 'text': 'This is a test text.', + 'stats': { + 'version': 1 + }, + }], + 'target': [{ + 'text': 'This is a test text.', + 'stats': { + 'version': 1 + }, + }], + }] + for sample in samples: + self.run_test(sample) + + def test_outer_fields(self): + samples = [ + { + 'source': [{ + 'text': 'This is a test text.', + 'meta': { + 'meta_inner': 'nice' + }, + 'outer_field': 'value' + }], + 'target': [{ + 'text': 'This is a test text.', + 'meta': { + 'meta_inner': 'nice', + }, + 'meta.outer_field': 'value', + }], + }, + { + 'source': [{ + 'text': 'This is a test text.', + 'outer_key': 'nice', + 'outer_field': 'value' + }], + 'target': [{ + 'text': 'This is a test text.', + 'meta.outer_key': 'nice', + 'meta.outer_field': 'value', + }], + }, + { + 'source': [{ + 'text': 'This is a test text.', + 'meta': 'nice', + 'outer_field': 'value' + }], + 'target': [{ + 'text': 'This is a test text.', + 'meta': 'nice', + 'meta.outer_field': 'value', + }], + }, + { + 'source': [{ + 'text': 'This is a test text.', + 'meta': { + 'meta_inner': 'nice' + }, + 'outer_field': 'value', + 'stats': { + 'lang': 'en' + }, + }], + 'target': [{ + 'text': 'This is a test text.', + 'meta': { + 'meta_inner': 'nice' + }, + 'meta.outer_field': 'value', + 'stats': { + 'lang': 'en' + }, + }], + }, + { + 'source': [{ + 'text': 'This is a test text.', + 'outer_key': 'nice', + 'outer_field': 'value', + 'stats': { + 'lang': 'en' + }, + }], + 'target': [{ + 'text': 'This is a test text.', + 'meta.outer_key': 'nice', + 'meta.outer_field': 'value', + 'stats': { + 'lang': 'en' + }, + }], + }, + { + 'source': [{ + 'text': 'This is a test text.', + 'meta': 'nice', + 'outer_field': 'value', + 'stats': { + 'lang': 'en' + }, + }], + 'target': [{ + 'text': 'This is a test text.', + 'meta': 'nice', + 'meta.outer_field': 'value', + 'stats': { + 'lang': 'en' + }, + }], + }, + { + 'source': [{ + 'text': 'This is a test text.', + 'meta': { + 'meta_inner': 'nice' + }, + 'outer_field': 'value', + 'stats': 'en', + }], + 'target': [{ + 'text': 'This is a test text.', + 'meta': { + 'meta_inner': 'nice' + }, + 'meta.outer_field': 'value', + 'meta.stats': 'en' + }], + }, + { + 'source': [{ + 'text': 'This is a test text.', + 'outer_key': 'nice', + 'outer_field': 'value', + 'stats': 'en', + }], + 'target': [{ + 'text': 'This is a test text.', + 'meta.outer_key': 'nice', + 'meta.outer_field': 'value', + 'meta.stats': 'en' + }], + }, + { + 'source': [{ + 'text': 'This is a test text.', + 'meta': 'nice', + 'outer_field': 'value', + 'stats': 'en', + }], + 'target': [{ + 'text': 'This is a test text.', + 'meta': 'nice', + 'meta.outer_field': 'value', + 'meta.stats': 'en' + }], + }, + ] + for sample in samples: + self.run_test(sample) + + def test_recursive_meta(self): + samples = [{ + 'source': [{ + 'text': 'This is a test text.', + 'outer_field': { + 'rec1': { + 'rec2': 'value' + } + }, + }], + 'target': [{ + 'text': 'This is a test text.', + 'meta.outer_field': { + 'rec1': { + 'rec2': 'value' + } + }, + }], + }] + for sample in samples: + self.run_test(sample) + + def test_hetero_meta(self): + cur_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'structured') + file_path = os.path.join(cur_dir, 'demo-dataset.jsonl') + ds = load_dataset('json', data_files=file_path) + ds = unify_format(ds) + import datetime + # the 'None' fields are missing fields after merging + sample = [{ + 'text': "Today is Sunday and it's a happy day!", + 'meta': { + 'src': 'Arxiv', + 'date': datetime.datetime(2023, 4, 27, 0, 0), + 'version': '1.0', + 'author': None + } + }, { + 'text': 'Do you need a cup of coffee?', + 'meta': { + 'src': 'code', + 'date': None, + 'version': None, + 'author': 'xxx' + } + }] + unified_sample_list = ds.to_list() + self.assertEqual(unified_sample_list, sample) + # test nested and missing field for the following cases: + # 1. first row, then column + unified_sample_first = ds[0] + unified_sample_second = ds[1] + self.assertEqual(unified_sample_first['meta.src'], 'Arxiv') + self.assertEqual(unified_sample_first['meta.author'], None) + self.assertEqual(unified_sample_second['meta.date'], None) + # 2. first column, then row + self.assertEqual(ds['meta.src'][0], 'Arxiv') + self.assertEqual(ds['meta.src'][1], 'code') + self.assertEqual(ds['meta.author'][0], None) + self.assertEqual(ds['meta.date'][1], None) + # 3. first partial rows, then column, final row + unified_ds_first = ds.select([0]) + unified_ds_second = ds.select([1]) + self.assertEqual(unified_ds_first['meta.src'][0], 'Arxiv') + self.assertEqual(unified_ds_first['meta.author'][0], None) + self.assertEqual(unified_ds_second['meta.date'][0], None) + + def test_empty_meta(self): + samples = [{ + 'source': [{ + 'text': 'This is a test text.', + 'meta': {}, + }], + 'target': [{ + 'text': 'This is a test text.', + 'meta': {}, + }], + }] + for sample in samples: + self.run_test(sample) + + def test_empty_stats(self): + samples = [{ + 'source': [{ + 'text': 'This is a test text.', + 'meta': {}, + 'stats': {}, + }], + 'target': [{ + 'text': 'This is a test text.', + 'meta': {}, + 'stats': {}, + }], + }] + for sample in samples: + self.run_test(sample) + + def test_empty_outer_fields(self): + samples = [{ + 'source': [{ + 'text': 'This is a test text.', + 'meta': {}, + 'out_field': {}, + }], + 'target': [{ + 'text': 'This is a test text.', + 'meta': {}, + 'meta.out_field': {}, + }], + }, { + 'source': [{ + 'text': 'This is a test text.', + 'out_field': {}, + }], + 'target': [{ + 'text': 'This is a test text.', + 'meta.out_field': {}, + }], + }] + for sample in samples: + self.run_test(sample) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/.DS_Store b/tests/ops/.DS_Store new file mode 100644 index 000000000..2220be4b6 Binary files /dev/null and b/tests/ops/.DS_Store differ diff --git a/tests/ops/__init__.py b/tests/ops/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/ops/common/__init__.py b/tests/ops/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/ops/deduplicator/__init__.py b/tests/ops/deduplicator/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/ops/deduplicator/test_document_deduplicator.py b/tests/ops/deduplicator/test_document_deduplicator.py new file mode 100644 index 000000000..740caae18 --- /dev/null +++ b/tests/ops/deduplicator/test_document_deduplicator.py @@ -0,0 +1,100 @@ +import unittest + +from datasets import Dataset + +from data_juicer.ops.deduplicator.document_deduplicator import \ + DocumentDeduplicator + + +class DocumentDeduplicatorTest(unittest.TestCase): + + def _run_doc_dedup(self, dataset: Dataset, target_list, op): + dataset = dataset.map(op.compute_hash) + dataset, _ = op.process(dataset) + dataset = dataset.select_columns(column_names=['text']) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_english_deduplication(self): + ds_list = [ + { + 'text': 'Today is Sunday and it\'s a happy day!' + }, + { + 'text': 'Do you need a cup of coffee?' + }, + { + 'text': 'Today is sunday and it\'s a happy day!' + }, + { + 'text': + 'This paper proposed a novel method on LLM pretraining.' + }, + { + 'text': + 'This paper proposed a novel method on LLM pretraining.' + }, + ] + tgt_list = [{ + 'text': 'Today is Sunday and it\'s a happy day!' + }, { + 'text': 'Do you need a cup of coffee?' + }, { + 'text': 'Today is sunday and it\'s a happy day!' + }, { + 'text': + 'This paper proposed a novel method on LLM pretraining.' + }] + dataset = Dataset.from_list(ds_list) + op = DocumentDeduplicator(lowercase=False, ignore_non_character=False) + self._run_doc_dedup(dataset, tgt_list, op) + + def test_chinese_deduplication(self): + ds_list = [ + { + 'text': '你好,请问你是谁' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + { + 'text': + '第九届会议\n2003年7月28日至8月8日\n牙买加金斯敦\n为来自发展中国家的法' + '律和技术委员会以及财务委员会成员\n参加委员会会议支付费用的方式\n1.' + }, + { + 'text': + '第九届会议\n2003年7月28日至8月8日\n牙买加金斯敦\n为来自发展中国家的法' + '律和技术委员会以及财务委员会成员\n参加委员会会议支付费用的方式\n1.' + }, + { + 'text': + '第九届会议\n时间:2003年7月28日至8月8日\n牙买加金斯敦\n为来自发展中国家的法' + '律和技术委员会以及财务委员会成员\n参加委员会会议支付费用的方式\n1.' + }, + ] + tgt_list = [ + { + 'text': '你好,请问你是谁' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + { + 'text': + '第九届会议\n2003年7月28日至8月8日\n牙买加金斯敦\n为来自发展中国家的法' + '律和技术委员会以及财务委员会成员\n参加委员会会议支付费用的方式\n1.' + }, + { + 'text': + '第九届会议\n时间:2003年7月28日至8月8日\n牙买加金斯敦\n为来自发展中国家的法' + '律和技术委员会以及财务委员会成员\n参加委员会会议支付费用的方式\n1.' + }, + ] + dataset = Dataset.from_list(ds_list) + op = DocumentDeduplicator(lowercase=False, ignore_non_character=False) + self._run_doc_dedup(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/deduplicator/test_document_minhash_deduplicator.py b/tests/ops/deduplicator/test_document_minhash_deduplicator.py new file mode 100644 index 000000000..b60209e8b --- /dev/null +++ b/tests/ops/deduplicator/test_document_minhash_deduplicator.py @@ -0,0 +1,962 @@ +import unittest + +from datasets import Dataset + +from data_juicer.ops.deduplicator.document_minhash_deduplicator import \ + DocumentMinhashDeduplicator + + +class DocumentMinhashDeduplicatorTest(unittest.TestCase): + + def _run_minhash_dedup(self, dataset: Dataset, target_list, op): + dataset = dataset.map(op.compute_hash) + dataset, _ = op.process(dataset) + dataset = dataset.select_columns(column_names=['text']) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_english_deduplication(self): + ds_list = [ + { + 'text': 'Today is Sunday and it\'s a happy day!' + }, + { + 'text': 'Do you need a cup of coffee?' + }, + { + 'text': 'Today is sunday and it\'s really a happy day!' + }, + { + 'text': + 'This paper proposed a novel method on LLM pretraining.' + }, + { + 'text': + 'Smithfield employs 3,700 people at its plant in Sioux Falls, ' + 'South Dakota. The plant slaughters 19,500 pigs a day — 5 ' + 'percent of U.S. pork. Most of the workers are immigrants ' + 'from Ethiopia, Mexico, South Sudan, Honduras, Myanmar, ' + 'Somalia, Guatemala, and other poor countries.\n\nInevitably ' + 'workers must pass within one foot of hundreds of colleagues ' + 'in the hallways, locker rooms, cafeterias, and cutting ' + 'lines. The same conditions have spurred Covid-19 outbreaks ' + 'at meat plants from Minnesota and Wisconsin to Colorado, ' + 'Nebraska, Missouri, Iowa, Pennsylvania, North Carolina, and ' + 'Georgia.\n\n801 workers at the Sioux Falls plant have tested ' + 'positive, together with 206 people close to them. The ' + 'outbreak has killed Agustín Rodríguez Martínez, aged 64, an ' + 'employee with two decades of experience originally from El ' + 'Salvador, and Craig Allen Franken, 61, who worked for ' + 'Smithfield his entire adult life.\n\nThe company knew of its ' + 'first infection on March 24 or earlier. The virus spread ' + 'exponentially for several weeks. Ahead of Easter Sunday and ' + 'Monday (April 12-13), Smithfield promised to “completely ' + 'shutter” to sanitize and put up cardboard and plastic sheet ' + 'dividers. This would not end transmission, as potentially ' + 'hundreds of staff were already carrying the virus. But even ' + 'during this “shutdown,” many cars were seen in the parking ' + 'lot. The mayor admits that the company lied, and the local ' + 'AFL-CIO alleges the plant ran 60 percent production. On ' + 'Easter, with 238 known infections, Smithfield finally ' + 'agreed to shut down indefinitely after a request from the ' + 'mayor and the governor. Yet the company insisted on waiting ' + 'three more days to actually halt production.\n\nSmithfield ' + 'denied contributing to the outbreak, saying it took a “very ' + 'proactive approach.” Relying on racism, the company blamed ' + 'workers for getting themselves sick. A spokesperson said ' + 'the outbreak was so severe because of the plant’s “large ' + 'immigrant population,” claming “Living circumstances in ' + 'certain cultures are different than they are with your ' + 'traditional American family.” They slandered the workers as ' + 'dirty, ignorant, and untrustworthy with help from governor ' + 'Kristi Noem, who claimed, “99 percent of what’s going on ' + 'today wasn’t happening inside the facility. It was more at ' + 'home, where these employees were going home and spreading ' + 'some of the virus” by living too close together.\n\nOne ' + 'sick worker, Michael Bul Gayo Gatluak, 22 and originally ' + 'from South Sudan, says, “With how we work on the line, ' + 'I would say I got sick because of them not taking safety ' + 'measures.” His job is “really, really close” to other ' + 'workers chopping fresh-killed pigs. “The job is so heavy. ' + 'You have to breathe so hard.”\n\nIn early March, ' + 'union officials requested masks, overcoats, entrance ' + 'checking for fevers, and less crowding in 500-capacity ' + 'cafeterias. But Smithfield waited on most safety measures ' + 'until early April. Only April 6 did they start checking for ' + 'fevers. Instead of protective masks, they gave out beard ' + 'nets.\n\nSmithfield concealed infections with a policy of ' + 'informing only employees whose work stations were in the ' + 'same area as a person who tested positive. The fact that ' + 'workers are required to move around was willfully ignored. ' + 'One worker who tested positive said, “I clearly would have ' + 'gotten it at the factory. This week I have worked on three ' + 'different floors. I’ve eaten in two different cafeterias … ' + 'I’ve been walking through the whole place.” Employees from ' + 'the eighth floor of the plant were quarantined, ' + 'but everyone else was told to keep working.\n\nWhat Is ' + 'Really Going On?\n\nAverage plant wages are around $16 an ' + 'hour. Smithfield never raised them. Instead, they offered ' + '$500 to employees who could go all of April without an ' + 'unapproved day off. The company says their “Responsibility ' + 'Bonuses” show their “immense gratefulness” to employees ' + '“for their selfless sacrifices.”\n\nMeanwhile, the local ' + 'Argus Leader wrote union members wanted essential-worker ' + 'hazard pay, which “would be considered hourly compensation ' + 'about 1.5 or two times their normal pay.” One worker said, ' + '“I feel like they’re bribing us with [the bonus] to come to ' + 'work sick. That’s how you know they don’t care.”\n\nBoth ' + 'Sioux Falls workers killed by Covid-19 were in their ' + 'sixties. It is unconscionable that they were still working. ' + 'All meatpackers over 50 should be on paid leave. Agustín ' + 'Rodríguez, 64, had a rough job sawing the legs off dead ' + 'pigs. He mopped floors with a fever shortly before he was ' + 'hospitalized.\n\nWhen CEO Kenneth Sullivan closed the ' + 'plant, he claimed, “We have continued to run our facilities ' + 'for one reason: to sustain our nation’s food supply.” This ' + 'is an effort to sweep Smithfield’s abuses under the rug, ' + 'as if the company were operating for public benefit. This ' + 'patriotic propaganda that all Americans are in it together ' + 'is like a drug to keep workers from getting ' + 'organized.\n\nThe major union in the industry, including at ' + 'Smithfield, is the United Food and Commercial Workers union ' + '(UFCW). What union leaders have done is ultimately ' + 'troubling.\n\nCan Workers Fight?\n\nLocal AFL-CIO president ' + 'Kooper Caraway has publicly said management delayed safety ' + 'action as long as possible for profit. But while some ' + 'workers were demanding a two-week shutdown, Caraway told ' + 'the Argus Leader that was unrealistic because the ' + 'government considers the plant essential. He suggested the ' + 'union would be happy with minimal safety measures: “Even if ' + '10 people get exposed in a day rather than 11. If you can ' + 'implement a program where even one or two less people get ' + 'exposed during a shift, that’s one or two less people.” Of ' + 'course reducing infections is good, but suggesting workers ' + 'would be satisfied if the company allowed 90% of the ' + 'contagion to continue is horrifying.\n\nThe response of ' + 'UFCW leadership was worse. As the disease was exploding, ' + 'they told the Argus Leader, “We applaud [Smithfield’s] ' + 'decision to temporarily close the plant [over Easter ' + 'weekend] to push for an even safer work environment.” What ' + 'does “even safer” mean in this context?\n\nThe union ' + 'bureaucracy has taken weak action elsewhere. In ' + 'Pennsylvania, the UFCW negotiated $2 hazard pay for two ' + 'months with Cargill Meat — the same pandemic premium Amazon ' + 'gave workers without a union. In Nebraska, the UFCW ' + 'negotiated $4 hazard pay for one month with meat giant ' + 'JBS.\n\nThe union has said nothing about forcing companies ' + 'to send older workers home with pay, even though a ' + '70-year-old shop steward and a 78-year-old grandfather ' + 'working at JBS plants were killed by Covid-19. Smithfield ' + 'workers were promised only two weeks of shutdown pay. For ' + 'many, this compensation is half their normal paycheck ' + 'because they routinely put in 66 hour weeks — overtime that ' + 'costs exhaustion and chronic pain.\n\nUnion officials ' + 'endeavor to cooperate with the meat companies. An Iowa UFCW ' + 'president actually suggested it might be impossible for ' + 'plants to move workers a full six feet apart and told the ' + 'Des Moines Register, “We can’t stop the plants. If we stop ' + 'the plants from running, we stop feeding the country. We ' + 'want to do everything we can to make sure the employees are ' + 'safe to keep the plant running.”\n\nEvery part of this ' + 'explanation directly overlaps with what the Smithfield CEO ' + 'said. Unfortunately, it amounts to accepting the company’s ' + 'excuses.\n\nThey claim that workers who do hard physical ' + 'labor, waking up at 4 a.m. and often working six days a ' + 'week for years, would be guilty of taking food away from ' + 'the people and hurting America if they dared to fight for ' + 'their human needs. But nothing is said about the company ' + 'raking in profits and even murdering workers to increase ' + 'them.\n\nSmithfield’s parent company W.H. Group, ' + 'which slaughters around 30 million pigs per year in plants ' + 'in both the United States and China, saw its profits ' + 'skyrocket by about one third in 2019 to $1.38 billion. It ' + 'is disturbing that UFCW officials do not bring up these ' + 'soaring profits in their response to the outbreaks. Reuters ' + 'published a report on the corporation’s financial success ' + 'in late March. The head of W.H. Group had touted to the ' + 'media that it got through the pandemic in China with very ' + 'limited impact on production.\n\nIt is true that many ' + 'Smithfield workers are reasonably afraid for their jobs and ' + 'want to keep working. A 25-year-old employee explained, ' + '“I have a lot of bills. My baby’s coming soon — I have to ' + 'work.” At the same time, he was afraid of infecting his ' + 'pregnant wife. His spouse, a former employee, ' + 'said bitterly, “Smithfield— they don’t care about ' + 'employees. They only care about their money.”\n\nWorkers ' + 'are pressured in these two painful directions. Nonetheless, ' + 'work can mean solidarity. Before Smithfield even checked ' + 'temperatures, there was a “sick-out” strike without union ' + 'support by 800 to 1,000 workers at a JBS meat factory in ' + 'Colorado. Hundreds of workers also called in sick days at a ' + 'Nebraska JBS plant.\n\nTrade union leaders won’t even ' + 'whisper the word “strike” when thousands of workers are ' + 'thinking about it. They are limiting themselves to polite ' + 'requests. We need a workers’ movement that asks who ' + 'controls the factory, that threatens to disrupt the bosses’ ' + 'profits, and that allows workers to use their immense power ' + '— this could change the meat industry and the world. ' + }, + { + 'text': + 'Smithfield employs 3,700 people at its plants in Sioux ' + 'Falls, South Dakota. The plant slaughters 19,500 pig a day ' + '— 5 percent of U.S. pork. Most of the workers are ' + 'immigrants from Ethiopia, Mexico, South Sudan, Honduras, ' + 'Myanmar, Somalia, Guatemala, and other poor ' + 'countries.\n\nInevitably workers must pass within one foot ' + 'of hundreds of colleagues in the hallways, locker rooms, ' + 'cafeterias, and cutting lines. The same conditions have ' + 'spurred Covid-19 outbreaks at meat plants from Minnesota ' + 'and Wisconsin to Colorado, Nebraska, Missouri, Iowa, ' + 'Pennsylvania, North Carolina, and Georgia.\n\n801 workers ' + 'at the Sioux Falls plant have tested positive, together ' + 'with 206 people close to them. The outbreak has killed ' + 'Agustín Rodríguez Martínez, aged 64, an employee with two ' + 'decades of experience originally from El Salvador, ' + 'and Craig Allen Franken, 61, who worked for Smithfield his ' + 'entire adult life.\n\nThe company knew of its first ' + 'infection on March 24 or earlier. The virus spread ' + 'exponentially for several weeks. Ahead of Easter Sunday and ' + 'Monday (April 12-13), Smithfield promised to “completely ' + 'shutter” to sanitize and put up cardboard and plastic sheet ' + 'dividers. This would not end transmission, as potentially ' + 'hundreds of staff were already carrying the virus. But even ' + 'during this “shutdown,” many cars were seen in the parking ' + 'lot. The mayor admits that the company lied, and the local ' + 'AFL-CIO alleges the plant ran 60 percent production. On ' + 'Easter, with 238 known infections, Smithfield finally ' + 'agreed to shut down indefinitely after a request from the ' + 'mayor and the governor. Yet the company insisted on waiting ' + 'three more days to actually halt production.\n\nSmithfield ' + 'denied contributing to the outbreak, saying it took a “very ' + 'proactive approach.” Relying on racism, the company blamed ' + 'workers for getting themselves sick. A spokesperson said ' + 'the outbreak was so severe because of the plant’s “large ' + 'immigrant population,” claming “Living circumstances in ' + 'certain cultures are different than they are with your ' + 'traditional American family.” They slandered the workers as ' + 'dirty, ignorant, and untrustworthy with help from governor ' + 'Kristi Noem, who claimed, “99 percent of what’s going on ' + 'today wasn’t happening inside the facility. It was more at ' + 'home, where these employees were going home and spreading ' + 'some of the virus” by living too close together.\n\nOne ' + 'sick worker, Michael Bul Gayo Gatluak, 22 and originally ' + 'from South Sudan, says, “With how we work on the line, ' + 'I would say I got sick because of them not taking safety ' + 'measures.” His job is “really, really close” to other ' + 'workers chopping fresh-killed pigs. “The job is so heavy. ' + 'You have to breathe so hard.”\n\nIn early March, ' + 'union officials requested masks, overcoats, entrance ' + 'checking for fevers, and less crowding in 500-capacity ' + 'cafeterias. But Smithfield waited on most safety measures ' + 'until early April. Only April 6 did they start checking for ' + 'fevers. Instead of protective masks, they gave out beard ' + 'nets.\n\nSmithfield concealed infections with a policy of ' + 'informing only employees whose work stations were in the ' + 'same area as a person who tested positive. The fact that ' + 'workers are required to move around was willfully ignored. ' + 'One worker who tested positive said, “I clearly would have ' + 'gotten it at the factory. This week I have worked on three ' + 'different floors. I’ve eaten in two different cafeterias … ' + 'I’ve been walking through the whole place.” Employees from ' + 'the eighth floor of the plant were quarantined, ' + 'but everyone else was told to keep working.\n\nWhat Is ' + 'Really Going On?\n\nAverage plant wages are around $16 an ' + 'hour. Smithfield never raised them. Instead, they offered ' + '$500 to employees who could go all of April without an ' + 'unapproved day off. The company says their “Responsibility ' + 'Bonuses” show their “immense gratefulness” to employees ' + '“for their selfless sacrifices.”\n\nMeanwhile, the local ' + 'Argus Leader wrote union members wanted essential-worker ' + 'hazard pay, which “would be considered hourly compensation ' + 'about 1.5 or two times their normal pay.” One worker said, ' + '“I feel like they’re bribing us with [the bonus] to come to ' + 'work sick. That’s how you know they don’t care.”\n\nBoth ' + 'Sioux Falls workers killed by Covid-19 were in their ' + 'sixties. It is unconscionable that they were still working. ' + 'All meatpackers over 50 should be on paid leave. Agustín ' + 'Rodríguez, 64, had a rough job sawing the legs off dead ' + 'pigs. He mopped floors with a fever shortly before he was ' + 'hospitalized.\n\nWhen CEO Kenneth Sullivan closed the ' + 'plant, he claimed, “We have continued to run our facilities ' + 'for one reason: to sustain our nation’s food supply.” This ' + 'is an effort to sweep Smithfield’s abuses under the rug, ' + 'as if the company were operating for public benefit. This ' + 'patriotic propaganda that all Americans are in it together ' + 'is like a drug to keep workers from getting ' + 'organized.\n\nThe major union in the industry, including at ' + 'Smithfield, is the United Food and Commercial Workers union ' + '(UFCW). What union leaders have done is ultimately ' + 'troubling.\n\nCan Workers Fight?\n\nLocal AFL-CIO president ' + 'Kooper Caraway has publicly said management delayed safety ' + 'action as long as possible for profit. But while some ' + 'workers were demanding a two-week shutdown, Caraway told ' + 'the Argus Leader that was unrealistic because the ' + 'government considers the plant essential. He suggested the ' + 'union would be happy with minimal safety measures: “Even if ' + '10 people get exposed in a day rather than 11. If you can ' + 'implement a program where even one or two less people get ' + 'exposed during a shift, that’s one or two less people.” Of ' + 'course reducing infections is good, but suggesting workers ' + 'would be satisfied if the company allowed 90% of the ' + 'contagion to continue is horrifying.\n\nThe response of ' + 'UFCW leadership was worse. As the disease was exploding, ' + 'they told the Argus Leader, “We applaud [Smithfield’s] ' + 'decision to temporarily close the plant [over Easter ' + 'weekend] to push for an even safer work environment.” What ' + 'does “even safer” mean in this context?\n\nThe union ' + 'bureaucracy has taken weak action elsewhere. In ' + 'Pennsylvania, the UFCW negotiated $2 hazard pay for two ' + 'months with Cargill Meat — the same pandemic premium Amazon ' + 'gave workers without a union. In Nebraska, the UFCW ' + 'negotiated $4 hazard pay for one month with meat giant ' + 'JBS.\n\nThe union has said nothing about forcing companies ' + 'to send older workers home with pay, even though a ' + '70-year-old shop steward and a 78-year-old grandfather ' + 'working at JBS plants were killed by Covid-19. Smithfield ' + 'workers were promised only two weeks of shutdown pay. For ' + 'many, this compensation is half their normal paycheck ' + 'because they routinely put in 66 hour weeks — overtime that ' + 'costs exhaustion and chronic pain.\n\nUnion officials ' + 'endeavor to cooperate with the meat companies. An Iowa UFCW ' + 'president actually suggested it might be impossible for ' + 'plants to move workers a full six feet apart and told the ' + 'Des Moines Register, “We can’t stop the plants. If we stop ' + 'the plants from running, we stop feeding the country. We ' + 'want to do everything we can to make sure the employees are ' + 'safe to keep the plant running.”\n\nEvery part of this ' + 'explanation directly overlaps with what the Smithfield CEO ' + 'said. Unfortunately, it amounts to accepting the company’s ' + 'excuses.\n\nThey claim that workers who do hard physical ' + 'labor, waking up at 4 a.m. and often working six days a ' + 'week for years, would be guilty of taking food away from ' + 'the people and hurting America if they dared to fight for ' + 'their human needs. But nothing is said about the company ' + 'raking in profits and even murdering workers to increase ' + 'them.\n\nSmithfield’s parent company W.H. Group, ' + 'which slaughters around 30 million pigs per year in plants ' + 'in both the United States and China, saw its profits ' + 'skyrocket by about one third in 2019 to $1.38 billion. It ' + 'is disturbing that UFCW officials do not bring up these ' + 'soaring profits in their response to the outbreaks. Reuters ' + 'published a report on the corporation’s financial success ' + 'in late March. The head of W.H. Group had touted to the ' + 'media that it got through the pandemic in China with very ' + 'limited impact on production.\n\nIt is true that many ' + 'Smithfield workers are reasonably afraid for their jobs and ' + 'want to keep working. A 25-year-old employee explained, ' + '“I have a lot of bills. My baby’s coming soon — I have to ' + 'work.” At the same time, he was afraid of infecting his ' + 'pregnant wife. His spouse, a former employee, ' + 'said bitterly, “Smithfield— they don’t care about ' + 'employees. They only care about their money.”\n\nWorkers ' + 'are pressured in these two painful directions. Nonetheless, ' + 'work can mean solidarity. Before Smithfield even checked ' + 'temperatures, there was a “sick-out” strike without union ' + 'support by 800 to 1,000 workers at a JBS meat factory in ' + 'Colorado. Hundreds of workers also called in sick days at a ' + 'Nebraska JBS plant.\n\nTrade union leaders won’t even ' + 'whisper the word “strike” when thousands of workers are ' + 'thinking about it. They are limiting themselves to polite ' + 'requests. We need a workers’ movement that asks who ' + 'controls the factory, that threatens to disrupt the bosses’ ' + 'profits, and that allows workers to use their immense power ' + '— this could change the meat industry and the world. ' + }, + { + 'text': + 'Smithfield employs 3,700 people at its plant in Sioux ' + 'Falls, South Dakota. The plant slaughters 19,500 pigs a day ' + '— 5 percent of U.S. pork. Most of the workers are ' + 'immigrants from Ethiopia, Mexico, South Sudan, Honduras, ' + 'Myanmar, Somalia, Guatemala, and other poor ' + 'countries.\n\nInevitably workers must pass within one foot ' + 'of hundreds of colleagues in the hallways, locker rooms, ' + 'cafeterias, and cutting lines. The same conditions have ' + 'spurred Covid-19 outbreaks at meat plants from Minnesota ' + 'and Wisconsin to Colorado, Nebraska, Missouri, Iowa, ' + 'Pennsylvania, North Carolina, and Georgia.\n\n801 workers ' + 'at the Sioux Falls plant have tested positive, together ' + 'with 206 people close to them. The outbreak has killed ' + 'Agustín Rodríguez Martínez, aged 64, an employee with two ' + 'decades of experience originally from El Salvador, ' + 'and Craig Allen Franken, 61, who worked for Smithfield his ' + 'entire adult life.\n\nThe company knew of its first ' + 'infection on March 24 or earlier. The virus spread ' + 'exponentially for several weeks. Ahead of Easter Sunday and ' + 'Monday (April 12-13), Smithfield promised to “completely ' + 'shutter” to sanitize and put up cardboard and plastic sheet ' + 'dividers. This would not end transmission, as potentially ' + 'hundreds of staff were already carrying the virus. But even ' + 'during this “shutdown,” many cars were seen in the parking ' + 'lot. The mayor admits that the company lied, and the local ' + 'AFL-CIO alleges the plant ran 60 percent production. On ' + 'Easter, with 238 known infections, Smithfield finally ' + 'agreed to shut down indefinitely after a request from the ' + 'mayor and the governor. Yet the company insisted on waiting ' + 'three more days to actually halt production.\n\nSmithfield ' + 'denied contributing to the outbreak, saying it took a “very ' + 'proactive approach.” Relying on racism, the company blamed ' + 'workers for getting themselves sick. A spokesperson said ' + 'the outbreak was so severe because of the plant’s “large ' + 'immigrant population,” claming “Living circumstances in ' + 'certain cultures are different than they are with your ' + 'traditional American family.” They slandered the workers as ' + 'dirty, ignorant, and untrustworthy with help from governor ' + 'Kristi Noem, who claimed, “99 percent of what’s going on ' + 'today wasn’t happening inside the facility. It was more at ' + 'home, where these employees were going home and spreading ' + 'some of the virus” by living too close together.\n\nOne ' + 'sick worker, Michael Bul Gayo Gatluak, 22 and originally ' + 'from South Sudan, says, “With how we work on the line, ' + 'I would say I got sick because of them not taking safety ' + 'measures.” His job is “really, really close” to other ' + 'workers chopping fresh-killed pigs. “The job is so heavy. ' + 'You have to breathe so hard.”\n\nIn early March, ' + 'union officials requested masks, overcoats, entrance ' + 'checking for fevers, and less crowding in 500-capacity ' + 'cafeterias. But Smithfield waited on most safety measures ' + 'until early April. Only April 6 did they start checking for ' + 'fevers. Instead of protective masks, they gave out beard ' + 'nets.\n\nSmithfield concealed infections with a policy of ' + 'informing only employees whose work stations were in the ' + 'same area as a person who tested positive. The fact that ' + 'workers are required to move around was willfully ignored. ' + 'One worker who tested positive said, “I clearly would have ' + 'gotten it at the factory. This week I have worked on three ' + 'different floors. I’ve eaten in two different cafeterias … ' + 'I’ve been walking through the whole place.” Employees from ' + 'the eighth floor of the plant were quarantined, ' + 'but everyone else was told to keep working.\n\nWhat Is ' + 'Really Going On?\n\nAverage plant wages are around $16 an ' + 'hour. Smithfield never raised them. Instead, they offered ' + '$500 to employees who could go all of April without an ' + 'unapproved day off. The company says their “Responsibility ' + 'Bonuses” show their “immense gratefulness” to employees ' + '“for their selfless sacrifices.”\n\nMeanwhile, the local ' + 'Argus Leader wrote union members wanted essential-worker ' + 'hazard pay, which “would be considered hourly compensation ' + 'about 1.5 or two times their normal pay.” One worker said, ' + '“I feel like they’re bribing us with [the bonus] to come to ' + 'work sick. That’s how you know they don’t care.”\n\nBoth ' + 'Sioux Falls workers killed by Covid-19 were in their ' + 'sixties. It is unconscionable that they were still working. ' + 'All meatpackers over 50 should be on paid leave. Agustín ' + 'Rodríguez, 64, had a rough job sawing the legs off dead ' + 'pigs. He mopped floors with a fever shortly before he was ' + 'hospitalized.\n\nWhen CEO Kenneth Sullivan closed the ' + 'plant, he claimed, “We have continued to run our facilities ' + 'for one reason: to sustain our nation’s food supply.” This ' + 'is an effort to sweep Smithfield’s abuses under the rug, ' + 'as if the company were operating for public benefit. This ' + 'patriotic propaganda that all Americans are in it together ' + 'is like a drug to keep workers from getting organized. ' + }, + { + 'text': + 'Smithfield employs 3,700 people at its plants in Sioux ' + 'Falls, South Dakota. The plant slaughters 19,500 pig a day ' + '— 5 percent of U.S. pork. Most of the workers are ' + 'immigrants from Ethiopia, Mexico, South Sudan, Honduras, ' + 'Myanmar, Somalia, Guatemala, and other poor ' + 'countries.\n\nInevitably workers must pass within one foot ' + 'of hundreds of colleagues in the hallways, locker rooms, ' + 'cafeterias, and cutting lines. The same conditions have ' + 'spurred Covid-19 outbreaks at meat plants from Minnesota ' + 'and Wisconsin to Colorado, Nebraska, Missouri, Iowa, ' + 'Pennsylvania, North Carolina, and Georgia.\n\n801 workers ' + 'at the Sioux Falls plant have tested positive, together ' + 'with 206 people close to them. The outbreak has killed ' + 'Agustín Rodríguez Martínez, aged 64, an employee with two ' + 'decades of experience originally from El Salvador, ' + 'and Craig Allen Franken, 61, who worked for Smithfield his ' + 'entire adult life.\n\nThe company knew of its first ' + 'infection on March 24 or earlier. The virus spread ' + 'exponentially for several weeks. Ahead of Easter Sunday and ' + 'Monday (April 12-13), Smithfield promised to “completely ' + 'shutter” to sanitize and put up cardboard and plastic sheet ' + 'dividers. This would not end transmission, as potentially ' + 'hundreds of staff were already carrying the virus. But even ' + 'during this “shutdown,” many cars were seen in the parking ' + 'lot. The mayor admits that the company lied, and the local ' + 'AFL-CIO alleges the plant ran 60 percent production. On ' + 'Easter, with 238 known infections, Smithfield finally ' + 'agreed to shut down indefinitely after a request from the ' + 'mayor and the governor. Yet the company insisted on waiting ' + 'three more days to actually halt production.\n\nSmithfield ' + 'denied contributing to the outbreak, saying it took a “very ' + 'proactive approach.” Relying on racism, the company blamed ' + 'workers for getting themselves sick. A spokesperson said ' + 'the outbreak was so severe because of the plant’s “large ' + 'immigrant population,” claming “Living circumstances in ' + 'certain cultures are different than they are with your ' + 'traditional American family.” They slandered the workers as ' + 'dirty, ignorant, and untrustworthy with help from governor ' + 'Kristi Noem, who claimed, “99 percent of what’s going on ' + 'today wasn’t happening inside the facility. It was more at ' + 'home, where these employees were going home and spreading ' + 'some of the virus” by living too close together.\n\nOne ' + 'sick worker, Michael Bul Gayo Gatluak, 22 and originally ' + 'from South Sudan, says, “With how we work on the line, ' + 'I would say I got sick because of them not taking safety ' + 'measures.” His job is “really, really close” to other ' + 'workers chopping fresh-killed pigs. “The job is so heavy. ' + 'You have to breathe so hard.”\n\nIn early March, ' + 'union officials requested masks, overcoats, entrance ' + 'checking for fevers, and less crowding in 500-capacity ' + 'cafeterias. But Smithfield waited on most safety measures ' + 'until early April. Only April 6 did they start checking for ' + 'fevers. Instead of protective masks, they gave out beard ' + 'nets.\n\nSmithfield concealed infections with a policy of ' + 'informing only employees whose work stations were in the ' + 'same area as a person who tested positive. The fact that ' + 'workers are required to move around was willfully ignored. ' + 'One worker who tested positive said, “I clearly would have ' + 'gotten it at the factory. This week I have worked on three ' + 'different floors. I’ve eaten in two different cafeterias … ' + 'I’ve been walking through the whole place.” Employees from ' + 'the eighth floor of the plant were quarantined, ' + 'but everyone else was told to keep working.\n\nWhat Is ' + 'Really Going On?\n\nAverage plant wages are around $16 an ' + 'hour. Smithfield never raised them. Instead, they offered ' + '$500 to employees who could go all of April without an ' + 'unapproved day off. The company says their “Responsibility ' + 'Bonuses” show their “immense gratefulness” to employees ' + '“for their selfless sacrifices.”\n\nMeanwhile, the local ' + 'Argus Leader wrote union members wanted essential-worker ' + 'hazard pay, which “would be considered hourly compensation ' + 'about 1.5 or two times their normal pay.” One worker said, ' + '“I feel like they’re bribing us with [the bonus] to come to ' + 'work sick. That’s how you know they don’t care.”\n\nBoth ' + 'Sioux Falls workers killed by Covid-19 were in their ' + 'sixties. It is unconscionable that they were still working. ' + 'All meatpackers over 50 should be on paid leave. Agustín ' + 'Rodríguez, 64, had a rough job sawing the legs off dead ' + 'pigs. He mopped floors with a fever shortly before he was ' + 'hospitalized.\n\nWhen CEO Kenneth Sullivan closed the ' + 'plant, he claimed, “We have continued to run our facilities ' + 'for one reason: to sustain our nation’s food supply.” This ' + 'is an effort to sweep Smithfield’s abuses under the rug, ' + 'as if the company were operating for public benefit. This ' + 'patriotic propaganda that all Americans are in it together ' + 'is like a drug to keep workers from getting organized. ' + }, + { + 'text': + 'This paper proposed a novel method on LLM pretraining.' + }, + ] + tgt_list = [ + { + 'text': 'Today is Sunday and it\'s a happy day!' + }, + { + 'text': 'Do you need a cup of coffee?' + }, + { + 'text': 'Today is sunday and it\'s really a happy day!' + }, + { + 'text': + 'This paper proposed a novel method on LLM pretraining.' + }, + { + 'text': + 'Smithfield employs 3,700 people at its plant in Sioux ' + 'Falls, South Dakota. The plant slaughters 19,500 pigs a day ' + '— 5 percent of U.S. pork. Most of the workers are ' + 'immigrants from Ethiopia, Mexico, South Sudan, Honduras, ' + 'Myanmar, Somalia, Guatemala, and other poor ' + 'countries.\n\nInevitably workers must pass within one foot ' + 'of hundreds of colleagues in the hallways, locker rooms, ' + 'cafeterias, and cutting lines. The same conditions have ' + 'spurred Covid-19 outbreaks at meat plants from Minnesota ' + 'and Wisconsin to Colorado, Nebraska, Missouri, Iowa, ' + 'Pennsylvania, North Carolina, and Georgia.\n\n801 workers ' + 'at the Sioux Falls plant have tested positive, together ' + 'with 206 people close to them. The outbreak has killed ' + 'Agustín Rodríguez Martínez, aged 64, an employee with two ' + 'decades of experience originally from El Salvador, ' + 'and Craig Allen Franken, 61, who worked for Smithfield his ' + 'entire adult life.\n\nThe company knew of its first ' + 'infection on March 24 or earlier. The virus spread ' + 'exponentially for several weeks. Ahead of Easter Sunday and ' + 'Monday (April 12-13), Smithfield promised to “completely ' + 'shutter” to sanitize and put up cardboard and plastic sheet ' + 'dividers. This would not end transmission, as potentially ' + 'hundreds of staff were already carrying the virus. But even ' + 'during this “shutdown,” many cars were seen in the parking ' + 'lot. The mayor admits that the company lied, and the local ' + 'AFL-CIO alleges the plant ran 60 percent production. On ' + 'Easter, with 238 known infections, Smithfield finally ' + 'agreed to shut down indefinitely after a request from the ' + 'mayor and the governor. Yet the company insisted on waiting ' + 'three more days to actually halt production.\n\nSmithfield ' + 'denied contributing to the outbreak, saying it took a “very ' + 'proactive approach.” Relying on racism, the company blamed ' + 'workers for getting themselves sick. A spokesperson said ' + 'the outbreak was so severe because of the plant’s “large ' + 'immigrant population,” claming “Living circumstances in ' + 'certain cultures are different than they are with your ' + 'traditional American family.” They slandered the workers as ' + 'dirty, ignorant, and untrustworthy with help from governor ' + 'Kristi Noem, who claimed, “99 percent of what’s going on ' + 'today wasn’t happening inside the facility. It was more at ' + 'home, where these employees were going home and spreading ' + 'some of the virus” by living too close together.\n\nOne ' + 'sick worker, Michael Bul Gayo Gatluak, 22 and originally ' + 'from South Sudan, says, “With how we work on the line, ' + 'I would say I got sick because of them not taking safety ' + 'measures.” His job is “really, really close” to other ' + 'workers chopping fresh-killed pigs. “The job is so heavy. ' + 'You have to breathe so hard.”\n\nIn early March, ' + 'union officials requested masks, overcoats, entrance ' + 'checking for fevers, and less crowding in 500-capacity ' + 'cafeterias. But Smithfield waited on most safety measures ' + 'until early April. Only April 6 did they start checking for ' + 'fevers. Instead of protective masks, they gave out beard ' + 'nets.\n\nSmithfield concealed infections with a policy of ' + 'informing only employees whose work stations were in the ' + 'same area as a person who tested positive. The fact that ' + 'workers are required to move around was willfully ignored. ' + 'One worker who tested positive said, “I clearly would have ' + 'gotten it at the factory. This week I have worked on three ' + 'different floors. I’ve eaten in two different cafeterias … ' + 'I’ve been walking through the whole place.” Employees from ' + 'the eighth floor of the plant were quarantined, ' + 'but everyone else was told to keep working.\n\nWhat Is ' + 'Really Going On?\n\nAverage plant wages are around $16 an ' + 'hour. Smithfield never raised them. Instead, they offered ' + '$500 to employees who could go all of April without an ' + 'unapproved day off. The company says their “Responsibility ' + 'Bonuses” show their “immense gratefulness” to employees ' + '“for their selfless sacrifices.”\n\nMeanwhile, the local ' + 'Argus Leader wrote union members wanted essential-worker ' + 'hazard pay, which “would be considered hourly compensation ' + 'about 1.5 or two times their normal pay.” One worker said, ' + '“I feel like they’re bribing us with [the bonus] to come to ' + 'work sick. That’s how you know they don’t care.”\n\nBoth ' + 'Sioux Falls workers killed by Covid-19 were in their ' + 'sixties. It is unconscionable that they were still working. ' + 'All meatpackers over 50 should be on paid leave. Agustín ' + 'Rodríguez, 64, had a rough job sawing the legs off dead ' + 'pigs. He mopped floors with a fever shortly before he was ' + 'hospitalized.\n\nWhen CEO Kenneth Sullivan closed the ' + 'plant, he claimed, “We have continued to run our facilities ' + 'for one reason: to sustain our nation’s food supply.” This ' + 'is an effort to sweep Smithfield’s abuses under the rug, ' + 'as if the company were operating for public benefit. This ' + 'patriotic propaganda that all Americans are in it together ' + 'is like a drug to keep workers from getting ' + 'organized.\n\nThe major union in the industry, including at ' + 'Smithfield, is the United Food and Commercial Workers union ' + '(UFCW). What union leaders have done is ultimately ' + 'troubling.\n\nCan Workers Fight?\n\nLocal AFL-CIO president ' + 'Kooper Caraway has publicly said management delayed safety ' + 'action as long as possible for profit. But while some ' + 'workers were demanding a two-week shutdown, Caraway told ' + 'the Argus Leader that was unrealistic because the ' + 'government considers the plant essential. He suggested the ' + 'union would be happy with minimal safety measures: “Even if ' + '10 people get exposed in a day rather than 11. If you can ' + 'implement a program where even one or two less people get ' + 'exposed during a shift, that’s one or two less people.” Of ' + 'course reducing infections is good, but suggesting workers ' + 'would be satisfied if the company allowed 90% of the ' + 'contagion to continue is horrifying.\n\nThe response of ' + 'UFCW leadership was worse. As the disease was exploding, ' + 'they told the Argus Leader, “We applaud [Smithfield’s] ' + 'decision to temporarily close the plant [over Easter ' + 'weekend] to push for an even safer work environment.” What ' + 'does “even safer” mean in this context?\n\nThe union ' + 'bureaucracy has taken weak action elsewhere. In ' + 'Pennsylvania, the UFCW negotiated $2 hazard pay for two ' + 'months with Cargill Meat — the same pandemic premium Amazon ' + 'gave workers without a union. In Nebraska, the UFCW ' + 'negotiated $4 hazard pay for one month with meat giant ' + 'JBS.\n\nThe union has said nothing about forcing companies ' + 'to send older workers home with pay, even though a ' + '70-year-old shop steward and a 78-year-old grandfather ' + 'working at JBS plants were killed by Covid-19. Smithfield ' + 'workers were promised only two weeks of shutdown pay. For ' + 'many, this compensation is half their normal paycheck ' + 'because they routinely put in 66 hour weeks — overtime that ' + 'costs exhaustion and chronic pain.\n\nUnion officials ' + 'endeavor to cooperate with the meat companies. An Iowa UFCW ' + 'president actually suggested it might be impossible for ' + 'plants to move workers a full six feet apart and told the ' + 'Des Moines Register, “We can’t stop the plants. If we stop ' + 'the plants from running, we stop feeding the country. We ' + 'want to do everything we can to make sure the employees are ' + 'safe to keep the plant running.”\n\nEvery part of this ' + 'explanation directly overlaps with what the Smithfield CEO ' + 'said. Unfortunately, it amounts to accepting the company’s ' + 'excuses.\n\nThey claim that workers who do hard physical ' + 'labor, waking up at 4 a.m. and often working six days a ' + 'week for years, would be guilty of taking food away from ' + 'the people and hurting America if they dared to fight for ' + 'their human needs. But nothing is said about the company ' + 'raking in profits and even murdering workers to increase ' + 'them.\n\nSmithfield’s parent company W.H. Group, ' + 'which slaughters around 30 million pigs per year in plants ' + 'in both the United States and China, saw its profits ' + 'skyrocket by about one third in 2019 to $1.38 billion. It ' + 'is disturbing that UFCW officials do not bring up these ' + 'soaring profits in their response to the outbreaks. Reuters ' + 'published a report on the corporation’s financial success ' + 'in late March. The head of W.H. Group had touted to the ' + 'media that it got through the pandemic in China with very ' + 'limited impact on production.\n\nIt is true that many ' + 'Smithfield workers are reasonably afraid for their jobs and ' + 'want to keep working. A 25-year-old employee explained, ' + '“I have a lot of bills. My baby’s coming soon — I have to ' + 'work.” At the same time, he was afraid of infecting his ' + 'pregnant wife. His spouse, a former employee, ' + 'said bitterly, “Smithfield— they don’t care about ' + 'employees. They only care about their money.”\n\nWorkers ' + 'are pressured in these two painful directions. Nonetheless, ' + 'work can mean solidarity. Before Smithfield even checked ' + 'temperatures, there was a “sick-out” strike without union ' + 'support by 800 to 1,000 workers at a JBS meat factory in ' + 'Colorado. Hundreds of workers also called in sick days at a ' + 'Nebraska JBS plant.\n\nTrade union leaders won’t even ' + 'whisper the word “strike” when thousands of workers are ' + 'thinking about it. They are limiting themselves to polite ' + 'requests. We need a workers’ movement that asks who ' + 'controls the factory, that threatens to disrupt the bosses’ ' + 'profits, and that allows workers to use their immense power ' + '— this could change the meat industry and the world. ' + }, + { + 'text': + 'Smithfield employs 3,700 people at its plant in Sioux ' + 'Falls, South Dakota. The plant slaughters 19,500 pigs a day ' + '— 5 percent of U.S. pork. Most of the workers are ' + 'immigrants from Ethiopia, Mexico, South Sudan, Honduras, ' + 'Myanmar, Somalia, Guatemala, and other poor ' + 'countries.\n\nInevitably workers must pass within one foot ' + 'of hundreds of colleagues in the hallways, locker rooms, ' + 'cafeterias, and cutting lines. The same conditions have ' + 'spurred Covid-19 outbreaks at meat plants from Minnesota ' + 'and Wisconsin to Colorado, Nebraska, Missouri, Iowa, ' + 'Pennsylvania, North Carolina, and Georgia.\n\n801 workers ' + 'at the Sioux Falls plant have tested positive, together ' + 'with 206 people close to them. The outbreak has killed ' + 'Agustín Rodríguez Martínez, aged 64, an employee with two ' + 'decades of experience originally from El Salvador, ' + 'and Craig Allen Franken, 61, who worked for Smithfield his ' + 'entire adult life.\n\nThe company knew of its first ' + 'infection on March 24 or earlier. The virus spread ' + 'exponentially for several weeks. Ahead of Easter Sunday and ' + 'Monday (April 12-13), Smithfield promised to “completely ' + 'shutter” to sanitize and put up cardboard and plastic sheet ' + 'dividers. This would not end transmission, as potentially ' + 'hundreds of staff were already carrying the virus. But even ' + 'during this “shutdown,” many cars were seen in the parking ' + 'lot. The mayor admits that the company lied, and the local ' + 'AFL-CIO alleges the plant ran 60 percent production. On ' + 'Easter, with 238 known infections, Smithfield finally ' + 'agreed to shut down indefinitely after a request from the ' + 'mayor and the governor. Yet the company insisted on waiting ' + 'three more days to actually halt production.\n\nSmithfield ' + 'denied contributing to the outbreak, saying it took a “very ' + 'proactive approach.” Relying on racism, the company blamed ' + 'workers for getting themselves sick. A spokesperson said ' + 'the outbreak was so severe because of the plant’s “large ' + 'immigrant population,” claming “Living circumstances in ' + 'certain cultures are different than they are with your ' + 'traditional American family.” They slandered the workers as ' + 'dirty, ignorant, and untrustworthy with help from governor ' + 'Kristi Noem, who claimed, “99 percent of what’s going on ' + 'today wasn’t happening inside the facility. It was more at ' + 'home, where these employees were going home and spreading ' + 'some of the virus” by living too close together.\n\nOne ' + 'sick worker, Michael Bul Gayo Gatluak, 22 and originally ' + 'from South Sudan, says, “With how we work on the line, ' + 'I would say I got sick because of them not taking safety ' + 'measures.” His job is “really, really close” to other ' + 'workers chopping fresh-killed pigs. “The job is so heavy. ' + 'You have to breathe so hard.”\n\nIn early March, ' + 'union officials requested masks, overcoats, entrance ' + 'checking for fevers, and less crowding in 500-capacity ' + 'cafeterias. But Smithfield waited on most safety measures ' + 'until early April. Only April 6 did they start checking for ' + 'fevers. Instead of protective masks, they gave out beard ' + 'nets.\n\nSmithfield concealed infections with a policy of ' + 'informing only employees whose work stations were in the ' + 'same area as a person who tested positive. The fact that ' + 'workers are required to move around was willfully ignored. ' + 'One worker who tested positive said, “I clearly would have ' + 'gotten it at the factory. This week I have worked on three ' + 'different floors. I’ve eaten in two different cafeterias … ' + 'I’ve been walking through the whole place.” Employees from ' + 'the eighth floor of the plant were quarantined, ' + 'but everyone else was told to keep working.\n\nWhat Is ' + 'Really Going On?\n\nAverage plant wages are around $16 an ' + 'hour. Smithfield never raised them. Instead, they offered ' + '$500 to employees who could go all of April without an ' + 'unapproved day off. The company says their “Responsibility ' + 'Bonuses” show their “immense gratefulness” to employees ' + '“for their selfless sacrifices.”\n\nMeanwhile, the local ' + 'Argus Leader wrote union members wanted essential-worker ' + 'hazard pay, which “would be considered hourly compensation ' + 'about 1.5 or two times their normal pay.” One worker said, ' + '“I feel like they’re bribing us with [the bonus] to come to ' + 'work sick. That’s how you know they don’t care.”\n\nBoth ' + 'Sioux Falls workers killed by Covid-19 were in their ' + 'sixties. It is unconscionable that they were still working. ' + 'All meatpackers over 50 should be on paid leave. Agustín ' + 'Rodríguez, 64, had a rough job sawing the legs off dead ' + 'pigs. He mopped floors with a fever shortly before he was ' + 'hospitalized.\n\nWhen CEO Kenneth Sullivan closed the ' + 'plant, he claimed, “We have continued to run our facilities ' + 'for one reason: to sustain our nation’s food supply.” This ' + 'is an effort to sweep Smithfield’s abuses under the rug, ' + 'as if the company were operating for public benefit. This ' + 'patriotic propaganda that all Americans are in it together ' + 'is like a drug to keep workers from getting organized. ' + }, + ] + dataset = Dataset.from_list(ds_list) + op = DocumentMinhashDeduplicator(ignore_pattern=r'\p{P}') + self._run_minhash_dedup(dataset, tgt_list, op) + + def test_chinese_deduplication(self): + ds_list = [ + { + 'text': '你好,请问你是谁' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + { + 'text': + '第九届会议\n2003年7月28日至8月8日\n牙买加金斯敦\n为来自发展中国家的法' + '律和技术委员会以及财务委员会成员\n参加委员会会议支付费用的方式\n1. 国际' + '海底管理局大会第八届会议请秘书长采取一项临时措施,设立一个自愿信托基金,' + '以便支付来自发展中国家的法律和技术委员会成员以及来自发展中国家的财务委员' + '会成员参加委员会会议的费用。\n2. 由于秘书长向会员国发出为该信托基金捐款' + '的请求,已收到三笔捐款,共计10 500美元。 管理局已为基金设立一个单独的账' + '户。\n3. 管理局第八届会议还决定,由财务委员会审查资助参加这两个委员会会' + '议的方式,包括审查是否可能从管理局行政预算中提供经费。\n4. 自愿信托基金' + '迄今收到的捐款数额很小。 这两个委员会成员虽然由缔约国提名,但他们以个人身' + '份当选。 因此,必须确保这些机构的成员在任期内能够参加会议并且持续不断地履' + '行职务。 现已注意到,这两个委员会若干成员因旅费和生活津贴费用方面有困难而' + '未能出席会议。 来自发展中国家成员参加会议的费用估计数见附件,其中比较了经' + '济舱和公务舱机票价格以及适用于金斯敦的每日生活津贴费用。 从表中可以看出,' + '根据不同的人数、机舱等级和会议持续时间,每年平均需要捐款120 000美元至' + '215 000美元。\n5. 为了指导委员会确定提供经费的方式,对某些国际组织的现' + '行办法作了一次简要调查。 为支付参加会议的旅费和生活费而设立信托基金最相关' + '的实例是2000年大会为来自发展中国家的大陆架界限委员会成员设立的自愿信托基' + '金。 目前这一基金正在运作,但现有资源有限。 联合国制定的程序表明,委员会' + '成员的政府应在规定时间内尽可能提前提出请求。 这种请求按照先到先核可的办法' + '处理。 提供的机票将是最直接路线的经济舱机票,每日生活津贴将按照联合国费率' + '提供。 购买机票的所有安排均由联合国秘书处执行。\n6. 虽然已经设立了临时性' + '的自愿信托基金,但是,对该基金的捐款数额很小,捐款速度很慢。 因此,除了对' + '信托基金提供自愿捐款的办法之外,建议委员会还可以考虑采用下列办法:\n(a) ' + '从管理局一般行政经费累计利息中拨出一定数额的经费;\n(b) 每年从上一年预算' + '未动用部分中拨出规定的数额;\n(c) 从先驱投资者基金利息中拨出规定的数额。' + '\n7. 委员会还不妨建议由管理局秘书处依照行政规则和程序管理该基金,并向财' + '务委员会提出一份报告。\n附件\n资助来自发展中国家的法律和技术委员会以及财' + '务\n委员会成员出席会议的指示性费用(美元)\n成员\n机票\n机场\n费用\n金' + '斯敦每日生活\n津贴\n转机途中每日生活\n7日\n共计\n14日\n经济舱\n公务舱' + '\n7天=(8天每日生活\n津贴)\n14天= (15天每日生活津贴)\n商务舱\n法律和技' + '术委员会\n印度尼西亚\n(纽约)\n黎巴嫩\n巴基斯坦\n阿根廷\n喀麦隆\n墨西哥' + '\n巴西\n塞内加尔\n莫桑比克\n埃及(纽约)\n大韩民国\n印度\n斐济\n智利\n' + '中国\n纳米比亚\n小计\n财务委员会\n缅甸\n乌干达\n牙买加\n印度(纽约)\n尼' + '日利亚\n总计\n注:估计费用表表明每年资助每个机构一次会议需要经费120 000' + '美元至215 000美元(四舍五入)。' + }, + { + 'text': + '第九届会议\n时间:2003年7月28日至8月8日\n牙买加金斯敦\n为来自发展中国家的法' + '律和技术委员会以及财务委员会成员\n参加委员会会议支付费用的方式\n1. 国际' + '海底管理局大会第八届会议请秘书长采取一项临时措施,设立一个自愿信托基金,' + '以便支付来自发展中国家的法律和技术委员会成员以及来自发展中国家的财务委员' + '会成员参加委员会会议的费用。\n2. 由于秘书长向会员国发出为该信托基金捐款' + '的请求,已收到三笔捐款,共计10 500美元。 管理局已为基金设立一个单独的账' + '户。\n3. 管理局第八届会议还决定,由财务委员会审查资助参加这两个委员会会' + '议的方式,包括审查是否可能从管理局行政预算中提供经费。\n4. 自愿信托基金' + '迄今收到的捐款数额很小。 这两个委员会成员虽然由缔约国提名,但他们以个人身' + '份当选。 因此,必须确保这些机构的成员在任期内能够参加会议并且持续不断地履' + '行职务。 现已注意到,这两个委员会若干成员因旅费和生活津贴费用方面有困难而' + '未能出席会议。 来自发展中国家成员参加会议的费用估计数见附件,其中比较了经' + '济舱和公务舱机票价格以及适用于金斯敦的每日生活津贴费用。 从表中可以看出,' + '根据不同的人数、机舱等级和会议持续时间,每年平均需要捐款120 000美元至' + '215 000美元。\n5. 为了指导委员会确定提供经费的方式,对某些国际组织的现' + '行办法作了一次简要调查。 为支付参加会议的旅费和生活费而设立信托基金最相关' + '的实例是2000年大会为来自发展中国家的大陆架界限委员会成员设立的自愿信托基' + '金。 目前这一基金正在运作,但现有资源有限。 联合国制定的程序表明,委员会' + '成员的政府应在规定时间内尽可能提前提出请求。 这种请求按照先到先核可的办法' + '处理。 提供的机票将是最直接路线的经济舱机票,每日生活津贴将按照联合国费率' + '提供。 购买机票的所有安排均由联合国秘书处执行。\n6. 虽然已经设立了临时性' + '的自愿信托基金,但是,对该基金的捐款数额很小,捐款速度很慢。 因此,除了对' + '信托基金提供自愿捐款的办法之外,建议委员会还可以考虑采用下列办法:\n(a) ' + '从管理局一般行政经费累计利息中拨出一定数额的经费;\n(b) 每年从上一年预算' + '未动用部分中拨出规定的数额;\n(c) 从先驱投资者基金利息中拨出规定的数额。' + '\n7. 委员会还不妨建议由管理局秘书处依照行政规则和程序管理该基金,并向财' + '务委员会提出一份报告。\n附件\n资助来自发展中国家的法律和技术委员会以及财' + '务\n委员会成员出席会议的指示性费用(美元)\n成员\n机票\n机场\n费用\n金' + '斯敦每日生活\n津贴\n转机途中每日生活\n7日\n共计\n14日\n经济舱\n公务舱' + '\n7天=(8天每日生活\n津贴)\n14天= (15天每日生活津贴)\n商务舱\n法律和技' + '术委员会\n印度尼西亚\n(纽约)\n黎巴嫩\n巴基斯坦\n阿根廷\n喀麦隆\n墨西哥' + '\n巴西\n塞内加尔\n莫桑比克\n埃及(纽约)\n大韩民国\n印度\n斐济\n智利\n' + '中国\n纳米比亚\n小计\n财务委员会\n缅甸\n乌干达\n牙买加\n印度(纽约)\n尼' + '日利亚\n总计\n注:估计费用表表明每年资助每个机构一次会议需要经费120 000' + '美元至215 000美元(四舍五入)。' + }, + ] + tgt_list = [ + { + 'text': '你好,请问你是谁' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + { + 'text': + '第九届会议\n2003年7月28日至8月8日\n牙买加金斯敦\n为来自发展中国家的法' + '律和技术委员会以及财务委员会成员\n参加委员会会议支付费用的方式\n1. 国际' + '海底管理局大会第八届会议请秘书长采取一项临时措施,设立一个自愿信托基金,' + '以便支付来自发展中国家的法律和技术委员会成员以及来自发展中国家的财务委员' + '会成员参加委员会会议的费用。\n2. 由于秘书长向会员国发出为该信托基金捐款' + '的请求,已收到三笔捐款,共计10 500美元。 管理局已为基金设立一个单独的账' + '户。\n3. 管理局第八届会议还决定,由财务委员会审查资助参加这两个委员会会' + '议的方式,包括审查是否可能从管理局行政预算中提供经费。\n4. 自愿信托基金' + '迄今收到的捐款数额很小。 这两个委员会成员虽然由缔约国提名,但他们以个人身' + '份当选。 因此,必须确保这些机构的成员在任期内能够参加会议并且持续不断地履' + '行职务。 现已注意到,这两个委员会若干成员因旅费和生活津贴费用方面有困难而' + '未能出席会议。 来自发展中国家成员参加会议的费用估计数见附件,其中比较了经' + '济舱和公务舱机票价格以及适用于金斯敦的每日生活津贴费用。 从表中可以看出,' + '根据不同的人数、机舱等级和会议持续时间,每年平均需要捐款120 000美元至' + '215 000美元。\n5. 为了指导委员会确定提供经费的方式,对某些国际组织的现' + '行办法作了一次简要调查。 为支付参加会议的旅费和生活费而设立信托基金最相关' + '的实例是2000年大会为来自发展中国家的大陆架界限委员会成员设立的自愿信托基' + '金。 目前这一基金正在运作,但现有资源有限。 联合国制定的程序表明,委员会' + '成员的政府应在规定时间内尽可能提前提出请求。 这种请求按照先到先核可的办法' + '处理。 提供的机票将是最直接路线的经济舱机票,每日生活津贴将按照联合国费率' + '提供。 购买机票的所有安排均由联合国秘书处执行。\n6. 虽然已经设立了临时性' + '的自愿信托基金,但是,对该基金的捐款数额很小,捐款速度很慢。 因此,除了对' + '信托基金提供自愿捐款的办法之外,建议委员会还可以考虑采用下列办法:\n(a) ' + '从管理局一般行政经费累计利息中拨出一定数额的经费;\n(b) 每年从上一年预算' + '未动用部分中拨出规定的数额;\n(c) 从先驱投资者基金利息中拨出规定的数额。' + '\n7. 委员会还不妨建议由管理局秘书处依照行政规则和程序管理该基金,并向财' + '务委员会提出一份报告。\n附件\n资助来自发展中国家的法律和技术委员会以及财' + '务\n委员会成员出席会议的指示性费用(美元)\n成员\n机票\n机场\n费用\n金' + '斯敦每日生活\n津贴\n转机途中每日生活\n7日\n共计\n14日\n经济舱\n公务舱' + '\n7天=(8天每日生活\n津贴)\n14天= (15天每日生活津贴)\n商务舱\n法律和技' + '术委员会\n印度尼西亚\n(纽约)\n黎巴嫩\n巴基斯坦\n阿根廷\n喀麦隆\n墨西哥' + '\n巴西\n塞内加尔\n莫桑比克\n埃及(纽约)\n大韩民国\n印度\n斐济\n智利\n' + '中国\n纳米比亚\n小计\n财务委员会\n缅甸\n乌干达\n牙买加\n印度(纽约)\n尼' + '日利亚\n总计\n注:估计费用表表明每年资助每个机构一次会议需要经费120 000' + '美元至215 000美元(四舍五入)。' + }, + ] + dataset = Dataset.from_list(ds_list) + op = DocumentMinhashDeduplicator(tokenization='character', + ignore_pattern=r'\p{P}') + self._run_minhash_dedup(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/deduplicator/test_document_simhash_deduplicator.py b/tests/ops/deduplicator/test_document_simhash_deduplicator.py new file mode 100644 index 000000000..d021423c0 --- /dev/null +++ b/tests/ops/deduplicator/test_document_simhash_deduplicator.py @@ -0,0 +1,962 @@ +import unittest + +from datasets import Dataset + +from data_juicer.ops.deduplicator.document_simhash_deduplicator import \ + DocumentSimhashDeduplicator + + +class DocumentSimhashDeduplicatorTest(unittest.TestCase): + + def _run_simhash_dedup(self, dataset: Dataset, target_list, op): + dataset = dataset.map(op.compute_hash) + dataset, _ = op.process(dataset) + dataset = dataset.select_columns(column_names=['text']) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_english_deduplication(self): + ds_list = [ + { + 'text': 'Today is Sunday and it\'s a happy day!' + }, + { + 'text': 'Do you need a cup of coffee?' + }, + { + 'text': 'Today is sunday and it\'s really a happy day!' + }, + { + 'text': + 'This paper proposed a novel method on LLM pretraining.' + }, + { + 'text': + 'Smithfield employs 3,700 people at its plant in Sioux Falls, ' + 'South Dakota. The plant slaughters 19,500 pigs a day — 5 ' + 'percent of U.S. pork. Most of the workers are immigrants ' + 'from Ethiopia, Mexico, South Sudan, Honduras, Myanmar, ' + 'Somalia, Guatemala, and other poor countries.\n\nInevitably ' + 'workers must pass within one foot of hundreds of colleagues ' + 'in the hallways, locker rooms, cafeterias, and cutting ' + 'lines. The same conditions have spurred Covid-19 outbreaks ' + 'at meat plants from Minnesota and Wisconsin to Colorado, ' + 'Nebraska, Missouri, Iowa, Pennsylvania, North Carolina, and ' + 'Georgia.\n\n801 workers at the Sioux Falls plant have tested ' + 'positive, together with 206 people close to them. The ' + 'outbreak has killed Agustín Rodríguez Martínez, aged 64, an ' + 'employee with two decades of experience originally from El ' + 'Salvador, and Craig Allen Franken, 61, who worked for ' + 'Smithfield his entire adult life.\n\nThe company knew of its ' + 'first infection on March 24 or earlier. The virus spread ' + 'exponentially for several weeks. Ahead of Easter Sunday and ' + 'Monday (April 12-13), Smithfield promised to “completely ' + 'shutter” to sanitize and put up cardboard and plastic sheet ' + 'dividers. This would not end transmission, as potentially ' + 'hundreds of staff were already carrying the virus. But even ' + 'during this “shutdown,” many cars were seen in the parking ' + 'lot. The mayor admits that the company lied, and the local ' + 'AFL-CIO alleges the plant ran 60 percent production. On ' + 'Easter, with 238 known infections, Smithfield finally ' + 'agreed to shut down indefinitely after a request from the ' + 'mayor and the governor. Yet the company insisted on waiting ' + 'three more days to actually halt production.\n\nSmithfield ' + 'denied contributing to the outbreak, saying it took a “very ' + 'proactive approach.” Relying on racism, the company blamed ' + 'workers for getting themselves sick. A spokesperson said ' + 'the outbreak was so severe because of the plant’s “large ' + 'immigrant population,” claming “Living circumstances in ' + 'certain cultures are different than they are with your ' + 'traditional American family.” They slandered the workers as ' + 'dirty, ignorant, and untrustworthy with help from governor ' + 'Kristi Noem, who claimed, “99 percent of what’s going on ' + 'today wasn’t happening inside the facility. It was more at ' + 'home, where these employees were going home and spreading ' + 'some of the virus” by living too close together.\n\nOne ' + 'sick worker, Michael Bul Gayo Gatluak, 22 and originally ' + 'from South Sudan, says, “With how we work on the line, ' + 'I would say I got sick because of them not taking safety ' + 'measures.” His job is “really, really close” to other ' + 'workers chopping fresh-killed pigs. “The job is so heavy. ' + 'You have to breathe so hard.”\n\nIn early March, ' + 'union officials requested masks, overcoats, entrance ' + 'checking for fevers, and less crowding in 500-capacity ' + 'cafeterias. But Smithfield waited on most safety measures ' + 'until early April. Only April 6 did they start checking for ' + 'fevers. Instead of protective masks, they gave out beard ' + 'nets.\n\nSmithfield concealed infections with a policy of ' + 'informing only employees whose work stations were in the ' + 'same area as a person who tested positive. The fact that ' + 'workers are required to move around was willfully ignored. ' + 'One worker who tested positive said, “I clearly would have ' + 'gotten it at the factory. This week I have worked on three ' + 'different floors. I’ve eaten in two different cafeterias … ' + 'I’ve been walking through the whole place.” Employees from ' + 'the eighth floor of the plant were quarantined, ' + 'but everyone else was told to keep working.\n\nWhat Is ' + 'Really Going On?\n\nAverage plant wages are around $16 an ' + 'hour. Smithfield never raised them. Instead, they offered ' + '$500 to employees who could go all of April without an ' + 'unapproved day off. The company says their “Responsibility ' + 'Bonuses” show their “immense gratefulness” to employees ' + '“for their selfless sacrifices.”\n\nMeanwhile, the local ' + 'Argus Leader wrote union members wanted essential-worker ' + 'hazard pay, which “would be considered hourly compensation ' + 'about 1.5 or two times their normal pay.” One worker said, ' + '“I feel like they’re bribing us with [the bonus] to come to ' + 'work sick. That’s how you know they don’t care.”\n\nBoth ' + 'Sioux Falls workers killed by Covid-19 were in their ' + 'sixties. It is unconscionable that they were still working. ' + 'All meatpackers over 50 should be on paid leave. Agustín ' + 'Rodríguez, 64, had a rough job sawing the legs off dead ' + 'pigs. He mopped floors with a fever shortly before he was ' + 'hospitalized.\n\nWhen CEO Kenneth Sullivan closed the ' + 'plant, he claimed, “We have continued to run our facilities ' + 'for one reason: to sustain our nation’s food supply.” This ' + 'is an effort to sweep Smithfield’s abuses under the rug, ' + 'as if the company were operating for public benefit. This ' + 'patriotic propaganda that all Americans are in it together ' + 'is like a drug to keep workers from getting ' + 'organized.\n\nThe major union in the industry, including at ' + 'Smithfield, is the United Food and Commercial Workers union ' + '(UFCW). What union leaders have done is ultimately ' + 'troubling.\n\nCan Workers Fight?\n\nLocal AFL-CIO president ' + 'Kooper Caraway has publicly said management delayed safety ' + 'action as long as possible for profit. But while some ' + 'workers were demanding a two-week shutdown, Caraway told ' + 'the Argus Leader that was unrealistic because the ' + 'government considers the plant essential. He suggested the ' + 'union would be happy with minimal safety measures: “Even if ' + '10 people get exposed in a day rather than 11. If you can ' + 'implement a program where even one or two less people get ' + 'exposed during a shift, that’s one or two less people.” Of ' + 'course reducing infections is good, but suggesting workers ' + 'would be satisfied if the company allowed 90% of the ' + 'contagion to continue is horrifying.\n\nThe response of ' + 'UFCW leadership was worse. As the disease was exploding, ' + 'they told the Argus Leader, “We applaud [Smithfield’s] ' + 'decision to temporarily close the plant [over Easter ' + 'weekend] to push for an even safer work environment.” What ' + 'does “even safer” mean in this context?\n\nThe union ' + 'bureaucracy has taken weak action elsewhere. In ' + 'Pennsylvania, the UFCW negotiated $2 hazard pay for two ' + 'months with Cargill Meat — the same pandemic premium Amazon ' + 'gave workers without a union. In Nebraska, the UFCW ' + 'negotiated $4 hazard pay for one month with meat giant ' + 'JBS.\n\nThe union has said nothing about forcing companies ' + 'to send older workers home with pay, even though a ' + '70-year-old shop steward and a 78-year-old grandfather ' + 'working at JBS plants were killed by Covid-19. Smithfield ' + 'workers were promised only two weeks of shutdown pay. For ' + 'many, this compensation is half their normal paycheck ' + 'because they routinely put in 66 hour weeks — overtime that ' + 'costs exhaustion and chronic pain.\n\nUnion officials ' + 'endeavor to cooperate with the meat companies. An Iowa UFCW ' + 'president actually suggested it might be impossible for ' + 'plants to move workers a full six feet apart and told the ' + 'Des Moines Register, “We can’t stop the plants. If we stop ' + 'the plants from running, we stop feeding the country. We ' + 'want to do everything we can to make sure the employees are ' + 'safe to keep the plant running.”\n\nEvery part of this ' + 'explanation directly overlaps with what the Smithfield CEO ' + 'said. Unfortunately, it amounts to accepting the company’s ' + 'excuses.\n\nThey claim that workers who do hard physical ' + 'labor, waking up at 4 a.m. and often working six days a ' + 'week for years, would be guilty of taking food away from ' + 'the people and hurting America if they dared to fight for ' + 'their human needs. But nothing is said about the company ' + 'raking in profits and even murdering workers to increase ' + 'them.\n\nSmithfield’s parent company W.H. Group, ' + 'which slaughters around 30 million pigs per year in plants ' + 'in both the United States and China, saw its profits ' + 'skyrocket by about one third in 2019 to $1.38 billion. It ' + 'is disturbing that UFCW officials do not bring up these ' + 'soaring profits in their response to the outbreaks. Reuters ' + 'published a report on the corporation’s financial success ' + 'in late March. The head of W.H. Group had touted to the ' + 'media that it got through the pandemic in China with very ' + 'limited impact on production.\n\nIt is true that many ' + 'Smithfield workers are reasonably afraid for their jobs and ' + 'want to keep working. A 25-year-old employee explained, ' + '“I have a lot of bills. My baby’s coming soon — I have to ' + 'work.” At the same time, he was afraid of infecting his ' + 'pregnant wife. His spouse, a former employee, ' + 'said bitterly, “Smithfield— they don’t care about ' + 'employees. They only care about their money.”\n\nWorkers ' + 'are pressured in these two painful directions. Nonetheless, ' + 'work can mean solidarity. Before Smithfield even checked ' + 'temperatures, there was a “sick-out” strike without union ' + 'support by 800 to 1,000 workers at a JBS meat factory in ' + 'Colorado. Hundreds of workers also called in sick days at a ' + 'Nebraska JBS plant.\n\nTrade union leaders won’t even ' + 'whisper the word “strike” when thousands of workers are ' + 'thinking about it. They are limiting themselves to polite ' + 'requests. We need a workers’ movement that asks who ' + 'controls the factory, that threatens to disrupt the bosses’ ' + 'profits, and that allows workers to use their immense power ' + '— this could change the meat industry and the world. ' + }, + { + 'text': + 'Smithfield employs 3,700 people at its plants in Sioux ' + 'Falls, South Dakota. The plant slaughters 19,500 pig a day ' + '— 5 percent of U.S. pork. Most of the workers are ' + 'immigrants from Ethiopia, Mexico, South Sudan, Honduras, ' + 'Myanmar, Somalia, Guatemala, and other poor ' + 'countries.\n\nInevitably workers must pass within one foot ' + 'of hundreds of colleagues in the hallways, locker rooms, ' + 'cafeterias, and cutting lines. The same conditions have ' + 'spurred Covid-19 outbreaks at meat plants from Minnesota ' + 'and Wisconsin to Colorado, Nebraska, Missouri, Iowa, ' + 'Pennsylvania, North Carolina, and Georgia.\n\n801 workers ' + 'at the Sioux Falls plant have tested positive, together ' + 'with 206 people close to them. The outbreak has killed ' + 'Agustín Rodríguez Martínez, aged 64, an employee with two ' + 'decades of experience originally from El Salvador, ' + 'and Craig Allen Franken, 61, who worked for Smithfield his ' + 'entire adult life.\n\nThe company knew of its first ' + 'infection on March 24 or earlier. The virus spread ' + 'exponentially for several weeks. Ahead of Easter Sunday and ' + 'Monday (April 12-13), Smithfield promised to “completely ' + 'shutter” to sanitize and put up cardboard and plastic sheet ' + 'dividers. This would not end transmission, as potentially ' + 'hundreds of staff were already carrying the virus. But even ' + 'during this “shutdown,” many cars were seen in the parking ' + 'lot. The mayor admits that the company lied, and the local ' + 'AFL-CIO alleges the plant ran 60 percent production. On ' + 'Easter, with 238 known infections, Smithfield finally ' + 'agreed to shut down indefinitely after a request from the ' + 'mayor and the governor. Yet the company insisted on waiting ' + 'three more days to actually halt production.\n\nSmithfield ' + 'denied contributing to the outbreak, saying it took a “very ' + 'proactive approach.” Relying on racism, the company blamed ' + 'workers for getting themselves sick. A spokesperson said ' + 'the outbreak was so severe because of the plant’s “large ' + 'immigrant population,” claming “Living circumstances in ' + 'certain cultures are different than they are with your ' + 'traditional American family.” They slandered the workers as ' + 'dirty, ignorant, and untrustworthy with help from governor ' + 'Kristi Noem, who claimed, “99 percent of what’s going on ' + 'today wasn’t happening inside the facility. It was more at ' + 'home, where these employees were going home and spreading ' + 'some of the virus” by living too close together.\n\nOne ' + 'sick worker, Michael Bul Gayo Gatluak, 22 and originally ' + 'from South Sudan, says, “With how we work on the line, ' + 'I would say I got sick because of them not taking safety ' + 'measures.” His job is “really, really close” to other ' + 'workers chopping fresh-killed pigs. “The job is so heavy. ' + 'You have to breathe so hard.”\n\nIn early March, ' + 'union officials requested masks, overcoats, entrance ' + 'checking for fevers, and less crowding in 500-capacity ' + 'cafeterias. But Smithfield waited on most safety measures ' + 'until early April. Only April 6 did they start checking for ' + 'fevers. Instead of protective masks, they gave out beard ' + 'nets.\n\nSmithfield concealed infections with a policy of ' + 'informing only employees whose work stations were in the ' + 'same area as a person who tested positive. The fact that ' + 'workers are required to move around was willfully ignored. ' + 'One worker who tested positive said, “I clearly would have ' + 'gotten it at the factory. This week I have worked on three ' + 'different floors. I’ve eaten in two different cafeterias … ' + 'I’ve been walking through the whole place.” Employees from ' + 'the eighth floor of the plant were quarantined, ' + 'but everyone else was told to keep working.\n\nWhat Is ' + 'Really Going On?\n\nAverage plant wages are around $16 an ' + 'hour. Smithfield never raised them. Instead, they offered ' + '$500 to employees who could go all of April without an ' + 'unapproved day off. The company says their “Responsibility ' + 'Bonuses” show their “immense gratefulness” to employees ' + '“for their selfless sacrifices.”\n\nMeanwhile, the local ' + 'Argus Leader wrote union members wanted essential-worker ' + 'hazard pay, which “would be considered hourly compensation ' + 'about 1.5 or two times their normal pay.” One worker said, ' + '“I feel like they’re bribing us with [the bonus] to come to ' + 'work sick. That’s how you know they don’t care.”\n\nBoth ' + 'Sioux Falls workers killed by Covid-19 were in their ' + 'sixties. It is unconscionable that they were still working. ' + 'All meatpackers over 50 should be on paid leave. Agustín ' + 'Rodríguez, 64, had a rough job sawing the legs off dead ' + 'pigs. He mopped floors with a fever shortly before he was ' + 'hospitalized.\n\nWhen CEO Kenneth Sullivan closed the ' + 'plant, he claimed, “We have continued to run our facilities ' + 'for one reason: to sustain our nation’s food supply.” This ' + 'is an effort to sweep Smithfield’s abuses under the rug, ' + 'as if the company were operating for public benefit. This ' + 'patriotic propaganda that all Americans are in it together ' + 'is like a drug to keep workers from getting ' + 'organized.\n\nThe major union in the industry, including at ' + 'Smithfield, is the United Food and Commercial Workers union ' + '(UFCW). What union leaders have done is ultimately ' + 'troubling.\n\nCan Workers Fight?\n\nLocal AFL-CIO president ' + 'Kooper Caraway has publicly said management delayed safety ' + 'action as long as possible for profit. But while some ' + 'workers were demanding a two-week shutdown, Caraway told ' + 'the Argus Leader that was unrealistic because the ' + 'government considers the plant essential. He suggested the ' + 'union would be happy with minimal safety measures: “Even if ' + '10 people get exposed in a day rather than 11. If you can ' + 'implement a program where even one or two less people get ' + 'exposed during a shift, that’s one or two less people.” Of ' + 'course reducing infections is good, but suggesting workers ' + 'would be satisfied if the company allowed 90% of the ' + 'contagion to continue is horrifying.\n\nThe response of ' + 'UFCW leadership was worse. As the disease was exploding, ' + 'they told the Argus Leader, “We applaud [Smithfield’s] ' + 'decision to temporarily close the plant [over Easter ' + 'weekend] to push for an even safer work environment.” What ' + 'does “even safer” mean in this context?\n\nThe union ' + 'bureaucracy has taken weak action elsewhere. In ' + 'Pennsylvania, the UFCW negotiated $2 hazard pay for two ' + 'months with Cargill Meat — the same pandemic premium Amazon ' + 'gave workers without a union. In Nebraska, the UFCW ' + 'negotiated $4 hazard pay for one month with meat giant ' + 'JBS.\n\nThe union has said nothing about forcing companies ' + 'to send older workers home with pay, even though a ' + '70-year-old shop steward and a 78-year-old grandfather ' + 'working at JBS plants were killed by Covid-19. Smithfield ' + 'workers were promised only two weeks of shutdown pay. For ' + 'many, this compensation is half their normal paycheck ' + 'because they routinely put in 66 hour weeks — overtime that ' + 'costs exhaustion and chronic pain.\n\nUnion officials ' + 'endeavor to cooperate with the meat companies. An Iowa UFCW ' + 'president actually suggested it might be impossible for ' + 'plants to move workers a full six feet apart and told the ' + 'Des Moines Register, “We can’t stop the plants. If we stop ' + 'the plants from running, we stop feeding the country. We ' + 'want to do everything we can to make sure the employees are ' + 'safe to keep the plant running.”\n\nEvery part of this ' + 'explanation directly overlaps with what the Smithfield CEO ' + 'said. Unfortunately, it amounts to accepting the company’s ' + 'excuses.\n\nThey claim that workers who do hard physical ' + 'labor, waking up at 4 a.m. and often working six days a ' + 'week for years, would be guilty of taking food away from ' + 'the people and hurting America if they dared to fight for ' + 'their human needs. But nothing is said about the company ' + 'raking in profits and even murdering workers to increase ' + 'them.\n\nSmithfield’s parent company W.H. Group, ' + 'which slaughters around 30 million pigs per year in plants ' + 'in both the United States and China, saw its profits ' + 'skyrocket by about one third in 2019 to $1.38 billion. It ' + 'is disturbing that UFCW officials do not bring up these ' + 'soaring profits in their response to the outbreaks. Reuters ' + 'published a report on the corporation’s financial success ' + 'in late March. The head of W.H. Group had touted to the ' + 'media that it got through the pandemic in China with very ' + 'limited impact on production.\n\nIt is true that many ' + 'Smithfield workers are reasonably afraid for their jobs and ' + 'want to keep working. A 25-year-old employee explained, ' + '“I have a lot of bills. My baby’s coming soon — I have to ' + 'work.” At the same time, he was afraid of infecting his ' + 'pregnant wife. His spouse, a former employee, ' + 'said bitterly, “Smithfield— they don’t care about ' + 'employees. They only care about their money.”\n\nWorkers ' + 'are pressured in these two painful directions. Nonetheless, ' + 'work can mean solidarity. Before Smithfield even checked ' + 'temperatures, there was a “sick-out” strike without union ' + 'support by 800 to 1,000 workers at a JBS meat factory in ' + 'Colorado. Hundreds of workers also called in sick days at a ' + 'Nebraska JBS plant.\n\nTrade union leaders won’t even ' + 'whisper the word “strike” when thousands of workers are ' + 'thinking about it. They are limiting themselves to polite ' + 'requests. We need a workers’ movement that asks who ' + 'controls the factory, that threatens to disrupt the bosses’ ' + 'profits, and that allows workers to use their immense power ' + '— this could change the meat industry and the world. ' + }, + { + 'text': + 'Smithfield employs 3,700 people at its plant in Sioux ' + 'Falls, South Dakota. The plant slaughters 19,500 pigs a day ' + '— 5 percent of U.S. pork. Most of the workers are ' + 'immigrants from Ethiopia, Mexico, South Sudan, Honduras, ' + 'Myanmar, Somalia, Guatemala, and other poor ' + 'countries.\n\nInevitably workers must pass within one foot ' + 'of hundreds of colleagues in the hallways, locker rooms, ' + 'cafeterias, and cutting lines. The same conditions have ' + 'spurred Covid-19 outbreaks at meat plants from Minnesota ' + 'and Wisconsin to Colorado, Nebraska, Missouri, Iowa, ' + 'Pennsylvania, North Carolina, and Georgia.\n\n801 workers ' + 'at the Sioux Falls plant have tested positive, together ' + 'with 206 people close to them. The outbreak has killed ' + 'Agustín Rodríguez Martínez, aged 64, an employee with two ' + 'decades of experience originally from El Salvador, ' + 'and Craig Allen Franken, 61, who worked for Smithfield his ' + 'entire adult life.\n\nThe company knew of its first ' + 'infection on March 24 or earlier. The virus spread ' + 'exponentially for several weeks. Ahead of Easter Sunday and ' + 'Monday (April 12-13), Smithfield promised to “completely ' + 'shutter” to sanitize and put up cardboard and plastic sheet ' + 'dividers. This would not end transmission, as potentially ' + 'hundreds of staff were already carrying the virus. But even ' + 'during this “shutdown,” many cars were seen in the parking ' + 'lot. The mayor admits that the company lied, and the local ' + 'AFL-CIO alleges the plant ran 60 percent production. On ' + 'Easter, with 238 known infections, Smithfield finally ' + 'agreed to shut down indefinitely after a request from the ' + 'mayor and the governor. Yet the company insisted on waiting ' + 'three more days to actually halt production.\n\nSmithfield ' + 'denied contributing to the outbreak, saying it took a “very ' + 'proactive approach.” Relying on racism, the company blamed ' + 'workers for getting themselves sick. A spokesperson said ' + 'the outbreak was so severe because of the plant’s “large ' + 'immigrant population,” claming “Living circumstances in ' + 'certain cultures are different than they are with your ' + 'traditional American family.” They slandered the workers as ' + 'dirty, ignorant, and untrustworthy with help from governor ' + 'Kristi Noem, who claimed, “99 percent of what’s going on ' + 'today wasn’t happening inside the facility. It was more at ' + 'home, where these employees were going home and spreading ' + 'some of the virus” by living too close together.\n\nOne ' + 'sick worker, Michael Bul Gayo Gatluak, 22 and originally ' + 'from South Sudan, says, “With how we work on the line, ' + 'I would say I got sick because of them not taking safety ' + 'measures.” His job is “really, really close” to other ' + 'workers chopping fresh-killed pigs. “The job is so heavy. ' + 'You have to breathe so hard.”\n\nIn early March, ' + 'union officials requested masks, overcoats, entrance ' + 'checking for fevers, and less crowding in 500-capacity ' + 'cafeterias. But Smithfield waited on most safety measures ' + 'until early April. Only April 6 did they start checking for ' + 'fevers. Instead of protective masks, they gave out beard ' + 'nets.\n\nSmithfield concealed infections with a policy of ' + 'informing only employees whose work stations were in the ' + 'same area as a person who tested positive. The fact that ' + 'workers are required to move around was willfully ignored. ' + 'One worker who tested positive said, “I clearly would have ' + 'gotten it at the factory. This week I have worked on three ' + 'different floors. I’ve eaten in two different cafeterias … ' + 'I’ve been walking through the whole place.” Employees from ' + 'the eighth floor of the plant were quarantined, ' + 'but everyone else was told to keep working.\n\nWhat Is ' + 'Really Going On?\n\nAverage plant wages are around $16 an ' + 'hour. Smithfield never raised them. Instead, they offered ' + '$500 to employees who could go all of April without an ' + 'unapproved day off. The company says their “Responsibility ' + 'Bonuses” show their “immense gratefulness” to employees ' + '“for their selfless sacrifices.”\n\nMeanwhile, the local ' + 'Argus Leader wrote union members wanted essential-worker ' + 'hazard pay, which “would be considered hourly compensation ' + 'about 1.5 or two times their normal pay.” One worker said, ' + '“I feel like they’re bribing us with [the bonus] to come to ' + 'work sick. That’s how you know they don’t care.”\n\nBoth ' + 'Sioux Falls workers killed by Covid-19 were in their ' + 'sixties. It is unconscionable that they were still working. ' + 'All meatpackers over 50 should be on paid leave. Agustín ' + 'Rodríguez, 64, had a rough job sawing the legs off dead ' + 'pigs. He mopped floors with a fever shortly before he was ' + 'hospitalized.\n\nWhen CEO Kenneth Sullivan closed the ' + 'plant, he claimed, “We have continued to run our facilities ' + 'for one reason: to sustain our nation’s food supply.” This ' + 'is an effort to sweep Smithfield’s abuses under the rug, ' + 'as if the company were operating for public benefit. This ' + 'patriotic propaganda that all Americans are in it together ' + 'is like a drug to keep workers from getting organized. ' + }, + { + 'text': + 'Smithfield employs 3,700 people at its plants in Sioux ' + 'Falls, South Dakota. The plant slaughters 19,500 pig a day ' + '— 5 percent of U.S. pork. Most of the workers are ' + 'immigrants from Ethiopia, Mexico, South Sudan, Honduras, ' + 'Myanmar, Somalia, Guatemala, and other poor ' + 'countries.\n\nInevitably workers must pass within one foot ' + 'of hundreds of colleagues in the hallways, locker rooms, ' + 'cafeterias, and cutting lines. The same conditions have ' + 'spurred Covid-19 outbreaks at meat plants from Minnesota ' + 'and Wisconsin to Colorado, Nebraska, Missouri, Iowa, ' + 'Pennsylvania, North Carolina, and Georgia.\n\n801 workers ' + 'at the Sioux Falls plant have tested positive, together ' + 'with 206 people close to them. The outbreak has killed ' + 'Agustín Rodríguez Martínez, aged 64, an employee with two ' + 'decades of experience originally from El Salvador, ' + 'and Craig Allen Franken, 61, who worked for Smithfield his ' + 'entire adult life.\n\nThe company knew of its first ' + 'infection on March 24 or earlier. The virus spread ' + 'exponentially for several weeks. Ahead of Easter Sunday and ' + 'Monday (April 12-13), Smithfield promised to “completely ' + 'shutter” to sanitize and put up cardboard and plastic sheet ' + 'dividers. This would not end transmission, as potentially ' + 'hundreds of staff were already carrying the virus. But even ' + 'during this “shutdown,” many cars were seen in the parking ' + 'lot. The mayor admits that the company lied, and the local ' + 'AFL-CIO alleges the plant ran 60 percent production. On ' + 'Easter, with 238 known infections, Smithfield finally ' + 'agreed to shut down indefinitely after a request from the ' + 'mayor and the governor. Yet the company insisted on waiting ' + 'three more days to actually halt production.\n\nSmithfield ' + 'denied contributing to the outbreak, saying it took a “very ' + 'proactive approach.” Relying on racism, the company blamed ' + 'workers for getting themselves sick. A spokesperson said ' + 'the outbreak was so severe because of the plant’s “large ' + 'immigrant population,” claming “Living circumstances in ' + 'certain cultures are different than they are with your ' + 'traditional American family.” They slandered the workers as ' + 'dirty, ignorant, and untrustworthy with help from governor ' + 'Kristi Noem, who claimed, “99 percent of what’s going on ' + 'today wasn’t happening inside the facility. It was more at ' + 'home, where these employees were going home and spreading ' + 'some of the virus” by living too close together.\n\nOne ' + 'sick worker, Michael Bul Gayo Gatluak, 22 and originally ' + 'from South Sudan, says, “With how we work on the line, ' + 'I would say I got sick because of them not taking safety ' + 'measures.” His job is “really, really close” to other ' + 'workers chopping fresh-killed pigs. “The job is so heavy. ' + 'You have to breathe so hard.”\n\nIn early March, ' + 'union officials requested masks, overcoats, entrance ' + 'checking for fevers, and less crowding in 500-capacity ' + 'cafeterias. But Smithfield waited on most safety measures ' + 'until early April. Only April 6 did they start checking for ' + 'fevers. Instead of protective masks, they gave out beard ' + 'nets.\n\nSmithfield concealed infections with a policy of ' + 'informing only employees whose work stations were in the ' + 'same area as a person who tested positive. The fact that ' + 'workers are required to move around was willfully ignored. ' + 'One worker who tested positive said, “I clearly would have ' + 'gotten it at the factory. This week I have worked on three ' + 'different floors. I’ve eaten in two different cafeterias … ' + 'I’ve been walking through the whole place.” Employees from ' + 'the eighth floor of the plant were quarantined, ' + 'but everyone else was told to keep working.\n\nWhat Is ' + 'Really Going On?\n\nAverage plant wages are around $16 an ' + 'hour. Smithfield never raised them. Instead, they offered ' + '$500 to employees who could go all of April without an ' + 'unapproved day off. The company says their “Responsibility ' + 'Bonuses” show their “immense gratefulness” to employees ' + '“for their selfless sacrifices.”\n\nMeanwhile, the local ' + 'Argus Leader wrote union members wanted essential-worker ' + 'hazard pay, which “would be considered hourly compensation ' + 'about 1.5 or two times their normal pay.” One worker said, ' + '“I feel like they’re bribing us with [the bonus] to come to ' + 'work sick. That’s how you know they don’t care.”\n\nBoth ' + 'Sioux Falls workers killed by Covid-19 were in their ' + 'sixties. It is unconscionable that they were still working. ' + 'All meatpackers over 50 should be on paid leave. Agustín ' + 'Rodríguez, 64, had a rough job sawing the legs off dead ' + 'pigs. He mopped floors with a fever shortly before he was ' + 'hospitalized.\n\nWhen CEO Kenneth Sullivan closed the ' + 'plant, he claimed, “We have continued to run our facilities ' + 'for one reason: to sustain our nation’s food supply.” This ' + 'is an effort to sweep Smithfield’s abuses under the rug, ' + 'as if the company were operating for public benefit. This ' + 'patriotic propaganda that all Americans are in it together ' + 'is like a drug to keep workers from getting organized. ' + }, + { + 'text': + 'This paper proposed a novel method on LLM pretraining.' + }, + ] + tgt_list = [ + { + 'text': 'Today is Sunday and it\'s a happy day!' + }, + { + 'text': 'Do you need a cup of coffee?' + }, + { + 'text': 'Today is sunday and it\'s really a happy day!' + }, + { + 'text': + 'This paper proposed a novel method on LLM pretraining.' + }, + { + 'text': + 'Smithfield employs 3,700 people at its plant in Sioux ' + 'Falls, South Dakota. The plant slaughters 19,500 pigs a day ' + '— 5 percent of U.S. pork. Most of the workers are ' + 'immigrants from Ethiopia, Mexico, South Sudan, Honduras, ' + 'Myanmar, Somalia, Guatemala, and other poor ' + 'countries.\n\nInevitably workers must pass within one foot ' + 'of hundreds of colleagues in the hallways, locker rooms, ' + 'cafeterias, and cutting lines. The same conditions have ' + 'spurred Covid-19 outbreaks at meat plants from Minnesota ' + 'and Wisconsin to Colorado, Nebraska, Missouri, Iowa, ' + 'Pennsylvania, North Carolina, and Georgia.\n\n801 workers ' + 'at the Sioux Falls plant have tested positive, together ' + 'with 206 people close to them. The outbreak has killed ' + 'Agustín Rodríguez Martínez, aged 64, an employee with two ' + 'decades of experience originally from El Salvador, ' + 'and Craig Allen Franken, 61, who worked for Smithfield his ' + 'entire adult life.\n\nThe company knew of its first ' + 'infection on March 24 or earlier. The virus spread ' + 'exponentially for several weeks. Ahead of Easter Sunday and ' + 'Monday (April 12-13), Smithfield promised to “completely ' + 'shutter” to sanitize and put up cardboard and plastic sheet ' + 'dividers. This would not end transmission, as potentially ' + 'hundreds of staff were already carrying the virus. But even ' + 'during this “shutdown,” many cars were seen in the parking ' + 'lot. The mayor admits that the company lied, and the local ' + 'AFL-CIO alleges the plant ran 60 percent production. On ' + 'Easter, with 238 known infections, Smithfield finally ' + 'agreed to shut down indefinitely after a request from the ' + 'mayor and the governor. Yet the company insisted on waiting ' + 'three more days to actually halt production.\n\nSmithfield ' + 'denied contributing to the outbreak, saying it took a “very ' + 'proactive approach.” Relying on racism, the company blamed ' + 'workers for getting themselves sick. A spokesperson said ' + 'the outbreak was so severe because of the plant’s “large ' + 'immigrant population,” claming “Living circumstances in ' + 'certain cultures are different than they are with your ' + 'traditional American family.” They slandered the workers as ' + 'dirty, ignorant, and untrustworthy with help from governor ' + 'Kristi Noem, who claimed, “99 percent of what’s going on ' + 'today wasn’t happening inside the facility. It was more at ' + 'home, where these employees were going home and spreading ' + 'some of the virus” by living too close together.\n\nOne ' + 'sick worker, Michael Bul Gayo Gatluak, 22 and originally ' + 'from South Sudan, says, “With how we work on the line, ' + 'I would say I got sick because of them not taking safety ' + 'measures.” His job is “really, really close” to other ' + 'workers chopping fresh-killed pigs. “The job is so heavy. ' + 'You have to breathe so hard.”\n\nIn early March, ' + 'union officials requested masks, overcoats, entrance ' + 'checking for fevers, and less crowding in 500-capacity ' + 'cafeterias. But Smithfield waited on most safety measures ' + 'until early April. Only April 6 did they start checking for ' + 'fevers. Instead of protective masks, they gave out beard ' + 'nets.\n\nSmithfield concealed infections with a policy of ' + 'informing only employees whose work stations were in the ' + 'same area as a person who tested positive. The fact that ' + 'workers are required to move around was willfully ignored. ' + 'One worker who tested positive said, “I clearly would have ' + 'gotten it at the factory. This week I have worked on three ' + 'different floors. I’ve eaten in two different cafeterias … ' + 'I’ve been walking through the whole place.” Employees from ' + 'the eighth floor of the plant were quarantined, ' + 'but everyone else was told to keep working.\n\nWhat Is ' + 'Really Going On?\n\nAverage plant wages are around $16 an ' + 'hour. Smithfield never raised them. Instead, they offered ' + '$500 to employees who could go all of April without an ' + 'unapproved day off. The company says their “Responsibility ' + 'Bonuses” show their “immense gratefulness” to employees ' + '“for their selfless sacrifices.”\n\nMeanwhile, the local ' + 'Argus Leader wrote union members wanted essential-worker ' + 'hazard pay, which “would be considered hourly compensation ' + 'about 1.5 or two times their normal pay.” One worker said, ' + '“I feel like they’re bribing us with [the bonus] to come to ' + 'work sick. That’s how you know they don’t care.”\n\nBoth ' + 'Sioux Falls workers killed by Covid-19 were in their ' + 'sixties. It is unconscionable that they were still working. ' + 'All meatpackers over 50 should be on paid leave. Agustín ' + 'Rodríguez, 64, had a rough job sawing the legs off dead ' + 'pigs. He mopped floors with a fever shortly before he was ' + 'hospitalized.\n\nWhen CEO Kenneth Sullivan closed the ' + 'plant, he claimed, “We have continued to run our facilities ' + 'for one reason: to sustain our nation’s food supply.” This ' + 'is an effort to sweep Smithfield’s abuses under the rug, ' + 'as if the company were operating for public benefit. This ' + 'patriotic propaganda that all Americans are in it together ' + 'is like a drug to keep workers from getting ' + 'organized.\n\nThe major union in the industry, including at ' + 'Smithfield, is the United Food and Commercial Workers union ' + '(UFCW). What union leaders have done is ultimately ' + 'troubling.\n\nCan Workers Fight?\n\nLocal AFL-CIO president ' + 'Kooper Caraway has publicly said management delayed safety ' + 'action as long as possible for profit. But while some ' + 'workers were demanding a two-week shutdown, Caraway told ' + 'the Argus Leader that was unrealistic because the ' + 'government considers the plant essential. He suggested the ' + 'union would be happy with minimal safety measures: “Even if ' + '10 people get exposed in a day rather than 11. If you can ' + 'implement a program where even one or two less people get ' + 'exposed during a shift, that’s one or two less people.” Of ' + 'course reducing infections is good, but suggesting workers ' + 'would be satisfied if the company allowed 90% of the ' + 'contagion to continue is horrifying.\n\nThe response of ' + 'UFCW leadership was worse. As the disease was exploding, ' + 'they told the Argus Leader, “We applaud [Smithfield’s] ' + 'decision to temporarily close the plant [over Easter ' + 'weekend] to push for an even safer work environment.” What ' + 'does “even safer” mean in this context?\n\nThe union ' + 'bureaucracy has taken weak action elsewhere. In ' + 'Pennsylvania, the UFCW negotiated $2 hazard pay for two ' + 'months with Cargill Meat — the same pandemic premium Amazon ' + 'gave workers without a union. In Nebraska, the UFCW ' + 'negotiated $4 hazard pay for one month with meat giant ' + 'JBS.\n\nThe union has said nothing about forcing companies ' + 'to send older workers home with pay, even though a ' + '70-year-old shop steward and a 78-year-old grandfather ' + 'working at JBS plants were killed by Covid-19. Smithfield ' + 'workers were promised only two weeks of shutdown pay. For ' + 'many, this compensation is half their normal paycheck ' + 'because they routinely put in 66 hour weeks — overtime that ' + 'costs exhaustion and chronic pain.\n\nUnion officials ' + 'endeavor to cooperate with the meat companies. An Iowa UFCW ' + 'president actually suggested it might be impossible for ' + 'plants to move workers a full six feet apart and told the ' + 'Des Moines Register, “We can’t stop the plants. If we stop ' + 'the plants from running, we stop feeding the country. We ' + 'want to do everything we can to make sure the employees are ' + 'safe to keep the plant running.”\n\nEvery part of this ' + 'explanation directly overlaps with what the Smithfield CEO ' + 'said. Unfortunately, it amounts to accepting the company’s ' + 'excuses.\n\nThey claim that workers who do hard physical ' + 'labor, waking up at 4 a.m. and often working six days a ' + 'week for years, would be guilty of taking food away from ' + 'the people and hurting America if they dared to fight for ' + 'their human needs. But nothing is said about the company ' + 'raking in profits and even murdering workers to increase ' + 'them.\n\nSmithfield’s parent company W.H. Group, ' + 'which slaughters around 30 million pigs per year in plants ' + 'in both the United States and China, saw its profits ' + 'skyrocket by about one third in 2019 to $1.38 billion. It ' + 'is disturbing that UFCW officials do not bring up these ' + 'soaring profits in their response to the outbreaks. Reuters ' + 'published a report on the corporation’s financial success ' + 'in late March. The head of W.H. Group had touted to the ' + 'media that it got through the pandemic in China with very ' + 'limited impact on production.\n\nIt is true that many ' + 'Smithfield workers are reasonably afraid for their jobs and ' + 'want to keep working. A 25-year-old employee explained, ' + '“I have a lot of bills. My baby’s coming soon — I have to ' + 'work.” At the same time, he was afraid of infecting his ' + 'pregnant wife. His spouse, a former employee, ' + 'said bitterly, “Smithfield— they don’t care about ' + 'employees. They only care about their money.”\n\nWorkers ' + 'are pressured in these two painful directions. Nonetheless, ' + 'work can mean solidarity. Before Smithfield even checked ' + 'temperatures, there was a “sick-out” strike without union ' + 'support by 800 to 1,000 workers at a JBS meat factory in ' + 'Colorado. Hundreds of workers also called in sick days at a ' + 'Nebraska JBS plant.\n\nTrade union leaders won’t even ' + 'whisper the word “strike” when thousands of workers are ' + 'thinking about it. They are limiting themselves to polite ' + 'requests. We need a workers’ movement that asks who ' + 'controls the factory, that threatens to disrupt the bosses’ ' + 'profits, and that allows workers to use their immense power ' + '— this could change the meat industry and the world. ' + }, + { + 'text': + 'Smithfield employs 3,700 people at its plant in Sioux ' + 'Falls, South Dakota. The plant slaughters 19,500 pigs a day ' + '— 5 percent of U.S. pork. Most of the workers are ' + 'immigrants from Ethiopia, Mexico, South Sudan, Honduras, ' + 'Myanmar, Somalia, Guatemala, and other poor ' + 'countries.\n\nInevitably workers must pass within one foot ' + 'of hundreds of colleagues in the hallways, locker rooms, ' + 'cafeterias, and cutting lines. The same conditions have ' + 'spurred Covid-19 outbreaks at meat plants from Minnesota ' + 'and Wisconsin to Colorado, Nebraska, Missouri, Iowa, ' + 'Pennsylvania, North Carolina, and Georgia.\n\n801 workers ' + 'at the Sioux Falls plant have tested positive, together ' + 'with 206 people close to them. The outbreak has killed ' + 'Agustín Rodríguez Martínez, aged 64, an employee with two ' + 'decades of experience originally from El Salvador, ' + 'and Craig Allen Franken, 61, who worked for Smithfield his ' + 'entire adult life.\n\nThe company knew of its first ' + 'infection on March 24 or earlier. The virus spread ' + 'exponentially for several weeks. Ahead of Easter Sunday and ' + 'Monday (April 12-13), Smithfield promised to “completely ' + 'shutter” to sanitize and put up cardboard and plastic sheet ' + 'dividers. This would not end transmission, as potentially ' + 'hundreds of staff were already carrying the virus. But even ' + 'during this “shutdown,” many cars were seen in the parking ' + 'lot. The mayor admits that the company lied, and the local ' + 'AFL-CIO alleges the plant ran 60 percent production. On ' + 'Easter, with 238 known infections, Smithfield finally ' + 'agreed to shut down indefinitely after a request from the ' + 'mayor and the governor. Yet the company insisted on waiting ' + 'three more days to actually halt production.\n\nSmithfield ' + 'denied contributing to the outbreak, saying it took a “very ' + 'proactive approach.” Relying on racism, the company blamed ' + 'workers for getting themselves sick. A spokesperson said ' + 'the outbreak was so severe because of the plant’s “large ' + 'immigrant population,” claming “Living circumstances in ' + 'certain cultures are different than they are with your ' + 'traditional American family.” They slandered the workers as ' + 'dirty, ignorant, and untrustworthy with help from governor ' + 'Kristi Noem, who claimed, “99 percent of what’s going on ' + 'today wasn’t happening inside the facility. It was more at ' + 'home, where these employees were going home and spreading ' + 'some of the virus” by living too close together.\n\nOne ' + 'sick worker, Michael Bul Gayo Gatluak, 22 and originally ' + 'from South Sudan, says, “With how we work on the line, ' + 'I would say I got sick because of them not taking safety ' + 'measures.” His job is “really, really close” to other ' + 'workers chopping fresh-killed pigs. “The job is so heavy. ' + 'You have to breathe so hard.”\n\nIn early March, ' + 'union officials requested masks, overcoats, entrance ' + 'checking for fevers, and less crowding in 500-capacity ' + 'cafeterias. But Smithfield waited on most safety measures ' + 'until early April. Only April 6 did they start checking for ' + 'fevers. Instead of protective masks, they gave out beard ' + 'nets.\n\nSmithfield concealed infections with a policy of ' + 'informing only employees whose work stations were in the ' + 'same area as a person who tested positive. The fact that ' + 'workers are required to move around was willfully ignored. ' + 'One worker who tested positive said, “I clearly would have ' + 'gotten it at the factory. This week I have worked on three ' + 'different floors. I’ve eaten in two different cafeterias … ' + 'I’ve been walking through the whole place.” Employees from ' + 'the eighth floor of the plant were quarantined, ' + 'but everyone else was told to keep working.\n\nWhat Is ' + 'Really Going On?\n\nAverage plant wages are around $16 an ' + 'hour. Smithfield never raised them. Instead, they offered ' + '$500 to employees who could go all of April without an ' + 'unapproved day off. The company says their “Responsibility ' + 'Bonuses” show their “immense gratefulness” to employees ' + '“for their selfless sacrifices.”\n\nMeanwhile, the local ' + 'Argus Leader wrote union members wanted essential-worker ' + 'hazard pay, which “would be considered hourly compensation ' + 'about 1.5 or two times their normal pay.” One worker said, ' + '“I feel like they’re bribing us with [the bonus] to come to ' + 'work sick. That’s how you know they don’t care.”\n\nBoth ' + 'Sioux Falls workers killed by Covid-19 were in their ' + 'sixties. It is unconscionable that they were still working. ' + 'All meatpackers over 50 should be on paid leave. Agustín ' + 'Rodríguez, 64, had a rough job sawing the legs off dead ' + 'pigs. He mopped floors with a fever shortly before he was ' + 'hospitalized.\n\nWhen CEO Kenneth Sullivan closed the ' + 'plant, he claimed, “We have continued to run our facilities ' + 'for one reason: to sustain our nation’s food supply.” This ' + 'is an effort to sweep Smithfield’s abuses under the rug, ' + 'as if the company were operating for public benefit. This ' + 'patriotic propaganda that all Americans are in it together ' + 'is like a drug to keep workers from getting organized. ' + }, + ] + dataset = Dataset.from_list(ds_list) + op = DocumentSimhashDeduplicator(ignore_pattern=r'\p{P}') + self._run_simhash_dedup(dataset, tgt_list, op) + + def test_chinese_deduplication(self): + ds_list = [ + { + 'text': '你好,请问你是谁' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + { + 'text': + '第九届会议\n2003年7月28日至8月8日\n牙买加金斯敦\n为来自发展中国家的法' + '律和技术委员会以及财务委员会成员\n参加委员会会议支付费用的方式\n1. 国际' + '海底管理局大会第八届会议请秘书长采取一项临时措施,设立一个自愿信托基金,' + '以便支付来自发展中国家的法律和技术委员会成员以及来自发展中国家的财务委员' + '会成员参加委员会会议的费用。\n2. 由于秘书长向会员国发出为该信托基金捐款' + '的请求,已收到三笔捐款,共计10 500美元。 管理局已为基金设立一个单独的账' + '户。\n3. 管理局第八届会议还决定,由财务委员会审查资助参加这两个委员会会' + '议的方式,包括审查是否可能从管理局行政预算中提供经费。\n4. 自愿信托基金' + '迄今收到的捐款数额很小。 这两个委员会成员虽然由缔约国提名,但他们以个人身' + '份当选。 因此,必须确保这些机构的成员在任期内能够参加会议并且持续不断地履' + '行职务。 现已注意到,这两个委员会若干成员因旅费和生活津贴费用方面有困难而' + '未能出席会议。 来自发展中国家成员参加会议的费用估计数见附件,其中比较了经' + '济舱和公务舱机票价格以及适用于金斯敦的每日生活津贴费用。 从表中可以看出,' + '根据不同的人数、机舱等级和会议持续时间,每年平均需要捐款120 000美元至' + '215 000美元。\n5. 为了指导委员会确定提供经费的方式,对某些国际组织的现' + '行办法作了一次简要调查。 为支付参加会议的旅费和生活费而设立信托基金最相关' + '的实例是2000年大会为来自发展中国家的大陆架界限委员会成员设立的自愿信托基' + '金。 目前这一基金正在运作,但现有资源有限。 联合国制定的程序表明,委员会' + '成员的政府应在规定时间内尽可能提前提出请求。 这种请求按照先到先核可的办法' + '处理。 提供的机票将是最直接路线的经济舱机票,每日生活津贴将按照联合国费率' + '提供。 购买机票的所有安排均由联合国秘书处执行。\n6. 虽然已经设立了临时性' + '的自愿信托基金,但是,对该基金的捐款数额很小,捐款速度很慢。 因此,除了对' + '信托基金提供自愿捐款的办法之外,建议委员会还可以考虑采用下列办法:\n(a) ' + '从管理局一般行政经费累计利息中拨出一定数额的经费;\n(b) 每年从上一年预算' + '未动用部分中拨出规定的数额;\n(c) 从先驱投资者基金利息中拨出规定的数额。' + '\n7. 委员会还不妨建议由管理局秘书处依照行政规则和程序管理该基金,并向财' + '务委员会提出一份报告。\n附件\n资助来自发展中国家的法律和技术委员会以及财' + '务\n委员会成员出席会议的指示性费用(美元)\n成员\n机票\n机场\n费用\n金' + '斯敦每日生活\n津贴\n转机途中每日生活\n7日\n共计\n14日\n经济舱\n公务舱' + '\n7天=(8天每日生活\n津贴)\n14天= (15天每日生活津贴)\n商务舱\n法律和技' + '术委员会\n印度尼西亚\n(纽约)\n黎巴嫩\n巴基斯坦\n阿根廷\n喀麦隆\n墨西哥' + '\n巴西\n塞内加尔\n莫桑比克\n埃及(纽约)\n大韩民国\n印度\n斐济\n智利\n' + '中国\n纳米比亚\n小计\n财务委员会\n缅甸\n乌干达\n牙买加\n印度(纽约)\n尼' + '日利亚\n总计\n注:估计费用表表明每年资助每个机构一次会议需要经费120 000' + '美元至215 000美元(四舍五入)。' + }, + { + 'text': + '第九届会议\n时间:2003年7月28日至8月8日\n牙买加金斯敦\n为来自发展中国家的法' + '律和技术委员会以及财务委员会成员\n参加委员会会议支付费用的方式\n1. 国际' + '海底管理局大会第八届会议请秘书长采取一项临时措施,设立一个自愿信托基金,' + '以便支付来自发展中国家的法律和技术委员会成员以及来自发展中国家的财务委员' + '会成员参加委员会会议的费用。\n2. 由于秘书长向会员国发出为该信托基金捐款' + '的请求,已收到三笔捐款,共计10 500美元。 管理局已为基金设立一个单独的账' + '户。\n3. 管理局第八届会议还决定,由财务委员会审查资助参加这两个委员会会' + '议的方式,包括审查是否可能从管理局行政预算中提供经费。\n4. 自愿信托基金' + '迄今收到的捐款数额很小。 这两个委员会成员虽然由缔约国提名,但他们以个人身' + '份当选。 因此,必须确保这些机构的成员在任期内能够参加会议并且持续不断地履' + '行职务。 现已注意到,这两个委员会若干成员因旅费和生活津贴费用方面有困难而' + '未能出席会议。 来自发展中国家成员参加会议的费用估计数见附件,其中比较了经' + '济舱和公务舱机票价格以及适用于金斯敦的每日生活津贴费用。 从表中可以看出,' + '根据不同的人数、机舱等级和会议持续时间,每年平均需要捐款120 000美元至' + '215 000美元。\n5. 为了指导委员会确定提供经费的方式,对某些国际组织的现' + '行办法作了一次简要调查。 为支付参加会议的旅费和生活费而设立信托基金最相关' + '的实例是2000年大会为来自发展中国家的大陆架界限委员会成员设立的自愿信托基' + '金。 目前这一基金正在运作,但现有资源有限。 联合国制定的程序表明,委员会' + '成员的政府应在规定时间内尽可能提前提出请求。 这种请求按照先到先核可的办法' + '处理。 提供的机票将是最直接路线的经济舱机票,每日生活津贴将按照联合国费率' + '提供。 购买机票的所有安排均由联合国秘书处执行。\n6. 虽然已经设立了临时性' + '的自愿信托基金,但是,对该基金的捐款数额很小,捐款速度很慢。 因此,除了对' + '信托基金提供自愿捐款的办法之外,建议委员会还可以考虑采用下列办法:\n(a) ' + '从管理局一般行政经费累计利息中拨出一定数额的经费;\n(b) 每年从上一年预算' + '未动用部分中拨出规定的数额;\n(c) 从先驱投资者基金利息中拨出规定的数额。' + '\n7. 委员会还不妨建议由管理局秘书处依照行政规则和程序管理该基金,并向财' + '务委员会提出一份报告。\n附件\n资助来自发展中国家的法律和技术委员会以及财' + '务\n委员会成员出席会议的指示性费用(美元)\n成员\n机票\n机场\n费用\n金' + '斯敦每日生活\n津贴\n转机途中每日生活\n7日\n共计\n14日\n经济舱\n公务舱' + '\n7天=(8天每日生活\n津贴)\n14天= (15天每日生活津贴)\n商务舱\n法律和技' + '术委员会\n印度尼西亚\n(纽约)\n黎巴嫩\n巴基斯坦\n阿根廷\n喀麦隆\n墨西哥' + '\n巴西\n塞内加尔\n莫桑比克\n埃及(纽约)\n大韩民国\n印度\n斐济\n智利\n' + '中国\n纳米比亚\n小计\n财务委员会\n缅甸\n乌干达\n牙买加\n印度(纽约)\n尼' + '日利亚\n总计\n注:估计费用表表明每年资助每个机构一次会议需要经费120 000' + '美元至215 000美元(四舍五入)。' + }, + ] + tgt_list = [ + { + 'text': '你好,请问你是谁' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + { + 'text': + '第九届会议\n2003年7月28日至8月8日\n牙买加金斯敦\n为来自发展中国家的法' + '律和技术委员会以及财务委员会成员\n参加委员会会议支付费用的方式\n1. 国际' + '海底管理局大会第八届会议请秘书长采取一项临时措施,设立一个自愿信托基金,' + '以便支付来自发展中国家的法律和技术委员会成员以及来自发展中国家的财务委员' + '会成员参加委员会会议的费用。\n2. 由于秘书长向会员国发出为该信托基金捐款' + '的请求,已收到三笔捐款,共计10 500美元。 管理局已为基金设立一个单独的账' + '户。\n3. 管理局第八届会议还决定,由财务委员会审查资助参加这两个委员会会' + '议的方式,包括审查是否可能从管理局行政预算中提供经费。\n4. 自愿信托基金' + '迄今收到的捐款数额很小。 这两个委员会成员虽然由缔约国提名,但他们以个人身' + '份当选。 因此,必须确保这些机构的成员在任期内能够参加会议并且持续不断地履' + '行职务。 现已注意到,这两个委员会若干成员因旅费和生活津贴费用方面有困难而' + '未能出席会议。 来自发展中国家成员参加会议的费用估计数见附件,其中比较了经' + '济舱和公务舱机票价格以及适用于金斯敦的每日生活津贴费用。 从表中可以看出,' + '根据不同的人数、机舱等级和会议持续时间,每年平均需要捐款120 000美元至' + '215 000美元。\n5. 为了指导委员会确定提供经费的方式,对某些国际组织的现' + '行办法作了一次简要调查。 为支付参加会议的旅费和生活费而设立信托基金最相关' + '的实例是2000年大会为来自发展中国家的大陆架界限委员会成员设立的自愿信托基' + '金。 目前这一基金正在运作,但现有资源有限。 联合国制定的程序表明,委员会' + '成员的政府应在规定时间内尽可能提前提出请求。 这种请求按照先到先核可的办法' + '处理。 提供的机票将是最直接路线的经济舱机票,每日生活津贴将按照联合国费率' + '提供。 购买机票的所有安排均由联合国秘书处执行。\n6. 虽然已经设立了临时性' + '的自愿信托基金,但是,对该基金的捐款数额很小,捐款速度很慢。 因此,除了对' + '信托基金提供自愿捐款的办法之外,建议委员会还可以考虑采用下列办法:\n(a) ' + '从管理局一般行政经费累计利息中拨出一定数额的经费;\n(b) 每年从上一年预算' + '未动用部分中拨出规定的数额;\n(c) 从先驱投资者基金利息中拨出规定的数额。' + '\n7. 委员会还不妨建议由管理局秘书处依照行政规则和程序管理该基金,并向财' + '务委员会提出一份报告。\n附件\n资助来自发展中国家的法律和技术委员会以及财' + '务\n委员会成员出席会议的指示性费用(美元)\n成员\n机票\n机场\n费用\n金' + '斯敦每日生活\n津贴\n转机途中每日生活\n7日\n共计\n14日\n经济舱\n公务舱' + '\n7天=(8天每日生活\n津贴)\n14天= (15天每日生活津贴)\n商务舱\n法律和技' + '术委员会\n印度尼西亚\n(纽约)\n黎巴嫩\n巴基斯坦\n阿根廷\n喀麦隆\n墨西哥' + '\n巴西\n塞内加尔\n莫桑比克\n埃及(纽约)\n大韩民国\n印度\n斐济\n智利\n' + '中国\n纳米比亚\n小计\n财务委员会\n缅甸\n乌干达\n牙买加\n印度(纽约)\n尼' + '日利亚\n总计\n注:估计费用表表明每年资助每个机构一次会议需要经费120 000' + '美元至215 000美元(四舍五入)。' + }, + ] + dataset = Dataset.from_list(ds_list) + op = DocumentSimhashDeduplicator(tokenization='character', + ignore_pattern=r'\p{P}') + self._run_simhash_dedup(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/__init__.py b/tests/ops/filter/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/ops/filter/test_alphanumeric_filter.py b/tests/ops/filter/test_alphanumeric_filter.py new file mode 100644 index 000000000..8dc1d6733 --- /dev/null +++ b/tests/ops/filter/test_alphanumeric_filter.py @@ -0,0 +1,83 @@ +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.alphanumeric_filter import AlphanumericFilter + + +class AlphanumericFilterTest(unittest.TestCase): + + def _run_alphanumeric_filter(self, dataset: Dataset, target_list, op): + if 'stats' not in dataset.features: + # TODO: + # this is a temp solution, + # only add stats when calling filter op + dataset = dataset.add_column(name='stats', + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats) + dataset = dataset.filter(op.process) + dataset = dataset.select_columns(column_names=['text']) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_case(self): + + ds_list = [{ + 'text': 'a=1\nb\nc=1+2+3+5\nd=6' + }, { + 'text': + "Today is Sund Sund Sund Sunda and it's a happy day!\nYou know" + }, { + 'text': 'a v s e e f g a qkc' + }, { + 'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►' + }, { + 'text': 'Do you need a cup of coffee?' + }, { + 'text': 'emoji表情测试下😊,😸31231\n' + }] + tgt_list = [{ + 'text': 'a=1\nb\nc=1+2+3+5\nd=6' + }, { + 'text': + "Today is Sund Sund Sund Sunda and it's a happy day!\nYou know" + }, { + 'text': 'a v s e e f g a qkc' + }, { + 'text': 'Do you need a cup of coffee?' + }, { + 'text': 'emoji表情测试下😊,😸31231\n' + }] + dataset = Dataset.from_list(ds_list) + op = AlphanumericFilter(min_ratio=0.2, max_ratio=0.9) + self._run_alphanumeric_filter(dataset, tgt_list, op) + + def test_token_case(self): + + ds_list = [{ + 'text': 'a=1\nb\nc=1+2+3+5\nd=6' + }, { + 'text': + "Today is Sund Sund Sund Sunda and it's a happy day!\nYou know" + }, { + 'text': 'a v s e e f g a qkc' + }, { + 'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►' + }, { + 'text': 'Do you need a cup of coffee?' + }, { + 'text': 'emoji表情测试下😊,😸31231\n' + }] + tgt_list = [{ + 'text': + "Today is Sund Sund Sund Sunda and it's a happy day!\nYou know" + }, { + 'text': 'Do you need a cup of coffee?' + }] + dataset = Dataset.from_list(ds_list) + op = AlphanumericFilter(tokenization=True, min_ratio=1.5) + self._run_alphanumeric_filter(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_average_line_length_filter.py b/tests/ops/filter/test_average_line_length_filter.py new file mode 100644 index 000000000..fa1090f33 --- /dev/null +++ b/tests/ops/filter/test_average_line_length_filter.py @@ -0,0 +1,52 @@ +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.average_line_length_filter import \ + AverageLineLengthFilter + + +class AverageLineLengthFilterTest(unittest.TestCase): + + def _run_average_line_length_filter(self, dataset: Dataset, target_list, + op): + if 'stats' not in dataset.features: + # TODO: + # this is a temp solution, + # only add stats when calling filter op + dataset = dataset.add_column(name='stats', + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats) + dataset = dataset.filter(op.process) + dataset = dataset.select_columns(column_names=['text']) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_case(self): + + ds_list = [{ + 'text': 'a=1\nb\nc=1+2+3+5\nd=6' + }, { + 'text': + "Today is Sund Sund Sunda and it's a happy day!\nYou know" + }, { + 'text': 'a v s e e f g a qkc' + }, { + 'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►' + }, { + 'text': 'Do you need a cup of coffee?' + }, { + 'text': 'emoji表情测试下😊,😸31231\n' + }] + tgt_list = [{ + 'text': 'a v s e e f g a qkc' + }, { + 'text': 'emoji表情测试下😊,😸31231\n' + }] + dataset = Dataset.from_list(ds_list) + op = AverageLineLengthFilter(min_len=10, max_len=20) + self._run_average_line_length_filter(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_character_repetition_filter.py b/tests/ops/filter/test_character_repetition_filter.py new file mode 100644 index 000000000..f1cf7a8db --- /dev/null +++ b/tests/ops/filter/test_character_repetition_filter.py @@ -0,0 +1,50 @@ +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.character_repetition_filter import \ + CharacterRepetitionFilter + + +class CharacterRepetitionFilterTest(unittest.TestCase): + + def _run_character_repetition_filter(self, dataset: Dataset, target_list, + op): + if 'stats' not in dataset.features: + # TODO: + # this is a temp solution, + # only add stats when calling filter op + dataset = dataset.add_column(name='stats', + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats) + dataset = dataset.filter(op.process) + dataset = dataset.select_columns(column_names=['text']) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_case(self): + + ds_list = [{ + 'text': + "Today is Sund Sund Sund Sund Sund Sunda and it's a happy day!" + }, { + 'text': 'a v s e c s f e f g a a a a a a a a a a' + }, { + 'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►' + }, { + 'text': '中文也是一个字算一个长度' + }] + tgt_list = [{ + 'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►' + }, { + 'text': '中文也是一个字算一个长度' + }] + dataset = Dataset.from_list(ds_list) + op = CharacterRepetitionFilter(rep_len=10, + min_ratio=0.0, + max_ratio=0.4) + self._run_character_repetition_filter(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_flagged_words_filter.py b/tests/ops/filter/test_flagged_words_filter.py new file mode 100644 index 000000000..ce10c869e --- /dev/null +++ b/tests/ops/filter/test_flagged_words_filter.py @@ -0,0 +1,76 @@ +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.flagged_words_filter import FlaggedWordFilter + + +class FlaggedWordFilterTest(unittest.TestCase): + + def _run_flagged_words_filter(self, dataset: Dataset, target_list, op): + if 'stats' not in dataset.features: + # TODO: + # this is a temp solution, + # only add stats when calling filter op + dataset = dataset.add_column(name='stats', + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats) + dataset = dataset.filter(op.process) + dataset = dataset.select_columns(column_names=['text']) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_en_case(self): + + ds_list = [{ + 'text': 'Today is anal cumshot day' + }, { + 'text': 'Fuck you doggystyle!' + }, { + 'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►' + }, { + 'text': 'Do you need a cup of coffee?' + }, { + 'text': 'emoji表情测试下😊,😸31231\n' + }] + tgt_list = [{ + 'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►' + }, { + 'text': 'Do you need a cup of coffee?' + }, { + 'text': 'emoji表情测试下😊,😸31231\n' + }] + dataset = Dataset.from_list(ds_list) + op = FlaggedWordFilter(lang='en', max_ratio=0.045) + self._run_flagged_words_filter(dataset, tgt_list, op) + + def test_zh_case(self): + + ds_list = [{ + 'text': '你是个卖淫女' + }, { + 'text': '根据算子使用情况增量安装方案确定' + }, { + 'text': '去除字母、数字、下划线占比过低或过高的代码' + }, { + 'text': '基于前一步结果,除掉打飞机、三级片等敏感词' + }, { + 'text': '使用片段分词器对每个页面进行分词,使用语言模型计算每个段落的困惑度得分,由此过滤低质量文本' + }] + tgt_list = [{ + 'text': '根据算子使用情况增量安装方案确定' + }, { + 'text': '去除字母、数字、下划线占比过低或过高的代码' + }, { + 'text': '使用片段分词器对每个页面进行分词,使用语言模型计算每个段落的困惑度得分,由此过滤低质量文本' + }] + dataset = Dataset.from_list(ds_list) + op = FlaggedWordFilter(lang='zh', + tokenization=True, + max_ratio=0.045, + use_words_aug=True) + self._run_flagged_words_filter(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_language_id_score_filter.py b/tests/ops/filter/test_language_id_score_filter.py new file mode 100644 index 000000000..78edef8d6 --- /dev/null +++ b/tests/ops/filter/test_language_id_score_filter.py @@ -0,0 +1,113 @@ +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.language_id_score_filter import \ + LanguageIDScoreFilter + + +class LanguageIDScoreFilterTest(unittest.TestCase): + + def _run_language_id_score_filter(self, dataset: Dataset, target_list, op): + if 'stats' not in dataset.features: + # TODO: + # this is a temp solution, + # only add stats when calling filter op + dataset = dataset.add_column(name='stats', + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats) + dataset = dataset.filter(op.process) + dataset = dataset.select_columns(column_names=['text']) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_en_case(self): + + ds_list = [{ + 'text': 'a=1\nb\nc=1+2+3+5\nd=6' + }, { + 'text': + "Today is Sund Sund Sund Sunda and it's a happy day!\nYou know" + }, { + 'text': 'a v s e e f g a qkc' + }, { + 'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►' + }, { + 'text': 'Do you need a cup of coffee?' + }, { + 'text': 'emoji表情测试下😊,😸31231\n' + }] + tgt_list = [{ + 'text': + "Today is Sund Sund Sund Sunda and it's a happy day!\nYou know" + }, { + 'text': 'Do you need a cup of coffee?' + }] + dataset = Dataset.from_list(ds_list) + op = LanguageIDScoreFilter(lang='en', min_score=0.8) + self._run_language_id_score_filter(dataset, tgt_list, op) + + def test_zh_case(self): + + ds_list = [{ + 'text': 'a=1\nb\nc=1+2+3+5\nd=6' + }, { + 'text': + "Today is Sund Sund Sund Sunda and it's a happy day!\nYou know" + }, { + 'text': '我出生于2023年12月15日' + }, { + 'text': ',。、„”“«»1」「《》´∶:?!();–—' + }, { + 'text': '他的英文名字叫Harry Potter' + }, { + 'text': '这是一个测试' + }] + tgt_list = [{ + 'text': '我出生于2023年12月15日' + }, { + 'text': ',。、„”“«»1」「《》´∶:?!();–—' + }, { + 'text': '他的英文名字叫Harry Potter' + }, { + 'text': '这是一个测试' + }] + dataset = Dataset.from_list(ds_list) + op = LanguageIDScoreFilter(lang='zh', min_score=0.8) + self._run_language_id_score_filter(dataset, tgt_list, op) + + def test_none_case(self): + + ds_list = [{ + 'text': 'a=1\nb\nc=1+2+3+5\nd=6' + }, { + 'text': + "Today is Sund Sund Sund Sunda and it's a happy day!\nYou know" + }, { + 'text': '我出生于2023年12月15日' + }, { + 'text': ',。、„”“«»1」「《》´∶:?!();–—' + }, { + 'text': '他的英文名字叫Harry Potter' + }, { + 'text': '这是一个测试' + }] + tgt_list = [{ + 'text': + "Today is Sund Sund Sund Sunda and it's a happy day!\nYou know" + }, { + 'text': '我出生于2023年12月15日' + }, { + 'text': ',。、„”“«»1」「《》´∶:?!();–—' + }, { + 'text': '他的英文名字叫Harry Potter' + }, { + 'text': '这是一个测试' + }] + dataset = Dataset.from_list(ds_list) + op = LanguageIDScoreFilter(lang='', min_score=0.8) + self._run_language_id_score_filter(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_maximum_line_length_filter.py b/tests/ops/filter/test_maximum_line_length_filter.py new file mode 100644 index 000000000..aa50fa286 --- /dev/null +++ b/tests/ops/filter/test_maximum_line_length_filter.py @@ -0,0 +1,52 @@ +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.maximum_line_length_filter import \ + MaximumLineLengthFilter + + +class MaximumLineLengthFilterTest(unittest.TestCase): + + def _run_maximum_line_length_filter(self, dataset: Dataset, target_list, + op): + if 'stats' not in dataset.features: + # TODO: + # this is a temp solution, + # only add stats when calling filter op + dataset = dataset.add_column(name='stats', + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats) + dataset = dataset.filter(op.process) + dataset = dataset.select_columns(column_names=['text']) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_case(self): + + ds_list = [{ + 'text': 'a=1\nb\nc=1+2+3+5\nd=6' + }, { + 'text': + "Today is Sund Sund Sund Sunda and it's a happy day!\nYou know" + }, { + 'text': 'a v s e e f g a qkc' + }, { + 'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►' + }, { + 'text': 'Do you need a cup of coffee?' + }, { + 'text': 'emoji表情测试下😊,😸31231\n' + }] + tgt_list = [{ + 'text': 'a v s e e f g a qkc' + }, { + 'text': 'emoji表情测试下😊,😸31231\n' + }] + dataset = Dataset.from_list(ds_list) + op = MaximumLineLengthFilter(min_len=10, max_len=20) + self._run_maximum_line_length_filter(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_perplexity_filter.py b/tests/ops/filter/test_perplexity_filter.py new file mode 100644 index 000000000..cc07992a4 --- /dev/null +++ b/tests/ops/filter/test_perplexity_filter.py @@ -0,0 +1,50 @@ +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.perplexity_filter import PerplexityFilter + + +class PerplexityFilterTest(unittest.TestCase): + + def _run_perplexity_filter(self, dataset: Dataset, target_list, op): + if 'stats' not in dataset.features: + # TODO: + # this is a temp solution, + # only add stats when calling filter op + dataset = dataset.add_column(name='stats', + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats) + dataset = dataset.filter(op.process) + dataset = dataset.select_columns(column_names=['text']) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_en_case(self): + + ds_list = [{ + 'text': "Today is Sunday and it's a happy day!" + }, { + 'text': + "Today is Sund Sund Sund Sund Sunda and it's a happy day!" + }, { + 'text': 'a v s e c s f e f g a qkc' + }, { + 'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►' + }, { + 'text': 'Do you need a cup of coffee?' + }, { + 'text': 'emoji表情测试下😊,😸31231' + }] + tgt_list = [{ + 'text': "Today is Sunday and it's a happy day!" + }, { + 'text': 'Do you need a cup of coffee?' + }] + dataset = Dataset.from_list(ds_list) + op = PerplexityFilter(lang='en', max_ppl=900) + self._run_perplexity_filter(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_special_characters_filter.py b/tests/ops/filter/test_special_characters_filter.py new file mode 100644 index 000000000..1ded47fef --- /dev/null +++ b/tests/ops/filter/test_special_characters_filter.py @@ -0,0 +1,55 @@ +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.special_characters_filter import \ + SpecialCharactersFilter + + +class SpecialCharactersFilterTest(unittest.TestCase): + + def _run_special_characters_filter(self, dataset: Dataset, target_list, + op): + if 'stats' not in dataset.features: + # TODO: + # this is a temp solution, + # only add stats when calling filter op + dataset = dataset.add_column(name='stats', + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats) + dataset = dataset.filter(op.process) + dataset = dataset.select_columns(column_names=['text']) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_case(self): + + ds_list = [{ + 'text': "Today is Sunday and it's a happy day!" + }, { + 'text': + "Today is Sund Sund Sund Sund Sunda and it's a happy day!" + }, { + 'text': 'a v s e c s f e f g a qkc' + }, { + 'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►' + }, { + 'text': 'Do you need a cup of coffee?' + }, { + 'text': 'emoji表情测试下😊,😸31231' + }] + tgt_list = [{ + 'text': "Today is Sunday and it's a happy day!" + }, { + 'text': + "Today is Sund Sund Sund Sund Sunda and it's a happy day!" + }, { + 'text': 'Do you need a cup of coffee?' + }] + dataset = Dataset.from_list(ds_list) + op = SpecialCharactersFilter(min_ratio=0.0, max_ratio=0.25) + self._run_special_characters_filter(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_specified_field_filter.py b/tests/ops/filter/test_specified_field_filter.py new file mode 100644 index 000000000..99ef622e8 --- /dev/null +++ b/tests/ops/filter/test_specified_field_filter.py @@ -0,0 +1,147 @@ +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.specified_field_filter import SpecifiedFieldFilter + + +class SpecifiedFieldFilterTest(unittest.TestCase): + + def _run_specified_field_filter(self, dataset: Dataset, target_list, op): + dataset = dataset.map(op.compute_stats) + dataset = dataset.filter(op.process) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_case(self): + + ds_list = [{ + 'text': 'Today is Sun', + 'meta': { + 'suffix': '.pdf', + 'star': 50 + } + }, { + 'text': 'a v s e c s f e f g a a a ', + 'meta': { + 'suffix': '.docx', + 'star': 6 + } + }, { + 'text': '中文也是一个字算一个长度', + 'meta': { + 'suffix': '.txt', + 'star': 100 + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'meta': { + 'suffix': '', + 'star': 12.51 + } + }, { + 'text': 'dasdasdasdasdasdasdasd', + 'meta': { + 'suffix': None + } + }] + tgt_list = [{ + 'text': 'Today is Sun', + 'meta': { + 'suffix': '.pdf', + 'star': 50 + } + }, { + 'text': '中文也是一个字算一个长度', + 'meta': { + 'suffix': '.txt', + 'star': 100 + } + }] + dataset = Dataset.from_list(ds_list) + op = SpecifiedFieldFilter(text_key='meta.suffix', + target_value=['.pdf', '.txt']) + self._run_specified_field_filter(dataset, tgt_list, op) + + def test_list_case(self): + + ds_list = [{ + 'text': 'Today is Sun', + 'meta': { + 'suffix': '.pdf', + 'star': 50, + 'path': { + 'test': ['txt', 'json'], + 'test2': 'asadd' + } + } + }, { + 'text': 'a v s e c s f e f g a a a ', + 'meta': { + 'suffix': '.docx', + 'star': 6, + 'path': { + 'test': ['pdf', 'txt', 'xbs'], + 'test2': '' + } + } + }, { + 'text': '中文也是一个字算一个长度', + 'meta': { + 'suffix': '.txt', + 'star': 100, + 'path': { + 'test': ['docx', '', 'html'], + 'test2': 'abcd' + } + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'meta': { + 'suffix': '', + 'star': 12.51, + 'path': { + 'test': ['json'], + 'test2': 'aasddddd' + } + } + }, { + 'text': 'dasdasdasdasdasdasdasd', + 'meta': { + 'suffix': None, + 'star': 333, + 'path': { + 'test': ['pdf', 'txt', 'json', 'docx'], + 'test2': None + } + } + }] + tgt_list = [{ + 'text': 'Today is Sun', + 'meta': { + 'suffix': '.pdf', + 'star': 50, + 'path': { + 'test': ['txt', 'json'], + 'test2': 'asadd' + } + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'meta': { + 'suffix': '', + 'star': 12.51, + 'path': { + 'test': ['json'], + 'test2': 'aasddddd' + } + } + }] + dataset = Dataset.from_list(ds_list) + op = SpecifiedFieldFilter(text_key='meta.path.test', + target_value=['pdf', 'txt', 'json']) + self._run_specified_field_filter(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_specified_numeric_field_filter.py b/tests/ops/filter/test_specified_numeric_field_filter.py new file mode 100644 index 000000000..813c47252 --- /dev/null +++ b/tests/ops/filter/test_specified_numeric_field_filter.py @@ -0,0 +1,204 @@ +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.specified_numeric_field_filter import \ + SpecifiedNumericFieldFilter + + +class SpecifiedNumericFieldFilterTest(unittest.TestCase): + + def _run_specified_numeric_field_filter(self, dataset: Dataset, + target_list, op): + dataset = dataset.map(op.compute_stats) + dataset = dataset.filter(op.process) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_case(self): + + ds_list = [{ + 'text': 'Today is Sun', + 'meta': { + 'suffix': '.pdf', + 'star': 50 + } + }, { + 'text': 'a v s e c s f e f g a a a ', + 'meta': { + 'suffix': '.docx', + 'star': 6 + } + }, { + 'text': '中文也是一个字算一个长度', + 'meta': { + 'suffix': '.txt', + 'star': 100 + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'meta': { + 'suffix': '.html', + 'star': 12.51 + } + }, { + 'text': 'dasdasdasdasdasdasdasd', + 'meta': { + 'suffix': None + } + }] + tgt_list = [{ + 'text': 'Today is Sun', + 'meta': { + 'suffix': '.pdf', + 'star': 50 + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'meta': { + 'suffix': '.html', + 'star': 12.51 + } + }] + dataset = Dataset.from_list(ds_list) + op = SpecifiedNumericFieldFilter(text_key='meta.star', + min_value=10, + max_value=70) + self._run_specified_numeric_field_filter(dataset, tgt_list, op) + + def test_multi_case(self): + + ds_list = [{ + 'text': 'Today is Sun', + 'count': 101, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 34 + }, + 'count': 5 + } + } + }, { + 'text': 'a v s e c s f e f g a a a ', + 'count': 16, + 'meta': { + 'suffix': '.docx', + 'key1': { + 'key2': { + 'count': 243 + }, + 'count': 63 + } + } + }, { + 'text': '中文也是一个字算一个长度', + 'count': 162, + 'meta': { + 'suffix': '.txt', + 'key1': { + 'key2': { + 'count': None + }, + 'count': 23 + } + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'count': None, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 18 + }, + 'count': 48 + } + } + }] + tgt_list = [{ + 'text': 'Today is Sun', + 'count': 101, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 34 + }, + 'count': 5 + } + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'count': None, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 18 + }, + 'count': 48 + } + } + }] + dataset = Dataset.from_list(ds_list) + op = SpecifiedNumericFieldFilter(text_key='meta.key1.key2.count', + min_value=10, + max_value=70) + self._run_specified_numeric_field_filter(dataset, tgt_list, op) + + def test_str_case(self): + + ds_list = [{ + 'text': 'Today is Sun', + 'meta': { + 'suffix': '.pdf', + 'star': '36' + } + }, { + 'text': 'a v s e c s f e f g a a a ', + 'meta': { + 'suffix': '.docx', + 'star': '13.5' + } + }, { + 'text': '中文也是一个字算一个长度', + 'meta': { + 'suffix': '.txt', + 'star': 'asdkc' + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'meta': { + 'suffix': '.html', + 'star': '441' + } + }, { + 'text': 'dasdasdasdasdasdasdasd', + 'meta': { + 'suffix': None + } + }] + tgt_list = [{ + 'text': 'Today is Sun', + 'meta': { + 'suffix': '.pdf', + 'star': '36' + } + }, { + 'text': 'a v s e c s f e f g a a a ', + 'meta': { + 'suffix': '.docx', + 'star': '13.5' + } + }] + dataset = Dataset.from_list(ds_list) + op = SpecifiedNumericFieldFilter(text_key='meta.star', + min_value=10, + max_value=70) + self._run_specified_numeric_field_filter(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_stop_words_filter.py b/tests/ops/filter/test_stop_words_filter.py new file mode 100644 index 000000000..5ffcbc3ca --- /dev/null +++ b/tests/ops/filter/test_stop_words_filter.py @@ -0,0 +1,76 @@ +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.stopwords_filter import StopWordsFilter + + +class StopWordsFilterTest(unittest.TestCase): + + def _run_stopwords_filter(self, dataset: Dataset, target_list, op): + if 'stats' not in dataset.features: + # TODO: + # this is a temp solution, + # only add stats when calling filter op + dataset = dataset.add_column(name='stats', + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats) + dataset = dataset.filter(op.process) + dataset = dataset.select_columns(column_names=['text']) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_en_case(self): + + ds_list = [{ + 'text': "Today is Sunday and it's a happy day!" + }, { + 'text': + "Today is Sund Sund Sund Sund Sunda and it's a happy day!" + }, { + 'text': 'a v s e c s f e f g a qkc' + }, { + 'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►' + }, { + 'text': 'Do you need a cup of coffee?' + }] + tgt_list = [{ + 'text': "Today is Sunday and it's a happy day!" + }, { + 'text': + "Today is Sund Sund Sund Sund Sunda and it's a happy day!" + }, { + 'text': 'Do you need a cup of coffee?' + }] + dataset = Dataset.from_list(ds_list) + op = StopWordsFilter(lang='en', min_ratio=0.3) + self._run_stopwords_filter(dataset, tgt_list, op) + + def test_zh_case(self): + + ds_list = [{ + 'text': '你好,请问你是谁' + }, { + 'text': '字母、数字、下划线、占比、代码' + }, { + 'text': '基于前一步结果,在同一个聚类中找出那些过长文档为假正例,暂不进行滤除' + }, { + 'text': '使用片段分词器对每个页面进行分词,使用语言模型计算每个段落的困惑度得分,由此过滤低质量文本' + }] + tgt_list = [{ + 'text': '你好,请问你是谁' + }, { + 'text': '基于前一步结果,在同一个聚类中找出那些过长文档为假正例,暂不进行滤除' + }, { + 'text': '使用片段分词器对每个页面进行分词,使用语言模型计算每个段落的困惑度得分,由此过滤低质量文本' + }] + dataset = Dataset.from_list(ds_list) + op = StopWordsFilter(lang='zh', + tokenization=True, + min_ratio=0.2, + use_words_aug=True) + self._run_stopwords_filter(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_suffix_filter.py b/tests/ops/filter/test_suffix_filter.py new file mode 100644 index 000000000..8d90d1885 --- /dev/null +++ b/tests/ops/filter/test_suffix_filter.py @@ -0,0 +1,119 @@ +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.suffix_filter import SuffixFilter + + +class SuffixFilterTest(unittest.TestCase): + + def _run_suffix_filter(self, dataset: Dataset, target_list, op): + dataset = dataset.map(op.compute_stats) + dataset = dataset.filter(op.process) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_case(self): + + ds_list = [{ + 'text': 'Today is Sun', + 'meta': { + 'suffix': '.pdf' + } + }, { + 'text': 'a v s e c s f e f g a a a ', + 'meta': { + 'suffix': '.docx' + } + }, { + 'text': '中文也是一个字算一个长度', + 'meta': { + 'suffix': '.txt' + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'meta': { + 'suffix': '.html' + } + }, { + 'text': 'dasdasdasdasdasdasdasd', + 'meta': { + 'suffix': '.py' + } + }] + tgt_list = [{ + 'text': 'Today is Sun', + 'meta': { + 'suffix': '.pdf' + } + }, { + 'text': '中文也是一个字算一个长度', + 'meta': { + 'suffix': '.txt' + } + }] + dataset = Dataset.from_list(ds_list) + op = SuffixFilter(suffixes=['.txt', '.pdf']) + self._run_suffix_filter(dataset, tgt_list, op) + + def test_none_case(self): + + ds_list = [{ + 'text': 'Today is Sun', + 'meta': { + 'suffix': '.pdf' + } + }, { + 'text': 'a v s e c s f e f g a a a ', + 'meta': { + 'suffix': '.docx' + } + }, { + 'text': '中文也是一个字算一个长度', + 'meta': { + 'suffix': '.txt' + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'meta': { + 'suffix': '.html' + } + }, { + 'text': 'dasdasdasdasdasdasdasd', + 'meta': { + 'suffix': '.py' + } + }] + tgt_list = [{ + 'text': 'Today is Sun', + 'meta': { + 'suffix': '.pdf' + } + }, { + 'text': 'a v s e c s f e f g a a a ', + 'meta': { + 'suffix': '.docx' + } + }, { + 'text': '中文也是一个字算一个长度', + 'meta': { + 'suffix': '.txt' + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'meta': { + 'suffix': '.html' + } + }, { + 'text': 'dasdasdasdasdasdasdasd', + 'meta': { + 'suffix': '.py' + } + }] + dataset = Dataset.from_list(ds_list) + op = SuffixFilter() + self._run_suffix_filter(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_text_length_filter.py b/tests/ops/filter/test_text_length_filter.py new file mode 100644 index 000000000..fea4f25aa --- /dev/null +++ b/tests/ops/filter/test_text_length_filter.py @@ -0,0 +1,50 @@ +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.text_length_filter import TextLengthFilter + + +class TextLengthFilterTest(unittest.TestCase): + + def _run_text_length_filter(self, dataset: Dataset, target_list, op): + if 'stats' not in dataset.features: + # TODO: + # this is a temp solution, + # only add stats when calling filter op + dataset = dataset.add_column(name='stats', + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats) + dataset = dataset.filter(op.process) + dataset = dataset.select_columns(column_names=['text']) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_case(self): + + ds_list = [{ + 'text': 'Today is' + }, { + 'text': + "Today is Sund Sund Sund Sund Sund Sunda and it's a happy day!" + }, { + 'text': 'a v s e c s f e f g a a a ' + }, { + 'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►' + }, { + 'text': '中文也是一个字算一个长度' + }] + tgt_list = [{ + 'text': 'a v s e c s f e f g a a a ' + }, { + 'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►' + }, { + 'text': '中文也是一个字算一个长度' + }] + dataset = Dataset.from_list(ds_list) + op = TextLengthFilter(min_len=10, max_len=50) + self._run_text_length_filter(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_word_num_filter.py b/tests/ops/filter/test_word_num_filter.py new file mode 100644 index 000000000..e7ee02415 --- /dev/null +++ b/tests/ops/filter/test_word_num_filter.py @@ -0,0 +1,74 @@ +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.word_num_filter import WordNumFilter + + +class WordNumFilterTest(unittest.TestCase): + + def _run_word_num_filter(self, dataset: Dataset, target_list, op): + if 'stats' not in dataset.features: + # TODO: + # this is a temp solution, + # only add stats when calling filter op + dataset = dataset.add_column(name='stats', + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats) + dataset = dataset.filter(op.process) + dataset = dataset.select_columns(column_names=['text']) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_case(self): + + ds_list = [{ + 'text': 'Today is Sun' + }, { + 'text': + "Today is Sund Sund Sund Sund Sund Sunda and it's a happy day!" + }, { + 'text': 'a v s e c s f e f g a a a ' + }, { + 'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►' + }] + tgt_list = [{ + 'text': + "Today is Sund Sund Sund Sund Sund Sunda and it's a happy day!" + }, { + 'text': 'a v s e c s f e f g a a a ' + }] + dataset = Dataset.from_list(ds_list) + op = WordNumFilter(min_num=5, max_num=15) + self._run_word_num_filter(dataset, tgt_list, op) + + def test_zh_case(self): + + ds_list = [{ + 'text': '你好,请问你是谁' + }, { + 'text': '欢迎来到阿里巴巴' + }, { + 'text': '根据算子使用情况增量安装方案确定' + }, { + 'text': '去除字母、数字、下划线占比过低或过高的代码' + }, { + 'text': '使用片段分词器对每个页面进行分词,使用语言模型计算每个段落的困惑度得分,由此过滤低质量文本' + }, { + 'text': '基于前一步结果,在同一个聚类中找出那些过长文档为假正例,暂不进行滤除' + }] + tgt_list = [{ + 'text': '去除字母、数字、下划线占比过低或过高的代码' + }, { + 'text': '基于前一步结果,在同一个聚类中找出那些过长文档为假正例,暂不进行滤除' + }] + dataset = Dataset.from_list(ds_list) + op = WordNumFilter(lang='zh', + tokenization=True, + min_num=10, + max_num=25) + self._run_word_num_filter(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_word_repetition_filter.py b/tests/ops/filter/test_word_repetition_filter.py new file mode 100644 index 000000000..8678ff5b6 --- /dev/null +++ b/tests/ops/filter/test_word_repetition_filter.py @@ -0,0 +1,85 @@ +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.word_repetition_filter import WordRepetitionFilter + + +class WordRepetitionFilterTest(unittest.TestCase): + + def _run_word_repetition_filter(self, dataset: Dataset, target_list, op): + if 'stats' not in dataset.features: + # TODO: + # this is a temp solution, + # only add stats when calling filter op + dataset = dataset.add_column(name='stats', + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats) + dataset = dataset.filter(op.process) + dataset = dataset.select_columns(column_names=['text']) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_en_case(self): + + ds_list = [{ + 'text': + "Today is Sunday Sunday Sunday Sunday Sunday and it's a happy day!" + }, { + 'text': + "Today is Sunday Sunday Sunday and it's a happy day!" + }, { + 'text': + "Today is Sund Sund Sund Sund Sund Sunda and it's a happy day!" + }, { + 'text': + "plusieurs èrdash@hqbchd.ckd d'accéder à ces wwwasdasd fonc" + }, { + 'text': + 'This proposed a novel proposed pretraining proposed pretraining.' + }] + tgt_list = [{ + 'text': + "Today is Sunday Sunday Sunday and it's a happy day!" + }, { + 'text': + "plusieurs èrdash@hqbchd.ckd d'accéder à ces wwwasdasd fonc" + }, { + 'text': + 'This proposed a novel proposed pretraining proposed pretraining.' + }] + dataset = Dataset.from_list(ds_list) + op = WordRepetitionFilter(rep_len=3, min_ratio=0.0, max_ratio=0.2) + self._run_word_repetition_filter(dataset, tgt_list, op) + + def test_zh_case(self): + + ds_list = [{ + 'text': '去除字母、数字、下划线占比过低或过高的代码' + }, { + 'text': '欢迎来到阿里巴巴巴巴巴巴巴巴' + }, { + 'text': '使用片段分词器对每个页面进行分词,使用语言模型计算每个段落的困惑度得分' + }, { + 'text': '根据算子使用使用使用使用安装方案确定' + }, { + 'text': '基于前一步结果,在同一个聚类中找出那些过长文档为假正例,暂不进行滤除' + }] + tgt_list = [{ + 'text': '去除字母、数字、下划线占比过低或过高的代码' + }, { + 'text': '使用片段分词器对每个页面进行分词,使用语言模型计算每个段落的困惑度得分' + }, { + 'text': '基于前一步结果,在同一个聚类中找出那些过长文档为假正例,暂不进行滤除' + }] + dataset = Dataset.from_list(ds_list) + op = WordRepetitionFilter(lang='zh', + tokenization=True, + rep_len=3, + min_ratio=0.0, + max_ratio=0.2) + self._run_word_repetition_filter(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/__init__.py b/tests/ops/mapper/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/ops/mapper/test_clean_copyright_mapper.py b/tests/ops/mapper/test_clean_copyright_mapper.py new file mode 100644 index 000000000..302942d26 --- /dev/null +++ b/tests/ops/mapper/test_clean_copyright_mapper.py @@ -0,0 +1,41 @@ +import unittest + +from data_juicer.ops.mapper.clean_copyright_mapper import CleanCopyrightMapper + + +class CleanCopyrightMapperTest(unittest.TestCase): + + def setUp(self): + self.op = CleanCopyrightMapper() + + def _run_clean_copyright(self, samples): + for sample in samples: + result = self.op.process(sample) + self.assertEqual(result['text'], result['target']) + + def test_clean_copyright(self): + + samples = [{ + 'text': '这是一段 /* 多行注释\n注释内容copyright\n*/ 的文本。另外还有一些 // 单行注释。', + 'target': '这是一段 的文本。另外还有一些 // 单行注释。' + }, { + 'text': '如果多行/*注释中没有\n关键词,那么\n这部分注释也不会\n被清除*/\n会保留下来', + 'target': '如果多行/*注释中没有\n关键词,那么\n这部分注释也不会\n被清除*/\n会保留下来' + }, { + 'text': '//if start with\n//that will be cleand \n envenly', + 'target': ' envenly' + }, { + 'text': 'http://www.nasosnsncc.com', + 'target': 'http://www.nasosnsncc.com' + }, { + 'text': '#if start with\nthat will be cleand \n#envenly', + 'target': 'that will be cleand \n#envenly' + }, { + 'text': '--if start with\n--that will be cleand \n#envenly', + 'target': '' + }] + self._run_clean_copyright(samples) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_clean_email_mapper.py b/tests/ops/mapper/test_clean_email_mapper.py new file mode 100644 index 000000000..2b3d0fa80 --- /dev/null +++ b/tests/ops/mapper/test_clean_email_mapper.py @@ -0,0 +1,35 @@ +import unittest + +from data_juicer.ops.mapper.clean_email_mapper import CleanEmailMapper + + +class CleanEmailMapperTest(unittest.TestCase): + + def setUp(self): + self.op = CleanEmailMapper() + + def _run_clean_email(self, samples): + for sample in samples: + result = self.op.process(sample) + self.assertEqual(result['text'], result['target']) + + def test_clean_email(self): + + samples = [{ + 'text': 'happy day euqdh@cjqi.com', + 'target': 'happy day ' + }, { + 'text': '请问你是谁dasoidhao@1264fg.45om', + 'target': '请问你是谁dasoidhao@1264fg.45om' + }, { + 'text': 'ftp://examplema-nièrdash@hqbchd.ckdhnfes.cds', + 'target': 'ftp://examplema-niè' + }, { + 'text': '👊23da44sh12@46hqb12chd.ckdhnfes.comd.dasd.asd.dc', + 'target': '👊' + }] + self._run_clean_email(samples) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_clean_html_mapper.py b/tests/ops/mapper/test_clean_html_mapper.py new file mode 100644 index 000000000..ecab4114d --- /dev/null +++ b/tests/ops/mapper/test_clean_html_mapper.py @@ -0,0 +1,212 @@ +import unittest + +from data_juicer.ops.mapper.clean_html_mapper import CleanHtmlMapper + + +class CleanHtmlMapperTest(unittest.TestCase): + + def setUp(self): + self.op = CleanHtmlMapper() + + def _run_helper(self, samples): + for sample in samples: + result = self.op.process(sample) + self.assertEqual(result['text'], result['target']) + + def test_complete_html_text(self): + + samples = [ + { + 'text': + '
    ' + '

    Welcome to My Website

    ' + '

    Lorem ipsum dolor sit amet, consectetur adipiscing elit.' + '

    ' + '

    © ' + '2021 My Website. All Rights Reserved.

    ', + 'target': + '测试\n*Home\n*About\n*Services\n*Contact' + 'Welcome to My WebsiteLorem ipsum dolor sit amet, ' + 'consectetur adipiscing elit.' + 'Learn More© 2021 My Website. All Rights Reserved.' + }, + ] + self._run_helper(samples) + + def test_html_en_text(self): + + samples = [ + { + 'text': '

    This is a test

    ', + 'target': 'This is a test' + }, + { + 'text': + 'Test', + 'target': 'Test' + }, + { + 'text': + '

    This is a test

    ' + '

    This is a test

    ' + '

    This is a test

    ' + '

    Test:This is a test

    ' + '

    Test:This is a test

    ' + '

    ', + 'target': + 'This is a test' + 'This is a test' + 'This is a test' + 'Test:This is a test' + 'Test:This is a test' + }, + ] + + self._run_helper(samples) + + def test_html_zh_text(self): + + samples = [ + { + 'text': '

    这是个测试

    ', + 'target': '这是个测试' + }, + { + 'text': + '测试', + 'target': '测试' + }, + { + 'text': + '

    这是1个测试。

    ' + '

    这是2个测试。

    ' + '

    这是3个测试。

    ' + '

    测试:这是4个测试。

    ' + '

    测试:这是5个测试。

    ' + '

    ', + 'target': + '这是1个测试。' + '这是2个测试。' + '这是3个测试。' + '测试:这是4个测试。' + '测试:这是5个测试。' + }, + ] + self._run_helper(samples) + + def test_no_html_text(self): + + samples = [ + { + 'text': 'This is a test', + 'target': 'This is a test' + }, + { + 'text': '这是个测试', + 'target': '这是个测试' + }, + { + 'text': '12345678', + 'target': '12345678' + }, + ] + self._run_helper(samples) + + def test_fake_html_text(self): + + samples = [ + { + 'text': 'This is a test

    ', + 'target': 'This is a test' + }, + { + 'text': '

    这是个测试', + 'target': '这是个测试' + }, + { + 'text': 'hello', + 'target': 'hello' + }, + { + 'text': '这是个测试', + 'target': '这是个测试' + }, + { + 'text': '<测试>这是个测试', + 'target': '<测试>这是个测试' + }, + { + 'text': + 'abc="https://www.example.com/file.html?name=Test\" 测试', + 'target': + 'abc="https://www.example.com/file.html?name=Test" 测试' + }, + { + 'text': + 'href="https://www.example.com/file.html;name=Test">测试', + 'target': + 'href="https://www.example.com/file.html;name=Test">测试' + }, + { + 'text': + '测试', + 'target': '测试' + }, + ] + self._run_helper(samples) + + def test_whitespace_text(self): + + samples = [ + { + 'text': ' ', + 'target': '' + }, + { + 'text': '', + 'target': '' + }, + { + 'text': ' This is a test', + 'target': 'This is a test' + }, + { + 'text': ' This is a test ', + 'target': 'This is a test ' + }, + ] + self._run_helper(samples) + + def test_only_list_text(self): + + samples = [ + { + 'text': '

  • Apple
  • ', + 'target': '*Apple' + }, + { + 'text': '
  • 苹果
  • ', + 'target': '*苹果' + }, + { + 'text': '
      Apple
    ', + 'target': '*Apple' + }, + { + 'text': '
      苹果
    ', + 'target': '*苹果' + }, + ] + + self._run_helper(samples) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_clean_ip_mapper.py b/tests/ops/mapper/test_clean_ip_mapper.py new file mode 100644 index 000000000..e4dbbcba7 --- /dev/null +++ b/tests/ops/mapper/test_clean_ip_mapper.py @@ -0,0 +1,52 @@ +import unittest + +from data_juicer.ops.mapper.clean_ip_mapper import CleanIpMapper + + +class CleanIpMapperTest(unittest.TestCase): + + def setUp(self): + self.op = CleanIpMapper() + + def _run_clean_ip(self, samples): + for sample in samples: + result = self.op.process(sample) + self.assertEqual(result['text'], result['target']) + + def test_ipv4(self): + + samples = [{ + 'text': 'test of ip 234.128.124.123', + 'target': 'test of ip ' + }, { + 'text': '34.0.124.123', + 'target': '' + }, { + 'text': 'ftp://example.com/188.46.244.216my-page.html', + 'target': 'ftp://example.com/my-page.html' + }, { + 'text': 'ft174.1421.237.246my', + 'target': 'ft174.1421.237.246my' + }] + self._run_clean_ip(samples) + + def test_ipv6(self): + + samples = [{ + 'text': 'dd41:cbaf:d1b4:10a0:b215:72e3:6eaf:3ecb', + 'target': '' + }, { + 'text': 'test of ip 4394:538a:3bf3:61c3:cb0d:d214:526f:70d', + 'target': 'test of ip ' + }, { + 'text': 'com/f770:c52e:ddce:3a9f:8c3b:a7bd:d81f:985cmy-page.html', + 'target': 'com/my-page.html' + }, { + 'text': 'ft1926:43a1:fcb5:ees06:ae63:a2a4:c656:d014my', + 'target': 'ft1926:43a1:fcb5:ees06:ae63:a2a4:c656:d014my' + }] + self._run_clean_ip(samples) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_clean_links_mapper.py b/tests/ops/mapper/test_clean_links_mapper.py new file mode 100644 index 000000000..c3994afbf --- /dev/null +++ b/tests/ops/mapper/test_clean_links_mapper.py @@ -0,0 +1,208 @@ +import unittest + +from data_juicer.ops.mapper.clean_links_mapper import CleanLinksMapper + + +class CleanLinksMapperTest(unittest.TestCase): + + def setUp(self): + self.op = CleanLinksMapper() + + def _run_clean_links(self, samples): + for sample in samples: + result = self.op.process(sample) + self.assertEqual(result['text'], result['target']) + + def test_lower_ftp_links_text(self): + + samples = [{ + 'text': 'ftp://user:password@ftp.example.com:21/', + 'target': '' + }, { + 'text': 'ftp://www.example.com/path/to/file.txt', + 'target': '' + }, { + 'text': 'ftp://example.com/my-page.html', + 'target': '' + }, { + 'text': 'ftp://example.com', + 'target': '' + }] + self._run_clean_links(samples) + + def test_upper_ftp_links_text(self): + + samples = [{ + 'text': 'FTP://user:password@ftp.example.COm:21/', + 'target': '' + }, { + 'text': 'FTP://www.example.com/path/to/file.txt', + 'target': '' + }, { + 'text': 'Ftp://example.com/my-page.HTMl', + 'target': '' + }, { + 'text': 'FTP://EXAMPLE.COM', + 'target': '' + }] + self._run_clean_links(samples) + + def test_lower_https_links_text(self): + + samples = [{ + 'text': + 'https://www.example.com/file.html?param1=value1¶m2=value2', + 'target': '' + }, { + 'text': + 'https://example.com/my-page.html?param1=value1¶m2=value2', + 'target': '' + }, { + 'text': 'https://example.com', + 'target': '' + }] + self._run_clean_links(samples) + + def test_upper_https_links_text(self): + + samples = [{ + 'text': + 'hTTps://www.example.com/file.html?param1=value1¶m2=value2', + 'target': '' + }, { + 'text': + 'HttpS://example.Com/my-page.HTML?param1=value1¶m2=value2', + 'target': '' + }, { + 'text': 'HTTPS://EXAMPLE.COM', + 'target': '' + }] + self._run_clean_links(samples) + + def test_mixed_https_links_text(self): + + samples = [{ + 'text': 'This is a test,' + 'https://www.example.com/file.html?param1=value1¶m2=value2', + 'target': 'This is a test,' + }, { + 'text': '这是个测试,' + 'https://example.com/my-page.html?param1=value1¶m2=value2', + 'target': '这是个测试,' + }, { + 'text': '这是个测试,https://example.com', + 'target': '这是个测试,' + }] + self._run_clean_links(samples) + + def test_lower_http_links_text(self): + + samples = [{ + 'text': + 'http://example.com/my-page.html?param1=value1¶m2=value2', + 'target': '' + }, { + 'text': + 'http://www.example.com/file.html?param1=value1¶m2=value2', + 'target': '' + }, { + 'text': 'https://example.com', + 'target': '' + }] + self._run_clean_links(samples) + + def test_upper_http_links_text(self): + + samples = [ + { + 'text': + 'HTTP://example.com/my-page.html?param1=value1¶m2=value2', + 'target': '' + }, + { + 'text': + 'hTTp://www.example.com/file.html?param1=value1¶m2=value2', + 'target': '' + }, + { + 'text': 'HTTPS://EXAMPLE.COM', + 'target': '' + }, + ] + self._run_clean_links(samples) + + def test_mixed_http_links_text(self): + + samples = [{ + 'text': 'This is a test,' + 'http://www.example.com/file.html?param1=value1¶m2=value2', + 'target': 'This is a test,' + }, { + 'text': '这是个测试,' + 'http://example.com/my-page.html?param1=value1¶m2=value2', + 'target': '这是个测试,' + }, { + 'text': '这是个测试,https://example.com', + 'target': '这是个测试,' + }] + self._run_clean_links(samples) + + def test_email_text(self): + + samples = [ + { + 'text': 'This is a sample@example for test', + 'target': 'This is a sample@example for test', + }, + { + 'text': '这是一个测试, sample@example', + 'target': '这是一个测试, sample@example', + }, + ] + self._run_clean_links(samples) + + def test_fake_links_text(self): + + samples = [ + { + 'text': 'abcd:/e f is a sample for test', + 'target': 'abcd:/e f is a sample for test', + }, + { + 'text': 'abcd://ef is a sample for test', + 'target': ' is a sample for test', + }, + { + 'text': 'This is a test,' + 'http测试://www.example.com/file.html?param1=value1', + 'target': 'This is a test,' + }, + { + 'text': 'This is a test,' + 'http://www.测试.com/path/file.html?param1=value1¶m2=value2', + 'target': 'This is a test,' + }, + ] + self._run_clean_links(samples) + + def test_no_link_text(self): + + samples = [ + { + 'text': 'This is a sample for test', + 'target': 'This is a sample for test', + }, + { + 'text': '这是一个测试', + 'target': '这是一个测试', + }, + { + 'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►', + 'target': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►', + }, + ] + self._run_clean_links(samples) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_exapnd_macro_mapper.py b/tests/ops/mapper/test_exapnd_macro_mapper.py new file mode 100644 index 000000000..3cdc8a0c1 --- /dev/null +++ b/tests/ops/mapper/test_exapnd_macro_mapper.py @@ -0,0 +1,29 @@ +import unittest + +from data_juicer.ops.mapper.expand_macro_mapper import ExpandMacroMapper + + +class ExpandMacroMapperTest(unittest.TestCase): + + def setUp(self): + self.op = ExpandMacroMapper() + + def _run_expand_macro(self, samples): + for sample in samples: + result = self.op.process(sample) + self.assertEqual(result['text'], result['target']) + + def test_case(self): + + samples = [{ + 'text': + '\\documentclass{article}\n% Recommended, but optional, packages for figures and better typesetting:\n\\usepackage{microtype}\n\\usepackage{graphicx}\n\n% Attempt to make hyperref and algorithmic work together better:\n\\newcommand{\\theHalgorithm}{\\arabic{algorithm}}\n% For theorems and such\n\\usepackage{amsmath}\n\n%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n% THEOREMS\n%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n\\theoremstyle{plain}\n\\newtheorem{lemma}[theorem]{Lemma}\n\\newtheorem{corollary}[theorem]{Corollary}\n\\theoremstyle{definition}\n\n\\usepackage[textsize=small]{todonotes}\n\\setuptodonotes{inline}\n\n\\usepackage{makecell}\n\\newcommand{\\cmark}{\\ding{51}\\xspace}%\n\\newcommand{\\xmark}{\\ding{55}\\xspace}%\n\n\\def \\alambic {\\includegraphics[height=1.52ex]{img/alembic-crop.pdf}\\xspace}\n\n\\newcommand\\binke[1]{{\\color{blue} \\footnote{\\color{blue}binke: #1}} }\n\\newcommand\\Zerocost{Zero-cost}\n\\newcommand\\imagenet{ImageNet}\n\n\\begin{document}\n\n\\begin{abstract}\nThe wide\n\\end{abstract}\n\\section{Introduction}\n\\label{introduction}\nThe main contributions are summarized as follows:\n\\section{Background and Related Work}\\label{background}\n\\subsection{One-Shot NAS} In one-shot NAS\n\\section{PreNAS}\\label{method}In this\n\\subsection{One-Shot NAS with Preferred Learning}\nIn the specialization stage, the optimal architectures under given resource constraints can be directly obtained:\n\\begin{equation}\n\\widetilde{\\mathcal{A}}^* = \\widetilde{\\mathcal{A}} .\n\\end{equation}\n\\subsection{Zero-Cost Transformer Selector}\\label{sub:layerNorm}\n\\subsection{Performance Balancing} We discuss\n\\section{Experiments}\\label{experiments}\n\\subsection{Setup}\n\\subsection{Main Results}\\label{sec:sota}\n\\subsection{Analysis and Ablation study}\\label{ablation}\n\\begin{figure}[t]\n\\vskip 0.1in\n \\centering\n \\subfigure[Search spaces]{\\includegraphics[width=0.36\\linewidth]{img/search_space.pdf}\\label{fg:search_space:a}}%\n \\hfil%\n \\subfigure[Error distributions]{\\includegraphics[width=0.58\\linewidth]{img/cumulation.pdf}\\label{fg:search_space:b}}\n \\caption{Model quality}\n\\vskip -0.1in\n\\end{figure}\n\\paragraph{Effect of Performance Balancing} During\n\\subsection{Transfer Learning Results}\n\\subsection{CNN Results} in terms of similar FLOPs.\n\\FloatBarrier\n\\section{Conclusion}\\label{conclusion} In this\n% Acknowledgements should only appear in the accepted version.\n\\bibliography{ref}\n\\bibliographystyle{icml2023}\n\\clearpage\n\\appendix\n\\onecolumn\n\\section{Statistical}\n\\label{appendix:snipAnalysis} We analyze\n\\section{The Greedy Algorithm}\n\\label{appendix:greedy}\n\\section{Regularization \\& Data Augmentation}\\label{appendix:aug}\n\\renewcommand{\\arraystretch}{1.2}\n\\end{document}\n', # noqa: E501 + 'target': + '\\documentclass{article}\n% Recommended, but optional, packages for figures and better typesetting:\n\\usepackage{microtype}\n\\usepackage{graphicx}\n\n% Attempt to make hyperref and algorithmic work together better:\n\\newcommand{\\arabic{algorithm}}{\\arabic{algorithm}}\n% For theorems and such\n\\usepackage{amsmath}\n\n%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n% THEOREMS\n%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n\\theoremstyle{plain}\n\\newtheorem{lemma}[theorem]{Lemma}\n\\newtheorem{corollary}[theorem]{Corollary}\n\\theoremstyle{definition}\n\n\\usepackage[textsize=small]{todonotes}\n\\setuptodonotes{inline}\n\n\\usepackage{makecell}\n\\newcommand{\\cmark}{\\ding{51}\\xspace}%\n\\newcommand{\\xmark}{\\ding{55}\\xspace}%\n\n\\def \\includegraphics[height=1.52ex]{img/alembic-crop.pdf}\\xspace {\\includegraphics[height=1.52ex]{img/alembic-crop.pdf}\\xspace}\n\n\\newcommand\\binke[1]{{\\color{blue} \\footnote{\\color{blue}binke: #1}} }\n\\newcommand\\Zerocost{Zero-cost}\n\\newcommand\\imagenet{ImageNet}\n\n\\begin{document}\n\n\\begin{abstract}\nThe wide\n\\end{abstract}\n\\section{Introduction}\n\\label{introduction}\nThe main contributions are summarized as follows:\n\\section{Background and Related Work}\\label{background}\n\\subsection{One-Shot NAS} In one-shot NAS\n\\section{PreNAS}\\label{method}In this\n\\subsection{One-Shot NAS with Preferred Learning}\nIn the specialization stage, the optimal architectures under given resource constraints can be directly obtained:\n\\begin{equation}\n\\widetilde{\\mathcal{A}}^* = \\widetilde{\\mathcal{A}} .\n\\end{equation}\n\\subsection{Zero-Cost Transformer Selector}\\label{sub:layerNorm}\n\\subsection{Performance Balancing} We discuss\n\\section{Experiments}\\label{experiments}\n\\subsection{Setup}\n\\subsection{Main Results}\\label{sec:sota}\n\\subsection{Analysis and Ablation study}\\label{ablation}\n\\begin{figure}[t]\n\\vskip 0.1in\n \\centering\n \\subfigure[Search spaces]{\\includegraphics[width=0.36\\linewidth]{img/search_space.pdf}\\label{fg:search_space:a}}%\n \\hfil%\n \\subfigure[Error distributions]{\\includegraphics[width=0.58\\linewidth]{img/cumulation.pdf}\\label{fg:search_space:b}}\n \\caption{Model quality}\n\\vskip -0.1in\n\\end{figure}\n\\paragraph{Effect of Performance Balancing} During\n\\subsection{Transfer Learning Results}\n\\subsection{CNN Results} in terms of similar FLOPs.\n\\FloatBarrier\n\\section{Conclusion}\\label{conclusion} In this\n% Acknowledgements should only appear in the accepted version.\n\\bibliography{ref}\n\\bibliographystyle{icml2023}\n\\clearpage\n\\appendix\n\\onecolumn\n\\section{Statistical}\n\\label{appendix:snipAnalysis} We analyze\n\\section{The Greedy Algorithm}\n\\label{appendix:greedy}\n\\section{Regularization \\& Data Augmentation}\\label{appendix:aug}\n\\renewcommand{\\arraystretch}{1.2}\n\\end{document}\n' # noqa: E501 + }] + + self._run_expand_macro(samples) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_fix_unicode_mapper.py b/tests/ops/mapper/test_fix_unicode_mapper.py new file mode 100644 index 000000000..f77e53eb7 --- /dev/null +++ b/tests/ops/mapper/test_fix_unicode_mapper.py @@ -0,0 +1,47 @@ +import unittest + +from data_juicer.ops.mapper.fix_unicode_mapper import FixUnicodeMapper + + +class FixUnicodeMapperTest(unittest.TestCase): + + def setUp(self): + self.op = FixUnicodeMapper() + + def _run_fix_unicode(self, samples): + for sample in samples: + result = self.op.process(sample) + self.assertEqual(result['text'], result['target']) + + def test_bad_unicode_text(self): + + samples = [ + { + 'text': '✔ No problems', + 'target': '✔ No problems' + }, + { + 'text': + 'The Mona Lisa doesn’t have eyebrows.', + 'target': 'The Mona Lisa doesn\'t have eyebrows.' + }, + ] + + self._run_fix_unicode(samples) + + def test_good_unicode_text(self): + samples = [ + { + 'text': 'No problems', + 'target': 'No problems' + }, + { + 'text': '阿里巴巴', + 'target': '阿里巴巴' + }, + ] + self._run_fix_unicode(samples) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_punctuation_normalization_mapper.py b/tests/ops/mapper/test_punctuation_normalization_mapper.py new file mode 100644 index 000000000..a114b83b1 --- /dev/null +++ b/tests/ops/mapper/test_punctuation_normalization_mapper.py @@ -0,0 +1,30 @@ +import unittest + +from data_juicer.ops.mapper.punctuation_normalization_mapper import \ + PunctuationNormalizationMapper + + +class PunctuationNormalizationMapperTest(unittest.TestCase): + + def setUp(self): + self.op = PunctuationNormalizationMapper() + + def _run_punctuation_normalization(self, samples): + for sample in samples: + result = self.op.process(sample) + self.assertEqual(result['text'], result['target']) + + def test_case(self): + + samples = [{ + 'text': + ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►', + 'target': + ",.,\"\"\"\"\"\"\"\"\"\"'::?!();- - . ~'...-<>[]%-" + }] + + self._run_punctuation_normalization(samples) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_remove_bibliography_mapper.py b/tests/ops/mapper/test_remove_bibliography_mapper.py new file mode 100644 index 000000000..449cb59c7 --- /dev/null +++ b/tests/ops/mapper/test_remove_bibliography_mapper.py @@ -0,0 +1,41 @@ +import unittest + +from data_juicer.ops.mapper.remove_bibliography_mapper import \ + RemoveBibliographyMapper + + +class RemoveBibliographyMapperTest(unittest.TestCase): + + def setUp(self): + self.op = RemoveBibliographyMapper() + + def _run_remove_bibliography(self, samples): + for sample in samples: + result = self.op.process(sample) + self.assertEqual(result['text'], result['target']) + + def test_bibliography_case(self): + + samples = [{ + 'text': + "%%\n%% This is file `sample-sigconf.tex\\clearpage\n\\bibliographystyle{ACM-Reference-Format}\n\\bibliography{sample-base}\n\\end{document}\n\\endinput\n%%\n%% End of file `sample-sigconf.tex'.\n", # noqa: E501 + 'target': + '%%\n%% This is file `sample-sigconf.tex\\clearpage\n\\bibliographystyle{ACM-Reference-Format}\n' # noqa: E501 + }] + + self._run_remove_bibliography(samples) + + def test_ref_case(self): + + samples = [{ + 'text': + "%%\n%% This is file `sample-sigconf.tex\\clearpage\n\\begin{references}\n\\end{document}\n\\endinput\n%%\n%% End of file `sample-sigconf.tex'.\n", # noqa: E501 + 'target': + '%%\n%% This is file `sample-sigconf.tex\\clearpage\n' # noqa: E501 + }] + + self._run_remove_bibliography(samples) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_remove_comments_mapper.py b/tests/ops/mapper/test_remove_comments_mapper.py new file mode 100644 index 000000000..d61494c14 --- /dev/null +++ b/tests/ops/mapper/test_remove_comments_mapper.py @@ -0,0 +1,27 @@ +import unittest + +from data_juicer.ops.mapper.remove_comments_mapper import RemoveCommentsMapper + + +class RemoveCommentsMapperTest(unittest.TestCase): + + def _run_remove_comments(self, samples, op): + for sample in samples: + result = op.process(sample) + self.assertEqual(result['text'], result['target']) + + def test_tex_case(self): + + samples = [{ + 'text': + "%%\n%% This is file `sample-sigconf.tex',\n%% The first command in your LaTeX source must be the \\documentclass command.\n\\documentclass[sigconf,review,anonymous]{acmart}\n%% NOTE that a single column version is required for \n%% submission and peer review. This can be done by changing\n\\input{math_commands.tex}\n%% end of the preamble, start of the body of the document source.\n\\begin{document}\n%% The \"title\" command has an optional parameter,\n\\title{Hierarchical Cross Contrastive Learning of Visual Representations}\n%%\n%% The \"author\" command and its associated commands are used to define\n%% the authors and their affiliations.\n\\author{Hesen Chen}\n\\affiliation{%\n \\institution{Alibaba Group}\n \\city{Beijing}\n \\country{China}}\n\\email{hesen.chs@alibaba-inc.com}\n%% By default, the full list of authors will be used in the page\n\\begin{abstract}The rapid\n\\end{abstract}\n\\begin{CCSXML}\n\\ccsdesc[500]{Computing methodologies~Image representations}\n%% Keywords. The author(s) should pick words that accurately describe\n\\keywords{self-supervised, ontrastive Learning, hierarchical projection, cross-level}\n%% page.\n\\begin{teaserfigure}\n\\end{teaserfigure}\n%% This command processes the author and affiliation and title\n\\maketitle\n\\section{Introduction}\n\\begin{itemize}\n\\end{itemize}\n\\section{Related Work}\n\\label{gen_inst} Self-supervised\n\\section{Method}\n\\label{method}In this section,\n\\subsection{Framework} kkk\n\\subsection{Cross Contrastive Loss}\nSince $\\sZ^n$ are extracted\n\\subsection{Implementation details}\n\\textbf{Image augmentations} We use\n\\textbf{Architecture} We use\n\\textbf{Optimization} We adapt \n\\section{Experiments}\n\\label{experiments}In this section\n\\subsection{Linear and Semi-Supervised Evaluations on ImageNet}\n\\textbf{Linear evaluation on ImageNet} We firs\n\\textbf{Semi-supervised learning on ImageNet} We simply\n\\subsection{Transfer to other datasets and tasks}\n\\textbf{Image classification with fixed features} We follow\n\\section{Ablations} We present\n\\subsection{Influence of hierarchical projection head and cross contrastive loss} get out\n\\subsection{Levels and depth of projector network}\n\\end{center}\n\\caption{\\label{figure3} \\textbf{Different way of cross-correlation on 3 level hierarchical projection head.} '=' denotes stop gradient.}\n\\end{figure}\n\\subsection{Analyze of} In this\n\\textbf{Similarity between} Using SimSiam\n\\textbf{Feature similarity} We extracted\n\\section{Conclusion}\nWe propose HCCL\n\\clearpage\n\\bibliographystyle{ACM-Reference-Format}\n\\bibliography{sample-base}\n\\end{document}\n\\endinput\n%%\n%% End of file `sample-sigconf.tex'.\n", # noqa: E501 + 'target': + "\\documentclass[sigconf,review,anonymous]{acmart}\n\\input{math_commands.tex}\n\\begin{document}\n\\title{Hierarchical Cross Contrastive Learning of Visual Representations}\n\\author{Hesen Chen}\n\\affiliation{%\n \\institution{Alibaba Group}\n \\city{Beijing}\n \\country{China}}\n\\email{hesen.chs@alibaba-inc.com}\n\\begin{abstract}The rapid\n\\end{abstract}\n\\begin{CCSXML}\n\\ccsdesc[500]{Computing methodologies~Image representations}\n\\keywords{self-supervised, ontrastive Learning, hierarchical projection, cross-level}\n\\begin{teaserfigure}\n\\end{teaserfigure}\n\\maketitle\n\\section{Introduction}\n\\begin{itemize}\n\\end{itemize}\n\\section{Related Work}\n\\label{gen_inst} Self-supervised\n\\section{Method}\n\\label{method}In this section,\n\\subsection{Framework} kkk\n\\subsection{Cross Contrastive Loss}\nSince $\\sZ^n$ are extracted\n\\subsection{Implementation details}\n\\textbf{Image augmentations} We use\n\\textbf{Architecture} We use\n\\textbf{Optimization} We adapt \n\\section{Experiments}\n\\label{experiments}In this section\n\\subsection{Linear and Semi-Supervised Evaluations on ImageNet}\n\\textbf{Linear evaluation on ImageNet} We firs\n\\textbf{Semi-supervised learning on ImageNet} We simply\n\\subsection{Transfer to other datasets and tasks}\n\\textbf{Image classification with fixed features} We follow\n\\section{Ablations} We present\n\\subsection{Influence of hierarchical projection head and cross contrastive loss} get out\n\\subsection{Levels and depth of projector network}\n\\end{center}\n\\caption{\\label{figure3} \\textbf{Different way of cross-correlation on 3 level hierarchical projection head.} '=' denotes stop gradient.}\n\\end{figure}\n\\subsection{Analyze of} In this\n\\textbf{Similarity between} Using SimSiam\n\\textbf{Feature similarity} We extracted\n\\section{Conclusion}\nWe propose HCCL\n\\clearpage\n\\bibliographystyle{ACM-Reference-Format}\n\\bibliography{sample-base}\n\\end{document}\n\\endinput\n" # noqa: E501 + }] + + op = RemoveCommentsMapper(doc_type='tex', inline=True, multiline=True) + self._run_remove_comments(samples, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_remove_header_mapper.py b/tests/ops/mapper/test_remove_header_mapper.py new file mode 100644 index 000000000..ea7170fad --- /dev/null +++ b/tests/ops/mapper/test_remove_header_mapper.py @@ -0,0 +1,29 @@ +import unittest + +from data_juicer.ops.mapper.remove_header_mapper import RemoveHeaderMapper + + +class RemoveHeaderMapperTest(unittest.TestCase): + + def setUp(self): + self.op = RemoveHeaderMapper() + + def _run_remove_header(self, samples): + for sample in samples: + result = self.op.process(sample) + self.assertEqual(result['text'], result['target']) + + def test_case(self): + + samples = [{ + 'text': + "%%\n%% This is file `sample-sigconf.tex',\n%% The first command in your LaTeX source must be the \\documentclass command.\n\\documentclass[sigconf,review,anonymous]{acmart}\n%% NOTE that a single column version is required for \n%% submission and peer review. This can be done by changing\n\\input{math_commands.tex}\n%% end of the preamble, start of the body of the document source.\n\\begin{document}\n%% The \"title\" command has an optional parameter,\n\\title{Hierarchical Cross Contrastive Learning of Visual Representations}\n%%\n%% The \"author\" command and its associated commands are used to define\n%% the authors and their affiliations.\n\\author{Hesen Chen}\n\\affiliation{%\n \\institution{Alibaba Group}\n \\city{Beijing}\n \\country{China}}\n\\email{hesen.chs@alibaba-inc.com}\n%% By default, the full list of authors will be used in the page\n\\begin{abstract}The rapid\n\\end{abstract}\n\\begin{CCSXML}\n\\ccsdesc[500]{Computing methodologies~Image representations}\n%% Keywords. The author(s) should pick words that accurately describe\n\\keywords{self-supervised, ontrastive Learning, hierarchical projection, cross-level}\n%% page.\n\\begin{teaserfigure}\n\\end{teaserfigure}\n%% This command processes the author and affiliation and title\n\\maketitle\n\\section{Introduction}\n\\begin{itemize}\n\\end{itemize}\n\\section{Related Work}\n\\label{gen_inst} Self-supervised\n\\section{Method}\n\\label{method}In this section,\n\\subsection{Framework} kkk\n\\subsection{Cross Contrastive Loss}\nSince $\\sZ^n$ are extracted\n\\subsection{Implementation details}\n\\textbf{Image augmentations} We use\n\\textbf{Architecture} We use\n\\textbf{Optimization} We adapt \n\\section{Experiments}\n\\label{experiments}In this section\n\\subsection{Linear and Semi-Supervised Evaluations on ImageNet}\n\\textbf{Linear evaluation on ImageNet} We firs\n\\textbf{Semi-supervised learning on ImageNet} We simply\n\\subsection{Transfer to other datasets and tasks}\n\\textbf{Image classification with fixed features} We follow\n\\section{Ablations} We present\n\\subsection{Influence of hierarchical projection head and cross contrastive loss} get out\n\\subsection{Levels and depth of projector network}\n\\end{center}\n\\caption{\\label{figure3} \\textbf{Different way of cross-correlation on 3 level hierarchical projection head.} '=' denotes stop gradient.}\n\\end{figure}\n\\subsection{Analyze of} In this\n\\textbf{Similarity between} Using SimSiam\n\\textbf{Feature similarity} We extracted\n\\section{Conclusion}\nWe propose HCCL\n\\clearpage\n\\bibliographystyle{ACM-Reference-Format}\n\\bibliography{sample-base}\n\\end{document}\n\\endinput\n%%\n%% End of file `sample-sigconf.tex'.\n", # noqa: E501 + 'target': + "\\section{Introduction}\n\\begin{itemize}\n\\end{itemize}\n\\section{Related Work}\n\\label{gen_inst} Self-supervised\n\\section{Method}\n\\label{method}In this section,\n\\subsection{Framework} kkk\n\\subsection{Cross Contrastive Loss}\nSince $\\sZ^n$ are extracted\n\\subsection{Implementation details}\n\\textbf{Image augmentations} We use\n\\textbf{Architecture} We use\n\\textbf{Optimization} We adapt \n\\section{Experiments}\n\\label{experiments}In this section\n\\subsection{Linear and Semi-Supervised Evaluations on ImageNet}\n\\textbf{Linear evaluation on ImageNet} We firs\n\\textbf{Semi-supervised learning on ImageNet} We simply\n\\subsection{Transfer to other datasets and tasks}\n\\textbf{Image classification with fixed features} We follow\n\\section{Ablations} We present\n\\subsection{Influence of hierarchical projection head and cross contrastive loss} get out\n\\subsection{Levels and depth of projector network}\n\\end{center}\n\\caption{\\label{figure3} \\textbf{Different way of cross-correlation on 3 level hierarchical projection head.} '=' denotes stop gradient.}\n\\end{figure}\n\\subsection{Analyze of} In this\n\\textbf{Similarity between} Using SimSiam\n\\textbf{Feature similarity} We extracted\n\\section{Conclusion}\nWe propose HCCL\n\\clearpage\n\\bibliographystyle{ACM-Reference-Format}\n\\bibliography{sample-base}\n\\end{document}\n\\endinput\n%%\n%% End of file `sample-sigconf.tex'.\n" # noqa: E501 + }] + + self._run_remove_header(samples) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_remove_long_words_mapper.py b/tests/ops/mapper/test_remove_long_words_mapper.py new file mode 100644 index 000000000..01962e508 --- /dev/null +++ b/tests/ops/mapper/test_remove_long_words_mapper.py @@ -0,0 +1,61 @@ +import unittest + +from data_juicer.ops.mapper.remove_long_words_mapper import \ + RemoveLongWordsMapper + + +class RemoveLongWordsMapperTest(unittest.TestCase): + + def _run_remove_long_words(self, samples, op): + for sample in samples: + result = op.process(sample) + self.assertEqual(result['text'], result['target']) + + def test_normal_case(self): + + samples = [{ + 'text': + 'This paper proposed novel method LLM pretraining.', + 'target': + 'This paper proposed novel method LLM pretraining.' + }] + op = RemoveLongWordsMapper(min_len=3, max_len=15) + self._run_remove_long_words(samples, op) + + def test_long_short_words_case(self): + + samples = [{ + 'text': + 'This paper a novel eqeqweqwewqeqwe121e1 method on LLM pretrain.', + 'target': 'This paper novel method LLM pretrain.' + }, { + 'text': + 'Sur la plateforme MT4, manières à ces fonctionnalités sont conçu', + 'target': + 'Sur plateforme MT4, manières ces fonctionnalités sont conçu' + }] + op = RemoveLongWordsMapper(min_len=3, max_len=15) + self._run_remove_long_words(samples, op) + + def test_special_words_case(self): + + samples = [{ + 'text': + 'This paper proposed a novel eqeqweqwewqenhq😊😠 method on LLM.', + 'target': + 'This paper proposed novel eqeqweqwewqenhq😊😠 method LLM.' + }, { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder0123813976125", + 'target': + "Sur plateforme MT4, plusieurs manières d'accéder0123813976125" + }, { + 'text': 'The Mona Lisa doesn’t have eyebrows.', + 'target': 'The Mona Lisa have eyebrows.' + }] + op = RemoveLongWordsMapper(min_len=3, max_len=15) + self._run_remove_long_words(samples, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_remove_specific_chars_mapper.py b/tests/ops/mapper/test_remove_specific_chars_mapper.py new file mode 100644 index 000000000..4073d45df --- /dev/null +++ b/tests/ops/mapper/test_remove_specific_chars_mapper.py @@ -0,0 +1,45 @@ +import unittest + +from data_juicer.ops.mapper.remove_specific_chars_mapper import \ + RemoveSpecificCharsMapper + + +class RemoveSpecificCharsMapperTest(unittest.TestCase): + + def setUp(self): + self.op = RemoveSpecificCharsMapper() + + def _run_helper(self, samples): + for sample in samples: + result = self.op.process(sample) + self.assertEqual(result['text'], result['target']) + + def test_complete_html_text(self): + + samples = [ + { + 'text': '这是一个干净的文本。Including Chinese and English.', + 'target': '这是一个干净的文本。Including Chinese and English.', + }, + { + 'text': '◆●■►▼▲▴∆▻▷❖♡□', + 'target': '', + }, + { + 'text': '►This is a dirty text ▻ 包括中文和英文', + 'target': 'This is a dirty text 包括中文和英文', + }, + { + 'text': '多个●■►▼这样的特殊字符可以►▼▲▴∆吗?', + 'target': '多个这样的特殊字符可以吗?', + }, + { + 'text': '未指定的●■☛₨➩►▼▲特殊字符会☻▷❖被删掉吗??', + 'target': '未指定的☛₨➩特殊字符会☻被删掉吗??', + }, + ] + self._run_helper(samples) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_remove_table_text_mapper.py b/tests/ops/mapper/test_remove_table_text_mapper.py new file mode 100644 index 000000000..d08585d3e --- /dev/null +++ b/tests/ops/mapper/test_remove_table_text_mapper.py @@ -0,0 +1,101 @@ +import unittest + +from data_juicer.ops.mapper.remove_table_text_mapper import \ + RemoveTableTextMapper + + +class RemoveTableTextMapperTest(unittest.TestCase): + + def setUp(self): + self.op = RemoveTableTextMapper() + + def _run_remove_header(self, samples): + for sample in samples: + result = self.op.process(sample) + self.assertEqual(result['text'], result['target']) + + def test_single_table_case(self): + + samples = [{ + 'text': + 'This is a table:\n编号 分行 营运资金1 营运资金2 营运资金3 营运资金4 营运资金5\n① 北京分行 495,000,000.00 200,000,000.00 295,000,000.00 - 495,000,000.00\n② 大连分行 440,000,000.00 100,000,000.00 340,000,000.00 - 440,000,000.00\n③ 重庆分行 500,000,000.00 100,000,000.00 400,000,000.00 - 500,000,000.00\n④ 南京分行 430,000,000.00 100,000,000.00 330,000,000.00 - 430,000,000.00\n⑤ 青岛分行 500,000,000.00 - 100,159,277.60 399,840,722.40 500,000,000.00\nThe end of the table.', # noqa: E501 + 'target': 'This is a table:\nThe end of the table.' + }] + + self._run_remove_header(samples) + + def test_one_row_text_case(self): + + samples = [{ + 'text': + 'This is a table:\n编号\n①\n②\n③\n④\n⑤\nThe end of the table.', + 'target': + 'This is a table:\n编号\n①\n②\n③\n④\n⑤\nThe end of the table.' + }] + + self._run_remove_header(samples) + + def test_one_column_text_case(self): + + samples = [{ + 'text': + 'This is a table:\n编号 分行 营运资金1 营运资金2 营运资金3 营运资金4 营运资金5\nThe end of the table.', # noqa: E501 + 'target': + 'This is a table:\n编号 分行 营运资金1 营运资金2 营运资金3 营运资金4 营运资金5\nThe end of the table.' # noqa: E501 + }] + + self._run_remove_header(samples) + + def test_smallest_2x2_table_case(self): + + samples = [{ + 'text': + 'This is a table:\n编号 分行\n① 北京分行\nThe end of the table.', # noqa: E501 + 'target': 'This is a table:\nThe end of the table.' + }] + + self._run_remove_header(samples) + + def test_unaligned_column_num_case(self): + samples = [{ + 'text': + 'This is a table:\n编号 分行 营运资金1 营运资金2 营运资金3 营运资金4\n① 北京分行 495,000,000.00 200,000,000.00 295,000,000.00 - 495,000,000.00\n② 大连分行 440,000,000.00 100,000,000.00 340,000,000.00 -\n③ 重庆分行 500,000,000.00 100,000,000.00 400,000,000.00 - 500,000,000.00\n④ 南京分行 430,000,000.00 100,000,000.00 - 430,000,000.00\n⑤ 青岛分行 500,000,000.00 - 100,159,277.60 399,840,722.40 500,000,000.00\nThe end of the table.', # noqa: E501 + 'target': + 'This is a table:\n编号 分行 营运资金1 营运资金2 营运资金3 营运资金4\n① 北京分行 495,000,000.00 200,000,000.00 295,000,000.00 - 495,000,000.00\n② 大连分行 440,000,000.00 100,000,000.00 340,000,000.00 -\n③ 重庆分行 500,000,000.00 100,000,000.00 400,000,000.00 - 500,000,000.00\n④ 南京分行 430,000,000.00 100,000,000.00 - 430,000,000.00\n⑤ 青岛分行 500,000,000.00 - 100,159,277.60 399,840,722.40 500,000,000.00\nThe end of the table.', # noqa: E501 + }] + + self._run_remove_header(samples) + + def test_false_positive_case(self): + samples = [{ + 'text': + '平安银行股份有限公司\n非公开发行普通股认购资金到位情况的验资报告非公开发行普通股认购资金到位情况的验资报告\n普华永道中天验字(2015)第446号\n(第一页,共二页)\n平安银行股份有限公司:\n平安银行股份有限公司(以下简称“贵公司”)委托中信证券股份有限公司作为主\n承销商非公开发行普通股 598,802,395 股。我们接受委托,审验了贵公司截至\n2015年 5月 5日止由中信证券股份有限公司代收取的向境内合格投资者非公开发行\n普通股认购资金的到位情况。按照国家相关法律、法规的规定以及认购协议、合同\n的要求出资认购,提供真实、合法、完整的验资资料,保护资产的安全、完整是贵\n公司管理层及中信证券股份有限公司的责任。我们的责任是对贵公司由中信证券股\n份有限公司代收取的境内合格投资者本次非公开发行普通股认购资金的到位情况发\n表审验意见。我们的审验是依据《中国注册会计师审计准则第 1602号——验资》进\n行的。在审验过程中,我们结合贵公司的实际情况,实施了检查等必要的审验程\n序。\n经贵公司 2014 年 7 月 15 日第九届董事会第五次会议提议,2014 年 8 月 4 日\n2014 年第二次临时股东大会审议通过《平安银行股份有限公司关于非公开发行普\n通股方案的议案》,贵公司拟向境内合格投资者非公开发行不超过 1,070,663,811\n股普通股。根据中国证券监督管理委员会证监许可[2015]697 号文《关于核准平安\n银行股份有限公司非公开发行股票的批复》,贵公司获准向境内合格投资者非公开\n发行不超过1,070,663,811股普通股。普华永道中天验字(2015)第446号\n(第二页,共二页)\n经我们审验,截至 2015 年 5 月 5 日止,贵公司以每股人民币 16.70 元合计向\n境内合格投资者非公开发行普通股 598,802,395 股,由发行主承销商中信证券股份\n有限公司代贵公司实际收到人民币 9,999,999,996.50元。所有认购资金均以人民币\n现金形式汇入。\n本验资报告仅供贵公司向中国证券监督管理委员会、深圳证券交易所报送资料\n及向中国证券登记结算有限责任公司深圳分公司申请非公开发行普通股登记时使\n用,不应将其视为是对贵公司验资报告日后资本保全、偿债能力和持续经营能力等\n的保证。因使用本验资报告不当造成的后果,与执行本验资业务的注册会计师及会\n计师事务所无关。\n附件一 非公开发行普通股认购资金到位情况明细表\n附件二 验资事项说明\n附件三 普华永道中天会计师事务所(特殊普通合伙)营业执照\n附件四 普华永道中天会计师事务所(特殊普通合伙)执业证书\n附件五 普华永道中天会计师事务所(特殊普通合伙)证券相关业务许可证\n普华永道中天会计师事务所 注册会计师\n(特殊普通合伙) 姚文平\n中国•上海市 注册会计师\n2015年 5月7日 朱丽平\n2附件一\n非公开发行普通股认购资金到位情况明细表\n截至2015年5月 5日止\n被审验单位名称:平安银行股份有限公司\n货币单位:人民币元\n金额\n到位认购资金 9,999,999,996.50\n3附件二\n验资事项说明\n一、 基本情况\n平安银行股份有限公司(以下简称“贵公司”)是中国平安保险(集团)股份有限公司控股的\n一家跨区域经营的股份制商业银行,是原深圳发展银行股份有限公司以吸收合并原平\n安银行股份有限公司的方式完成两行整合并更名的银行,总部位于深圳。原深圳发展\n银行成立于 1987年 12月 22日,并于 1991年 4月 3日在深圳证券交易所上市(股票代\n码:000001)。\n贵公司经中国银行业监督管理委员会批准领有 00386413 号金融许可证,机构编码为\nB0014H144030001,深圳市工商行政管理局批准核发的 440301103098545 号《中华\n人民共和国企业法人执照》。贵公司注册资本为人民币 13,709,873,744 元,实收资本\n(股本)为人民币 13,709,873,744元,其中包括有限售条件股份 1,905,819,165 股,无限\n售条件股份 11,804,054,579 股。贵公司的上述实收资本(股本)已经普华永道中天会计\n师事务所(特殊普通合伙)审验,并已于 2015 年 4 月 13 日出具普华永道中天验字(2015)\n第321号验资报告。\n二、本次非公开发行普通股审批及情况说明\n于 2014 年 7 月 15 日第九届董事会第五次会议《平安银行股份有限公司关于非公开发\n行普通股方案的议案》,同意提议股东大会批准贵公司向境内合格投资者非公开发行\n不超过 1,070,663,811 股普通股。于 2014 年 8 月 4 日 2014 年第二次临时股东大会审\n议通过,批准了董事会的上述提议。中国证券监督管理委员会于 2015 年 4 月 22 日出\n具证监许可[2015]697 号文《关于核准平安银行股份有限公司非公开发行股票的批\n复》核准了贵公司向境内合格投资者非公开发行不超过1,070,663,811股普通股。\n三、 审验结果\n经我们审验,截至 2015 年 5 月 5 日止,贵公司已完成普通股 598,802,395 股的发\n行,每股发行价格为人民币 16.70 元,认购资金合计人民币 9,999,999,996.50 元,全\n部以人民币现金形式汇入,由发行主承销商中信证券股份有限公司代贵公司收缴,已\n全部存入主承销商中信证券股份有限公司于平安银行股份有限公司北京分行营业部开\n立的19014508950004银行账号内。\n4', # noqa: E501 + 'target': + '平安银行股份有限公司\n非公开发行普通股认购资金到位情况的验资报告非公开发行普通股认购资金到位情况的验资报告\n普华永道中天验字(2015)第446号\n(第一页,共二页)\n平安银行股份有限公司:\n平安银行股份有限公司(以下简称“贵公司”)委托中信证券股份有限公司作为主\n普通股认购资金的到位情况。按照国家相关法律、法规的规定以及认购协议、合同\n的要求出资认购,提供真实、合法、完整的验资资料,保护资产的安全、完整是贵\n公司管理层及中信证券股份有限公司的责任。我们的责任是对贵公司由中信证券股\n份有限公司代收取的境内合格投资者本次非公开发行普通股认购资金的到位情况发\n表审验意见。我们的审验是依据《中国注册会计师审计准则第 1602号——验资》进\n行的。在审验过程中,我们结合贵公司的实际情况,实施了检查等必要的审验程\n序。\n经贵公司 2014 年 7 月 15 日第九届董事会第五次会议提议,2014 年 8 月 4 日\n银行股份有限公司非公开发行股票的批复》,贵公司获准向境内合格投资者非公开\n发行不超过1,070,663,811股普通股。普华永道中天验字(2015)第446号\n(第二页,共二页)\n经我们审验,截至 2015 年 5 月 5 日止,贵公司以每股人民币 16.70 元合计向\n境内合格投资者非公开发行普通股 598,802,395 股,由发行主承销商中信证券股份\n有限公司代贵公司实际收到人民币 9,999,999,996.50元。所有认购资金均以人民币\n现金形式汇入。\n本验资报告仅供贵公司向中国证券监督管理委员会、深圳证券交易所报送资料\n及向中国证券登记结算有限责任公司深圳分公司申请非公开发行普通股登记时使\n用,不应将其视为是对贵公司验资报告日后资本保全、偿债能力和持续经营能力等\n的保证。因使用本验资报告不当造成的后果,与执行本验资业务的注册会计师及会\n计师事务所无关。\n2015年 5月7日 朱丽平\n2附件一\n非公开发行普通股认购资金到位情况明细表\n截至2015年5月 5日止\n被审验单位名称:平安银行股份有限公司\n货币单位:人民币元\n金额\n到位认购资金 9,999,999,996.50\n3附件二\n验资事项说明\n一、 基本情况\n平安银行股份有限公司(以下简称“贵公司”)是中国平安保险(集团)股份有限公司控股的\n一家跨区域经营的股份制商业银行,是原深圳发展银行股份有限公司以吸收合并原平\n安银行股份有限公司的方式完成两行整合并更名的银行,总部位于深圳。原深圳发展\n银行成立于 1987年 12月 22日,并于 1991年 4月 3日在深圳证券交易所上市(股票代\n码:000001)。\n(股本)为人民币 13,709,873,744元,其中包括有限售条件股份 1,905,819,165 股,无限\n售条件股份 11,804,054,579 股。贵公司的上述实收资本(股本)已经普华永道中天会计\n师事务所(特殊普通合伙)审验,并已于 2015 年 4 月 13 日出具普华永道中天验字(2015)\n第321号验资报告。\n二、本次非公开发行普通股审批及情况说明\n于 2014 年 7 月 15 日第九届董事会第五次会议《平安银行股份有限公司关于非公开发\n行普通股方案的议案》,同意提议股东大会批准贵公司向境内合格投资者非公开发行\n不超过 1,070,663,811 股普通股。于 2014 年 8 月 4 日 2014 年第二次临时股东大会审\n议通过,批准了董事会的上述提议。中国证券监督管理委员会于 2015 年 4 月 22 日出\n具证监许可[2015]697 号文《关于核准平安银行股份有限公司非公开发行股票的批\n复》核准了贵公司向境内合格投资者非公开发行不超过1,070,663,811股普通股。\n三、 审验结果\n经我们审验,截至 2015 年 5 月 5 日止,贵公司已完成普通股 598,802,395 股的发\n行,每股发行价格为人民币 16.70 元,认购资金合计人民币 9,999,999,996.50 元,全\n部以人民币现金形式汇入,由发行主承销商中信证券股份有限公司代贵公司收缴,已\n全部存入主承销商中信证券股份有限公司于平安银行股份有限公司北京分行营业部开\n立的19014508950004银行账号内。\n4' # noqa: E501 + }] + + self._run_remove_header(samples) + + def test_false_negative_case(self): + samples = [{ + 'text': + '深圳发展银行股份有限公司关于前次募集资金使用情况\n的专项说明\n一、前次募集资金的数额和资金到位时间\n本行经中国证券监督管理委员会证监公司字[2000]154 号文批准,于 2000\n年在1998年末总股本1,551,847,092股的基础上,以每10股配3股的比例向全\n体股东配售新股。根据本次实际配股认购情况及深圳发展银行和主承销商海通证\n券有限公司签定的配股承销协议,本次实际配股总数为393,975,057股,每股配\n股价 8.00 元,实收配股资金 3,151,800,456 元,扣除与本次配股相关的费用\n61,915,333.08 元,实际配股募集资金为 3,089,885,122.92 元,于 2000 年 11\n月 24 日全部到位并经深圳鹏城会计师事务所(深鹏所验字[2000]163 号验资报\n告)验证。\n二、 截至2006年12月31日前次募集资金的实际使用情况\n根据配股说明书承诺的募集资金使用计划,截至2006年12月31日止,前\n次募集资金全部用于充实本行资本金和拓展分支机构。前次募集资金到位后,增\n加了本行的股本及资本公积金共计3,089,885,122.92元,充实了本行的资本金。\n配股说明书中承诺的拓展分支机构情况及其相关的营运资金投入情况和实\n际情况比较如下:\n1、配股说明书承诺的项目和实际营运投入情况比较:(见下表及说明)配股说明书承诺的项目和实际营运投入情况比较表\n金额单位:人民币元\n配股说明书承诺 实际投入营运资金\n项目 计划投入 2000年投入 2001年投入 2002年投入 合计\n1、 拓展分支机构,\n拨付新设机构营运资金\n① 北京分行 495,000,000.00 200,000,000.00 295,000,000.00 - 495,000,000.00\n② 大连分行 440,000,000.00 100,000,000.00 340,000,000.00 - 440,000,000.00\n③ 重庆分行 500,000,000.00 100,000,000.00 400,000,000.00 - 500,000,000.00\n④ 南京分行 430,000,000.00 100,000,000.00 330,000,000.00 - 430,000,000.00\n⑤ 青岛分行 500,000,000.00 - 100,159,277.60 399,840,722.40 500,000,000.00\n小计 2,365,000,000.00 500,000,000.00 1,465,159,277.60 399,840,722.40 2,365,000,000.00\n2、充实资本金及营运资金\n① 总行 215,000,000.00 224,885,122.92 - - 224,885,122.92\n② 广州分行 300,000,000.00 100,000,000.00 200,000,000.00 - 300,000,000.00\n③ 上海分行 200,000,000.00 100,000,000.00 100,000,000.00 - 200,000,000.00\n小计 715,000,000.00 424,885,122.92 300,000,000.00 - 724,885,122.92\n合计 3,080,000,000.00 924,885,122.92 1,765,159,277.60 399,840,722.40 3,089,885,122.92\n注:截至2002年12月31日,前次募股资金已按配股说明书承诺的项目及用途全部投入使用完毕,截至2006年12月31日,前次募股资金较2002年12\n月31日的使用情况没有发生变化。上表所列表明,截至2006年12月31日,本行前次募集资金已按配股说明\n书承诺,全部用于充实资本金;实际募集资金大于计划使用资金的差额\n9,885,122.92 元,用于补充总行资本金。前次募集资金在未拨付上述分支行营\n运资金前,用作充实总行资本金用途。\n本行配股说明书中承诺的拓展分支机构情况及其相关的营运资金投入事项\n已全部完成。\n2、项目实际完成程度及效益情况\n北京、大连、重庆、广州及上海分行在配股前已开业,青岛及南京分行在\n筹建阶段,经人民银行批准,南京分行于 2000 年 11 月正式营业,青岛分行于\n2002年3月正式营业(2001年拨付筹办资金100,159,277.60元)。\n本行未在2000年度的配股说明书中进行专项的盈利预测。募集资金投入各\n分行及补充总行营运资金所产生直接效益亦无法单独核算列示。\n本次配股募集资金3,089,885,122.92元,全部用作充实资本金,提高了本\n行资本充足率,占2006年12月31日本行净资产的48%。\n本行前次募集资金对于本行前次募集资金后的生存与发展以及效益方面有\n直接和关键的作用。\n3、募集资金实际使用情况和本行其他信息披露文件披露内容比较\n2000 年度配股事项完成后,本行已在 2000 年 12 月 6 日在《中国证券报》\n及《证券时报》发布公告,本行的2000年至2006年年度报告及其他定期报告中\n关于前次募集资金的信息披露和实际使用情况相符。\n特此报告。\n深圳发展银行股份有限公司董事会\n2007年6月1日', # noqa: E501 + 'target': + '深圳发展银行股份有限公司关于前次募集资金使用情况\n的专项说明\n一、前次募集资金的数额和资金到位时间\n本行经中国证券监督管理委员会证监公司字[2000]154 号文批准,于 2000\n年在1998年末总股本1,551,847,092股的基础上,以每10股配3股的比例向全\n体股东配售新股。根据本次实际配股认购情况及深圳发展银行和主承销商海通证\n券有限公司签定的配股承销协议,本次实际配股总数为393,975,057股,每股配\n股价 8.00 元,实收配股资金 3,151,800,456 元,扣除与本次配股相关的费用\n61,915,333.08 元,实际配股募集资金为 3,089,885,122.92 元,于 2000 年 11\n月 24 日全部到位并经深圳鹏城会计师事务所(深鹏所验字[2000]163 号验资报\n告)验证。\n二、 截至2006年12月31日前次募集资金的实际使用情况\n根据配股说明书承诺的募集资金使用计划,截至2006年12月31日止,前\n次募集资金全部用于充实本行资本金和拓展分支机构。前次募集资金到位后,增\n加了本行的股本及资本公积金共计3,089,885,122.92元,充实了本行的资本金。\n配股说明书中承诺的拓展分支机构情况及其相关的营运资金投入情况和实\n际情况比较如下:\n1、配股说明书承诺的项目和实际营运投入情况比较:(见下表及说明)配股说明书承诺的项目和实际营运投入情况比较表\n金额单位:人民币元\n配股说明书承诺 实际投入营运资金\n项目 计划投入 2000年投入 2001年投入 2002年投入 合计\n1、 拓展分支机构,\n拨付新设机构营运资金\n小计 2,365,000,000.00 500,000,000.00 1,465,159,277.60 399,840,722.40 2,365,000,000.00\n2、充实资本金及营运资金\n注:截至2002年12月31日,前次募股资金已按配股说明书承诺的项目及用途全部投入使用完毕,截至2006年12月31日,前次募股资金较2002年12\n月31日的使用情况没有发生变化。上表所列表明,截至2006年12月31日,本行前次募集资金已按配股说明\n书承诺,全部用于充实资本金;实际募集资金大于计划使用资金的差额\n9,885,122.92 元,用于补充总行资本金。前次募集资金在未拨付上述分支行营\n运资金前,用作充实总行资本金用途。\n本行配股说明书中承诺的拓展分支机构情况及其相关的营运资金投入事项\n已全部完成。\n2、项目实际完成程度及效益情况\n北京、大连、重庆、广州及上海分行在配股前已开业,青岛及南京分行在\n筹建阶段,经人民银行批准,南京分行于 2000 年 11 月正式营业,青岛分行于\n2002年3月正式营业(2001年拨付筹办资金100,159,277.60元)。\n本行未在2000年度的配股说明书中进行专项的盈利预测。募集资金投入各\n分行及补充总行营运资金所产生直接效益亦无法单独核算列示。\n本次配股募集资金3,089,885,122.92元,全部用作充实资本金,提高了本\n行资本充足率,占2006年12月31日本行净资产的48%。\n本行前次募集资金对于本行前次募集资金后的生存与发展以及效益方面有\n直接和关键的作用。\n3、募集资金实际使用情况和本行其他信息披露文件披露内容比较\n2000 年度配股事项完成后,本行已在 2000 年 12 月 6 日在《中国证券报》\n及《证券时报》发布公告,本行的2000年至2006年年度报告及其他定期报告中\n关于前次募集资金的信息披露和实际使用情况相符。\n特此报告。\n深圳发展银行股份有限公司董事会\n2007年6月1日' # noqa: E501 + }] + + self._run_remove_header(samples) + + def test_long_text_case(self): + samples = [{ + 'text': + '证券代码:000001 证券简称:平安银行 公告编号:2019-024\n优先股代码:140002 优先股简称:平银优01\n可转债代码:127010 可转债简称:平银转债\n平安银行股份有限公司\n2019年第一季度报告\n二〇一九年四月二十四日第一节 重要提示\n一、本行董事会、监事会及董事、监事、高级管理人员保证本报告内容的真实、准确、完整,\n不存在虚假记载、误导性陈述或者重大遗漏,并承担个别和连带的法律责任。\n二、本行第十届董事会第二十四次会议审议了 2019 年第一季度报告。本次董事会会议应出席\n董事14人,实到董事14人。会议一致同意此报告。\n三、本行董事长谢永林、行长胡跃飞、首席财务官项有志、会计机构负责人朱培卿保证 2019\n年第一季度报告中财务报表的真实、准确、完整。\n四、非标准审计意见提示\n□适用 √不适用\n本行本季度财务报表未经审计,普华永道中天会计师事务所(特殊普通合伙)对个别项目及财\n务报表编制流程执行了商定程序。\n1第二节 公司基本情况\n一、主要会计数据和财务指标\n公司是否需追溯调整或重述以前年度会计数据\n□是 √否\n本行于2019年1月1日起实施《企业会计准则第21号——租赁》(财会〔2018〕35号),并自\n2019年第一季度报告起按上述会计准则要求进行会计报表披露。根据衔接规定,本行按首次执行本\n会计准则的累积影响数,调整财务报表相关项目金额,不调整可比期间信息,相关影响请参阅《平\n安银行股份有限公司关于会计政策变更的公告》。\n(货币单位:人民币百万元)\n期末比上年末\n项 目 2019年3月31日 2018年12月31日\n增减\n资产总额 3,530,180 3,418,592 3.3%\n股东权益 250,938 240,042 4.5%\n归属于普通股股东的股东权益 230,985 220,089 5.0%\n股本 17,170 17,170 -\n归属于普通股股东的每股净资产(元/股) 13.45 12.82 5.0%\n项 目 2019年1-3月 2018年1-3月 同比增减\n营业收入 32,476 28,026 15.9%\n归属于本公司股东的净利润 7,446 6,595 12.9%\n扣除非经常性损益后归属于本公司股东的净利润 7,422 6,555 13.2%\n经营活动产生的现金流量净额 53,184 41,442 28.3%\n每股经营活动产生的现金流量净额(元/股) 3.10 2.41 28.6%\n基本每股收益(元/股) 0.38 0.33 15.2%\n稀释每股收益(元/股) 0.36 0.33 9.1%\n扣除非经常性损益后的基本每股收益(元/股) 0.38 0.33 15.2%\n平均总资产收益率(未年化) 0.21% 0.20% 0.01个百分点\n平均总资产收益率(年化) 0.86% 0.80% 0.06个百分点\n加权平均净资产收益率(未年化) 2.91% 2.79% 0.12个百分点\n加权平均净资产收益率(年化) 12.15% 11.87% 0.28个百分点\n扣除非经常性损益后的加权平均净资产收益率\n2.90% 2.77% 0.13个百分点\n(未年化)\n扣除非经常性损益后的加权平均净资产收益率\n12.11% 11.80% 0.31个百分点\n(年化)\n注:(1)本行于2016年3月7日非公开发行200亿元非累积型优先股,在计算“每股收益”及“加权平均净资产收\n2益率”时,分子均扣减了已发放的优先股股息。\n(2)本行于2019年1月25日发行260亿元A股可转换公司债券,在计算“稀释每股收益”时已考虑可转换公\n司债券转为普通股的稀释性影响。\n(货币单位:人民币百万元)\n项 目 2019年3月31日 2018年12月31日 2017年12月31日 期末比上年末增减\n吸收存款本金 2,286,977 2,128,557 2,000,420 7.4%\n其中:企业存款 1,765,403 1,666,966 1,659,421 5.9%\n个人存款 521,574 461,591 340,999 13.0%\n发放贷款和垫款本金总额 2,051,445 1,997,529 1,704,230 2.7%\n其中:企业贷款 860,913 843,516 855,195 2.1%\n一般企业贷款 798,940 801,814 840,439 (0.4%)\n贴现 61,973 41,702 14,756 48.6%\n个人贷款 1,190,532 1,154,013 849,035 3.2%\n一般个人贷款 698,225 680,718 545,407 2.6%\n信用卡应收账款 492,307 473,295 303,628 4.0%\n注:(1)根据《中国人民银行关于调整金融机构存贷款统计口径的通知》(银发〔2015〕14号),从2015年开始,非\n存款类金融机构存放在存款类金融机构的款项纳入“各项存款”、存款类金融机构拆放给非存款类金融机构的款\n项纳入“各项贷款”统计口径。按此统计口径,2019年3月31日的各项存款为26,027亿元,各项贷款为21,101\n亿元。\n(2)根据财政部《关于修订印发2018年度金融企业财务报表格式的通知》(财会〔2018〕36号)的规定,基于\n实际利率法计提的利息计入金融工具账面余额中,于资产负债表日尚未收到或尚未支付的利息在“其他资产”或\n“其他负债”列示。除非特别说明,本报告提及的“发放贷款和垫款”、“吸收存款”及其明细项目均为不含息\n金额。\n截至披露前一交易日的公司总股本\n截至披露前一交易日的公司总股本(股) 17,170,411,366\n用最新股本计算的全面摊薄每股收益(元/股,1-3月累计) 0.38\n报告期末至季度报告披露日股本是否因发行新股、增发、配股、股权激励行权、回购等原因发\n生变化且影响所有者权益金额\n□是 √否\n非经常性损益项目和金额\n√适用 □不适用\n3(货币单位:人民币百万元)\n项 目 2019年1-3月\n非流动性资产处置净损益 12\n其他 19\n所得税影响 (7)\n合 计 24\n注:非经常性损益根据证监会《公开发行证券的公司信息披露解释性公告第1号——非经常性损益》的定义计算。\n本行报告期不存在将根据《公开发行证券的公司信息披露解释性公告第1号——非经常性损益》\n定义、列举的非经常性损益项目界定为经常性损益的项目的情形。\n二、监管指标和财务比率\n(单位:%)\n项 目 标准值 2019年3月31日 2018年12月31日 2017年12月31日\n资本充足率 ≥10.5 11.50 11.50 11.20\n一级资本充足率 ≥8.5 9.59 9.39 9.18\n核心一级资本充足率 ≥7.5 8.75 8.54 8.28\n不良贷款率 ≤5 1.73 1.75 1.70\n拨备覆盖率 ≥150 170.59 155.24 151.08\n拨贷比 ≥2.5 2.94 2.71 2.57\n成本收入比(年初至期末) 不适用 29.62 30.32 29.89\n存贷差(年初至期末,年化/未年化) 不适用 4.09/1.01 4.03 3.99\n净利差(年初至期末,年化/未年化) 不适用 2.44/0.60 2.26 2.20\n净息差(年初至期末,年化/未年化) 不适用 2.53/0.62 2.35 2.37\n注:监管指标根据监管口径列示。\n三、报告期末股东总数及前十名股东持股情况表\n1、普通股股东和表决权恢复的优先股股东总数及前10名股东持股情况表\n(单位:股)\n报告期末表决权恢复的\n报告期末普通股股东总数(户) 354,508 -\n优先股股东总数(如有)\n前10名普通股东持股情况\n质押或冻结\n持有有限售条 情况\n股东名称 股东性质 持股比例(%) 持股总数\n件股份数量 股份\n数量\n状态\n中国平安保险(集团)股份有限公\n境内法人 49.56 8,510,493,066 - - -\n司-集团本级-自有资金\n4中国平安人寿保险股份有限公司\n境内法人 6.11 1,049,462,784 - - -\n-自有资金\n香港中央结算有限公司 境外法人 4.17 716,660,451 - - -\n中国证券金融股份有限公司 境内法人 2.50 429,232,688 - - -\n中国平安人寿保险股份有限公司\n境内法人 2.27 389,735,963 - - -\n-传统-普通保险产品\n中央汇金资产管理有限责任公司 境内法人 1.26 216,213,000 - - -\n深圳中电投资股份有限公司 境内法人 0.83 142,402,769 - - -\n河南鸿宝集团有限公司 境内法人 0.59 100,453,307 - - -\n中信信托有限责任公司-中信信\n托锐进 43 期高毅晓峰投资集合资 境内法人 0.39 66,346,066 - - -\n金信托计划\n上海高毅资产管理合伙企业(有限\n境内法人 0.29 49,221,916 - - -\n合伙)-高毅晓峰2号致信基金\n前10名无限售条件股东持股情况\n持有无限售条 股份种类\n股东名称\n件股份数量 股份种类 数量\n中国平安保险(集团)股份有限公司-集团本级-自有资金 8,510,493,066 人民币普通股 8,510,493,066\n中国平安人寿保险股份有限公司-自有资金 1,049,462,784 人民币普通股 1,049,462,784\n香港中央结算有限公司 716,660,451 人民币普通股 716,660,451\n中国证券金融股份有限公司 429,232,688 人民币普通股 429,232,688\n中国平安人寿保险股份有限公司-传统-普通保险产品 389,735,963 人民币普通股 389,735,963\n中央汇金资产管理有限责任公司 216,213,000 人民币普通股 216,213,000\n深圳中电投资股份有限公司 142,402,769 人民币普通股 142,402,769\n河南鸿宝集团有限公司 100,453,307 人民币普通股 100,453,307\n中信信托有限责任公司-中信信托锐进 43 期高毅晓峰投资集合资\n66,346,066 人民币普通股 66,346,066\n金信托计划\n上海高毅资产管理合伙企业(有限合伙)-高毅晓峰2号致信基金 49,221,916 人民币普通股 49,221,916\n1、中国平安人寿保险股份有限公司为中国平安保险(集团)股份有限公司控股子公司和\n一致行动人,“中国平安保险(集团)股份有限公司-集团本级-自有资金”、“中国平安\n上述股东关联关系或一致行\n人寿保险股份有限公司-自有资金”与“中国平安人寿保险股份有限公司-传统-普通保\n动的说明\n险产品”具有关联关系。\n2、本行未知其他股东间的关联关系,也未知其是否属于一致行动人。\n前10名普通股股东参与融\n资融券业务股东情况说明 无\n(如有)\n公司前10名普通股股东、前10名无限售条件普通股股东在报告期内是否进行约定购回交易\n□是 √否\n52、优先股股东总数及前10名优先股股东持股情况表\n√适用 □不适用\n(单位:股)\n报告期末优先股股东总数(户) 15\n前10名优先股股东持股情况\n持股比例 持有有限售条 质押或冻结情况\n股东名称 股东性质 持股数量\n(%) 件的股份数量 股份状态 数量\n中国平安人寿保险股份有限公司-\n境内法人 29.00 58,000,000 - - -\n分红-个险分红\n中国平安人寿保险股份有限公司-\n境内法人 19.34 38,670,000 - - -\n万能-个险万能\n中国平安财产保险股份有限公司-\n境内法人 9.67 19,330,000 - - -\n传统-普通保险产品\n中邮创业基金-华夏银行-华夏银\n境内法人 8.95 17,905,000 - - -\n行股份有限公司\n交银施罗德资管-交通银行-交通\n境内法人 8.95 17,905,000 - - -\n银行股份有限公司\n中国银行股份有限公司上海市分行 境内法人 4.47 8,930,000 - - -\n中国邮政储蓄银行股份有限公司 境内法人 2.98 5,950,000 - - -\n华润深国投信托有限公司-投资1\n境内法人 2.98 5,950,000 - - -\n号单一资金信托\n华宝信托有限责任公司-投资2号\n境内法人 2.98 5,950,000 - - -\n资金信托\n招商财富-邮储银行-中国邮政储\n境内法人 2.98 5,950,000 - - -\n蓄银行股份有限公司\n1、中国平安人寿保险股份有限公司和中国平安财产保险股份有限公司为中国平安\n保险(集团)股份有限公司控股子公司和一致行动人,“中国平安人寿保险股份有\n上述股东关联关系或一致行动的说\n限公司-分红-个险分红”、“中国平安人寿保险股份有限公司-万能-个险万能”\n明\n与“中国平安财产保险股份有限公司-传统-普通保险产品”具有关联关系。\n2、本行未知其他股东间的关联关系,也未知其是否属于一致行动人。\n注:(1)本行已发行的优先股不设限售期,均为无限售条件优先股;\n(2)本行无表决权恢复的优先股股东。\n63、前10名可转换公司债券持有人情况\n持有人名称 持有数量(张) 占比(%)\n中国平安保险(集团)股份有限公司-集团本级-自有资金 128,865,886 49.56\n中国平安人寿保险股份有限公司-自有资金 15,890,965 6.11\n中国平安人寿保险股份有限公司-传统-普通保险产品 5,901,382 2.27\n招商银行股份有限公司-兴全合宜灵活配置混合型证券投资基金 5,469,333 2.10\n全国社保基金四一四组合 3,473,563 1.34\n深圳中电投资股份有限公司 2,156,263 0.83\n广发证券股份有限公司 2,016,823 0.78\n中国石油天然气集团公司企业年金计划-中国工商银行股份有限公司 1,783,837 0.69\n国元证券股份有限公司 1,590,000 0.61\n中国银河证券股份有限公司 1,583,595 0.61\n注:(1)中国平安人寿保险股份有限公司为中国平安保险(集团)股份有限公司控股子公司和一致行动人,“中国\n平安保险(集团)股份有限公司-集团本级-自有资金”、“中国平安人寿保险股份有限公司-自有资金”与“中国\n平安人寿保险股份有限公司-传统-普通保险产品”具有关联关系;\n(2)本行未知其他股东间的关联关系,也未知其是否属于一致行动人。\n7第三节 重要事项\n一、 报告期主要财务数据、财务指标变动30%以上的情况及原因\n√适用 □不适用\n(货币单位:人民币百万元)\n项 目 本期金额 变动金额 变动比率 变动原因分析\n拆出资金 100,231 27,297 37.4% 拆放境内、外同业款项增加\n分类为以公允价值计量且其变动计入其他\n其他债权投资 111,039 40,375 57.1%\n综合收益的债券投资、同业投资规模增加\n租赁使用权资产 7,239 7,239 上年末为零 本年根据租赁准则新增该报表项目\n其他资产 23,284 9,506 69.0% 应收清算款项等增加\n交易性金融负债 15,721 7,146 83.3% 债券卖空业务导致交易性金融负债增加\n卖出回购金融资产款项 21,502 13,514 169.2% 卖出回购债券规模增加\n预计负债 1,271 411 47.8% 对或有事项、财务担保合同计提损失准备\n租赁负债 7,700 7,700 上年末为零 本年根据租赁准则新增该报表项目\n其他综合收益 1,385 599 76.2% 其他债权投资公允价值变动增加\n手续费及佣金支出 2,019 666 49.2% 信用卡业务量增长带来手续费支出增加\n自2018年三季度起根据新金融工具会计准\n则要求,将“以公允价值计量且其变动计入\n投资收益 2,571 2,155 518.0%\n当期损益的金融工具”产生的收入/支出从\n利息收入/支出计入投资收益\n公允价值变动损益 (642) (1,015) (272.1%) 交易性金融资产公允价值变动损益减少\n汇兑损益 169 347 上年同期为负 汇率波动导致汇兑损益增加\n其他业务收入 17 (59) (77.6%) 基期数小,上年同期为0.76亿元\n资产处置损益 12 (32) (72.7%) 基期数小,上年同期为0.44亿元\n计提的抵债资产减值损失增加;基期数小,\n其他资产减值损失 231 224 3,200.0%\n上年同期为0.07亿元\n营业外收入 29 24 480.0% 基期数小,上年同期为0.05亿元\n二、重要事项进展情况及其影响和解决方案的分析说明\n√适用 □不适用\n2019年1月25日,本行完成公开发行260亿元可转换公司债券。2019年1月30日,本行收到\n中国证券登记结算有限责任公司深圳分公司出具的《证券登记证明》。经深圳证券交易所批准,本\n行公开发行可转换公司债券已于2019年2月18日起在深圳证券交易所挂牌交易。\n2019年3月7日和4月10日,本行全额赎回规模为人民币90亿元的10年期二级资本债券和\n规模为人民币60亿元的10年期二级资本债券。\n8重要事项概述 披露日期 临时报告披露网站查询索引\n本行公开发行人民币260亿元A股可转 《中国证券报》、《证券时报》、《上\n2019年1月17日、2019年1月25日\n换公司债券 海证券报》、《证券日报》和巨潮资讯\n本行对二级资本债券行使赎回选择权 2019年3月9日、2019年4月12日 网(www.cninfo.com.cn)\n股份回购的实施进展情况\n□适用 √不适用\n采用集中竞价方式减持回购股份的实施进展情况\n□适用 √不适用\n三、公司实际控制人、股东、关联方、收购人以及公司等承诺相关方在报告期内超期未履行完\n毕的承诺事项\n□适用 √不适用\n公司报告期不存在公司实际控制人、股东、关联方、收购人以及公司等承诺相关方在报告期内\n超期未履行完毕的承诺事项。\n四、对2019年1-6月经营业绩的预计\n预测年初至下一报告期期末的累计净利润可能为亏损或者与上年同期相比发生大幅度变动的\n警示及原因说明\n□适用 √不适用\n五、证券投资情况\n报告期末,本行所持金融债券(政策性银行债、各类普通金融债、次级金融债,不含企业债)\n账面价值为1,633.17亿元,其中前十大面值金融债券的有关情况如下:\n(货币单位:人民币百万元)\n债券名称 面值 票面年利率(%) 到期日 减值准备\n2017年政策性银行债券 6,710 4.24 2027/08/24 -\n2018年政策性银行债券 3,990 4.69 2023/03/23 -\n2010年政策性银行债券 3,860 2.84 2020/02/25 -\n2017年政策性银行债券 3,650 4.44 2022/11/09 -\n2019年商业银行债券 3,050 3.50 2022/03/27 5.63\n2011年政策性银行债券 3,030 3.85 2021/02/17 -\n2016年商业银行债券 3,000 3.25 2021/03/07 2.34\n2016年商业银行债券 3,000 3.20 2021/03/29 2.34\n2016年政策性银行债券 2,780 2.96 2021/02/18 -\n2015年政策性银行债券 2,570 4.21 2025/04/13 -\n9六、委托理财\n□适用 √不适用\n报告期内,本行未发生正常业务范围之外的委托理财事项。\n七、衍生品投资情况\n(货币单位:人民币百万元)\n年初合约金额 期末合约金额 报告期公允价值\n合约种类\n(名义金额) (名义金额) 变动情况\n外汇衍生工具 874,747 1,013,726 411\n利率衍生工具 3,168,549 3,383,066 130\n贵金属衍生工具 84,071 87,477 654\n合计 4,127,367 4,484,269 1,195\n注:本行在董事会确立的风险偏好和市场风险总体限额框架内,开展包括衍生品的资金交易和投资业务。衍生\n品金融工具名义金额只体现交易量,并不反映其实际风险暴露。本行开展的外汇及利率衍生品业务主要采取对冲策\n略,实际汇率及利率风险暴露较小。\n八、报告期内接待调研、沟通、采访等活动情况\n接待时间 接待方式 接待对象类型 调研的基本情况索引\n2019/01/10 投行会议 机构 巨潮资讯网\n(www.cninfo.com.cn)\n2019/01/24 实地调研 机构\n《平安银行股份有限公司\n2019/03/25 实地调研 机构 投资者关系活动记录表》\n九、违规对外担保情况\n□适用 √不适用\n公司报告期无违规对外担保情况。\n十、控股股东及其关联方对上市公司的非经营性占用资金情况\n□适用 √不适用\n公司报告期不存在控股股东及其关联方对上市公司的非经营性占用资金。\n10十一、管理层讨论与分析\n(一)总体情况\n2019 年是新中国成立 70 周年,也是决胜全面建成小康社会的关键之年。正确把握金融本质,\n深化金融供给侧结构性改革,平衡好稳增长和防风险的关系,增强金融服务实体经济能力,打好防\n范化解包括金融风险在内的重大风险攻坚战,推动金融业高质量发展,是全年金融工作的重中之重。\n2019 年一季度,宏观经济运行总体平稳,供给侧结构性改革持续深化,“三大攻坚战”稳步推进,\n经济发展新动力进一步形成。\n本行紧跟国家战略,顺应经济金融形势,坚持以打造“中国最卓越、全球领先的智能化零售银\n行”为目标,持续深化“科技引领、零售突破、对公做精”策略方针,不断加强科技创新和场景应\n用,坚定推进智能化零售业务转型,持续深化对公业务从粗放型增长向效益型发展转型,严控各类\n金融风险,不断加大对民营企业、小微企业等实体经济的支持力度,各项业务稳健发展,资产质量\n持续改善,战略转型成效进一步显现。\n1、整体经营稳中趋好\n2019 年一季度,本行实现营业收入 324.76 亿元,同比增长 15.9%;其中,利息净收入 207.74\n亿元,同比增长11.2%;非利息净收入117.02亿元,同比增长25.3%。减值损失前营业利润225.52\n亿元,同比增长 17.1%;净利润 74.46 亿元,同比增长 12.9%;2019 年一季度净息差 2.53%,同比\n上升28个基点、环比上升3个基点,非利息净收入占比36.0%,同比上升2.7个百分点。\n2019 年 3 月末,本行资产总额 35,301.80 亿元,较上年末增长 3.3%;吸收存款余额 22,869.77\n亿元,较上年末增长7.4%;发放贷款和垫款总额(含贴现)20,514.45亿元,较上年末增长2.7%。\n2019 年 3 月末,本行保本理财产品余额 941.83 亿元、较上年末增长 13.9%,结构性存款余额\n4,848.06亿元、较上年末增长11.8%,非保本理财产品余额5,584.61亿元、较上年末增长3.8%。\n2、零售转型稳健前行\n2019年一季度,本行贯彻“零售突破”的策略,继续深化综合金融优势,重点发力“基础零售、\n消费金融、私行财富”三大业务模块,提升“风险和成本控制”两大核心能力,充分发挥科技赋能\n作用,优化业务流程与服务体验,转型工作稳健进行。\n(1)基础零售\n2019 年 3 月末,本行管理零售客户资产(AUM)16,632.98 亿元、较上年末增长 17.4%,零售\n客户数8,701万户、较上年末增长3.7%。个人存款余额5,215.74亿元、较上年末增长13.0%,平安\n口袋银行APP注册客户数6,765万户,较上年末增长8.7%;月活客户数2,603万户。\n2019年一季度,本行持续聚焦基础零售客户获客及经营,通过场景化、科技化手段,利用不同\n的互联网场景打造多种获客路径,坚持科技赋能和大数据驱动客户分析及经营策略的制定,促进获\n客及客户经营效率与产能提升。在获客方面,本行依托集团生态圈,深挖场景价值,持续优化口袋\n银行APP功能体验,构建银行线上生态圈,实现从用户到客户的转化;在经营方面,搭建智能运营\n平台,围绕客户生命周期,构造“客群-渠道”智能化产品推荐体系,并结合内外部资源构建 KYC\n11(充分了解客户)体系,实现大数据赋能客群经营,提升客户经营能力及管理效率。\n(2)私行财富\n2019 年 3 月末,本行财富客户 66.45 万户、较上年末增长 12.3%;私行达标客户 3.5 万户、较\n上年末增长16.7%。\n2019年一季度,本行在私人银行和财富管理方面,强化落实经营转型策略,在产品、体系和机\n制三方面积极推进战略落地。在产品平台和权益体系搭建上,本行充分整合内外部资源,持续供应\n优质产品和权益体验,支持私财业务发展。同时,大力推动投顾团队建设,充分利用 AI 科技力量\n和平安集团综合金融模式的优势,打造一支专业化、智能化的投顾团队。在内部管理体系方面,运\n作产品管理委员会机制,多元化引入优质资产;强化内嵌风险团队,严格把控资产风险;在机制建\n设方面,“分支行”、“直营”和“综合金融”三大业务模式并驾齐驱,强化私行客户经营能力和资\n产配置能力。\n(3)消费金融\n2019年3月末,个人贷款余额11,905.32亿元、较上年末增长3.2%。2018年末以来,鉴于宏观\n经济仍存下行压力,在确保资产质量稳定的前提下,本行主动优化贷款类产品投放策略,适当提高\n信用卡、贷款投放门槛,推动目标客群上移。同时,本行根据市场需求,持续优化客群画像、产品\n方案、流程、客户体验,覆盖更多场景,满足客户不同阶段消费及经营融资需求。目前,多项举措\n正在内部逐渐深化、生效,预计下半年起,个人贷款及信用卡业务主要规模指标增长将出现积极的\n变化。在优化投向配置的同时,本行着力支持小微企业、促进实体经济发展。\n2019年3月末,本行信用卡流通卡量5,352万张,较上年末增长3.9%;信用卡贷款余额4,923.07\n亿元,较上年末增长4.0%。2019年一季度,信用卡总交易金额8,010.30亿元,同比增长43.2%;信\n用卡商城交易量同比增长24.7%;2019年3月末,信用卡APP客户数已突破3,700万户。本行信用\n卡继续依托互联网平台深入打造“快、易、好”极致客户体验,并加强线上线下融合:线上完善信\n用卡商城服务平台,深入精细化运营,提升用户价值感;线下搭建精准营销服务平台,针对性地推\n出适合不同客群的精准营销活动,打造围绕不同客户偏好的差异化营销策略,提升客户不同用卡场\n景的消费体验。\n2019年3月末,本行“新一贷”余额1,551.90亿元,较上年末增长0.9%。“新一贷”借助科技\n力量赋能贷前、贷中、贷后全流程,全力打造“SAT+T”全线上申请流程模式,并通过嵌入数据直\n连、机器人流程自动化(RPA)、微表情等AI技术,实现在线智能“核身+核审”,不仅大幅简化客\n户申请所需材料、节约客户申请时间,为客户提供良好的移动金融交互体验,也积极利用互联网技\n术持续加强信贷全流程的防风险、反欺诈能力建设。\n2019 年 3 月末,本行住房按揭贷款余额 1,861.37 亿元,较上年末增长 2.1%。本行严格落实国\n家政策规定和监管要求,支持居民家庭首套自住购房需求,并将继续在合规前提下稳步开展住房信\n贷业务。\n2019 年 3 月末,本行汽车金融贷款余额 1,699.21 亿元,较上年末降幅 1.2%。在汽车消费市场\n12整体增速放缓的趋势下,本行不断通过产品创新、流程优化、AI智能化等多项举措,深耕车生态经\n营策略,在不断提升客户体验的同时,进一步挖潜汽车消费金融市场。\n(4)风险控制\n2019年3月末,本行个人贷款不良率1.10%,较上年末上升0.03个百分点。受到宏观经济下行、\n共债风险上升、汽车消费下滑等外部因素的影响,消费金融全行业的风险都有所上升,本行零售产\n品的不良率也略有上升,但整体风险表现仍维持在相对较低水平。其中信用卡不良率 1.34%,较上\n年末上升0.02个百分点;“新一贷”不良率1.14%,较上年末上升0.14个百分点;汽车金融业务的\n不良率 0.62%,较上年末上升 0.08 个百分点。本行自 2017 年底开始提前进行风险政策调整,重点\n防范共债风险,有效控制并降低了高风险客户占比,新发放业务的资产质量较好,预计这些管制措\n施的优化效应将会在 2019 年下半年逐渐展现,零售主要贷款的不良率下半年将呈现稳中向好的趋\n势。\n(5)科技平台赋能\n2019 年一季度本行持续加大科技投入,在线上进一步升级迭代,嵌入口袋银行 APP 和口袋银\n行家 APP;在线下持续复制推广“轻型化、社区化、智能化、多元化”的零售新门店,截至 2019\n年3月末,全国已开业177家新门店。同时,本行整合打造智能化OMO(Online Merge Offline,线\n上线下相融合)服务体系,通过综合化、场景化、个性化让客户能在线上线下无缝切换,为客户带\n来了更好的金融生活体验。\n2019 年,本行零售推行全面 AI 化,通过科技手段将零售业务在经营、服务、管理方面所需的\n能力标准化、系统化、智能化,并赋能给客户、员工以及第三方合作伙伴,打造全方位的AI Bank。\n目前已孵化出AI客服、AI营销和AI风控等项目成果,其中,2019年3月末AI客服非人工占比已\n达82.2%,较上年末提升2.2个百分点;AI营销已普遍应用于各业务场景中,销售转化效果较常规\n手段提升最高达3倍,并已全面赋能一线队伍,使其具备随时随地为客户提供专业服务的能力;AI\n风控已搭建了统一风控平台,实现了信用风险、欺诈风险的统一管控,并进一步提升了零售信贷产\n品的自动化审核能力,2019年一季度,“AI+风控”欺诈防堵金额为1.41亿元。\n(6)综合金融\n2019年一季度,综拓渠道迁徙客户(不含信用卡)净增74.39万户,占零售整体净增客户(不\n含信用卡)的比例为34.2%,其中财富客户净增2.76万户,占整体净增财富客户的比例为37.9%;\n管理零售客户资产(AUM)余额净增841.31亿元,占零售整体净增客户资产余额的比例为34.1%。\n综拓渠道发放“新一贷”152.36亿元,占“新一贷”整体发放的比例为55.03%;发放汽融贷款121\n亿元,占汽融贷款整体发放的比例为 37.0%。信用卡通过交叉销售渠道发卡 112 万张,在新增发卡\n量中占比为33.7%。零售全渠道代销集团保险累计实现非利息净收入8.73亿元,同比增长40.6%。\n133、对公聚焦精品打造\n2019 年本行公司业务按照对公做精的战略要求,借助科技手段持续打造智能化精品公司银行,\n努力实现在行业、客户、科技、协同四个方面的突破。具体来说,公司业务聚焦行业,以行业化、\n场景化模式推进集成营销;聚焦客户,在目标行业内精选客群,为客户提供系统的、智能的管理平\n台;聚焦科技,将智能与科技的基因融入业务的每个细胞,运用区块链、物联网等技术成果,实现\n业务升级;聚焦协同,利用集团综合金融优势,成为集团团体综合金融的发动机。\n2019 年一季度,对公精品业务打造成效初显,截至 3 月末企业存款余额 17,654.03 亿元,较上\n年末增长5.9%。存款增长得益于支付结算及供应链金融平台的打造,稳定了对公的基础客群。\n(1)精品公司业务经营\n① 互联网支付结算\n互联网支付结算聚焦与国计民生相关的战略行业和互联网新兴产业,深入研究细分行业平台需\n求,开展互联网头部客户攻坚战。同时,进一步锻造互联网支付结算“前中后台”能力,为客户提\n供更加高效的服务体验。一是前台,客户端操作移动化;二是中台,业务处理智能化,简化技术联\n调、项目实施和投产上线业务流程,提高作业效率;三是后台,运用区块链、大数据、反欺诈等金\n融科技手段,有效防范支付结算风险,保障用户资金和交易安全。2019年一季度,本行互联网支付\n结算新投产平台86个,交易笔数超1.5亿笔,平台交易量超5,300亿元。\n② 智能供应链金融\n本行持续升级供应链金融服务,应用人工智能、区块链、云计算、大数据等科技赋能,推出供\n应链应收账款服务平台,聚焦产业链核心企业及其上游客户,提供在线应收账款转让、融资、管理、\n结算等综合金融服务,协同多方构建中小企业融资服务生态。\n2019 年一季度,本行供应链应收账款服务平台累计交易量 78 亿元,已为 137 家核心企业及其\n上游供应商提供了金融服务,单笔融资最小金额仅 5.2 万元,有效解决了中小供应商融资难、融资\n贵问题,支持实体经济发展。\n③ 跨境金融\n本行持续推进跨境金融业务线上化,落地离在岸网银整合项目、离岸口袋财务迭代升级、离岸\n网银线上融资迭代升级、银企直联平台升级服务客群等线上化项目,增加客户可线上化操作的业务\n品种,提升客户体验。同时,本行聚焦跨境互联网交易平台客群,不断丰富跨境e金融平台服务场\n景,构建跨境电商平台金融服务生态。2019 年一季度,跨境 e 金融平台累计线上交易规模近 1,500\n亿元,同比增幅7.2%。\n④ 口袋财务\n本行口袋财务整合产品服务,实现了离岸、国际、现金管理业务移动化,同时积极引入集团综\n合金融产品,与平安证券、平安租赁、平安医保科技在场景方面的合作实现较大突破。2019年3月\n末,口袋财务注册客户27.8万户;2019年一季度,交易笔数85.3万笔、同比增长超7倍,交易金\n额6,419亿元、同比增长超10倍。\n14⑤ 小企业数字金融\n本行小企业数字金融通过深入企业经营场景,借助量化模型和数据分析完善中小企业精准画像,\n解决中小微企业信息不对称,风险较难把控的问题。2019 年,在推动和优化标准化产品的基础上,\n围绕车、房、医疗、大零售、大食品等行业积极探索,进一步为行业终端客户制定综合金融服务方\n案,解决融资难、融资贵的问题。2019 年 3 月末,本行小企业数字金融服务客户数 17,355 户,较\n上年末增长23.1%;2019年一季度,本行小企业数字金融累计发放贷款42亿元。\n⑥ 投行及同业业务\n本行坚定推进机构销售转型,以科技为手段,构建金融服务生态圈,打造金融机构的连接器。\n2019年一季度,本行继续拓展销售产品种类,扩大销售规模,做深客户经营,建立覆盖全国、全类\n型金融机构的销售网络;在此基础上,进一步优化销售产品结构,提高债券、非标资产及资管产品\n销售占比,大力发展交易做市业务及代客业务,提升交易贡献占比。2019年3月末,本行“行e通”\n平台累计合作客户达 2,115 户;2019 年一季度,同业机构销售业务量达到 1,300 亿元,为去年同期\n的 3 倍,债券及非标资产的销售占比进一步提升;利率掉期做市交易量达到 5,730 亿元,同比增长\n57.3%;代客衍生业务交易量实现 36.7 亿美元,同比增长 58.0%。本行积极推动真投行转型和轻型\n融资能力建设,发力债券承销、理财直融工具等领域,满足客户融资需求。2019年一季度,经银保\n监会注册发行的理财直融产品承销规模197亿元,市场发行份额排名位居前列。\n⑦ 政府金融\n顺应国家改革发展新趋势,依托平安集团“金融+科技”,本行在政府金融业务领域从传统的金\n融服务提供方转型为助力改革推进的合作伙伴,以科技赋能提升服务客户综合能力,聚焦财政、住\n房、司法三大生态持续创新。\n2019 年3 月末,本行新上线政府金融类平台 21个,累计达464 个。在财政生态领域,本行与\n平安集团重金所合作,引入区块链、AI智能辅助业务决策,研发涉企资金补助产品,助力政府穿透\n式资金监管。在住房生态领域,公积金业务、商品房交易资金监管等方面均有突破。在司法生态领\n域,锁定破产清算重点突破,研发全流程破产案件管理系统,利用区块链、人脸识别、OCR等技术,\n减轻破产管理人案件管理压力。同时,本行协同集团智慧司法多级联动营销,2019年一季度累计获\n客300户。\n(2)持续推进综合金融,成为集团团体综合金融的发动机\n本行充分运用集团综合金融资源与平台,做强银行渠道,围绕不同场景下的客户群体,主动设\n计场景化、嵌入式、线上化、智能化的综合服务方案,着力推进保险与投融突破。借助科技手段深\n度实现系统流程优化和管理能力升级。2019 年一季度,销售平安团体保险保费规模 2.02 亿元,新\n增银投合作项目规模424.09亿元。\n(3)特殊资产智慧经营\n2019 年,依托集团“金融+科技”优势,本行全力打造特殊资产智慧经营平台,平台借鉴行业\n领先的互联网企业“AI大脑”设计,以大数据为内驱动力,聚焦智能清收、智慧管理、生态经营三\n15大能力建设,推动特殊资产清收向智慧模式转型,不良资产清收业绩大幅提升。2019年一季度,本\n行收回不良资产总额95.22亿元、同比增长26.8%,其中信贷资产(贷款本金)88.88亿元;收回的\n贷款本金中,已核销贷款55.34亿元,未核销不良贷款33.54亿元;不良资产收回额中83.7%为现金\n收回,其余为以物抵债等方式收回。\n4、科技赋能成效显现\n本行将“科技引领”作为全行转型升级的驱动力,实行精益和敏捷的双模研发体系,进一步完\n善数据治理体系,加快技术架构向分布式架构转型,围绕产品创新、客户体验、风险管理,积极探\n索前沿技术的研究创新和实际应用,大力推动人工智能、大数据、区块链等新技术嵌入各项业务场\n景,助推银行在产品、服务、平台等方面转型升级。\n2019 年一季度,在零售条线,全面启动 AI 中台的建设,通过打造知识库平台、生物特征识别\n平台、业务中台等系统平台赋能前端业务场景,目前已在信用卡、投顾、客服、移动柜面、远程运\n营等 40 多项业务场景中进行探索或实践。在对公条线,持续完善产品库和案例库建设,综合运用\n视觉、语音、文本、图像识别等人工智能技术建设对公客户 360°画像平台,持续提升精准营销、\n自动审批、智慧风控等能力;运用深度学习技术,支持对支票、汇票、本票在内的数十种票据快速、\n准确识别,实现票据信息录入和处理的集中化、自动化和便捷化。科技与业务的融合更加紧密,全\n方位的AI Bank体系正加快形成。\n5、资产质量持续改善\n本行积极应对外部风险、调整业务结构,将贷款更多地投放到资产质量较好的零售业务。对公\n持续做精,新发放贷款聚焦重点行业、重点客户和重点产品,同时继续做好存量资产结构调整,加\n大问题资产清收处置力度,资产质量持续改善。\n2019 年 3 月末,本行逾期贷款占比、逾期 90 天以上贷款占比、关注率和不良率较上年末均有\n下降,偏离度低于1。其中:本行逾期贷款占比2.45%,较上年末下降0.03个百分点;逾期90天以\n上贷款占比1.66%,较上年末下降0.04个百分点;关注率2.58%,较上年末下降0.15个百分点;不\n良贷款率1.73%,较上年末下降0.02个百分点。不良贷款偏离度96%,较上年末下降1个百分点。\n2019年一季度,本行计提的信用及资产减值损失为128.85亿元,同比增长20.7%,其中发放贷\n款和垫款计提的减值损失为112.21 亿元;2019年3 月末,贷款减值准备余额为 604.08 亿元,较上\n年末增长 11.5%;拨贷比为 2.94%,较上年末增加 0.23 个百分点;拨备覆盖率为 170.59%、较上年\n末增加15.35个百分点,逾期90天以上贷款拨备覆盖率为177.71%、较上年末增加18.26个百分点,\n风险抵补能力进一步增强。\n6、支持服务实体经济\n本行充分发挥科技的赋能作用,多措并举服务实体经济,全面推进精准扶贫建设。\n一是细化区域政策和行业政策,主动融入国家战略,助力“粤港澳大湾区”、“一带一路”、“长\n三角一体化”等经济带建设,聚焦重点行业,切实服务地方经济发展。2019年3月末,本行重点行\n业客户授信占比48%。\n16二是积极贯彻落实国家要求,加强金融服务民营企业力度,支持小微企业发展。一是在科技运\n用上,运用人工智能、生物识别、大数据、区块链、云计算等前沿科技,打造“供应链应收账款服\n务平台”、“小企业数字金融”、“新一贷”等精品业务,切实支持民营企业、中小微企业发展,解决\n小企业融资难、融资贵问题。二是在制度执行上,一方面采取“差异化纾困+精准化服务”策略,\n借助平安集团纾困基金和团金合作模式,为企业提供定制化金融服务,为经营正常、流动性遇到暂\n时困难的民营企业提供支持;另一方面通过实施差异化信贷定价政策和风险容忍度,借力科技创新、\n产品创新和渠道创新,全方位支持小微企业发展。三是在执行情况和政策效果上,2019 年一季度,\n本行新增投放民营企业贷款客户占新增投放所有对公贷款客户达70%以上;2019年3月末,本行单\n户授信1,000万(含)以下的小微企业贷款较上年末增长3.5%,高于全行贷款平均增速,有贷款余\n额户数较上年末增加8,639户,该类小微企业贷款利率较上年末下降0.66个百分点,不良率控制在\n合理范围。\n三是持续推动精准扶贫,通过“金融+产业”扶贫,建立“金融+科技+产销”的扶贫闭环。2019\n年一季度,本行新增投放扶贫资金 43.85 亿元(其中产业扶贫贷款 12.95 亿元,扶贫政府债30.9亿\n元),累计覆盖和惠及建档立卡贫困人口超过43万人,直接帮扶建档立卡贫困人口超过2,800人。\n7、夯实基础提升资本\n本行持续推动精细化资本管理,全面实施经济资本管理,建立以经济增加值(EVA)和经济资\n本回报率(RAROC)为核心的综合绩效考核机制,通过调配业务、向低风险业务引导,加大税务核\n销力度,加大抵债资产处置力度、减少无效资本占用,提升资本水平。2019年3月末,本行核心一\n级资本充足率、一级资本充足率及资本充足率分别为 8.75%,9.59%及 11.50%,均满足监管达标要\n求。其中,核心一级资本充足率及一级资本充足率较上年末分别提高0.21个百分点、0.20个百分点,\n由于2019年3月7日全额赎回了90亿元二级资本债券,资本充足率与上年末持平。\n本行在通过利润留存补充资本的基础上,积极拓展外源性资本补充渠道,持续推进资本工具的\n发行工作。本行已于2019年1月25日完成发行260亿元A股可转换公司债券,转股后将有效补充\n本行核心一级资本。同时,本行正积极推进300亿元合格二级资本债券的发行工作,募集资金将依\n据适用法律,全部用于补充本行二级资本,进一步提高本行资本充足水平。\n8、合理配置网点布局\n本行持续实施网点智能化建设,合理配置网点布局,截至 2019 年 3 月末,本行有 81 家分行、\n共1,053家营业机构;全国已开业177家零售新门店。\n17(二)资产质量\n1、发放贷款和垫款五级分类情况\n(货币单位:人民币百万元)\n2019年3月31日 2018年12月31日\n期末比上年末\n项 目\n增减\n余额 占比 余额 占比\n正常贷款 1,963,035 95.69% 1,908,072 95.52% 2.9%\n关注贷款 52,998 2.58% 54,552 2.73% (2.8%)\n不良贷款 35,412 1.73% 34,905 1.75% 1.5%\n其中:次级 13,530 0.66% 17,955 0.90% (24.6%)\n可疑 6,459 0.32% 4,509 0.23% 43.2%\n损失 15,423 0.75% 12,441 0.62% 24.0%\n发放贷款和垫款本金总额 2,051,445 100.00% 1,997,529 100.00% 2.7%\n发放贷款和垫款减值准备 (60,408) (54,187) 11.5%\n其中:以摊余成本计量的发放贷\n(60,086) (54,033) 11.2%\n款和垫款减值准备\n以公允价值计量且其变动\n计入其他综合收益的发放 (322) (154) 109.1%\n贷款和垫款减值准备\n不良贷款率 1.73% 1.75% -0.02个百分点\n拨备覆盖率 170.59% 155.24% +15.35个百分点\n逾期90天以上贷款拨备覆盖率 177.71% 159.45% +18.26个百分点\n拨贷比 2.94% 2.71% +0.23个百分点\n2、发放贷款和垫款按产品划分的结构分布及质量情况\n(货币单位:人民币百万元)\n2019年3月31日 2018年12月31日\n项 目 不良率增减\n余额 不良率 余额 不良率\n企业贷款 860,913 2.59% 843,516 2.68% -0.09个百分点\n其中:一般企业贷款 798,940 2.79% 801,814 2.82% -0.03个百分点\n贴现 61,973 - 41,702 - -\n个人贷款 1,190,532 1.10% 1,154,013 1.07% +0.03个百分点\n其中:住房按揭贷款 186,137 0.14% 182,363 0.09% +0.05个百分点\n新一贷 155,190 1.14% 153,745 1.00% +0.14个百分点\n汽车金融贷款 169,921 0.62% 172,029 0.54% +0.08个百分点\n信用卡应收账款 492,307 1.34% 473,295 1.32% +0.02个百分点\n其他(注) 186,977 1.86% 172,581 1.97% -0.11个百分点\n发放贷款和垫款本金总额 2,051,445 1.73% 1,997,529 1.75% -0.02个百分点\n注:“其他”包括持证抵押贷款、小额消费贷款和其他保证或质押类贷款等。\n18企业贷款不良率较上年末有所下降,一方面本行持续加强问题授信处置力度,严格管控正常资\n产的新增下迁;另一方面,强化资产质量管控体系建设和完善机制保障,进一步夯实资产质量管理\n基础。同时,坚持新客户准入风险标准,从源头改善资产结构,提升资产质量。\n受宏观经济下行、共债风险上升、汽车消费下滑等外部因素影响,消费金融全行业风险都有所\n上升,本行个人贷款不良率较上年末也略有上升,但整体风险表现仍维持在相对较低水平,主要情\n况如下:\n(1)本行住房按揭贷款通过持续调整和优化客群结构,加大对优质客户的投放力度,提升新\n发放贷款质量,将按揭贷款不良率维持在较低的水平。\n(2)本行“新一贷”主动调整发放节奏,一方面进一步优化客户结构、提高新发放优质客群\n占比、对不同风险级别的客户进行差异化风险管理,另一方面持续加大催清收资源投入、多策并举\n提升催清收成效,整体风险水平在可控范围;同时本行充分运用互联网技术提升产品竞争力,运用\n人脸识别、微表情技术、智能语音等前沿科技手段,实现业务集中化、自动化、线上化和智能化,\n提升客户体验、服务效率,持续强化风险控制。\n(3)本行汽车金融业务由于2018年以来进行的产品结构调整,高收益产品发放占比进一步提\n升,不良率虽小幅上升,但依然处于可控范围内。同时本行通过对车辆识别技术与VIN码识别技术\n的深度运用,实现车型自动化精准定位,自动化的后台估价体系的搭建,实现全流程识人、识车、\n识价一秒完成,有效降低人工成本,提升业务能效,规避信贷及欺诈风险。在催收管理方面,持续\n推进催收系统优化升级,提升整体作业效率,有效融入科技手段,实现失联客户信息修复,以确保\n整体资产质量保持稳定,总体风险在可控范围之内。\n(4)本行信用卡业务全流程贯彻风险管理理念,充分利用量化工具,有效管控风险。一方面\n通过大数据平台和先进定量分析技术,结合风控模型、AI智能识别等科学工具全面优化风险管理策\n略,有效改善新户获客结构与品质,优化存量结构,确保组合资产可持续发展。另一方面,在科技\n引领的理念支持下,引入 AI 智能催收降低迁徙率,并持续优化进行规模推广;优化评分模型推进\n差异化催收策略;精细梳理不良资产,增加催清收资源投入,确保风险水平可控。\n193、逾期贷款情况\n(货币单位:人民币百万元)\n2019年3月31日 2018年12月31日\n项 目\n余额 占贷款总额比 余额 占贷款总额比\n本金和利息逾期90天以内贷款 16,282 0.79% 15,496 0.78%\n本金或利息逾期90天以上贷款 33,993 1.66% 33,984 1.70%\n2019年3月末,本行逾期90天以内贷款(含本金未逾期,利息逾期90天以内贷款)余额162.82\n亿元,占贷款总额比 0.79%,较上年末略升 0.01 个百分点;逾期 90 天以上贷款(含本金未逾期,\n利息逾期90天以上贷款)余额339.93亿元,占贷款总额比1.66%,较上年末下降0.04个百分点。\n本行已积极采取各项措施,分类制定清收和重组转化方案,加大问题资产清收处置力度,进一步做\n好风险管理和化解工作。\n(三)利息收支情况\n1、主要资产、负债项目的日均余额以及平均收益率或平均成本率情况\n(货币单位:人民币百万元)\n2019年1-3月 2018年1-3月\n项 目 利息收入/ 平均收益/ 利息收入/ 平均收益/\n日均余额 日均余额\n支出 成本率 支出 成本率\n资产\n发放贷款和垫款(不含贴现) 1,967,432 32,045 6.61% 1,879,483 27,677 5.97%\n债券投资 596,632 4,816 3.27% 521,762 4,790 3.72%\n存放央行 220,552 824 1.52% 266,602 1,026 1.56%\n票据贴现及同业业务 548,320 4,976 3.68% 694,339 7,451 4.35%\n生息资产总计 3,332,936 42,661 5.19% 3,362,186 40,944 4.94%\n负债\n吸收存款 2,190,470 13,600 2.52% 2,017,494 11,512 2.31%\n发行债务证券 400,271 3,803 3.85% 344,088 4,047 4.77%\n其中:同业存单 307,556 2,735 3.61% 298,972 3,461 4.69%\n同业业务及其他 642,101 4,484 2.83% 762,700 6,698 3.56%\n计息负债总计 3,232,842 21,887 2.75% 3,124,282 22,257 2.89%\n利息净收入 20,774 18,687\n存贷差 4.09% 3.66%\n净利差 2.44% 2.05%\n净息差 2.53% 2.25%\n本行持续优化业务结构,个人贷款规模和占比增加,生息资产收益率有所提升,计息负债成本\n率进一步下降,2019 年一季净利差、净息差分别为 2.44%、2.53%,同比分别提升 39 个基点、28\n个基点,环比均提升3个基点。\n20(货币单位:人民币百万元)\n2019年1-3月 2018年10-12月\n项 目 利息收入/ 平均收益/ 利息收入/ 平均收益/\n日均余额 日均余额\n支出 成本率 支出 成本率\n资产\n发放贷款和垫款(不含贴现) 1,967,432 32,045 6.61% 1,928,340 31,734 6.53%\n债券投资 596,632 4,816 3.27% 543,328 4,517 3.30%\n存放央行 220,552 824 1.52% 246,809 950 1.53%\n票据贴现及同业业务 548,320 4,976 3.68% 485,719 4,026 3.29%\n生息资产总计 3,332,936 42,661 5.19% 3,204,196 41,227 5.10%\n负债\n吸收存款 2,190,470 13,600 2.52% 2,085,555 12,765 2.43%\n发行债务证券 400,271 3,803 3.85% 341,493 3,318 3.85%\n其中:同业存单 307,556 2,735 3.61% 291,053 2,671 3.64%\n同业业务及其他 642,101 4,484 2.83% 670,183 4,928 2.92%\n计息负债总计 3,232,842 21,887 2.75% 3,097,231 21,011 2.69%\n利息净收入 20,774 20,216\n存贷差 4.09% 4.10%\n净利差 2.44% 2.41%\n净息差 2.53% 2.50%\n2、发放贷款和垫款日均余额及平均收益率\n(货币单位:人民币百万元)\n2019年1-3月 2018年1-3月\n项 目\n日均余额 利息收入 平均收益率 日均余额 利息收入 平均收益率\n企业贷款(不含贴现) 796,781 10,245 5.21% 984,401 11,138 4.59%\n个人贷款(含信用卡) 1,170,651 21,800 7.55% 895,082 16,539 7.49%\n发放贷款和垫款(不含贴现) 1,967,432 32,045 6.61% 1,879,483 27,677 5.97%\n2019年1-3月 2018年10-12月\n项 目\n日均余额 利息收入 平均收益率 日均余额 利息收入 平均收益率\n企业贷款(不含贴现) 796,781 10,245 5.21% 814,313 9,709 4.73%\n个人贷款(含信用卡) 1,170,651 21,800 7.55% 1,114,027 22,025 7.84%\n发放贷款和垫款(不含贴现) 1,967,432 32,045 6.61% 1,928,340 31,734 6.53%\n213、吸收存款日均余额及平均成本率\n(货币单位:人民币百万元)\n2019年1-3月 2018年1-3月\n项 目\n日均余额 利息支出 平均成本率 日均余额 利息支出 平均成本率\n企业存款 1,708,521 10,430 2.48% 1,659,390 9,262 2.26%\n其中:活期 501,005 774 0.63% 491,670 721 0.59%\n定期 989,634 8,339 3.42% 914,998 7,134 3.16%\n其中:国库及协议存款 105,906 1,204 4.61% 103,203 1,209 4.75%\n保证金存款 217,882 1,317 2.45% 252,722 1,407 2.26%\n个人存款 481,949 3,170 2.67% 358,104 2,250 2.55%\n其中:活期 158,906 117 0.30% 137,750 103 0.30%\n定期 298,927 2,772 3.76% 194,791 1,851 3.85%\n保证金存款 24,116 281 4.73% 25,563 296 4.70%\n吸收存款 2,190,470 13,600 2.52% 2,017,494 11,512 2.31%\n2019年1-3月 2018年10-12月\n项 目\n日均余额 利息支出 平均成本率 日均余额 利息支出 平均成本率\n企业存款 1,708,521 10,430 2.48% 1,660,230 9,966 2.38%\n其中:活期 501,005 774 0.63% 492,137 759 0.61%\n定期 989,634 8,339 3.42% 956,324 7,798 3.24%\n其中:国库及协议存款 105,906 1,204 4.61% 108,243 1,217 4.46%\n保证金存款 217,882 1,317 2.45% 211,769 1,409 2.64%\n个人存款 481,949 3,170 2.67% 425,325 2,799 2.61%\n其中:活期 158,906 117 0.30% 152,553 115 0.30%\n定期 298,927 2,772 3.76% 248,074 2,390 3.82%\n保证金存款 24,116 281 4.73% 24,698 294 4.72%\n吸收存款 2,190,470 13,600 2.52% 2,085,555 12,765 2.43%\n22十二、资本充足率、杠杆率、流动性覆盖率情况\n(一)资本充足率\n(货币单位:人民币百万元)\n项 目 2019年3月31日 2018年12月31日\n核心一级资本净额 209,462 199,782\n其他一级资本 19,953 19,953\n一级资本净额 229,415 219,735\n二级资本 45,820 49,380\n资本净额 275,235 269,115\n风险加权资产合计 2,393,377 2,340,236\n信用风险加权资产 2,150,812 2,090,152\n表内风险加权资产 1,938,038 1,892,934\n表外风险加权资产 206,348 194,921\n交易对手信用风险暴露的风险加权资产 6,426 2,297\n市场风险加权资产 35,745 43,264\n操作风险加权资产 206,820 206,820\n核心一级资本充足率 8.75% 8.54%\n一级资本充足率 9.59% 9.39%\n资本充足率 11.50% 11.50%\n(二)杠杆率\n(货币单位:人民币百万元)\n项 目 2019年3月31日 2018年12月31日 2018年9月30日 2018年6月30日\n杠杆率 5.80% 5.75% 5.76% 5.63%\n一级资本净额 229,415 219,735 214,375 208,444\n调整后表内外资产余额 3,953,768 3,818,886 3,722,035 3,704,345\n注:主要因核心一级资本净额增加,报告期末杠杆率较2018年末增加。\n(三)流动性覆盖率\n(货币单位:人民币百万元)\n项 目 2019年3月31日 2018年12月31日\n流动性覆盖率 133.64% 139.17%\n合格优质流动性资产 424,111 406,359\n净现金流出 317,350 291,995\n注:根据银保监会发布的《商业银行流动性风险管理办法》,商业银行的流动性覆盖率应当在2018年底前达到100%。\n23第四节 财务报表\n一、财务报表\n(一)资产负债表(未经审计)\n(二)利润表(未经审计)\n(三)现金流量表(未经审计)\n24平安银行股份有限公司\n资产负债表\n2019年3月31日\n货币单位:人民币百万元\n资产 2019年3月31日 2018年12月31日\n现金及存放中央银行款项 246,856 278,528\n存放同业款项 69,911 85,098\n贵金属 54,173 56,835\n拆出资金 100,231 72,934\n衍生金融资产 19,850 21,460\n买入返售金融资产 44,335 36,985\n发放贷款和垫款 1,997,321 1,949,757\n金融投资:\n交易性金融资产 163,945 148,768\n债权投资 636,607 629,366\n其他债权投资 111,039 70,664\n其他权益工具投资 1,565 1,519\n投资性房地产 198 194\n固定资产 10,623 10,899\n无形资产 4,575 4,771\n商誉 7,568 7,568\n递延所得税资产 30,860 29,468\n租赁使用权资产 7,239 -\n其他资产 23,284 13,778\n资产总计 3,530,180 3,418,592\n负债和股东权益\n负债\n向中央银行借款 142,394 149,756\n同业及其他金融机构存放款项 304,501 392,738\n拆入资金 21,669 24,606\n交易性金融负债 15,721 8,575\n衍生金融负债 18,408 21,605\n卖出回购金融资产款项 21,502 7,988\n吸收存款 2,308,782 2,149,142\n应付职工薪酬 9,848 12,238\n应交税费 11,053 9,366\n已发行债务证券 401,758 381,884\n预计负债 1,271 860\n租赁负债 7,700 -\n其他负债 14,635 19,792\n负债合计 3,279,242 3,178,550\n股东权益\n股本 17,170 17,170\n其他权益工具 23,678 19,953\n其中:优先股 19,953 19,953\n资本公积 56,465 56,465\n其他综合收益 1,385 786\n盈余公积 10,781 10,781\n一般风险准备 39,850 39,850\n未分配利润 101,609 95,037\n股东权益合计 250,938 240,042\n负债和股东权益总计 3,530,180 3,418,592\n法定代表人 行长 首席财务官 会计机构负责人\n谢永林 胡跃飞 项有志 朱培卿\n25平安银行股份有限公司\n利润表\n2019年1-3月\n货币单位:人民币百万元\n项 目 2019年1-3月 2018年1-3月\n一、营业收入 32,476 28,026\n利息净收入 20,774 18,687\n利息收入 42,661 40,944\n利息支出 21,887 22,257\n手续费及佣金净收入 9,560 8,594\n手续费及佣金收入 11,579 9,947\n手续费及佣金支出 2,019 1,353\n投资收益 2,571 416\n其中:以摊余成本计量的金融资产终止确认\n3 6\n产生的收益\n公允价值变动损益 (642) 373\n汇兑损益 169 (178)\n其他业务收入 17 76\n资产处置损益 12 44\n其他收益 15 14\n二、营业支出 9,924 8,771\n税金及附加 304 261\n业务及管理费 9,620 8,510\n三、减值损失前营业利润 22,552 19,255\n信用减值损失 12,654 10,664\n其他资产减值损失 231 7\n四、营业利润 9,667 8,584\n加:营业外收入 29 5\n减:营业外支出 25 22\n五、利润总额 9,671 8,567\n减:所得税费用 2,225 1,972\n六、净利润 7,446 6,595\n持续经营净利润 7,446 6,595\n终止经营净利润 - -\n七、其他综合收益的税后净额 599 151\n(一)不能重分类进损益的其他综合收益 (6) -\n其他权益工具投资公允价值变动 (6) -\n(二)将重分类进损益的其他综合收益 605 151\n1.以公允价值计量且其变动计入其他综\n291 137\n合收益的金融资产的公允价值变动\n2.以公允价值计量且其变动计入其他综\n314 14\n合收益的金融资产的信用损失准备\n八、综合收益总额 8,045 6,746\n九、每股收益\n(一)基本每股收益(元/股) 0.38 0.33\n(二)稀释每股收益(元/股) 0.36 0.33\n法定代表人 行长 首席财务官 会计机构负责人\n谢永林 胡跃飞 项有志 朱培卿\n26平安银行股份有限公司\n现金流量表\n2019年1-3月\n货币单位:人民币百万元\n项 目 2019年1-3月 2018年1-3月\n一、经营活动产生的现金流量:\n存放中央银行和同业款项净减少额 50,811 44,135\n向中央银行借款净增加额 - 14,245\n吸收存款和同业存放款项净增加额 69,010 59,227\n拆出资金净减少额 2,613 -\n买入返售款项净减少额 48 -\n卖出回购款项净增加额 13,515 -\n收取利息、手续费及佣金的现金 50,004 44,500\n收到其他与经营活动有关的现金 1,790 20,940\n经营活动现金流入小计 187,791 183,047\n向中央银行借款净减少额 7,715 -\n发放贷款和垫款净增加额 65,841 78,639\n拆出资金净增加额 - 3,398\n拆入资金净减少额 2,959 10,546\n为交易目的而持有的金融资产净增加额 22,611 -\n卖出回购款项净减少额 - 1,662\n应收账款净增加额 - 15,605\n支付利息、手续费及佣金的现金 17,111 15,658\n支付给职工及为职工支付的现金 7,595 6,649\n支付的各项税费 5,347 2,578\n支付其他与经营活动有关的现金 5,428 6,870\n经营活动现金流出小计 134,607 141,605\n经营活动产生的现金流量净额 53,184 41,442\n二、投资活动产生的现金流量:\n收回投资收到的现金 105,313 113,266\n取得投资收益收到的现金 7,500 7,281\n处置固定资产及其他长期资产收回的现金 238 -\n投资活动现金流入小计 113,051 120,547\n投资支付的现金 144,432 117,298\n购建固定资产、无形资产及其他长期资产所支付的现金 409 27\n投资活动现金流出小计 144,841 117,325\n投资活动产生的现金流量净额 (31,790) 3,222\n三、筹资活动产生的现金流量:\n发行债务证券收到的现金 78,154 199,253\n筹资活动现金流入小计 78,154 199,253\n偿还债务证券本金支付的现金 57,720 187,200\n偿付债务证券利息支付的现金 637 677\n分配股利及利润支付的现金 874 874\n筹资活动现金流出小计 59,231 188,751\n筹资活动产生的现金流量净额 18,923 10,502\n四、汇率变动对现金及现金等价物的影响 (1,155) (1,302)\n五、现金及现金等价物净增加/(减少)额 39,162 53,864\n加:年初现金及现金等价物余额 161,801 137,023\n六、期末现金及现金等价物余额 200,963 190,887\n法定代表人 行长 首席财务官 会计机构负责人\n谢永林 胡跃飞 项有志 朱培卿\n27二、审计报告\n第一季度报告是否经过审计\n□是 √否\n本行第一季度报告未经审计。\n平安银行股份有限公司董事会\n2019年4月24日\n28', # noqa: E501 + 'target': + '证券代码:000001 证券简称:平安银行 公告编号:2019-024\n平安银行股份有限公司\n2019年第一季度报告\n二〇一九年四月二十四日第一节 重要提示\n一、本行董事会、监事会及董事、监事、高级管理人员保证本报告内容的真实、准确、完整,\n不存在虚假记载、误导性陈述或者重大遗漏,并承担个别和连带的法律责任。\n二、本行第十届董事会第二十四次会议审议了 2019 年第一季度报告。本次董事会会议应出席\n董事14人,实到董事14人。会议一致同意此报告。\n三、本行董事长谢永林、行长胡跃飞、首席财务官项有志、会计机构负责人朱培卿保证 2019\n年第一季度报告中财务报表的真实、准确、完整。\n四、非标准审计意见提示\n□适用 √不适用\n本行本季度财务报表未经审计,普华永道中天会计师事务所(特殊普通合伙)对个别项目及财\n务报表编制流程执行了商定程序。\n1第二节 公司基本情况\n一、主要会计数据和财务指标\n公司是否需追溯调整或重述以前年度会计数据\n□是 √否\n本行于2019年1月1日起实施《企业会计准则第21号——租赁》(财会〔2018〕35号),并自\n2019年第一季度报告起按上述会计准则要求进行会计报表披露。根据衔接规定,本行按首次执行本\n会计准则的累积影响数,调整财务报表相关项目金额,不调整可比期间信息,相关影响请参阅《平\n安银行股份有限公司关于会计政策变更的公告》。\n(货币单位:人民币百万元)\n期末比上年末\n项 目 2019年3月31日 2018年12月31日\n增减\n项 目 2019年1-3月 2018年1-3月 同比增减\n扣除非经常性损益后的加权平均净资产收益率\n2.90% 2.77% 0.13个百分点\n(未年化)\n扣除非经常性损益后的加权平均净资产收益率\n12.11% 11.80% 0.31个百分点\n(年化)\n注:(1)本行于2016年3月7日非公开发行200亿元非累积型优先股,在计算“每股收益”及“加权平均净资产收\n2益率”时,分子均扣减了已发放的优先股股息。\n(2)本行于2019年1月25日发行260亿元A股可转换公司债券,在计算“稀释每股收益”时已考虑可转换公\n司债券转为普通股的稀释性影响。\n(货币单位:人民币百万元)\n项 目 2019年3月31日 2018年12月31日 2017年12月31日 期末比上年末增减\n注:(1)根据《中国人民银行关于调整金融机构存贷款统计口径的通知》(银发〔2015〕14号),从2015年开始,非\n存款类金融机构存放在存款类金融机构的款项纳入“各项存款”、存款类金融机构拆放给非存款类金融机构的款\n项纳入“各项贷款”统计口径。按此统计口径,2019年3月31日的各项存款为26,027亿元,各项贷款为21,101\n亿元。\n(2)根据财政部《关于修订印发2018年度金融企业财务报表格式的通知》(财会〔2018〕36号)的规定,基于\n实际利率法计提的利息计入金融工具账面余额中,于资产负债表日尚未收到或尚未支付的利息在“其他资产”或\n“其他负债”列示。除非特别说明,本报告提及的“发放贷款和垫款”、“吸收存款”及其明细项目均为不含息\n金额。\n截至披露前一交易日的公司总股本\n报告期末至季度报告披露日股本是否因发行新股、增发、配股、股权激励行权、回购等原因发\n生变化且影响所有者权益金额\n□是 √否\n非经常性损益项目和金额\n√适用 □不适用\n3(货币单位:人民币百万元)\n注:非经常性损益根据证监会《公开发行证券的公司信息披露解释性公告第1号——非经常性损益》的定义计算。\n本行报告期不存在将根据《公开发行证券的公司信息披露解释性公告第1号——非经常性损益》\n定义、列举的非经常性损益项目界定为经常性损益的项目的情形。\n二、监管指标和财务比率\n(单位:%)\n项 目 标准值 2019年3月31日 2018年12月31日 2017年12月31日\n注:监管指标根据监管口径列示。\n三、报告期末股东总数及前十名股东持股情况表\n1、普通股股东和表决权恢复的优先股股东总数及前10名股东持股情况表\n(单位:股)\n报告期末表决权恢复的\n报告期末普通股股东总数(户) 354,508 -\n优先股股东总数(如有)\n前10名普通股东持股情况\n质押或冻结\n持有有限售条 情况\n股东名称 股东性质 持股比例(%) 持股总数\n件股份数量 股份\n数量\n状态\n中国平安保险(集团)股份有限公\n境内法人 49.56 8,510,493,066 - - -\n司-集团本级-自有资金\n4中国平安人寿保险股份有限公司\n境内法人 6.11 1,049,462,784 - - -\n-自有资金\n中国平安人寿保险股份有限公司\n境内法人 2.27 389,735,963 - - -\n-传统-普通保险产品\n中信信托有限责任公司-中信信\n托锐进 43 期高毅晓峰投资集合资 境内法人 0.39 66,346,066 - - -\n金信托计划\n上海高毅资产管理合伙企业(有限\n境内法人 0.29 49,221,916 - - -\n合伙)-高毅晓峰2号致信基金\n前10名无限售条件股东持股情况\n持有无限售条 股份种类\n股东名称\n件股份数量 股份种类 数量\n金信托计划\n上海高毅资产管理合伙企业(有限合伙)-高毅晓峰2号致信基金 49,221,916 人民币普通股 49,221,916\n1、中国平安人寿保险股份有限公司为中国平安保险(集团)股份有限公司控股子公司和\n一致行动人,“中国平安保险(集团)股份有限公司-集团本级-自有资金”、“中国平安\n上述股东关联关系或一致行\n人寿保险股份有限公司-自有资金”与“中国平安人寿保险股份有限公司-传统-普通保\n动的说明\n险产品”具有关联关系。\n2、本行未知其他股东间的关联关系,也未知其是否属于一致行动人。\n前10名普通股股东参与融\n资融券业务股东情况说明 无\n(如有)\n公司前10名普通股股东、前10名无限售条件普通股股东在报告期内是否进行约定购回交易\n□是 √否\n52、优先股股东总数及前10名优先股股东持股情况表\n√适用 □不适用\n(单位:股)\n报告期末优先股股东总数(户) 15\n前10名优先股股东持股情况\n(%) 件的股份数量 股份状态 数量\n中国平安人寿保险股份有限公司-\n境内法人 29.00 58,000,000 - - -\n分红-个险分红\n中国平安人寿保险股份有限公司-\n境内法人 19.34 38,670,000 - - -\n万能-个险万能\n中国平安财产保险股份有限公司-\n境内法人 9.67 19,330,000 - - -\n传统-普通保险产品\n中邮创业基金-华夏银行-华夏银\n境内法人 8.95 17,905,000 - - -\n行股份有限公司\n交银施罗德资管-交通银行-交通\n境内法人 8.95 17,905,000 - - -\n银行股份有限公司\n华润深国投信托有限公司-投资1\n境内法人 2.98 5,950,000 - - -\n号单一资金信托\n华宝信托有限责任公司-投资2号\n境内法人 2.98 5,950,000 - - -\n资金信托\n招商财富-邮储银行-中国邮政储\n境内法人 2.98 5,950,000 - - -\n蓄银行股份有限公司\n1、中国平安人寿保险股份有限公司和中国平安财产保险股份有限公司为中国平安\n保险(集团)股份有限公司控股子公司和一致行动人,“中国平安人寿保险股份有\n上述股东关联关系或一致行动的说\n限公司-分红-个险分红”、“中国平安人寿保险股份有限公司-万能-个险万能”\n明\n与“中国平安财产保险股份有限公司-传统-普通保险产品”具有关联关系。\n2、本行未知其他股东间的关联关系,也未知其是否属于一致行动人。\n注:(1)本行已发行的优先股不设限售期,均为无限售条件优先股;\n(2)本行无表决权恢复的优先股股东。\n63、前10名可转换公司债券持有人情况\n注:(1)中国平安人寿保险股份有限公司为中国平安保险(集团)股份有限公司控股子公司和一致行动人,“中国\n平安保险(集团)股份有限公司-集团本级-自有资金”、“中国平安人寿保险股份有限公司-自有资金”与“中国\n平安人寿保险股份有限公司-传统-普通保险产品”具有关联关系;\n(2)本行未知其他股东间的关联关系,也未知其是否属于一致行动人。\n(货币单位:人民币百万元)\n项 目 本期金额 变动金额 变动比率 变动原因分析\n拆出资金 100,231 27,297 37.4% 拆放境内、外同业款项增加\n分类为以公允价值计量且其变动计入其他\n其他债权投资 111,039 40,375 57.1%\n综合收益的债券投资、同业投资规模增加\n自2018年三季度起根据新金融工具会计准\n则要求,将“以公允价值计量且其变动计入\n投资收益 2,571 2,155 518.0%\n当期损益的金融工具”产生的收入/支出从\n利息收入/支出计入投资收益\n计提的抵债资产减值损失增加;基期数小,\n其他资产减值损失 231 224 3,200.0%\n上年同期为0.07亿元\n营业外收入 29 24 480.0% 基期数小,上年同期为0.05亿元\n二、重要事项进展情况及其影响和解决方案的分析说明\n√适用 □不适用\n2019年1月25日,本行完成公开发行260亿元可转换公司债券。2019年1月30日,本行收到\n中国证券登记结算有限责任公司深圳分公司出具的《证券登记证明》。经深圳证券交易所批准,本\n行公开发行可转换公司债券已于2019年2月18日起在深圳证券交易所挂牌交易。\n2019年3月7日和4月10日,本行全额赎回规模为人民币90亿元的10年期二级资本债券和\n规模为人民币60亿元的10年期二级资本债券。\n8重要事项概述 披露日期 临时报告披露网站查询索引\n本行公开发行人民币260亿元A股可转 《中国证券报》、《证券时报》、《上\n2019年1月17日、2019年1月25日\n换公司债券 海证券报》、《证券日报》和巨潮资讯\n本行对二级资本债券行使赎回选择权 2019年3月9日、2019年4月12日 网(www.cninfo.com.cn)\n股份回购的实施进展情况\n□适用 √不适用\n采用集中竞价方式减持回购股份的实施进展情况\n□适用 √不适用\n三、公司实际控制人、股东、关联方、收购人以及公司等承诺相关方在报告期内超期未履行完\n毕的承诺事项\n□适用 √不适用\n公司报告期不存在公司实际控制人、股东、关联方、收购人以及公司等承诺相关方在报告期内\n超期未履行完毕的承诺事项。\n四、对2019年1-6月经营业绩的预计\n预测年初至下一报告期期末的累计净利润可能为亏损或者与上年同期相比发生大幅度变动的\n警示及原因说明\n□适用 √不适用\n五、证券投资情况\n报告期末,本行所持金融债券(政策性银行债、各类普通金融债、次级金融债,不含企业债)\n账面价值为1,633.17亿元,其中前十大面值金融债券的有关情况如下:\n(货币单位:人民币百万元)\n9六、委托理财\n□适用 √不适用\n报告期内,本行未发生正常业务范围之外的委托理财事项。\n七、衍生品投资情况\n(货币单位:人民币百万元)\n年初合约金额 期末合约金额 报告期公允价值\n合约种类\n(名义金额) (名义金额) 变动情况\n注:本行在董事会确立的风险偏好和市场风险总体限额框架内,开展包括衍生品的资金交易和投资业务。衍生\n品金融工具名义金额只体现交易量,并不反映其实际风险暴露。本行开展的外汇及利率衍生品业务主要采取对冲策\n略,实际汇率及利率风险暴露较小。\n八、报告期内接待调研、沟通、采访等活动情况\n(www.cninfo.com.cn)\n2019/01/24 实地调研 机构\n《平安银行股份有限公司\n2019/03/25 实地调研 机构 投资者关系活动记录表》\n九、违规对外担保情况\n□适用 √不适用\n公司报告期无违规对外担保情况。\n十、控股股东及其关联方对上市公司的非经营性占用资金情况\n□适用 √不适用\n公司报告期不存在控股股东及其关联方对上市公司的非经营性占用资金。\n10十一、管理层讨论与分析\n(一)总体情况\n2019 年是新中国成立 70 周年,也是决胜全面建成小康社会的关键之年。正确把握金融本质,\n深化金融供给侧结构性改革,平衡好稳增长和防风险的关系,增强金融服务实体经济能力,打好防\n范化解包括金融风险在内的重大风险攻坚战,推动金融业高质量发展,是全年金融工作的重中之重。\n2019 年一季度,宏观经济运行总体平稳,供给侧结构性改革持续深化,“三大攻坚战”稳步推进,\n经济发展新动力进一步形成。\n本行紧跟国家战略,顺应经济金融形势,坚持以打造“中国最卓越、全球领先的智能化零售银\n行”为目标,持续深化“科技引领、零售突破、对公做精”策略方针,不断加强科技创新和场景应\n用,坚定推进智能化零售业务转型,持续深化对公业务从粗放型增长向效益型发展转型,严控各类\n金融风险,不断加大对民营企业、小微企业等实体经济的支持力度,各项业务稳健发展,资产质量\n持续改善,战略转型成效进一步显现。\n1、整体经营稳中趋好\n2019 年一季度,本行实现营业收入 324.76 亿元,同比增长 15.9%;其中,利息净收入 207.74\n亿元,同比增长11.2%;非利息净收入117.02亿元,同比增长25.3%。减值损失前营业利润225.52\n亿元,同比增长 17.1%;净利润 74.46 亿元,同比增长 12.9%;2019 年一季度净息差 2.53%,同比\n上升28个基点、环比上升3个基点,非利息净收入占比36.0%,同比上升2.7个百分点。\n2019 年 3 月末,本行资产总额 35,301.80 亿元,较上年末增长 3.3%;吸收存款余额 22,869.77\n亿元,较上年末增长7.4%;发放贷款和垫款总额(含贴现)20,514.45亿元,较上年末增长2.7%。\n2019 年 3 月末,本行保本理财产品余额 941.83 亿元、较上年末增长 13.9%,结构性存款余额\n4,848.06亿元、较上年末增长11.8%,非保本理财产品余额5,584.61亿元、较上年末增长3.8%。\n2、零售转型稳健前行\n2019年一季度,本行贯彻“零售突破”的策略,继续深化综合金融优势,重点发力“基础零售、\n消费金融、私行财富”三大业务模块,提升“风险和成本控制”两大核心能力,充分发挥科技赋能\n作用,优化业务流程与服务体验,转型工作稳健进行。\n(1)基础零售\n2019 年 3 月末,本行管理零售客户资产(AUM)16,632.98 亿元、较上年末增长 17.4%,零售\n客户数8,701万户、较上年末增长3.7%。个人存款余额5,215.74亿元、较上年末增长13.0%,平安\n口袋银行APP注册客户数6,765万户,较上年末增长8.7%;月活客户数2,603万户。\n2019年一季度,本行持续聚焦基础零售客户获客及经营,通过场景化、科技化手段,利用不同\n的互联网场景打造多种获客路径,坚持科技赋能和大数据驱动客户分析及经营策略的制定,促进获\n客及客户经营效率与产能提升。在获客方面,本行依托集团生态圈,深挖场景价值,持续优化口袋\n银行APP功能体验,构建银行线上生态圈,实现从用户到客户的转化;在经营方面,搭建智能运营\n平台,围绕客户生命周期,构造“客群-渠道”智能化产品推荐体系,并结合内外部资源构建 KYC\n11(充分了解客户)体系,实现大数据赋能客群经营,提升客户经营能力及管理效率。\n(2)私行财富\n2019 年 3 月末,本行财富客户 66.45 万户、较上年末增长 12.3%;私行达标客户 3.5 万户、较\n上年末增长16.7%。\n2019年一季度,本行在私人银行和财富管理方面,强化落实经营转型策略,在产品、体系和机\n制三方面积极推进战略落地。在产品平台和权益体系搭建上,本行充分整合内外部资源,持续供应\n优质产品和权益体验,支持私财业务发展。同时,大力推动投顾团队建设,充分利用 AI 科技力量\n和平安集团综合金融模式的优势,打造一支专业化、智能化的投顾团队。在内部管理体系方面,运\n作产品管理委员会机制,多元化引入优质资产;强化内嵌风险团队,严格把控资产风险;在机制建\n设方面,“分支行”、“直营”和“综合金融”三大业务模式并驾齐驱,强化私行客户经营能力和资\n产配置能力。\n(3)消费金融\n2019年3月末,个人贷款余额11,905.32亿元、较上年末增长3.2%。2018年末以来,鉴于宏观\n经济仍存下行压力,在确保资产质量稳定的前提下,本行主动优化贷款类产品投放策略,适当提高\n信用卡、贷款投放门槛,推动目标客群上移。同时,本行根据市场需求,持续优化客群画像、产品\n方案、流程、客户体验,覆盖更多场景,满足客户不同阶段消费及经营融资需求。目前,多项举措\n正在内部逐渐深化、生效,预计下半年起,个人贷款及信用卡业务主要规模指标增长将出现积极的\n变化。在优化投向配置的同时,本行着力支持小微企业、促进实体经济发展。\n2019年3月末,本行信用卡流通卡量5,352万张,较上年末增长3.9%;信用卡贷款余额4,923.07\n亿元,较上年末增长4.0%。2019年一季度,信用卡总交易金额8,010.30亿元,同比增长43.2%;信\n用卡商城交易量同比增长24.7%;2019年3月末,信用卡APP客户数已突破3,700万户。本行信用\n卡继续依托互联网平台深入打造“快、易、好”极致客户体验,并加强线上线下融合:线上完善信\n用卡商城服务平台,深入精细化运营,提升用户价值感;线下搭建精准营销服务平台,针对性地推\n出适合不同客群的精准营销活动,打造围绕不同客户偏好的差异化营销策略,提升客户不同用卡场\n景的消费体验。\n2019年3月末,本行“新一贷”余额1,551.90亿元,较上年末增长0.9%。“新一贷”借助科技\n力量赋能贷前、贷中、贷后全流程,全力打造“SAT+T”全线上申请流程模式,并通过嵌入数据直\n连、机器人流程自动化(RPA)、微表情等AI技术,实现在线智能“核身+核审”,不仅大幅简化客\n户申请所需材料、节约客户申请时间,为客户提供良好的移动金融交互体验,也积极利用互联网技\n术持续加强信贷全流程的防风险、反欺诈能力建设。\n2019 年 3 月末,本行住房按揭贷款余额 1,861.37 亿元,较上年末增长 2.1%。本行严格落实国\n家政策规定和监管要求,支持居民家庭首套自住购房需求,并将继续在合规前提下稳步开展住房信\n贷业务。\n2019 年 3 月末,本行汽车金融贷款余额 1,699.21 亿元,较上年末降幅 1.2%。在汽车消费市场\n12整体增速放缓的趋势下,本行不断通过产品创新、流程优化、AI智能化等多项举措,深耕车生态经\n营策略,在不断提升客户体验的同时,进一步挖潜汽车消费金融市场。\n(4)风险控制\n2019年3月末,本行个人贷款不良率1.10%,较上年末上升0.03个百分点。受到宏观经济下行、\n共债风险上升、汽车消费下滑等外部因素的影响,消费金融全行业的风险都有所上升,本行零售产\n品的不良率也略有上升,但整体风险表现仍维持在相对较低水平。其中信用卡不良率 1.34%,较上\n年末上升0.02个百分点;“新一贷”不良率1.14%,较上年末上升0.14个百分点;汽车金融业务的\n不良率 0.62%,较上年末上升 0.08 个百分点。本行自 2017 年底开始提前进行风险政策调整,重点\n防范共债风险,有效控制并降低了高风险客户占比,新发放业务的资产质量较好,预计这些管制措\n施的优化效应将会在 2019 年下半年逐渐展现,零售主要贷款的不良率下半年将呈现稳中向好的趋\n势。\n(5)科技平台赋能\n2019 年一季度本行持续加大科技投入,在线上进一步升级迭代,嵌入口袋银行 APP 和口袋银\n上线下相融合)服务体系,通过综合化、场景化、个性化让客户能在线上线下无缝切换,为客户带\n来了更好的金融生活体验。\n2019 年,本行零售推行全面 AI 化,通过科技手段将零售业务在经营、服务、管理方面所需的\n能力标准化、系统化、智能化,并赋能给客户、员工以及第三方合作伙伴,打造全方位的AI Bank。\n目前已孵化出AI客服、AI营销和AI风控等项目成果,其中,2019年3月末AI客服非人工占比已\n达82.2%,较上年末提升2.2个百分点;AI营销已普遍应用于各业务场景中,销售转化效果较常规\n手段提升最高达3倍,并已全面赋能一线队伍,使其具备随时随地为客户提供专业服务的能力;AI\n风控已搭建了统一风控平台,实现了信用风险、欺诈风险的统一管控,并进一步提升了零售信贷产\n品的自动化审核能力,2019年一季度,“AI+风控”欺诈防堵金额为1.41亿元。\n(6)综合金融\n2019年一季度,综拓渠道迁徙客户(不含信用卡)净增74.39万户,占零售整体净增客户(不\n含信用卡)的比例为34.2%,其中财富客户净增2.76万户,占整体净增财富客户的比例为37.9%;\n管理零售客户资产(AUM)余额净增841.31亿元,占零售整体净增客户资产余额的比例为34.1%。\n综拓渠道发放“新一贷”152.36亿元,占“新一贷”整体发放的比例为55.03%;发放汽融贷款121\n亿元,占汽融贷款整体发放的比例为 37.0%。信用卡通过交叉销售渠道发卡 112 万张,在新增发卡\n量中占比为33.7%。零售全渠道代销集团保险累计实现非利息净收入8.73亿元,同比增长40.6%。\n133、对公聚焦精品打造\n2019 年本行公司业务按照对公做精的战略要求,借助科技手段持续打造智能化精品公司银行,\n努力实现在行业、客户、科技、协同四个方面的突破。具体来说,公司业务聚焦行业,以行业化、\n场景化模式推进集成营销;聚焦客户,在目标行业内精选客群,为客户提供系统的、智能的管理平\n台;聚焦科技,将智能与科技的基因融入业务的每个细胞,运用区块链、物联网等技术成果,实现\n业务升级;聚焦协同,利用集团综合金融优势,成为集团团体综合金融的发动机。\n2019 年一季度,对公精品业务打造成效初显,截至 3 月末企业存款余额 17,654.03 亿元,较上\n年末增长5.9%。存款增长得益于支付结算及供应链金融平台的打造,稳定了对公的基础客群。\n(1)精品公司业务经营\n① 互联网支付结算\n互联网支付结算聚焦与国计民生相关的战略行业和互联网新兴产业,深入研究细分行业平台需\n求,开展互联网头部客户攻坚战。同时,进一步锻造互联网支付结算“前中后台”能力,为客户提\n供更加高效的服务体验。一是前台,客户端操作移动化;二是中台,业务处理智能化,简化技术联\n调、项目实施和投产上线业务流程,提高作业效率;三是后台,运用区块链、大数据、反欺诈等金\n融科技手段,有效防范支付结算风险,保障用户资金和交易安全。2019年一季度,本行互联网支付\n结算新投产平台86个,交易笔数超1.5亿笔,平台交易量超5,300亿元。\n② 智能供应链金融\n本行持续升级供应链金融服务,应用人工智能、区块链、云计算、大数据等科技赋能,推出供\n应链应收账款服务平台,聚焦产业链核心企业及其上游客户,提供在线应收账款转让、融资、管理、\n结算等综合金融服务,协同多方构建中小企业融资服务生态。\n2019 年一季度,本行供应链应收账款服务平台累计交易量 78 亿元,已为 137 家核心企业及其\n上游供应商提供了金融服务,单笔融资最小金额仅 5.2 万元,有效解决了中小供应商融资难、融资\n贵问题,支持实体经济发展。\n③ 跨境金融\n本行持续推进跨境金融业务线上化,落地离在岸网银整合项目、离岸口袋财务迭代升级、离岸\n网银线上融资迭代升级、银企直联平台升级服务客群等线上化项目,增加客户可线上化操作的业务\n品种,提升客户体验。同时,本行聚焦跨境互联网交易平台客群,不断丰富跨境e金融平台服务场\n景,构建跨境电商平台金融服务生态。2019 年一季度,跨境 e 金融平台累计线上交易规模近 1,500\n亿元,同比增幅7.2%。\n④ 口袋财务\n本行口袋财务整合产品服务,实现了离岸、国际、现金管理业务移动化,同时积极引入集团综\n合金融产品,与平安证券、平安租赁、平安医保科技在场景方面的合作实现较大突破。2019年3月\n末,口袋财务注册客户27.8万户;2019年一季度,交易笔数85.3万笔、同比增长超7倍,交易金\n额6,419亿元、同比增长超10倍。\n14⑤ 小企业数字金融\n本行小企业数字金融通过深入企业经营场景,借助量化模型和数据分析完善中小企业精准画像,\n解决中小微企业信息不对称,风险较难把控的问题。2019 年,在推动和优化标准化产品的基础上,\n围绕车、房、医疗、大零售、大食品等行业积极探索,进一步为行业终端客户制定综合金融服务方\n案,解决融资难、融资贵的问题。2019 年 3 月末,本行小企业数字金融服务客户数 17,355 户,较\n上年末增长23.1%;2019年一季度,本行小企业数字金融累计发放贷款42亿元。\n⑥ 投行及同业业务\n本行坚定推进机构销售转型,以科技为手段,构建金融服务生态圈,打造金融机构的连接器。\n2019年一季度,本行继续拓展销售产品种类,扩大销售规模,做深客户经营,建立覆盖全国、全类\n型金融机构的销售网络;在此基础上,进一步优化销售产品结构,提高债券、非标资产及资管产品\n销售占比,大力发展交易做市业务及代客业务,提升交易贡献占比。2019年3月末,本行“行e通”\n平台累计合作客户达 2,115 户;2019 年一季度,同业机构销售业务量达到 1,300 亿元,为去年同期\n的 3 倍,债券及非标资产的销售占比进一步提升;利率掉期做市交易量达到 5,730 亿元,同比增长\n57.3%;代客衍生业务交易量实现 36.7 亿美元,同比增长 58.0%。本行积极推动真投行转型和轻型\n融资能力建设,发力债券承销、理财直融工具等领域,满足客户融资需求。2019年一季度,经银保\n监会注册发行的理财直融产品承销规模197亿元,市场发行份额排名位居前列。\n⑦ 政府金融\n顺应国家改革发展新趋势,依托平安集团“金融+科技”,本行在政府金融业务领域从传统的金\n融服务提供方转型为助力改革推进的合作伙伴,以科技赋能提升服务客户综合能力,聚焦财政、住\n房、司法三大生态持续创新。\n2019 年3 月末,本行新上线政府金融类平台 21个,累计达464 个。在财政生态领域,本行与\n平安集团重金所合作,引入区块链、AI智能辅助业务决策,研发涉企资金补助产品,助力政府穿透\n式资金监管。在住房生态领域,公积金业务、商品房交易资金监管等方面均有突破。在司法生态领\n域,锁定破产清算重点突破,研发全流程破产案件管理系统,利用区块链、人脸识别、OCR等技术,\n减轻破产管理人案件管理压力。同时,本行协同集团智慧司法多级联动营销,2019年一季度累计获\n客300户。\n(2)持续推进综合金融,成为集团团体综合金融的发动机\n本行充分运用集团综合金融资源与平台,做强银行渠道,围绕不同场景下的客户群体,主动设\n计场景化、嵌入式、线上化、智能化的综合服务方案,着力推进保险与投融突破。借助科技手段深\n度实现系统流程优化和管理能力升级。2019 年一季度,销售平安团体保险保费规模 2.02 亿元,新\n增银投合作项目规模424.09亿元。\n(3)特殊资产智慧经营\n2019 年,依托集团“金融+科技”优势,本行全力打造特殊资产智慧经营平台,平台借鉴行业\n领先的互联网企业“AI大脑”设计,以大数据为内驱动力,聚焦智能清收、智慧管理、生态经营三\n15大能力建设,推动特殊资产清收向智慧模式转型,不良资产清收业绩大幅提升。2019年一季度,本\n行收回不良资产总额95.22亿元、同比增长26.8%,其中信贷资产(贷款本金)88.88亿元;收回的\n贷款本金中,已核销贷款55.34亿元,未核销不良贷款33.54亿元;不良资产收回额中83.7%为现金\n收回,其余为以物抵债等方式收回。\n4、科技赋能成效显现\n本行将“科技引领”作为全行转型升级的驱动力,实行精益和敏捷的双模研发体系,进一步完\n善数据治理体系,加快技术架构向分布式架构转型,围绕产品创新、客户体验、风险管理,积极探\n索前沿技术的研究创新和实际应用,大力推动人工智能、大数据、区块链等新技术嵌入各项业务场\n景,助推银行在产品、服务、平台等方面转型升级。\n2019 年一季度,在零售条线,全面启动 AI 中台的建设,通过打造知识库平台、生物特征识别\n平台、业务中台等系统平台赋能前端业务场景,目前已在信用卡、投顾、客服、移动柜面、远程运\n营等 40 多项业务场景中进行探索或实践。在对公条线,持续完善产品库和案例库建设,综合运用\n视觉、语音、文本、图像识别等人工智能技术建设对公客户 360°画像平台,持续提升精准营销、\n自动审批、智慧风控等能力;运用深度学习技术,支持对支票、汇票、本票在内的数十种票据快速、\n准确识别,实现票据信息录入和处理的集中化、自动化和便捷化。科技与业务的融合更加紧密,全\n方位的AI Bank体系正加快形成。\n5、资产质量持续改善\n本行积极应对外部风险、调整业务结构,将贷款更多地投放到资产质量较好的零售业务。对公\n持续做精,新发放贷款聚焦重点行业、重点客户和重点产品,同时继续做好存量资产结构调整,加\n大问题资产清收处置力度,资产质量持续改善。\n2019 年 3 月末,本行逾期贷款占比、逾期 90 天以上贷款占比、关注率和不良率较上年末均有\n下降,偏离度低于1。其中:本行逾期贷款占比2.45%,较上年末下降0.03个百分点;逾期90天以\n上贷款占比1.66%,较上年末下降0.04个百分点;关注率2.58%,较上年末下降0.15个百分点;不\n良贷款率1.73%,较上年末下降0.02个百分点。不良贷款偏离度96%,较上年末下降1个百分点。\n2019年一季度,本行计提的信用及资产减值损失为128.85亿元,同比增长20.7%,其中发放贷\n款和垫款计提的减值损失为112.21 亿元;2019年3 月末,贷款减值准备余额为 604.08 亿元,较上\n年末增长 11.5%;拨贷比为 2.94%,较上年末增加 0.23 个百分点;拨备覆盖率为 170.59%、较上年\n末增加15.35个百分点,逾期90天以上贷款拨备覆盖率为177.71%、较上年末增加18.26个百分点,\n风险抵补能力进一步增强。\n6、支持服务实体经济\n本行充分发挥科技的赋能作用,多措并举服务实体经济,全面推进精准扶贫建设。\n一是细化区域政策和行业政策,主动融入国家战略,助力“粤港澳大湾区”、“一带一路”、“长\n三角一体化”等经济带建设,聚焦重点行业,切实服务地方经济发展。2019年3月末,本行重点行\n业客户授信占比48%。\n16二是积极贯彻落实国家要求,加强金融服务民营企业力度,支持小微企业发展。一是在科技运\n用上,运用人工智能、生物识别、大数据、区块链、云计算等前沿科技,打造“供应链应收账款服\n务平台”、“小企业数字金融”、“新一贷”等精品业务,切实支持民营企业、中小微企业发展,解决\n小企业融资难、融资贵问题。二是在制度执行上,一方面采取“差异化纾困+精准化服务”策略,\n借助平安集团纾困基金和团金合作模式,为企业提供定制化金融服务,为经营正常、流动性遇到暂\n时困难的民营企业提供支持;另一方面通过实施差异化信贷定价政策和风险容忍度,借力科技创新、\n产品创新和渠道创新,全方位支持小微企业发展。三是在执行情况和政策效果上,2019 年一季度,\n本行新增投放民营企业贷款客户占新增投放所有对公贷款客户达70%以上;2019年3月末,本行单\n户授信1,000万(含)以下的小微企业贷款较上年末增长3.5%,高于全行贷款平均增速,有贷款余\n额户数较上年末增加8,639户,该类小微企业贷款利率较上年末下降0.66个百分点,不良率控制在\n合理范围。\n三是持续推动精准扶贫,通过“金融+产业”扶贫,建立“金融+科技+产销”的扶贫闭环。2019\n年一季度,本行新增投放扶贫资金 43.85 亿元(其中产业扶贫贷款 12.95 亿元,扶贫政府债30.9亿\n元),累计覆盖和惠及建档立卡贫困人口超过43万人,直接帮扶建档立卡贫困人口超过2,800人。\n7、夯实基础提升资本\n本行持续推动精细化资本管理,全面实施经济资本管理,建立以经济增加值(EVA)和经济资\n本回报率(RAROC)为核心的综合绩效考核机制,通过调配业务、向低风险业务引导,加大税务核\n销力度,加大抵债资产处置力度、减少无效资本占用,提升资本水平。2019年3月末,本行核心一\n级资本充足率、一级资本充足率及资本充足率分别为 8.75%,9.59%及 11.50%,均满足监管达标要\n求。其中,核心一级资本充足率及一级资本充足率较上年末分别提高0.21个百分点、0.20个百分点,\n由于2019年3月7日全额赎回了90亿元二级资本债券,资本充足率与上年末持平。\n本行在通过利润留存补充资本的基础上,积极拓展外源性资本补充渠道,持续推进资本工具的\n发行工作。本行已于2019年1月25日完成发行260亿元A股可转换公司债券,转股后将有效补充\n本行核心一级资本。同时,本行正积极推进300亿元合格二级资本债券的发行工作,募集资金将依\n据适用法律,全部用于补充本行二级资本,进一步提高本行资本充足水平。\n8、合理配置网点布局\n本行持续实施网点智能化建设,合理配置网点布局,截至 2019 年 3 月末,本行有 81 家分行、\n共1,053家营业机构;全国已开业177家零售新门店。\n17(二)资产质量\n1、发放贷款和垫款五级分类情况\n(货币单位:人民币百万元)\n2019年3月31日 2018年12月31日\n期末比上年末\n项 目\n增减\n余额 占比 余额 占比\n发放贷款和垫款减值准备 (60,408) (54,187) 11.5%\n其中:以摊余成本计量的发放贷\n(60,086) (54,033) 11.2%\n款和垫款减值准备\n以公允价值计量且其变动\n计入其他综合收益的发放 (322) (154) 109.1%\n贷款和垫款减值准备\n2、发放贷款和垫款按产品划分的结构分布及质量情况\n(货币单位:人民币百万元)\n2019年3月31日 2018年12月31日\n项 目 不良率增减\n余额 不良率 余额 不良率\n注:“其他”包括持证抵押贷款、小额消费贷款和其他保证或质押类贷款等。\n18企业贷款不良率较上年末有所下降,一方面本行持续加强问题授信处置力度,严格管控正常资\n产的新增下迁;另一方面,强化资产质量管控体系建设和完善机制保障,进一步夯实资产质量管理\n基础。同时,坚持新客户准入风险标准,从源头改善资产结构,提升资产质量。\n受宏观经济下行、共债风险上升、汽车消费下滑等外部因素影响,消费金融全行业风险都有所\n上升,本行个人贷款不良率较上年末也略有上升,但整体风险表现仍维持在相对较低水平,主要情\n况如下:\n(1)本行住房按揭贷款通过持续调整和优化客群结构,加大对优质客户的投放力度,提升新\n发放贷款质量,将按揭贷款不良率维持在较低的水平。\n(2)本行“新一贷”主动调整发放节奏,一方面进一步优化客户结构、提高新发放优质客群\n占比、对不同风险级别的客户进行差异化风险管理,另一方面持续加大催清收资源投入、多策并举\n提升催清收成效,整体风险水平在可控范围;同时本行充分运用互联网技术提升产品竞争力,运用\n人脸识别、微表情技术、智能语音等前沿科技手段,实现业务集中化、自动化、线上化和智能化,\n提升客户体验、服务效率,持续强化风险控制。\n(3)本行汽车金融业务由于2018年以来进行的产品结构调整,高收益产品发放占比进一步提\n升,不良率虽小幅上升,但依然处于可控范围内。同时本行通过对车辆识别技术与VIN码识别技术\n的深度运用,实现车型自动化精准定位,自动化的后台估价体系的搭建,实现全流程识人、识车、\n识价一秒完成,有效降低人工成本,提升业务能效,规避信贷及欺诈风险。在催收管理方面,持续\n推进催收系统优化升级,提升整体作业效率,有效融入科技手段,实现失联客户信息修复,以确保\n整体资产质量保持稳定,总体风险在可控范围之内。\n(4)本行信用卡业务全流程贯彻风险管理理念,充分利用量化工具,有效管控风险。一方面\n通过大数据平台和先进定量分析技术,结合风控模型、AI智能识别等科学工具全面优化风险管理策\n略,有效改善新户获客结构与品质,优化存量结构,确保组合资产可持续发展。另一方面,在科技\n引领的理念支持下,引入 AI 智能催收降低迁徙率,并持续优化进行规模推广;优化评分模型推进\n差异化催收策略;精细梳理不良资产,增加催清收资源投入,确保风险水平可控。\n193、逾期贷款情况\n(货币单位:人民币百万元)\n余额 占贷款总额比 余额 占贷款总额比\n2019年3月末,本行逾期90天以内贷款(含本金未逾期,利息逾期90天以内贷款)余额162.82\n亿元,占贷款总额比 0.79%,较上年末略升 0.01 个百分点;逾期 90 天以上贷款(含本金未逾期,\n利息逾期90天以上贷款)余额339.93亿元,占贷款总额比1.66%,较上年末下降0.04个百分点。\n本行已积极采取各项措施,分类制定清收和重组转化方案,加大问题资产清收处置力度,进一步做\n好风险管理和化解工作。\n(三)利息收支情况\n1、主要资产、负债项目的日均余额以及平均收益率或平均成本率情况\n(货币单位:人民币百万元)\n2019年1-3月 2018年1-3月\n项 目 利息收入/ 平均收益/ 利息收入/ 平均收益/\n日均余额 日均余额\n支出 成本率 支出 成本率\n资产\n负债\n本行持续优化业务结构,个人贷款规模和占比增加,生息资产收益率有所提升,计息负债成本\n率进一步下降,2019 年一季净利差、净息差分别为 2.44%、2.53%,同比分别提升 39 个基点、28\n个基点,环比均提升3个基点。\n20(货币单位:人民币百万元)\n2019年1-3月 2018年10-12月\n项 目 利息收入/ 平均收益/ 利息收入/ 平均收益/\n日均余额 日均余额\n支出 成本率 支出 成本率\n资产\n负债\n2、发放贷款和垫款日均余额及平均收益率\n(货币单位:人民币百万元)\n日均余额 利息收入 平均收益率 日均余额 利息收入 平均收益率\n日均余额 利息收入 平均收益率 日均余额 利息收入 平均收益率\n213、吸收存款日均余额及平均成本率\n(货币单位:人民币百万元)\n日均余额 利息支出 平均成本率 日均余额 利息支出 平均成本率\n日均余额 利息支出 平均成本率 日均余额 利息支出 平均成本率\n22十二、资本充足率、杠杆率、流动性覆盖率情况\n(一)资本充足率\n(货币单位:人民币百万元)\n项 目 2019年3月31日 2018年12月31日\n(二)杠杆率\n(货币单位:人民币百万元)\n项 目 2019年3月31日 2018年12月31日 2018年9月30日 2018年6月30日\n注:主要因核心一级资本净额增加,报告期末杠杆率较2018年末增加。\n(三)流动性覆盖率\n(货币单位:人民币百万元)\n项 目 2019年3月31日 2018年12月31日\n注:根据银保监会发布的《商业银行流动性风险管理办法》,商业银行的流动性覆盖率应当在2018年底前达到100%。\n23第四节 财务报表\n一、财务报表\n(一)资产负债表(未经审计)\n(二)利润表(未经审计)\n(三)现金流量表(未经审计)\n24平安银行股份有限公司\n资产负债表\n2019年3月31日\n货币单位:人民币百万元\n金融投资:\n负债和股东权益\n负债\n股东权益\n25平安银行股份有限公司\n利润表\n2019年1-3月\n货币单位:人民币百万元\n项 目 2019年1-3月 2018年1-3月\n其中:以摊余成本计量的金融资产终止确认\n3 6\n产生的收益\n1.以公允价值计量且其变动计入其他综\n291 137\n合收益的金融资产的公允价值变动\n2.以公允价值计量且其变动计入其他综\n314 14\n合收益的金融资产的信用损失准备\n八、综合收益总额 8,045 6,746\n九、每股收益\n26平安银行股份有限公司\n现金流量表\n2019年1-3月\n货币单位:人民币百万元\n项 目 2019年1-3月 2018年1-3月\n一、经营活动产生的现金流量:\n二、投资活动产生的现金流量:\n三、筹资活动产生的现金流量:\n27二、审计报告\n第一季度报告是否经过审计\n□是 √否\n本行第一季度报告未经审计。\n平安银行股份有限公司董事会\n2019年4月24日\n28' # noqa: E501 + }] + + self._run_remove_header(samples) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_remove_words_with_incorrect_substrings_mapper.py b/tests/ops/mapper/test_remove_words_with_incorrect_substrings_mapper.py new file mode 100644 index 000000000..ad1fbe183 --- /dev/null +++ b/tests/ops/mapper/test_remove_words_with_incorrect_substrings_mapper.py @@ -0,0 +1,56 @@ +import unittest + +from data_juicer.ops.mapper.remove_words_with_incorrect_substrings_mapper import \ + RemoveWordsWithIncorrectSubstringsMapper # noqa: E501 + + +class RemoveWordsWithIncorrectSubstringsMapperTest(unittest.TestCase): + + def _run_remove_words_with_incorrect_sbstrings(self, samples, op): + for sample in samples: + result = op.process(sample) + self.assertEqual(result['text'], result['target']) + + def test_en_case(self): + + samples = [ + { + 'text': + 'This paper proposed a novel https://whiugc.com method on LLM', + 'target': 'This paper proposed a novel method on LLM' + }, + { + 'text': + "plusieurs èrdash@hqbchd.ckd d'accéder à ces wwwasdasd fonc", + 'target': "plusieurs èrdash@hqbchd.ckd d'accéder à ces fonc" + }, + ] + + op = RemoveWordsWithIncorrectSubstringsMapper( + substrings=['http', 'www', '.com', 'href', '//']) + self._run_remove_words_with_incorrect_sbstrings(samples, op) + + def test_zh_case(self): + + samples = [{ + 'text': '你好,请问你是谁', + 'target': '你好,请问你是谁' + }, { + 'text': '欢迎来到阿里巴巴!', + 'target': '欢迎来到阿里巴巴!' + }, { + 'text': '根据算子使用情况增量安装方案确定', + 'target': '根据使用情况增量安装方案确定' + }, { + 'text': '请用百度www.baidu.com进行搜索', + 'target': '请用百度www.baidu.进行搜索' + }] + + op = RemoveWordsWithIncorrectSubstringsMapper(lang='zh', + tokenization=True, + substrings=['com', '算子']) + self._run_remove_words_with_incorrect_sbstrings(samples, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_sentence_split_mapper.py b/tests/ops/mapper/test_sentence_split_mapper.py new file mode 100644 index 000000000..abd914bda --- /dev/null +++ b/tests/ops/mapper/test_sentence_split_mapper.py @@ -0,0 +1,83 @@ +import unittest + +from data_juicer.ops.mapper.sentence_split_mapper import SentenceSplitMapper + + +class SentenceSplitMapperTest(unittest.TestCase): + + def _run_helper(self, op, samples): + for sample in samples: + result = op.process(sample) + self.assertEqual(result['text'], result['target']) + + def test_en_text(self): + + samples = [ + { + 'text': + 'Smithfield employs 3,700 people at its plant in Sioux Falls, ' + 'South Dakota. The plant slaughters 19,500 pigs a day — 5 ' + 'percent of U.S. pork.', + 'target': + 'Smithfield employs 3,700 people at its plant in Sioux Falls, ' + 'South Dakota.\nThe plant slaughters 19,500 pigs a day — 5 ' + 'percent of U.S. pork.' + }, + ] + op = SentenceSplitMapper('en') + self._run_helper(op, samples) + + def test_fr_text(self): + + samples = [ + { + 'text': + 'Smithfield emploie 3,700 personnes dans son usine de' + ' Sioux Falls, dans le Dakota du Sud. L\'usine ' + 'abat 19 500 porcs par jour, soit 5 % du porc américain.', + 'target': + 'Smithfield emploie 3,700 personnes dans son usine de' + ' Sioux Falls, dans le Dakota du Sud.\nL\'usine ' + 'abat 19 500 porcs par jour, soit 5 % du porc américain.' + }, + ] + op = SentenceSplitMapper('fr') + self._run_helper(op, samples) + + def test_pt_text(self): + + samples = [ + { + 'text': + 'A Smithfield emprega 3.700 pessoas em sua fábrica em ' + 'Sioux Falls, Dakota do Sul. A fábrica ' + 'abate 19.500 porcos por dia – 5% da carne suína dos EUA.', + 'target': + 'A Smithfield emprega 3.700 pessoas em sua fábrica em ' + 'Sioux Falls, Dakota do Sul.\nA fábrica abate 19.500 ' + 'porcos por dia – 5% da carne suína dos EUA.' + }, + ] + op = SentenceSplitMapper('pt') + self._run_helper(op, samples) + + def test_es_text(self): + + samples = [ + { + 'text': + 'Smithfield emplea a 3.700 personas en su planta de ' + 'Sioux Falls, Dakota del Sur. La planta sacrifica 19.500 ' + 'cerdos al día, el 5 por ciento de la carne de cerdo de EE.', + 'target': + 'Smithfield emplea a 3.700 personas en su planta de Sioux ' + 'Falls, Dakota del Sur.\nLa planta sacrifica 19.500 cerdos ' + 'al día, el 5 por ciento de la carne de cerdo de EE.' + }, + ] + op = SentenceSplitMapper('es') + self._run_helper(op, samples) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_whitespace_normalization_mapper.py b/tests/ops/mapper/test_whitespace_normalization_mapper.py new file mode 100644 index 000000000..0bffdf60c --- /dev/null +++ b/tests/ops/mapper/test_whitespace_normalization_mapper.py @@ -0,0 +1,28 @@ +import unittest + +from data_juicer.ops.mapper.whitespace_normalization_mapper import \ + WhitespaceNormalizationMapper + + +class WhitespaceNormalizationMapperTest(unittest.TestCase): + + def setUp(self): + self.op = WhitespaceNormalizationMapper() + + def _run_whitespace_normalization(self, samples): + for sample in samples: + result = self.op.process(sample) + self.assertEqual(result['text'], result['target']) + + def test_case(self): + + samples = [{ + 'text': 'x \t               \u200B\u200C\u200D\u2060\u0084y', + 'target': 'x y' + }] + + self._run_whitespace_normalization(samples) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/selector/__init__.py b/tests/ops/selector/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/ops/selector/test_frequency_specified_field_selector.py b/tests/ops/selector/test_frequency_specified_field_selector.py new file mode 100644 index 000000000..9ed362229 --- /dev/null +++ b/tests/ops/selector/test_frequency_specified_field_selector.py @@ -0,0 +1,552 @@ +import unittest + +from datasets import Dataset + +from data_juicer.ops.selector.frequency_specified_field_selector import \ + FrequencySpecifiedFieldSelector + + +class FrequencySpecifiedFieldSelectorTest(unittest.TestCase): + + def _run_frequency_selector(self, dataset: Dataset, target_list, op): + dataset = op.process(dataset) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_topratio_select(self): + ds_list = [{ + 'text': 'Today is Sun', + 'count': 101, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 34 + }, + 'count': 5 + } + } + }, { + 'text': 'a v s e c s f e f g a a a ', + 'count': 16, + 'meta': { + 'suffix': '.docx', + 'key1': { + 'key2': { + 'count': 243 + }, + 'count': 63 + } + } + }, { + 'text': '中文也是一个字算一个长度', + 'count': 162, + 'meta': { + 'suffix': '.txt', + 'key1': { + 'key2': { + 'count': None + }, + 'count': 23 + } + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'count': None, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 18 + }, + 'count': 48 + } + } + }, { + 'text': '他的英文名字叫Harry Potter', + 'count': 88, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 551 + }, + 'count': 78 + } + } + }, { + 'text': '这是一个测试', + 'count': None, + 'meta': { + 'suffix': '.py', + 'key1': { + 'key2': { + 'count': 89 + }, + 'count': 3 + } + } + }, { + 'text': '我出生于2023年12月15日', + 'count': None, + 'meta': { + 'suffix': '.java', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 67 + } + } + }, { + 'text': 'emoji表情测试下😊,😸31231\n', + 'count': 2, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 32 + } + } + }, { + 'text': 'a=1\nb\nc=1+2+3+5\nd=6', + 'count': 178, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 33 + }, + 'count': 33 + } + } + }, { + 'text': '使用片段分词器对每个页面进行分词,使用语言', + 'count': 666, + 'meta': { + 'suffix': '.xml', + 'key1': { + 'key2': { + 'count': 18 + }, + 'count': 48 + } + } + }] + tgt_list = [{ + 'text': 'Today is Sun', + 'count': 101, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 34 + }, + 'count': 5 + } + } + }, { + 'text': '他的英文名字叫Harry Potter', + 'count': 88, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 551 + }, + 'count': 78 + } + } + }, { + 'text': 'a=1\nb\nc=1+2+3+5\nd=6', + 'count': 178, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 33 + }, + 'count': 33 + } + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'count': None, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 18 + }, + 'count': 48 + } + } + }, { + 'text': 'emoji表情测试下😊,😸31231\n', + 'count': 2, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 32 + } + } + }] + dataset = Dataset.from_list(ds_list) + op = FrequencySpecifiedFieldSelector(text_key='meta.suffix', + top_ratio=0.3, + topk=5, + reverse=True) + self._run_frequency_selector(dataset, tgt_list, op) + + def test_topk_select(self): + ds_list = [{ + 'text': 'Today is Sun', + 'count': 101, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 34 + }, + 'count': 5 + } + } + }, { + 'text': 'a v s e c s f e f g a a a ', + 'count': 16, + 'meta': { + 'suffix': '.docx', + 'key1': { + 'key2': { + 'count': 243 + }, + 'count': 63 + } + } + }, { + 'text': '中文也是一个字算一个长度', + 'count': 162, + 'meta': { + 'suffix': '.txt', + 'key1': { + 'key2': { + 'count': None + }, + 'count': 23 + } + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'count': None, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 34 + }, + 'count': 48 + } + } + }, { + 'text': '他的英文名字叫Harry Potter', + 'count': 88, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 551 + }, + 'count': 78 + } + } + }, { + 'text': '这是一个测试', + 'count': None, + 'meta': { + 'suffix': '.py', + 'key1': { + 'key2': { + 'count': 89 + }, + 'count': 3 + } + } + }, { + 'text': '我出生于2023年12月15日', + 'count': None, + 'meta': { + 'suffix': '.java', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 67 + } + } + }, { + 'text': 'emoji表情测试下😊,😸31231\n', + 'count': 2, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 32 + } + } + }, { + 'text': 'a=1\nb\nc=1+2+3+5\nd=6', + 'count': 178, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 34 + }, + 'count': 33 + } + } + }, { + 'text': '使用片段分词器对每个页面进行分词,使用语言', + 'count': 666, + 'meta': { + 'suffix': '.xml', + 'key1': { + 'key2': { + 'count': 18 + }, + 'count': 48 + } + } + }] + tgt_list = [{ + 'text': 'Today is Sun', + 'count': 101, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 34 + }, + 'count': 5 + } + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'count': None, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 34 + }, + 'count': 48 + } + } + }, { + 'text': 'a=1\nb\nc=1+2+3+5\nd=6', + 'count': 178, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 34 + }, + 'count': 33 + } + } + }, { + 'text': '我出生于2023年12月15日', + 'count': None, + 'meta': { + 'suffix': '.java', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 67 + } + } + }, { + 'text': 'emoji表情测试下😊,😸31231\n', + 'count': 2, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 32 + } + } + }] + dataset = Dataset.from_list(ds_list) + op = FrequencySpecifiedFieldSelector(text_key='meta.key1.key2.count', + top_ratio=0.4, + topk=2, + reverse=True) + self._run_frequency_selector(dataset, tgt_list, op) + + def test_reverse_select(self): + ds_list = [{ + 'text': 'Today is Sun', + 'count': 101, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 34 + }, + 'count': 5 + } + } + }, { + 'text': 'a v s e c s f e f g a a a ', + 'count': 16, + 'meta': { + 'suffix': '.docx', + 'key1': { + 'key2': { + 'count': 243 + }, + 'count': 63 + } + } + }, { + 'text': '中文也是一个字算一个长度', + 'count': 162, + 'meta': { + 'suffix': '.txt', + 'key1': { + 'key2': { + 'count': None + }, + 'count': 23 + } + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'count': None, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 34 + }, + 'count': 48 + } + } + }, { + 'text': '他的英文名字叫Harry Potter', + 'count': 88, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 551 + }, + 'count': 78 + } + } + }, { + 'text': '这是一个测试', + 'count': None, + 'meta': { + 'suffix': '.py', + 'key1': { + 'key2': { + 'count': 89 + }, + 'count': 3 + } + } + }, { + 'text': '我出生于2023年12月15日', + 'count': None, + 'meta': { + 'suffix': '.java', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 67 + } + } + }, { + 'text': 'emoji表情测试下😊,😸31231\n', + 'count': 2, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 32 + } + } + }, { + 'text': 'a=1\nb\nc=1+2+3+5\nd=6', + 'count': 178, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 34 + }, + 'count': 33 + } + } + }, { + 'text': '使用片段分词器对每个页面进行分词,使用语言', + 'count': 666, + 'meta': { + 'suffix': '.xml', + 'key1': { + 'key2': { + 'count': 18 + }, + 'count': 48 + } + } + }] + tgt_list = [{ + 'text': 'a v s e c s f e f g a a a ', + 'count': 16, + 'meta': { + 'suffix': '.docx', + 'key1': { + 'key2': { + 'count': 243 + }, + 'count': 63 + } + } + }, { + 'text': '中文也是一个字算一个长度', + 'count': 162, + 'meta': { + 'suffix': '.txt', + 'key1': { + 'key2': { + 'count': None + }, + 'count': 23 + } + } + }] + dataset = Dataset.from_list(ds_list) + op = FrequencySpecifiedFieldSelector(text_key='meta.key1.key2.count', + top_ratio=0.4, + topk=2, + reverse=False) + self._run_frequency_selector(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/selector/test_topk_specified_field_selector.py b/tests/ops/selector/test_topk_specified_field_selector.py new file mode 100644 index 000000000..8a51db54e --- /dev/null +++ b/tests/ops/selector/test_topk_specified_field_selector.py @@ -0,0 +1,682 @@ +import unittest + +from datasets import Dataset + +from data_juicer.ops.selector.topk_specified_field_selector import \ + TopkSpecifiedFieldSelector + + +class TopkSpecifiedFieldSelectorTest(unittest.TestCase): + + def _run_topk_selector(self, dataset: Dataset, target_list, op): + dataset = op.process(dataset) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_topratio_select(self): + ds_list = [{ + 'text': 'Today is Sun', + 'count': 101, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 34 + }, + 'count': 5 + } + } + }, { + 'text': 'a v s e c s f e f g a a a ', + 'count': 16, + 'meta': { + 'suffix': '.docx', + 'key1': { + 'key2': { + 'count': 243 + }, + 'count': 63 + } + } + }, { + 'text': '中文也是一个字算一个长度', + 'count': 162, + 'meta': { + 'suffix': '.txt', + 'key1': { + 'key2': { + 'count': None + }, + 'count': 23 + } + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'count': None, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 18 + }, + 'count': 48 + } + } + }, { + 'text': '他的英文名字叫Harry Potter', + 'count': 88, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 551 + }, + 'count': 78 + } + } + }, { + 'text': '这是一个测试', + 'count': None, + 'meta': { + 'suffix': '.py', + 'key1': { + 'key2': { + 'count': 89 + }, + 'count': 3 + } + } + }, { + 'text': '我出生于2023年12月15日', + 'count': None, + 'meta': { + 'suffix': '.java', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 67 + } + } + }, { + 'text': 'emoji表情测试下😊,😸31231\n', + 'count': 2, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 32 + } + } + }, { + 'text': 'a=1\nb\nc=1+2+3+5\nd=6', + 'count': 178, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 33 + }, + 'count': 33 + } + } + }, { + 'text': '使用片段分词器对每个页面进行分词,使用语言', + 'count': 666, + 'meta': { + 'suffix': '.xml', + 'key1': { + 'key2': { + 'count': 18 + }, + 'count': 48 + } + } + }] + tgt_list = [{ + 'text': '他的英文名字叫Harry Potter', + 'count': 88, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 551 + }, + 'count': 78 + } + } + }, { + 'text': '我出生于2023年12月15日', + 'count': None, + 'meta': { + 'suffix': '.java', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 67 + } + } + }] + dataset = Dataset.from_list(ds_list) + op = TopkSpecifiedFieldSelector(text_key='meta.key1.key2.count', + top_ratio=0.2, + topk=5, + reverse=True) + self._run_topk_selector(dataset, tgt_list, op) + + def test_topk_select(self): + ds_list = [{ + 'text': 'Today is Sun', + 'count': 101, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 34 + }, + 'count': 5 + } + } + }, { + 'text': 'a v s e c s f e f g a a a ', + 'count': 16, + 'meta': { + 'suffix': '.docx', + 'key1': { + 'key2': { + 'count': 243 + }, + 'count': 63 + } + } + }, { + 'text': '中文也是一个字算一个长度', + 'count': 162, + 'meta': { + 'suffix': '.txt', + 'key1': { + 'key2': { + 'count': None + }, + 'count': 23 + } + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'count': None, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 18 + }, + 'count': 48 + } + } + }, { + 'text': '他的英文名字叫Harry Potter', + 'count': 88, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 551 + }, + 'count': 78 + } + } + }, { + 'text': '这是一个测试', + 'count': None, + 'meta': { + 'suffix': '.py', + 'key1': { + 'key2': { + 'count': 89 + }, + 'count': 3 + } + } + }, { + 'text': '我出生于2023年12月15日', + 'count': None, + 'meta': { + 'suffix': '.java', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 67 + } + } + }, { + 'text': 'emoji表情测试下😊,😸31231\n', + 'count': 2, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 32 + } + } + }, { + 'text': 'a=1\nb\nc=1+2+3+5\nd=6', + 'count': 178, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 33 + }, + 'count': 33 + } + } + }, { + 'text': '使用片段分词器对每个页面进行分词,使用语言', + 'count': 666, + 'meta': { + 'suffix': '.xml', + 'key1': { + 'key2': { + 'count': 18 + }, + 'count': 48 + } + } + }] + tgt_list = [{ + 'text': '他的英文名字叫Harry Potter', + 'count': 88, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 551 + }, + 'count': 78 + } + } + }, { + 'text': '我出生于2023年12月15日', + 'count': None, + 'meta': { + 'suffix': '.java', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 67 + } + } + }, { + 'text': 'emoji表情测试下😊,😸31231\n', + 'count': 2, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 32 + } + } + }, { + 'text': 'a v s e c s f e f g a a a ', + 'count': 16, + 'meta': { + 'suffix': '.docx', + 'key1': { + 'key2': { + 'count': 243 + }, + 'count': 63 + } + } + }] + dataset = Dataset.from_list(ds_list) + op = TopkSpecifiedFieldSelector(text_key='meta.key1.key2.count', + top_ratio=0.5, + topk=4, + reverse=True) + self._run_topk_selector(dataset, tgt_list, op) + + def test_reverse_select(self): + ds_list = [{ + 'text': 'Today is Sun', + 'count': 101, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 34 + }, + 'count': 5 + } + } + }, { + 'text': 'a v s e c s f e f g a a a ', + 'count': 16, + 'meta': { + 'suffix': '.docx', + 'key1': { + 'key2': { + 'count': 243 + }, + 'count': 63 + } + } + }, { + 'text': '中文也是一个字算一个长度', + 'count': 162, + 'meta': { + 'suffix': '.txt', + 'key1': { + 'key2': { + 'count': None + }, + 'count': 23 + } + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'count': None, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 18 + }, + 'count': 48 + } + } + }, { + 'text': '他的英文名字叫Harry Potter', + 'count': 88, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 551 + }, + 'count': 78 + } + } + }, { + 'text': '这是一个测试', + 'count': None, + 'meta': { + 'suffix': '.py', + 'key1': { + 'key2': { + 'count': 89 + }, + 'count': 3 + } + } + }, { + 'text': '我出生于2023年12月15日', + 'count': None, + 'meta': { + 'suffix': '.java', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 67 + } + } + }, { + 'text': 'emoji表情测试下😊,😸31231\n', + 'count': 2, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 354.32 + }, + 'count': 32 + } + } + }, { + 'text': 'a=1\nb\nc=1+2+3+5\nd=6', + 'count': 178, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 33 + }, + 'count': 33 + } + } + }, { + 'text': '使用片段分词器对每个页面进行分词,使用语言', + 'count': 666, + 'meta': { + 'suffix': '.xml', + 'key1': { + 'key2': { + 'count': 2 + }, + 'count': 48 + } + } + }] + tgt_list = [{ + 'text': '使用片段分词器对每个页面进行分词,使用语言', + 'count': 666, + 'meta': { + 'suffix': '.xml', + 'key1': { + 'key2': { + 'count': 2 + }, + 'count': 48 + } + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'count': None, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': 18 + }, + 'count': 48 + } + } + }, { + 'text': 'a=1\nb\nc=1+2+3+5\nd=6', + 'count': 178, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': 33 + }, + 'count': 33 + } + } + }] + dataset = Dataset.from_list(ds_list) + op = TopkSpecifiedFieldSelector(text_key='meta.key1.key2.count', + top_ratio=0.5, + topk=3, + reverse=False) + self._run_topk_selector(dataset, tgt_list, op) + + def test_str_select(self): + ds_list = [{ + 'text': 'Today is Sun', + 'count': 101, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': '34' + }, + 'count': 5 + } + } + }, { + 'text': 'a v s e c s f e f g a a a ', + 'count': 16, + 'meta': { + 'suffix': '.docx', + 'key1': { + 'key2': { + 'count': '243' + }, + 'count': 63 + } + } + }, { + 'text': '中文也是一个字算一个长度', + 'count': 162, + 'meta': { + 'suffix': '.txt', + 'key1': { + 'key2': { + 'count': '' + }, + 'count': 23 + } + } + }, { + 'text': ',。、„”“«»1」「《》´∶:?!', + 'count': None, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': None + }, + 'count': 48 + } + } + }, { + 'text': '他的英文名字叫Harry Potter', + 'count': 88, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': '551' + }, + 'count': 78 + } + } + }, { + 'text': '这是一个测试', + 'count': None, + 'meta': { + 'suffix': '.py', + 'key1': { + 'key2': { + 'count': '89' + }, + 'count': 3 + } + } + }, { + 'text': '我出生于2023年12月15日', + 'count': None, + 'meta': { + 'suffix': '.java', + 'key1': { + 'key2': { + 'count': '354.32' + }, + 'count': 67 + } + } + }, { + 'text': 'emoji表情测试下😊,😸31231\n', + 'count': 2, + 'meta': { + 'suffix': '.html', + 'key1': { + 'key2': { + 'count': '354.32' + }, + 'count': 32 + } + } + }, { + 'text': 'a=1\nb\nc=1+2+3+5\nd=6', + 'count': 178, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': '33' + }, + 'count': 33 + } + } + }, { + 'text': '使用片段分词器对每个页面进行分词,使用语言', + 'count': 666, + 'meta': { + 'suffix': '.xml', + 'key1': { + 'key2': { + 'count': '2' + }, + 'count': 48 + } + } + }] + tgt_list = [{ + 'text': '使用片段分词器对每个页面进行分词,使用语言', + 'count': 666, + 'meta': { + 'suffix': '.xml', + 'key1': { + 'key2': { + 'count': '2' + }, + 'count': 48 + } + } + }, { + 'text': 'a=1\nb\nc=1+2+3+5\nd=6', + 'count': 178, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': '33' + }, + 'count': 33 + } + } + }, { + 'text': 'Today is Sun', + 'count': 101, + 'meta': { + 'suffix': '.pdf', + 'key1': { + 'key2': { + 'count': '34' + }, + 'count': 5 + } + } + }] + dataset = Dataset.from_list(ds_list) + op = TopkSpecifiedFieldSelector(text_key='meta.key1.key2.count', + top_ratio=0.5, + topk=3, + reverse=False) + self._run_topk_selector(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/run.py b/tests/run.py new file mode 100644 index 000000000..6d9d6c880 --- /dev/null +++ b/tests/run.py @@ -0,0 +1,56 @@ +# The code is from adapted from +# https://github.com/alibaba/FederatedScope/blob/master/tests/run.py + +# Data-Juicer adopts Apache 2.0 license, the original license of this file +# is as follows: +# -------------------------------------------------------- +# Copyright (c) Alibaba, Inc. and its affiliates + +import argparse +import os +import sys +import unittest + +file_dir = os.path.join(os.path.dirname(__file__), '..') +sys.path.append(file_dir) + +parser = argparse.ArgumentParser('test runner') +parser.add_argument('--list_tests', action='store_true', help='list all tests') +parser.add_argument('--pattern', default='test_*.py', help='test file pattern') +parser.add_argument('--test_dir', + default='tests', + help='directory to be tested') +args = parser.parse_args() + + +def gather_test_cases(test_dir, pattern, list_tests): + test_suite = unittest.TestSuite() + discover = unittest.defaultTestLoader.discover(test_dir, + pattern=pattern, + top_level_dir=None) + for suite_discovered in discover: + + for test_case in suite_discovered: + test_suite.addTest(test_case) + if hasattr(test_case, '__iter__'): + for subcase in test_case: + if list_tests: + print(subcase) + else: + if list_tests: + print(test_case) + return test_suite + + +def main(): + runner = unittest.TextTestRunner() + test_suite = gather_test_cases(os.path.abspath(args.test_dir), + args.pattern, args.list_tests) + if not args.list_tests: + res = runner.run(test_suite) + if not res.wasSuccessful(): + exit(1) + + +if __name__ == '__main__': + main() diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/thirdparty/README.md b/thirdparty/README.md new file mode 100644 index 000000000..ae30d53ec --- /dev/null +++ b/thirdparty/README.md @@ -0,0 +1,37 @@ +# Third-parties (LLM Ecosystems) + +Dependencies of Auto Evaluation Toolkit, see `tools/evaluator/README.md` for more details. + +## Installation + +The auto-evaluation toolkit requires customized Megatron-LM and HELM. +To avoid dependency problems when installing those packages, we recommand using NGC's PyTorch container (`nvcr.io/nvidia/pytorch:22.12-py3`). +Assuming the path to your shared file system (where your data and model checkpoints are saved) is `/mnt/shared`, start the docker container with following commands. + +```shell +docker pull nvcr.io/nvidia/pytorch:22.12-py3 +docker run --gpus all -it --rm -v /mnt/shared:/workspace +``` + +After starting the docker container, run the following scripts in the container to install Megatron-LM or HELM. + +The training machines only need to install Megatron-LM: + +```shell +./setup_megatron.sh +``` + +The evaluation machine needs to install both Megatron-LM and HELM + +```shell +./setup_megatron.sh +./setup_helm.sh +``` + +The toolkit use [WandB](https://wandb.ai/) to monitor the trend of metrics during training. Above steps have installed wandb, and you only need to run `wandb login` and enter your wandb API key. If you have your own instance of wandb, run the following script. + +```shell +wandb login --host +# enter your api key +``` + diff --git a/thirdparty/README_ZH.md b/thirdparty/README_ZH.md new file mode 100644 index 000000000..552ca1009 --- /dev/null +++ b/thirdparty/README_ZH.md @@ -0,0 +1,36 @@ +# 第三方库(大语言模型生态) + +本目录包含了 Auto Evaluation Toolkit 的第三方依赖项,更多细节请参考 `tools/evaluator/README_ZH.md`。 + +## 安装 + +Auto Evaluation Toolkit 依赖于定制化的 [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) 和 [HELM](https://github.com/stanford-crfm/helm)。 +为了避免安装这些软件包时可能出现的依赖项问题,我们建议使用 NGC 的 Pytorch 容器(`nvcr.io/nvidia/pytorch:22.12-py3`)。 +假设您共享文件系统的路径(即数据集和模型检查点的存储路径)为`/mnt/shared`,请使用如下指令启动 Docker 容器。 + +```shell +docker pull nvcr.io/nvidia/pytorch:22.12-py3 +docker run --gpus all -it --rm -v /mnt/shared:/workspace +``` + +启动 Docker 容器后,在容器中运行以下脚本以安装 Megatron-LM 或 HELM。 + +训练机只需要安装Megatron-LM: + +```shell +./setup_megatron.sh +``` + +评测机需要同时安装Megatron-LM和HELM + +```shell +./setup_megatron.sh +./setup_helm.sh +``` + +工具包使用[WandB](https://wandb.ai/)来监视训练期间各指标的趋势。上面的步骤中已安装 wandb,您只需要运行 `wand login` 并输入 wandb API 密钥即可。如果您有自己的wandb实例,请运行以下脚本。 + +```shell +wandb login --host +#输入您的API密钥 +``` diff --git a/thirdparty/patch/helm.diff b/thirdparty/patch/helm.diff new file mode 100644 index 000000000..b77d066f3 --- /dev/null +++ b/thirdparty/patch/helm.diff @@ -0,0 +1,806 @@ +diff --git a/requirements.txt b/requirements.txt +index 2a97e02b..bec4a16a 100644 +--- a/requirements.txt ++++ b/requirements.txt +@@ -67,8 +67,8 @@ spacy~=3.2.4 + summ-eval~=0.892 + surge-api~=1.1.0 + # End users should install a CUDA version of PyTorch manually if needed +-torch~=1.12.1 # Summarization metrics +-torchvision~=0.13.1 ++torch~=1.13.0 # Summarization metrics ++torchvision~=0.14.0 + + # plotting + colorcet~=3.0.1 +diff --git a/setup.cfg b/setup.cfg +index fce3cfc8..2f9b6d12 100644 +--- a/setup.cfg ++++ b/setup.cfg +@@ -75,8 +75,8 @@ install_requires= + summ-eval~=0.892 + surge-api~=1.1.0 + # End users should install a CUDA version of PyTorch manually if needed +- torch~=1.12.1 # Summarization metrics +- torchvision~=0.13.1 ++ torch~=1.13.0 # Summarization metrics ++ torchvision~=0.14.0 + + # plotting + colorcet~=3.0.1 +diff --git a/src/helm/benchmark/executor.py b/src/helm/benchmark/executor.py +index 9f619df4..d60ccfa8 100644 +--- a/src/helm/benchmark/executor.py ++++ b/src/helm/benchmark/executor.py +@@ -32,6 +32,8 @@ class ExecutionSpec: + # How many threads to have at once + parallelism: int + ++ my_config_path: Optional[str] ++ + # Whether to skip execution + dry_run: bool = False + +@@ -58,7 +60,8 @@ class Executor: + elif execution_spec.local_path: + hlog(f"Running in local mode with base path: {execution_spec.local_path}") + self.service = ServerService( +- base_path=execution_spec.local_path, root_mode=True, mongo_uri=execution_spec.mongo_uri ++ base_path=execution_spec.local_path, root_mode=True, mongo_uri=execution_spec.mongo_uri, ++ my_config_path=execution_spec.my_config_path + ) + else: + raise ValueError("Either the proxy server URL or the local path must be set") +@@ -73,7 +76,7 @@ class Executor: + request_states = parallel_map( + self.process, + scenario_state.request_states, +- parallelism=self.execution_spec.parallelism, ++ parallelism=1, + ) + + hlog(f"Processed {len(request_states)} requests") +diff --git a/src/helm/benchmark/presentation/create_plots.py b/src/helm/benchmark/presentation/create_plots.py +index b5dfe7ff..599f98b8 100644 +--- a/src/helm/benchmark/presentation/create_plots.py ++++ b/src/helm/benchmark/presentation/create_plots.py +@@ -579,18 +579,25 @@ class Plotter: + + def create_all_plots(self): + """Create all the plots used in the HELM paper.""" +- self.create_accuracy_v_x_plots() +- self.create_correlation_plots() +- self.create_leaderboard_plots() +- self.create_all_accuracy_v_model_property_plots() +- self.create_accuracy_v_access_bar_plot() +- self.create_task_summary_plots() +- self.create_targeted_eval_plots() +- self.create_copyright_plot() +- self.create_bbq_plot() +- self.create_in_context_examples_plot() +- self.create_mc_ablations_plot() +- self.create_constrast_set_plots() ++ plot_funcs = [ ++ self.create_accuracy_v_x_plots, ++ self.create_correlation_plots, ++ self.create_leaderboard_plots, ++ self.create_all_accuracy_v_model_property_plots, ++ self.create_accuracy_v_access_bar_plot, ++ self.create_task_summary_plots, ++ self.create_targeted_eval_plots, ++ self.create_copyright_plot, ++ self.create_bbq_plot, ++ self.create_in_context_examples_plot, ++ self.create_mc_ablations_plot, ++ self.create_constrast_set_plots ++ ] ++ for plot_func in plot_funcs: ++ try: ++ plot_func() ++ except Exception as e: ++ hlog(f"WARNING: {plot_func.__name__} failed: {e}") + + + def main(): +diff --git a/src/helm/benchmark/run.py b/src/helm/benchmark/run.py +index 5d950dfc..cf59d150 100644 +--- a/src/helm/benchmark/run.py ++++ b/src/helm/benchmark/run.py +@@ -11,10 +11,10 @@ from helm.proxy.clients.remote_model_registry import check_and_register_remote_m + from helm.proxy.services.remote_service import create_authentication, add_service_args + + from helm.benchmark.adaptation.adapter_spec import AdapterSpec +-from .executor import ExecutionSpec +-from .runner import Runner, RunSpec, LATEST_SYMLINK +-from .slurm_runner import SlurmRunner +-from .run_specs import construct_run_specs ++from helm.benchmark.executor import ExecutionSpec ++from helm.benchmark.runner import Runner, RunSpec, LATEST_SYMLINK ++from helm.benchmark.run_specs import construct_run_specs ++from helm.benchmark.slurm_runner import SlurmRunner + + + def run_entries_to_run_specs( +@@ -68,6 +68,7 @@ def run_benchmarking( + local_path: str, + num_threads: int, + output_path: str, ++ my_config_path: str, + suite: str, + dry_run: bool, + skip_instances: bool, +@@ -84,6 +85,7 @@ def run_benchmarking( + url=url, + local_path=local_path, + parallelism=num_threads, ++ my_config_path=my_config_path, + dry_run=dry_run, + mongo_uri=mongo_uri, + ) +@@ -226,6 +228,10 @@ def main(): + help="Experimental: Enable using AutoModelForCausalLM models from Hugging Face Model Hub. " + "Format: namespace/model_name[@revision]", + ) ++ parser.add_argument( ++ "--my-config-path", type=str, default=None, ++ help='Config to support mymodel/model_name' ++ ) + parser.add_argument( + "--enable-remote-models", + nargs="+", +@@ -282,6 +288,7 @@ def main(): + local_path=args.local_path, + num_threads=args.num_threads, + output_path=args.output_path, ++ my_config_path=args.my_config_path, + suite=args.suite, + dry_run=args.dry_run, + skip_instances=args.skip_instances, +diff --git a/src/helm/benchmark/run_specs.py b/src/helm/benchmark/run_specs.py +index 0a1f6cb7..534dc8c7 100644 +--- a/src/helm/benchmark/run_specs.py ++++ b/src/helm/benchmark/run_specs.py +@@ -512,7 +512,7 @@ def get_bias_metric_specs() -> List[MetricSpec]: + def get_generative_harms_metric_specs(include_basic_metrics: bool = False) -> List[MetricSpec]: + return ( + get_bias_metric_specs() +- + get_toxicity_metric_specs() ++ # + get_toxicity_metric_specs() # Connection Error + + (get_basic_metric_specs([]) if include_basic_metrics else []) + ) + +diff --git a/src/helm/benchmark/window_services/megatron_window_service.py b/src/helm/benchmark/window_services/megatron_window_service.py +deleted file mode 100644 +index 0a37943e..00000000 +--- a/src/helm/benchmark/window_services/megatron_window_service.py ++++ /dev/null +@@ -1,10 +0,0 @@ +-from .gpt2_window_service import GPT2WindowService +- +- +-# NOTE: The only difference between this and GPT2WindowService is that +-# the request length is constrained to the sequence length. +-class MegatronWindowService(GPT2WindowService): +- @property +- def max_request_length(self) -> int: +- """Return the max request length of GPT-2.""" +- return self.max_sequence_length +diff --git a/src/helm/benchmark/window_services/my_window_service.py b/src/helm/benchmark/window_services/my_window_service.py +new file mode 100644 +index 00000000..51e3f9e8 +--- /dev/null ++++ b/src/helm/benchmark/window_services/my_window_service.py +@@ -0,0 +1,31 @@ ++from .local_window_service import LocalWindowService ++from .tokenizer_service import TokenizerService ++ ++ ++class MyWindowService(LocalWindowService): ++ def __init__(self, service: TokenizerService): ++ super().__init__(service) ++ ++ @property ++ def tokenizer_name(self) -> str: ++ return "mymodel/model" ++ ++ @property ++ def max_sequence_length(self) -> int: ++ """Return the max sequence length.""" ++ return 2048 ++ ++ @property ++ def max_request_length(self) -> int: ++ """Return the max request length.""" ++ return self.max_sequence_length ++ ++ @property ++ def end_of_text_token(self) -> str: ++ """The end of text token.""" ++ return "<|endoftext|>" ++ ++ @property ++ def prefix_token(self) -> str: ++ """The prefix token""" ++ return self.end_of_text_token +diff --git a/src/helm/benchmark/window_services/window_service_factory.py b/src/helm/benchmark/window_services/window_service_factory.py +index 2fb7ce68..22e4b0c2 100644 +--- a/src/helm/benchmark/window_services/window_service_factory.py ++++ b/src/helm/benchmark/window_services/window_service_factory.py +@@ -35,7 +35,7 @@ from .starcoder_window_service import StarCoderWindowService + from .gpt2_window_service import GPT2WindowService + from .gptj_window_service import GPTJWindowService + from .gptneox_window_service import GPTNeoXWindowService +-from .megatron_window_service import MegatronWindowService ++from .my_window_service import MyWindowService + from .opt_window_service import OPTWindowService + from .palmyra_window_service import PalmyraWindowService, SilkRoadWindowService + from .remote_window_service import get_remote_window_service +@@ -153,8 +153,8 @@ class WindowServiceFactory: + window_service = UL2WindowService(service) + elif model_name == "together/yalm": + window_service = YaLMWindowService(service) +- elif model_name == "nvidia/megatron-gpt2": +- window_service = MegatronWindowService(service) ++ elif organization == "mymodel": ++ window_service = MyWindowService(service) + elif model_name in [ + "together/llama-7b", + "together/alpaca-7b", +diff --git a/src/helm/proxy/clients/auto_client.py b/src/helm/proxy/clients/auto_client.py +index 2069a12c..a77e8f09 100644 +--- a/src/helm/proxy/clients/auto_client.py ++++ b/src/helm/proxy/clients/auto_client.py +@@ -27,7 +27,7 @@ from .google_client import GoogleClient + from .goose_ai_client import GooseAIClient + from .huggingface_client import HuggingFaceClient + from .ice_tokenizer_client import ICETokenizerClient +-from .megatron_client import MegatronClient ++from .my_client import MyModelClient + from .openai_client import OpenAIClient + from .microsoft_client import MicrosoftClient + from .perspective_api_client import PerspectiveAPIClient +@@ -40,7 +40,7 @@ from helm.proxy.clients.huggingface_model_registry import get_huggingface_model_ + class AutoClient(Client): + """Automatically dispatch to the proper `Client` based on the organization.""" + +- def __init__(self, credentials: Dict[str, str], cache_path: str, mongo_uri: str = ""): ++ def __init__(self, credentials: Dict[str, str], cache_path: str, mongo_uri: str = "", my_config_path: str = None): + self.credentials = credentials + self.cache_path = cache_path + self.mongo_uri = mongo_uri +@@ -50,6 +50,7 @@ class AutoClient(Client): + self.critique_client: Optional[CritiqueClient] = None + huggingface_cache_config = self._build_cache_config("huggingface") + self.huggingface_client = HuggingFaceClient(huggingface_cache_config) ++ self.my_config_path = my_config_path + hlog(f"AutoClient: cache_path = {cache_path}") + hlog(f"AutoClient: mongo_uri = {mongo_uri}") + +@@ -131,8 +132,8 @@ class AutoClient(Client): + cache_config=cache_config, + tokenizer_client=self._get_tokenizer_client("huggingface"), + ) +- elif organization == "nvidia": +- client = MegatronClient(cache_config=cache_config) ++ elif organization == "mymodel": ++ client = MyModelClient(cache_config=cache_config, my_config_path=self.my_config_path) + else: + raise ValueError(f"Could not find client for model: {model}") + self.clients[model] = client +@@ -182,7 +183,6 @@ class AutoClient(Client): + "huggingface", + "microsoft", + "Writer", +- "hf-internal-testing", + ]: + client = HuggingFaceClient(cache_config=cache_config) + elif organization == "openai": +@@ -205,8 +205,8 @@ class AutoClient(Client): + client = CohereClient(api_key=self.credentials["cohereApiKey"], cache_config=cache_config) + elif organization == "simple": + client = SimpleClient(cache_config=cache_config) +- elif organization == "nvidia": +- client = MegatronClient(cache_config=cache_config) ++ elif organization == "mymodel": ++ client = MyModelClient(cache_config=cache_config, my_config_path=self.my_config_path) + else: + raise ValueError(f"Could not find tokenizer client for model: {tokenizer}") + self.tokenizer_clients[tokenizer] = client +diff --git a/src/helm/proxy/clients/huggingface_client.py b/src/helm/proxy/clients/huggingface_client.py +index 55e9bc34..76944839 100644 +--- a/src/helm/proxy/clients/huggingface_client.py ++++ b/src/helm/proxy/clients/huggingface_client.py +@@ -31,10 +31,10 @@ class HuggingFaceServer: + model_kwargs["revision"] = model_config.revision + with htrack_block(f"Loading Hugging Face model for config {model_config}"): + self.model = AutoModelForCausalLM.from_pretrained( +- model_config.model_id, trust_remote_code=True, **model_kwargs ++ model_config.model_id, trust_remote_code=True, local_files_only=True, **model_kwargs + ).to(self.device) + with htrack_block(f"Loading Hugging Face tokenizer model for config {model_config}"): +- self.tokenizer = AutoTokenizer.from_pretrained(model_config.model_id, **model_kwargs) ++ self.tokenizer = AutoTokenizer.from_pretrained(model_config.model_id, local_files_only=True, **model_kwargs) + + def serve_request(self, raw_request: Dict[str, Any]): + encoded_input = self.tokenizer(raw_request["prompt"], return_tensors="pt").to(self.device) +@@ -59,43 +59,67 @@ class HuggingFaceServer: + for key in raw_request + if key not in ["engine", "prompt", "echo_prompt", "stop_sequences"] + } ++ if raw_request['echo_prompt'] and raw_request['max_new_tokens'] == 0: ++ logits = self.model.forward(**encoded_input).logits ++ sequences = encoded_input.input_ids ++ all_tokens = [self.tokenizer.convert_ids_to_tokens(sequence) for sequence in sequences] ++ all_decoded_text = raw_request["prompt"] ++ all_logprobs_of_chosen_tokens = [] ++ all_top_logprobs_dicts = [] ++ for completion_id in range(raw_request["num_return_sequences"]): ++ logprobs_of_chosen_tokens = [0.0] ++ top_logprobs_dicts = [{}] ++ for i in range(len(encoded_input.input_ids[0]) - 1): ++ logprobs = torch.nn.functional.log_softmax(logits[completion_id][i], dim=0) ++ # Get top tokens in terms of log probability. ++ topk_logprobs = torch.topk(logprobs, k=top_k_per_token) ++ top_logprobs_dicts.append( ++ { ++ self.tokenizer.convert_ids_to_tokens(k.item()): v.item() ++ for (k, v) in zip(topk_logprobs.indices, topk_logprobs.values) ++ } ++ ) + +- # Use HuggingFace's `generate` method. +- output = self.model.generate(**encoded_input, **relevant_raw_request) +- sequences = output.sequences +- scores = output.scores +- +- # Compute logprobs for each completed sequence. +- all_logprobs_of_chosen_tokens = [] +- all_top_logprobs_dicts = [] +- for completion_id in range(raw_request["num_return_sequences"]): +- logprobs_of_chosen_tokens = [] +- top_logprobs_dicts = [] +- for i in range(len(sequences[completion_id]) - len(encoded_input.input_ids[0])): +- logprobs = torch.nn.functional.log_softmax(scores[i][completion_id], dim=0) +- +- # Get top tokens in terms of log probability. +- topk_logprobs = torch.topk(logprobs, k=top_k_per_token) +- top_logprobs_dicts.append( +- { +- self.tokenizer.convert_ids_to_tokens(k.item()): v.item() +- for (k, v) in zip(topk_logprobs.indices, topk_logprobs.values) +- } +- ) ++ logprobs_of_chosen_tokens.append(logprobs[sequences[completion_id][i + 1]].item()) ++ all_logprobs_of_chosen_tokens.append(logprobs_of_chosen_tokens) ++ all_top_logprobs_dicts.append(top_logprobs_dicts) ++ else: ++ # Use HuggingFace's `generate` method. ++ output = self.model.generate(**encoded_input, **relevant_raw_request) ++ sequences = output.sequences ++ scores = output.scores ++ ++ # Compute logprobs for each completed sequence. ++ all_logprobs_of_chosen_tokens = [] ++ all_top_logprobs_dicts = [] ++ for completion_id in range(raw_request["num_return_sequences"]): ++ logprobs_of_chosen_tokens = [] ++ top_logprobs_dicts = [] ++ for i in range(len(sequences[completion_id]) - len(encoded_input.input_ids[0])): ++ logprobs = torch.nn.functional.log_softmax(scores[i][completion_id], dim=0) ++ ++ # Get top tokens in terms of log probability. ++ topk_logprobs = torch.topk(logprobs, k=top_k_per_token) ++ top_logprobs_dicts.append( ++ { ++ self.tokenizer.convert_ids_to_tokens(k.item()): v.item() ++ for (k, v) in zip(topk_logprobs.indices, topk_logprobs.values) ++ } ++ ) + +- # Get log probability of chosen token. +- j = i + len(encoded_input.input_ids[0]) +- logprobs_of_chosen_tokens.append(logprobs[sequences[completion_id][j]].item()) +- all_logprobs_of_chosen_tokens.append(logprobs_of_chosen_tokens) +- all_top_logprobs_dicts.append(top_logprobs_dicts) ++ # Get log probability of chosen token. ++ j = i + len(encoded_input.input_ids[0]) ++ logprobs_of_chosen_tokens.append(logprobs[sequences[completion_id][j]].item()) ++ all_logprobs_of_chosen_tokens.append(logprobs_of_chosen_tokens) ++ all_top_logprobs_dicts.append(top_logprobs_dicts) + +- # Remove prompt from the start of each sequence if echo_prompt is False. +- if not raw_request["echo_prompt"]: +- sequences = [sequence[len(encoded_input.input_ids[0]) :] for sequence in sequences] ++ # Remove prompt from the start of each sequence if echo_prompt is False. ++ if not raw_request["echo_prompt"]: ++ sequences = [sequence[len(encoded_input.input_ids[0]) :] for sequence in sequences] + +- # TODO: Get rid of the extra tokenization? +- all_tokens = [self.tokenizer.convert_ids_to_tokens(sequence) for sequence in sequences] +- all_decoded_text = self.tokenizer.batch_decode(sequences) ++ # TODO: Get rid of the extra tokenization? ++ all_tokens = [self.tokenizer.convert_ids_to_tokens(sequence) for sequence in sequences] ++ all_decoded_text = self.tokenizer.batch_decode(sequences) + + completions = [] + for (decoded_text, tokens, logprobs_of_chosen_tokens, top_logprobs_dicts) in zip( +@@ -166,7 +190,6 @@ class HuggingFaceClient(Client): + # Get cached model server instance if possible (to save on model and tokenizer + # loading times). + model_server_instance: HuggingFaceServer = self.get_model_server_instance(request.model) +- + try: + + def do_it(): +@@ -183,7 +206,9 @@ class HuggingFaceClient(Client): + sequence_logprob: float = 0 + tokens: List[Token] = [] + +- if request.echo_prompt: ++ if request.echo_prompt and request.max_tokens == 0: ++ generated_tokens = raw_completion["tokens"] ++ elif request.echo_prompt: + # Add prompt to list of generated tokens. + generated_tokens = raw_completion["tokens"][response["input_length"] :] + for token_text in raw_completion["tokens"][: response["input_length"]]: +diff --git a/src/helm/proxy/clients/huggingface_tokenizer.py b/src/helm/proxy/clients/huggingface_tokenizer.py +index e55cf039..4bc87ede 100644 +--- a/src/helm/proxy/clients/huggingface_tokenizer.py ++++ b/src/helm/proxy/clients/huggingface_tokenizer.py +@@ -59,14 +59,17 @@ class HuggingFaceTokenizers: + # the Hugging Face Transformers library, while the fast versions are the ones provided by Hugging Face + # Tokenizers, which are written in Rust." So, use the "fast" version of the tokenizers if available. + return AutoTokenizer.from_pretrained( +- hf_tokenizer_name, local_files_only=True, use_fast=True, **tokenizer_kwargs ++ hf_tokenizer_name, local_files_only=True, use_fast=False, **tokenizer_kwargs + ) + except OSError: + hlog(f"Local files do not exist for HuggingFace tokenizer: {hf_tokenizer_name}. Downloading...") + return AutoTokenizer.from_pretrained( + hf_tokenizer_name, local_files_only=False, use_fast=True, **tokenizer_kwargs + ) +- ++ except ValueError: ++ return AutoTokenizer.from_pretrained( ++ hf_tokenizer_name, local_files_only=True, **tokenizer_kwargs ++ ) + if tokenizer_name not in HuggingFaceTokenizers.tokenizers: + with htrack_block(f"Loading {tokenizer_name} with Hugging Face Transformers"): + # To avoid deadlocks when using HuggingFace tokenizers with multiple processes +diff --git a/src/helm/proxy/clients/megatron_client.py b/src/helm/proxy/clients/megatron_client.py +deleted file mode 100644 +index 6d0e3867..00000000 +--- a/src/helm/proxy/clients/megatron_client.py ++++ /dev/null +@@ -1,99 +0,0 @@ +-import json +-import requests +-from typing import Any, Dict, List +-import traceback +- +-from helm.common.request import EMBEDDING_UNAVAILABLE_REQUEST_RESULT, Request, RequestResult, Sequence, Token +-from helm.common.tokenization_request import TokenizationRequest +-from helm.proxy.clients.huggingface_client import HuggingFaceClient +-from helm.proxy.clients.client import Client, wrap_request_time, truncate_sequence +- +- +-class MegatronClient(HuggingFaceClient): +- """Client for remote Megatron-LM server. +- +- This client expects an external Megatron-LM server to be run on localhost:5000. See the +- Megatron-LM respository for documentation on starting a Megatron text generation server: +- +- https://github.com/NVIDIA/Megatron-LM#gpt-text-generation +- """ +- +- def _send_request(self, raw_request: Dict[str, Any]) -> Dict[str, Any]: +- response = requests.request( +- method="PUT", +- # TODO(tgale): Make this configurable. +- url="http://localhost:5000/api", +- headers={ +- "Content-Type": "application/json; charset=UTF-8", +- }, +- data=json.dumps(raw_request), +- ) +- out = json.loads(response.text) +- +- # Detect if the server returned an error string. +- if type(out) != dict: +- raise ValueError(f"{response}: {response.text}") +- return out +- +- def _tokenize_response(self, text: str) -> List[Token]: +- tokenized_text = self.tokenize(TokenizationRequest(text, tokenizer="huggingface/gpt2")) +- +- # TODO(tgale): Support logprobs. +- tokens = [Token(text=str(token), logprob=0, top_logprobs={}) for token in tokenized_text.raw_tokens] +- return tokens +- +- def _make_request(self, request: Request) -> RequestResult: +- # Embedding not supported for this model +- if request.embedding: +- return EMBEDDING_UNAVAILABLE_REQUEST_RESULT +- +- # TODO(tgale): Relax these. +- assert request.num_completions == 1 +- assert not request.echo_prompt +- assert not request.stop_sequences +- assert request.top_p == 1 +- +- # TODO(tgale): Handle log probabilities. +- raw_request = { +- "prompts": [request.prompt], +- "tokens_to_generate": request.max_tokens, +- "temperature": request.temperature, +- "top_k": request.top_k_per_token, +- } +- +- cache_key = Client.make_cache_key(raw_request, request) +- response, cached = self.cache.get(cache_key, wrap_request_time(lambda: self._send_request(raw_request))) +- +- # Verify we got a single response for the prompt. +- assert len(response["text"]) == 1 +- +- # NOTE: Megatron returns the response with the prompt included. +- generated_text = response["text"][0] +- if not request.echo_prompt: +- generated_text = generated_text[len(request.prompt) :] +- +- # NOTE: Megatron returns the de-tokenized response. Re-tokenize. +- tokens = self._tokenize_response(generated_text) +- completion = Sequence(text=generated_text, logprob=0, tokens=tokens) +- completion = truncate_sequence(completion, request, print_warning=True) +- +- return RequestResult( +- success=True, +- cached=cached, +- request_time=response["request_time"], +- request_datetime=response.get("request_datetime"), +- completions=[completion], +- embedding=[], +- ) +- +- def make_request(self, request: Request) -> RequestResult: +- try: +- return self._make_request(request) +- except Exception as e: +- return RequestResult( +- success=False, +- cached=False, +- error=f"MegatronClient Error: {e}\n\n{traceback.format_exc()}", +- completions=[], +- embedding=[], +- ) +diff --git a/src/helm/proxy/clients/my_client.py b/src/helm/proxy/clients/my_client.py +new file mode 100644 +index 00000000..0eb8de52 +--- /dev/null ++++ b/src/helm/proxy/clients/my_client.py +@@ -0,0 +1,112 @@ ++import requests ++import json ++import yaml ++from typing import Any, Dict, List ++from dataclasses import asdict ++ ++from helm.common.cache import Cache, CacheConfig ++from helm.common.request import Request, RequestResult, Sequence, Token ++from helm.common.tokenization_request import ( ++ DecodeRequest, ++ DecodeRequestResult, ++ TokenizationRequest, ++ TokenizationRequestResult, ++ TokenizationToken ++) ++from helm.proxy.clients.my_tokenizer.my_tokenizer import MyTokenizer ++ ++from .client import Client, truncate_sequence, wrap_request_time ++ ++ ++class MyModelClient(Client): ++ def __init__(self, cache_config: CacheConfig, my_config_path): ++ if my_config_path == None: ++ raise NotImplementedError("--my-config-path must be set") ++ with open(my_config_path, 'r') as f: ++ cfg = yaml.safe_load(f) ++ port = cfg['port'] if 'port' in cfg else 5000 ++ self.tokenizer = MyTokenizer(cfg['tokenizer']) ++ self.url = f'http://localhost:{port}/api' ++ self.cache = Cache(cache_config) ++ ++ def _send_request(self, raw_request: Dict[str, Any]) -> Dict[str, Any]: ++ header = { ++ 'Content-Type': 'application/json; charset=UTF-8', ++ } ++ response = requests.put(url=self.url, headers=header, data=json.dumps(raw_request)) ++ return response.json() ++ ++ def make_request(self, request: Request) -> RequestResult: ++ completions: List[Sequence] = [] ++ raw_request = { ++ 'prompts': [request.prompt for _ in range(request.num_completions)], ++ 'tokens_to_generate': request.max_tokens, ++ 'temperature': request.temperature, ++ 'top_p': request.top_p, ++ 'logprobs': True, ++ 'echo_prompts': request.echo_prompt, ++ 'engine': request.model_engine ++ } ++ try: ++ def do_it(): ++ result = self._send_request(raw_request) ++ if 'text' not in result: ++ raise ValueError(f'Invalid response: {result}') ++ return result ++ cache_key = Client.make_cache_key(raw_request, request) ++ response, cached = self.cache.get(cache_key, wrap_request_time(do_it)) ++ except Exception as e: ++ error: str = f"Megatron-Server error: {e}" ++ return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[]) ++ for completion_idx in range(request.num_completions): ++ text: str = response['text'][completion_idx] ++ response['logprobs'][completion_idx].insert(0, 0) ++ if text.endswith(self.tokenizer.eos): ++ text = text[:-len(self.tokenizer.eos)] ++ response['segments'][completion_idx].pop() ++ response['logprobs'][completion_idx].pop() ++ tokens = [ ++ Token(text=text, logprob=logprob, top_logprobs={}) for text, logprob in zip(response['segments'][completion_idx], response['logprobs'][completion_idx]) ++ ] ++ completion = Sequence(text=text, logprob=sum(response['logprobs'][completion_idx]), tokens=tokens) ++ sequence = truncate_sequence(completion, request, print_warning=True) ++ completions.append(sequence) ++ return RequestResult( ++ success=True, ++ cached=cached, ++ request_time=response['request_time'], ++ request_datetime=response['request_datetime'], ++ completions=completions, ++ embedding=[] ++ ) ++ ++ def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult: ++ cache_key = asdict(request) ++ def do_it(): ++ token_ids = self.tokenizer.tokenize(request.text) ++ if request.truncation: ++ token_ids = token_ids[: request.max_length] ++ return {"tokens": token_ids} ++ result, cached = self.cache.get(cache_key, wrap_request_time(do_it)) ++ return TokenizationRequestResult( ++ success=True, ++ cached=cached, ++ text=request.text, ++ tokens=[TokenizationToken(value) for value in result["tokens"]], ++ request_time=result["request_time"], ++ ) ++ ++ def decode(self, request: DecodeRequest) -> DecodeRequestResult: ++ cache_key = asdict(request) ++ try: ++ def do_it(): ++ return {"text": self.tokenizer.decode(request.tokens)} ++ ++ result, cached = self.cache.get(cache_key, wrap_request_time(do_it)) ++ except Exception as e: ++ error: str = f"My Tokenizer error: {request.tokens}" ++ return DecodeRequestResult(success=False, cached=False, error=error, text="") ++ ++ return DecodeRequestResult( ++ success=True, cached=cached, text=result["text"], request_time=result["request_time"] ++ ) +diff --git a/src/helm/proxy/clients/my_tokenizer/__init__.py b/src/helm/proxy/clients/my_tokenizer/__init__.py +new file mode 100644 +index 00000000..e69de29b +diff --git a/src/helm/proxy/clients/my_tokenizer/config.json b/src/helm/proxy/clients/my_tokenizer/config.json +new file mode 100644 +index 00000000..3997f5c0 +--- /dev/null ++++ b/src/helm/proxy/clients/my_tokenizer/config.json +@@ -0,0 +1,7 @@ ++{ ++ "local": { ++ "type": "gpt2", ++ "vocab_path": "/home/data/panxuchen.pxc/code/helm/gpt2-zhcn3-v4.json", ++ "merge_path": "/home/data/panxuchen.pxc/code/helm/gpt2-zhcn3-v4.bpe" ++ } ++} +\ No newline at end of file +diff --git a/src/helm/proxy/clients/my_tokenizer/my_tokenizer.py b/src/helm/proxy/clients/my_tokenizer/my_tokenizer.py +new file mode 100644 +index 00000000..af96931a +--- /dev/null ++++ b/src/helm/proxy/clients/my_tokenizer/my_tokenizer.py +@@ -0,0 +1,29 @@ ++import json ++import importlib_resources as resources ++from io import open ++from transformers import AutoTokenizer, GPT2Tokenizer, LlamaTokenizer ++ ++TOKENIZER_PACKAGE: str = "helm.proxy.clients.my_tokenizer" ++TOKENIZER_CONFIG: str = "config.json" ++ ++class MyTokenizer(object): ++ ++ def __init__(self, config): ++ if config['type'] == 'huggingface': ++ self.tokenizer = AutoTokenizer.from_pretrained(config['model_path']) ++ elif config['type'] == 'gpt2': ++ self.tokenizer = GPT2Tokenizer(vocab_file=config['vocab_path'], merges_file=config['merge_path']) ++ elif config['type'] == 'sentencepiece' or config['type'] == 'llama': ++ self.tokenizer = LlamaTokenizer(vocab_file=config['tokenizer_path']) ++ else: ++ raise NotImplementedError("Unknown tokenizer") ++ ++ def tokenize(self, text): ++ return self.tokenizer(text)['input_ids'] ++ ++ def decode(self, token_ids): ++ return self.tokenizer.decode(token_ids) ++ ++ @property ++ def eos(self): ++ return self.tokenizer.eos_token +\ No newline at end of file +diff --git a/src/helm/proxy/models.py b/src/helm/proxy/models.py +index bfa3ed3c..df1c2c1d 100644 +--- a/src/helm/proxy/models.py ++++ b/src/helm/proxy/models.py +@@ -631,11 +631,11 @@ ALL_MODELS = [ + name="google/palm", + tags=[TEXT_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG], + ), +- # NVIDIA ++ # My Model + Model( +- group="nvidia", +- name="nvidia/megatron-gpt2", +- tags=[TEXT_MODEL_TAG, FULL_FUNCTIONALITY_TEXT_MODEL_TAG, GPT2_TOKENIZER_TAG, BUGGY_TEMP_0_TAG], ++ group="mymodel", ++ name="mymodel/model", ++ tags=[TEXT_MODEL_TAG, FULL_FUNCTIONALITY_TEXT_MODEL_TAG, GPT2_TOKENIZER_TAG] + ), + # For debugging + Model( +@@ -649,10 +649,15 @@ MODEL_NAME_TO_MODEL: Dict[str, Model] = {model.name: model for model in ALL_MODE + + def get_model(model_name: str) -> Model: + """Get the `Model` given the name.""" +- if model_name not in MODEL_NAME_TO_MODEL: +- raise ValueError(f"No model with name: {model_name}") ++ if model_name in MODEL_NAME_TO_MODEL: ++ return MODEL_NAME_TO_MODEL[model_name] ++ else: ++ model_name_prefix = model_name.split('/')[0] ++ if model_name_prefix == 'mymodel': ++ return MODEL_NAME_TO_MODEL['mymodel/model'] ++ else: ++ raise ValueError(f"No model with name: {model_name}") + +- return MODEL_NAME_TO_MODEL[model_name] + + + def get_model_group(model_name: str) -> str: +diff --git a/src/helm/proxy/services/server_service.py b/src/helm/proxy/services/server_service.py +index f08e4a06..222934d7 100644 +--- a/src/helm/proxy/services/server_service.py ++++ b/src/helm/proxy/services/server_service.py +@@ -40,7 +40,7 @@ class ServerService(Service): + Main class that supports various functionality for the server. + """ + +- def __init__(self, base_path: str = ".", root_mode=False, mongo_uri: str = ""): ++ def __init__(self, base_path: str = ".", root_mode=False, mongo_uri: str = "", my_config_path: str = None): + credentials_path = os.path.join(base_path, CREDENTIALS_FILE) + cache_path = os.path.join(base_path, CACHE_DIR) + ensure_directory_exists(cache_path) +@@ -52,7 +52,7 @@ class ServerService(Service): + else: + credentials = {} + +- self.client = AutoClient(credentials, cache_path, mongo_uri) ++ self.client = AutoClient(credentials, cache_path, mongo_uri, my_config_path) + self.token_counter = AutoTokenCounter(self.client.huggingface_client) + self.accounts = Accounts(accounts_path, root_mode=root_mode) + # Lazily instantiated by get_toxicity_scores() diff --git a/thirdparty/patch/megatron.diff b/thirdparty/patch/megatron.diff new file mode 100644 index 000000000..7ff33b9a2 --- /dev/null +++ b/thirdparty/patch/megatron.diff @@ -0,0 +1,1571 @@ +diff --git a/megatron/__init__.py b/megatron/__init__.py +index aa99c06..6ccd17d 100644 +--- a/megatron/__init__.py ++++ b/megatron/__init__.py +@@ -9,6 +9,7 @@ from .global_vars import get_signal_handler + from .global_vars import update_num_microbatches + from .global_vars import get_tokenizer + from .global_vars import get_tensorboard_writer ++from .global_vars import get_wandb + from .global_vars import get_adlr_autoresume + from .global_vars import get_timers + from .initialize import initialize_megatron +diff --git a/megatron/arguments.py b/megatron/arguments.py +index 6cc1cc0..c9a941d 100644 +--- a/megatron/arguments.py ++++ b/megatron/arguments.py +@@ -638,35 +638,42 @@ def _add_logging_args(parser): + ' max: report the max timing across all ranks' + ' minmax: report min and max timings across all ranks' + ' all: report timings of all ranks.') +- group.add_argument('--tensorboard-log-interval', type=int, default=1, +- help='Report to tensorboard interval.') ++ group.add_argument('--tracker-log-interval', type=int, default=1, ++ help='Report to trackers interval.') + group.add_argument('--tensorboard-queue-size', type=int, default=1000, + help='Size of the tensorboard queue for pending events ' + 'and summaries before one of the ‘add’ calls forces a ' + 'flush to disk.') +- group.add_argument('--log-timers-to-tensorboard', action='store_true', +- help='If set, write timers to tensorboard.') +- group.add_argument('--log-batch-size-to-tensorboard', action='store_true', +- help='If set, write batch-size to tensorboard.') +- group.add_argument('--no-log-learnig-rate-to-tensorboard', ++ group.add_argument('--log-timers-to-tracker', action='store_true', ++ help='If set, write timers to trackers.') ++ group.add_argument('--log-batch-size-to-tracker', action='store_true', ++ help='If set, write batch-size to trackers.') ++ group.add_argument('--no-log-learnig-rate-to-tracker', + action='store_false', +- help='Disable learning rate logging to tensorboard.', +- dest='log_learning_rate_to_tensorboard') +- group.add_argument('--no-log-loss-scale-to-tensorboard', ++ help='Disable learning rate logging to trackers.', ++ dest='log_learning_rate_to_tracker') ++ group.add_argument('--no-log-loss-scale-to-tracker', + action='store_false', +- help='Disable loss-scale logging to tensorboard.', +- dest='log_loss_scale_to_tensorboard') +- group.add_argument('--log-validation-ppl-to-tensorboard', ++ help='Disable loss-scale logging to trackers.', ++ dest='log_loss_scale_to_tracker') ++ group.add_argument('--log-validation-ppl-to-tracker', + action='store_true', + help='If set, write validation perplexity to ' +- 'tensorboard.') +- group.add_argument('--log-memory-to-tensorboard', ++ 'trackers.') ++ group.add_argument('--log-memory-to-tracker', + action='store_true', +- help='Enable memory logging to tensorboard.') +- group.add_argument('--log-world-size-to-tensorboard', ++ help='Enable memory logging to trackers.') ++ group.add_argument('--log-world-size-to-tracker', + action='store_true', +- help='Enable world size logging to tensorboard.') +- ++ help='Enable world size logging to trackers.') ++ group.add_argument('--wandb-project', type=str, default=None, ++ help='Wandb project name') ++ group.add_argument('--wandb-group', type=str, default=None, ++ help='Wandb group name') ++ group.add_argument('--wandb-master-name', type=str, default='master', ++ help='The name of master node in wandb') ++ group.add_argument('--wandb-worker-name', type=str, default='worker', ++ help='The name of worker node in wandb') + return parser + + +@@ -706,7 +713,7 @@ def _add_regularization_args(parser): + def _add_training_args(parser): + group = parser.add_argument_group(title='training') + +- group.add_argument('--micro-batch-size', type=int, default=None, ++ group.add_argument('--micro-batch-size', type=int, default=1, + help='Batch size per model instance (local batch size). ' + 'Global batch size is local batch size times data ' + 'parallel size times number of micro batches.') +@@ -841,6 +848,8 @@ def _add_training_args(parser): + help='Disable fusing gradient accumulation to weight ' + 'gradient computation of linear layers', + dest='gradient_accumulation_fusion') ++ group.add_argument('--use-rmsnorm', action='store_true', ++ help='Enable using RMSNorm instead of normal LayerNorm.') + return parser + + +@@ -925,6 +934,9 @@ def _add_checkpointing_args(parser): + help='Do not load optimizer when loading checkpoint.') + group.add_argument('--no-load-rng', action='store_true', default=None, + help='Do not load rng state when loading checkpoint.') ++ group.add_argument("--load-iteration", type=int, default=0, ++ help='Load the checkpoint of this iteration, ' ++ 'set 0 to load the latest checkpoint.') + group.add_argument('--finetune', action='store_true', + help='Load model for finetuning. Do not load optimizer ' + 'or rng state from checkpoint and set iteration to 0. ' +diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py +index e88b585..6dc8c5b 100644 +--- a/megatron/checkpointing.py ++++ b/megatron/checkpointing.py +@@ -357,27 +357,30 @@ def fix_query_key_value_ordering(model, checkpoint_version): + " checkpoint version {}".format(checkpoint_version)) + + +-def _load_base_checkpoint(load_dir, rank0=False): ++def _load_base_checkpoint(load_dir, load_iteration=0, rank0=False): + """ Load the base state_dict from the given directory + + If rank0 is true, just loads rank 0 checkpoint, ignoring arguments. + """ ++ if load_iteration != 0: ++ iteration = load_iteration ++ release = False ++ else: ++ # Read the tracker file and set the iteration. ++ tracker_filename = get_checkpoint_tracker_filename(load_dir) + +- # Read the tracker file and set the iteration. +- tracker_filename = get_checkpoint_tracker_filename(load_dir) +- +- # If no tracker file, return nothing +- if not os.path.isfile(tracker_filename): +- if not rank0: +- print_rank_0('WARNING: could not find the metadata file {} '.format( +- tracker_filename)) +- print_rank_0(' will not load any checkpoints and will start from ' +- 'random') +- return None, "", False ++ # If no tracker file, return nothing ++ if not os.path.isfile(tracker_filename): ++ if not rank0: ++ print_rank_0('WARNING: could not find the metadata file {} '.format( ++ tracker_filename)) ++ print_rank_0(' will not load any checkpoints and will start from ' ++ 'random') ++ return None, "", False + +- # Otherwise, read the tracker file and either set the iteration or +- # mark it as a release checkpoint. +- iteration, release = read_metadata(tracker_filename) ++ # Otherwise, read the tracker file and either set the iteration or ++ # mark it as a release checkpoint. ++ iteration, release = read_metadata(tracker_filename) + + # Checkpoint. + if rank0: +@@ -431,7 +434,7 @@ def load_args_from_checkpoint(args, load_arg='load'): + print_rank_0('No load directory specified, using provided arguments.') + return args + +- state_dict, checkpoint_name, release = _load_base_checkpoint(load_dir, rank0=True) ++ state_dict, checkpoint_name, release = _load_base_checkpoint(load_dir, load_iteration=args.load_iteration, rank0=True) + + # Args. + if not state_dict: +@@ -482,6 +485,8 @@ def load_args_from_checkpoint(args, load_arg='load'): + _set_arg('apply_layernorm_1p', force=True) + _set_arg('tokenizer_type') + _set_arg('padded_vocab_size') ++ _set_arg('use_rmsnorm', force=True) ++ _set_arg('no_persist_layer_norm', force=True) + if checkpoint_version < 3.0: + _set_arg('tensor_model_parallel_size', + 'model_parallel_size') +@@ -504,7 +509,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri + + model = unwrap_model(model) + +- state_dict, checkpoint_name, release = _load_base_checkpoint(load_dir, rank0=False) ++ state_dict, checkpoint_name, release = _load_base_checkpoint(load_dir, args.load_iteration, rank0=False) + + # Checkpoint not loaded. + if state_dict is None: +diff --git a/megatron/global_vars.py b/megatron/global_vars.py +index 4e0118e..fe079c1 100644 +--- a/megatron/global_vars.py ++++ b/megatron/global_vars.py +@@ -19,6 +19,8 @@ _GLOBAL_TENSORBOARD_WRITER = None + _GLOBAL_ADLR_AUTORESUME = None + _GLOBAL_TIMERS = None + _GLOBAL_SIGNAL_HANDLER = None ++_GLOBAL_WANDB = None ++ + + def get_args(): + """Return arguments.""" +@@ -56,6 +58,12 @@ def get_tensorboard_writer(): + return _GLOBAL_TENSORBOARD_WRITER + + ++def get_wandb(): ++ """Return wandb object. It can be None so no need ++ to check if it is initialized""" ++ return _GLOBAL_WANDB ++ ++ + def get_adlr_autoresume(): + """ADLR autoresume object. It can be None so no need + to check if it is initialized.""" +@@ -92,12 +100,13 @@ def set_global_variables(args, build_tokenizer=True): + if build_tokenizer: + _ = _build_tokenizer(args) + _set_tensorboard_writer(args) ++ _set_wandb(args) + _set_adlr_autoresume(args) + _set_timers(args) + + if args.exit_signal_handler: + _set_signal_handler() +- ++ + + def set_args(args): + global _GLOBAL_ARGS +@@ -153,6 +162,36 @@ def _set_tensorboard_writer(args): + 'no TensorBoard logs will be written.', flush=True) + + ++def _set_wandb(args): ++ global _GLOBAL_WANDB ++ _ensure_var_is_not_initialized(_GLOBAL_WANDB, ++ 'wandb writer') ++ is_local_main = (args.rank + 1) % torch.cuda.device_count() == 0 ++ node_rank = args.rank // torch.cuda.device_count() ++ description = os.environ.get('RUN_DESCRIPTION', default='') ++ if hasattr(args, 'wandb_project') and \ ++ args.wandb_project and is_local_main: ++ try: ++ import wandb ++ is_master = args.rank == (args.world_size - 1) ++ wandb.init( ++ project=args.wandb_project, ++ group=args.wandb_group, ++ name=args.wandb_master_name if is_master ++ else f'{args.wandb_worker_name}-{node_rank}', ++ save_code=False, ++ config=args, ++ force=False, ++ notes=description, ++ tags=['master'if is_master else 'worker'] ++ ) ++ if args.rank == (args.world_size - 1): ++ _GLOBAL_WANDB = wandb ++ except Exception: ++ print("WARNING: Skip wandb setup. Please execute " ++ "'wandb login' to enable wandb.", flush=True) ++ ++ + def _set_adlr_autoresume(args): + """Initialize ADLR autoresume.""" + global _GLOBAL_ADLR_AUTORESUME +@@ -186,6 +225,3 @@ def _ensure_var_is_initialized(var, name): + def _ensure_var_is_not_initialized(var, name): + """Make sure the input variable is not None.""" + assert var is None, '{} is already initialized.'.format(name) +- +- +- +diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py +index fd8591e..e57cba6 100644 +--- a/megatron/model/fused_layer_norm.py ++++ b/megatron/model/fused_layer_norm.py +@@ -11,6 +11,7 @@ from torch.nn import init + import importlib + + from megatron.core.utils import make_viewless_tensor ++from megatron import get_args + + try: + from apex.contrib.layer_norm.layer_norm import FastLayerNormFN +@@ -18,7 +19,7 @@ try: + except: + HAVE_PERSIST_LAYER_NORM = False + +-from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction ++from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction, FusedRMSNormAffineFunction + + + global fused_layer_norm_cuda +@@ -32,6 +33,8 @@ class MixedFusedLayerNorm(torch.nn.Module): + sequence_parallel=False, + apply_layernorm_1p=False): + super(MixedFusedLayerNorm, self).__init__() ++ args = get_args() ++ self.use_rmsnorm = args.use_rmsnorm + + self.apply_layernorm_1p = apply_layernorm_1p + +@@ -77,7 +80,10 @@ class MixedFusedLayerNorm(torch.nn.Module): + weight = self.weight + 1 if self.apply_layernorm_1p else self.weight + + if self.no_persist_layer_norm: +- return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps) ++ if not self.use_rmsnorm: ++ return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps) ++ else: ++ return FusedRMSNormAffineFunction.apply(input, weight, self.normalized_shape, self.eps) + else: + output = FastLayerNormFN.apply(input, weight, self.bias, self.eps) + +diff --git a/megatron/text_generation/api.py b/megatron/text_generation/api.py +index 090b630..726c82d 100644 +--- a/megatron/text_generation/api.py ++++ b/megatron/text_generation/api.py +@@ -8,26 +8,28 @@ import torch + from megatron.core import mpu + from .communication import broadcast_float_list + from .generation import ( +- generate_tokens_probs_and_return_on_first_stage, +- score_and_return_on_first_stage, +- beam_search_and_return_on_first_stage) ++ generate_tokens_probs_and_return_on_first_stage, ++ score_and_return_on_first_stage, ++ beam_search_and_return_on_first_stage) + from .tokenization import ( + tokenize_prompts, ++ tokenize_sequences, + detokenize_generations) + ++ + def generate_and_post_process(model, + prompts=None, + tokens_to_generate=0, + return_output_log_probs=False, ++ echo_prompts=False, + top_k_sampling=0, + top_p_sampling=0.0, + top_p_decay=0.0, + top_p_bound=0.0, + temperature=1.0, + add_BOS=False, +- use_eod_token_for_early_termination=True, +- stop_on_double_eol=False, +- stop_on_eol=False, ++ use_stop_tokens_for_early_termination=True, ++ stop_sequences=None, + prevent_newline_after_colon=False, + random_seed=-1): + """Run inference and post-process outputs, i.e., detokenize, +@@ -38,6 +40,7 @@ def generate_and_post_process(model, + model, + prompts=prompts, + tokens_to_generate=tokens_to_generate, ++ echo_prompts=echo_prompts, + return_output_log_probs=return_output_log_probs, + top_k_sampling=top_k_sampling, + top_p_sampling=top_p_sampling, +@@ -45,45 +48,44 @@ def generate_and_post_process(model, + top_p_bound=top_p_bound, + temperature=temperature, + add_BOS=add_BOS, +- use_eod_token_for_early_termination=use_eod_token_for_early_termination, +- stop_on_double_eol=stop_on_double_eol, +- stop_on_eol=stop_on_eol, ++ use_stop_tokens_for_early_termination=use_stop_tokens_for_early_termination, ++ stop_sequences=stop_sequences, + prevent_newline_after_colon=prevent_newline_after_colon, + random_seed=random_seed) + + # Only post-process on first stage. + if mpu.is_pipeline_first_stage(): +- tokens, prompts_plus_generations, prompts_plus_generations_segments = \ +- detokenize_generations(tokens, lengths, True) +- + if return_output_log_probs: ++ tokens, generations, generations_segments = \ ++ detokenize_generations(tokens, lengths, return_output_log_probs) + output_log_probs = output_log_probs.cpu().numpy().tolist() +- for i, (prob, seg) in enumerate(zip(output_log_probs, prompts_plus_generations_segments)): ++ for i, (prob, seg) in enumerate(zip(output_log_probs, generations_segments)): + output_log_probs[i] = prob[:len(seg)-1] +- +- return prompts_plus_generations, prompts_plus_generations_segments, \ +- output_log_probs, tokens +- ++ return generations, generations_segments, \ ++ output_log_probs, tokens ++ else: ++ tokens, generations = detokenize_generations(tokens, lengths, return_output_log_probs) ++ return generations, None, None, tokens + return None + + def generate(model, + prompts=None, + tokens_to_generate=0, + return_output_log_probs=False, ++ echo_prompts=False, + top_k_sampling=0, + top_p_sampling=0.0, + top_p_decay=0.0, + top_p_bound=0.0, + temperature=1.0, + add_BOS=False, +- use_eod_token_for_early_termination=True, +- stop_on_double_eol=False, +- stop_on_eol=False, ++ use_stop_tokens_for_early_termination=True, ++ stop_sequences=None, + prevent_newline_after_colon=False, + random_seed=-1): + """Given prompts and input parameters, run inference and return: + tokens: prompts plus the generated tokens. +- lengths: length of the prompt + generations. Note that we can ++ lengths: length of the (prompt +) generations. Note that we can + discard tokens in the tokens tensor that are after the + corresponding length. + output_log_probs: log probs of the tokens. +@@ -93,11 +95,22 @@ def generate(model, + values = [tokens_to_generate, + return_output_log_probs, + top_k_sampling, top_p_sampling, top_p_decay, top_p_bound, +- temperature, add_BOS, use_eod_token_for_early_termination, +- stop_on_double_eol, +- stop_on_eol, ++ temperature, add_BOS, use_stop_tokens_for_early_termination, + prevent_newline_after_colon, + random_seed] ++ if stop_sequences != None: ++ stop_tokens = [] ++ for i, tokens in enumerate(tokenize_sequences(stop_sequences)): ++ if len(tokens) == 1: ++ stop_tokens.append(tokens[0]) ++ else: ++ print( ++ f"Stop sequence [{stop_sequences[i]}] is not supported because its tokenized length exceeds 1") ++ stop_tokens = torch.tensor(stop_tokens, dtype=torch.int64) ++ values.append(len(stop_tokens)) ++ values.extend(stop_tokens) ++ else: ++ values.append(0) + values_float_tensor = broadcast_float_list(len(values), float_list=values) + tokens_to_generate = int(values_float_tensor[0].item()) + return_output_log_probs = bool(values_float_tensor[1].item()) +@@ -107,11 +120,14 @@ def generate(model, + top_p_bound = values_float_tensor[5].item() + temperature = values_float_tensor[6].item() + add_BOS = bool(values_float_tensor[7].item()) +- use_eod_token_for_early_termination = bool(values_float_tensor[8].item()) +- stop_on_double_eol = bool(values_float_tensor[9].item()) +- stop_on_eol = bool(values_float_tensor[10].item()) +- prevent_newline_after_colon = bool(values_float_tensor[11].item()) +- random_seed = int(values_float_tensor[12].item()) ++ use_stop_tokens_for_early_termination = bool(values_float_tensor[8].item()) ++ prevent_newline_after_colon = bool(values_float_tensor[9].item()) ++ random_seed = int(values_float_tensor[10].item()) ++ stop_tokens_length = int(values_float_tensor[11].item()) ++ if stop_tokens_length > 0: ++ stop_tokens = values_float_tensor[12: 12 + stop_tokens_length].int() ++ else: ++ stop_tokens = None + + if random_seed != -1: + torch.random.manual_seed(random_seed) +@@ -120,14 +136,14 @@ def generate(model, + # Note that these tensors are broadcaseted to all ranks. + if torch.distributed.get_rank() == 0: + assert prompts is not None +- ++ + context_tokens_tensor, context_length_tensor = tokenize_prompts( + prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS) + + if tokens_to_generate == 0: + return score_and_return_on_first_stage( + model, context_tokens_tensor, context_length_tensor) +- ++ + # Main inference function. + # Note that the outputs are available on the first stage. + return generate_tokens_probs_and_return_on_first_stage( +@@ -138,10 +154,10 @@ def generate(model, + top_p_decay=top_p_decay, + top_p_bound=top_p_bound, + temperature=temperature, +- use_eod_token_for_early_termination=use_eod_token_for_early_termination, +- stop_on_double_eol=stop_on_double_eol, +- stop_on_eol=stop_on_eol, +- prevent_newline_after_colon=prevent_newline_after_colon) ++ use_stop_tokens_for_early_termination=use_stop_tokens_for_early_termination, ++ stop_tokens=stop_tokens, ++ prevent_newline_after_colon=prevent_newline_after_colon, ++ echo_prompts=echo_prompts) + + def beam_search_and_post_process(model, + prompts=None, +diff --git a/megatron/text_generation/generation.py b/megatron/text_generation/generation.py +index 098706e..4e73d0e 100644 +--- a/megatron/text_generation/generation.py ++++ b/megatron/text_generation/generation.py +@@ -35,7 +35,7 @@ def score_and_return_on_first_stage(model, tokens, lengths): + assert max_prompt_length == tokens.size(1) + + if max_prompt_length > args.max_position_embeddings: +- raise ValueError("Length of prompt + tokens_to_generate longer than allowed") ++ raise ValueError(f"Length of prompt + tokens_to_generate ({max_prompt_length}) longer than allowed ({args.max_position_embeddings})") + + if max_prompt_length * batch_size > args.max_tokens_to_oom: + raise ValueError("Too many tokens. " + str(max_prompt_length*batch_size)+ " is greater than "+str(args.max_tokens_to_oom)) +@@ -90,10 +90,10 @@ def generate_tokens_probs_and_return_on_first_stage( + return_output_log_probs=False, + top_k=0, top_p=0.0, top_p_decay=0.0, top_p_bound=0.0, + temperature=1.0, +- use_eod_token_for_early_termination=True, +- stop_on_double_eol=False, +- stop_on_eol=False, +- prevent_newline_after_colon=True ++ use_stop_tokens_for_early_termination=True, ++ stop_tokens=None, ++ prevent_newline_after_colon=True, ++ echo_prompts=False + ): + """Main token generation function. + Arguments: +@@ -109,8 +109,8 @@ def generate_tokens_probs_and_return_on_first_stage( + if top-k > 0 then we expect top-p=0. + if top-p > 0 then we check for top-k=0. + temperature: sampling temperature. +- use_eod_token_for_early_termination: if True, do early termination if +- all the sequences have reached this token. ++ use_stop_tokens_for_early_termination: if True, do early termination if ++ all the sequences have reached stop tokens. + prevent_newline_after_colon: if True, it will disable generating new line \n after : + Note: Outside of model, other parameters only need to be available on + rank 0. +@@ -130,7 +130,7 @@ def generate_tokens_probs_and_return_on_first_stage( + max_sequence_length = tokens.size(1) + + if max_sequence_length > args.max_position_embeddings: +- raise ValueError("Length of prompt + tokens_to_generate longer than allowed") ++ raise ValueError(f"Length of prompt + tokens_to_generate ({max_sequence_length}) longer than allowed ({args.max_position_embeddings})") + + if max_sequence_length * batch_size > args.max_tokens_to_oom: + raise ValueError("Too many tokens. " + str(max_sequence_length*batch_size)+ " is greater than "+str(args.max_tokens_to_oom)) +@@ -213,19 +213,18 @@ def generate_tokens_probs_and_return_on_first_stage( + # Calculate the log probabilities. + if return_output_log_probs: + log_probs = F.log_softmax(logits, dim=2) +- if return_output_log_probs: +- # Pick the tokens that we need to get the log +- # probabilities for. Note that next input token is +- # the token which we selected in the current logits, +- # so shift by 1. +- indices = torch.unsqueeze( +- tokens[ +- :, +- (prev_context_length + 1):(context_length + 1)], +- 2) +- output_log_probs[:, +- prev_context_length:context_length] = \ +- torch.gather(log_probs, 2, indices).squeeze(2) ++ # Pick the tokens that we need to get the log ++ # probabilities for. Note that next input token is ++ # the token which we selected in the current logits, ++ # so shift by 1. ++ indices = torch.unsqueeze( ++ tokens[ ++ :, ++ (prev_context_length + 1):(context_length + 1)], ++ 2) ++ output_log_probs[:, ++ prev_context_length:context_length] = \ ++ torch.gather(log_probs, 2, indices).squeeze(2) + + # Update the tokens on the first stage so the next input to + # the network is correct. +@@ -240,14 +239,11 @@ def generate_tokens_probs_and_return_on_first_stage( + if mpu.is_pipeline_last_stage(): + # TODO(rprenger) These stopping methods are tokenizer dependent + # instead tokenization should be in the inference loop so stop sequences can be used +- if stop_on_double_eol: +- hit_double_eol = (new_sample == 628).byte() & started.byte() +- hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte() +- done_token = hit_double_eol | hit_two_eols +- elif stop_on_eol: +- hit_double_eol = (new_sample == 628).byte() & started.byte() +- hit_eol = (new_sample == 198).byte() & started.byte() +- done_token = hit_double_eol | hit_eol ++ if stop_tokens is not None and len(stop_tokens) > 0: ++ done_token = torch.any( ++ new_sample.expand(stop_tokens.shape[0], new_sample.shape[0]) == stop_tokens.unsqueeze(dim=1), dim=0) \ ++ & started.byte() ++ + else: + done_token = (new_sample == termination_id).byte() & \ + started.byte() +@@ -259,7 +255,7 @@ def generate_tokens_probs_and_return_on_first_stage( + done = torch.all(is_generation_done) + done = broadcast_from_last_pipeline_stage(1, torch.uint8, + tensor=done) +- if use_eod_token_for_early_termination and done: ++ if use_stop_tokens_for_early_termination and done: + break + + # =================================================== +@@ -269,7 +265,7 @@ def generate_tokens_probs_and_return_on_first_stage( + tokens = tokens[:, :(context_length + 1)] + if mpu.is_pipeline_last_stage(): + if return_output_log_probs: +- output_log_probs = output_log_probs[:, :context_length] ++ output_log_probs = output_log_probs[:, :context_length].contiguous() + + # ====================================== + # Broadcast to the first pipeline stage. +@@ -281,7 +277,13 @@ def generate_tokens_probs_and_return_on_first_stage( + output_log_probs_size = (batch_size, context_length) + output_log_probs = broadcast_from_last_to_first_pipeline_stage( + output_log_probs_size, torch.float32, output_log_probs) +- ++ if not echo_prompts and mpu.is_pipeline_first_stage(): ++ generated_sequence_lengths -= lengths ++ for i, (sequence, length) in enumerate(zip(tokens, lengths)): ++ tokens[i] = sequence.roll(-length.item(), dims=0) ++ if return_output_log_probs: ++ for i, (prob, length) in enumerate(zip(output_log_probs, lengths)): ++ output_log_probs[i] = prob.roll(-(length.item() - 1), dims=0) + return tokens, generated_sequence_lengths, output_log_probs + + def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, stop_token, num_return_gen, length_penalty, prevent_newline_after_colon=True): +diff --git a/megatron/text_generation/tokenization.py b/megatron/text_generation/tokenization.py +index accead3..3e8caa5 100644 +--- a/megatron/text_generation/tokenization.py ++++ b/megatron/text_generation/tokenization.py +@@ -32,7 +32,7 @@ def detokenize_generations(tokens_gpu_tensor, + for token in sequence_tokens: + if args.tokenizer_type in ['SentencePieceTokenizer', + 'GPTSentencePieceTokenizer']: +- word = tokenizer.decoder[token] ++ word = tokenizer.detokenize([token]) + elif args.tokenizer_type == 'NullTokenizer': + word = str(token) + else: +@@ -49,6 +49,12 @@ def detokenize_generations(tokens_gpu_tensor, + + return tokens, prompts_plus_generations + ++def tokenize_sequences(sequences=None): ++ sequence_tokens = None ++ if sequences != None: ++ tokenizer = get_tokenizer() ++ sequence_tokens = [tokenizer.tokenize(ele) for ele in sequences] ++ return sequence_tokens + + def tokenize_prompts(prompts=None, tokens_to_generate=None, + add_BOS=None, rank=0): +diff --git a/megatron/text_generation_server.py b/megatron/text_generation_server.py +index 58550f2..60f8fc8 100644 +--- a/megatron/text_generation_server.py ++++ b/megatron/text_generation_server.py +@@ -1,13 +1,13 @@ + # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + import datetime ++import time + import torch + import json + import threading +-from flask import Flask, request, jsonify, current_app ++import asyncio ++from flask import Flask, request, jsonify + from flask_restful import Resource, Api +-from megatron import get_args + from megatron.text_generation import generate_and_post_process +-from megatron.text_generation import beam_search_and_post_process + + + GENERATE_NUM = 0 +@@ -17,6 +17,8 @@ lock = threading.Lock() + class MegatronGenerate(Resource): + def __init__(self, model): + self.model = model ++ asyncio.set_event_loop(asyncio.new_event_loop()) ++ self.loop = asyncio.get_event_loop() + + @staticmethod + def send_do_generate(): +@@ -27,209 +29,216 @@ class MegatronGenerate(Resource): + def send_do_beam_search(): + choice = torch.cuda.LongTensor([BEAM_NUM]) + torch.distributed.broadcast(choice, 0) +- ++ ++ def check(self, raw_req): ++ if not 'prompts' in raw_req: ++ return 'prompts argument required', 400 ++ if len(raw_req['prompts']) == 0: ++ return "prompts is empty", 400 ++ if len(raw_req['prompts']) > 128: ++ return "Maximum number of prompts is 128", 400 ++ ++ async def generate(self, req): ++ MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate ++ start_time = time.time() ++ response, response_seg, response_logprobs, _ = \ ++ generate_and_post_process( ++ self.model, ++ prompts=req['prompts'], ++ tokens_to_generate=req['tokens_to_generate'], ++ echo_prompts=req['echo_prompts'], ++ return_output_log_probs=req['logprobs'], ++ top_k_sampling=req['top_k'], ++ top_p_sampling=req['top_p'], ++ top_p_decay=req['top_p_decay'], ++ top_p_bound=req['top_p_bound'], ++ temperature=req['temperature'], ++ add_BOS=req['add_BOS'], ++ use_stop_tokens_for_early_termination=True, ++ stop_sequences=req['stop_sequences'], ++ prevent_newline_after_colon=req['prevent_newline_after_colon'], ++ random_seed=req['random_seed']) ++ end_time = time.time() ++ print(f"Response(use {end_time - start_time}s): " + str(response)) ++ return { ++ "text": response, ++ "segments": response_seg, ++ "logprobs": response_logprobs ++ } ++ + def put(self): +- args = get_args() +- +- if not "prompts" in request.get_json(): ++ raw_req = request.get_json() ++ ++ if not "prompts" in raw_req: + return "prompts argument required", 400 + +- if "max_len" in request.get_json(): ++ if "max_len" in raw_req: + return "max_len is no longer used. Replace with tokens_to_generate", 400 + +- if "sentences" in request.get_json(): ++ if "sentences" in raw_req: + return "sentences is no longer used. Replace with prompts", 400 + +- prompts = request.get_json()["prompts"] +- if not isinstance(prompts, list): ++ if isinstance(raw_req["prompts"], str): ++ raw_req['prompts'] = [raw_req['prompts']] ++ ++ if not isinstance(raw_req["prompts"], list): + return "prompts is not a list of strings", 400 + +- if len(prompts) == 0: ++ if len(raw_req['prompts']) == 0: + return "prompts is empty", 400 + +- if len(prompts) > 128: ++ if len(raw_req['prompts']) > 128: + return "Maximum number of prompts is 128", 400 + +- tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow +- if "tokens_to_generate" in request.get_json(): +- tokens_to_generate = request.get_json()["tokens_to_generate"] +- if not isinstance(tokens_to_generate, int): ++ if 'tokens_to_generate' in raw_req: ++ if not isinstance(raw_req['tokens_to_generate'], int): + return "tokens_to_generate must be an integer greater than 0" +- if tokens_to_generate < 0: ++ if raw_req['tokens_to_generate'] < 0: + return "tokens_to_generate must be an integer greater than or equal to 0" ++ else: ++ raw_req['tokens_to_generate'] = 64 + + logprobs = False +- if "logprobs" in request.get_json(): +- logprobs = request.get_json()["logprobs"] ++ if "logprobs" in raw_req: ++ logprobs = raw_req["logprobs"] + if not isinstance(logprobs, bool): + return "logprobs must be a boolean value" +- +- if tokens_to_generate == 0 and not logprobs: +- return "tokens_to_generate=0 implies logprobs should be True" +- +- temperature = 1.0 +- if "temperature" in request.get_json(): +- temperature = request.get_json()["temperature"] +- if not (type(temperature) == int or type(temperature) == float): +- return "temperature must be a positive number less than or equal to 100.0" +- if not (0.0 < temperature <= 100.0): +- return "temperature must be a positive number less than or equal to 100.0" +- ++ else: ++ raw_req['logprobs'] = False ++ ++ if raw_req['tokens_to_generate'] == 0 and not raw_req['logprobs']: ++ print("tokens_to_generate=0 implies logprobs should be True") ++ raw_req['logprobs'] = True ++ ++ if "echo_prompts" in raw_req: ++ if not isinstance(raw_req['echo_prompts'], bool): ++ return "echo_prompts must be a bool" ++ else: ++ raw_req['echo_prompts'] = False ++ + top_k = 0.0 +- if "top_k" in request.get_json(): +- top_k = request.get_json()["top_k"] ++ if "top_k" in raw_req: ++ top_k = raw_req["top_k"] + if not (type(top_k) == int): + return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000" + if not (0 <= top_k <= 1000): + return "top_k must be equal to or greater than 0 and less than or equal to 1000" ++ else: ++ raw_req['top_k'] = 0.0 + +- top_p = 0.0 +- if "top_p" in request.get_json(): +- top_p = request.get_json()["top_p"] +- if not (type(top_p) == float): ++ if "top_p" in raw_req: ++ top_p = raw_req["top_p"] ++ if not (type(top_p) == float or type(top_p) == int): + return "top_p must be a positive float less than or equal to 1.0" + if top_p > 0.0 and top_k > 0.0: + return "cannot set both top-k and top-p samplings." + if not (0 <= top_p <= 1.0): + return "top_p must be less than or equal to 1.0" ++ else: ++ raw_req['top_p'] = 0.0 + +- top_p_decay = 0.0 +- if "top_p_decay" in request.get_json(): +- top_p_decay = request.get_json()["top_p_decay"] ++ if "top_p_decay" in raw_req: ++ top_p_decay = raw_req["top_p_decay"] + if not (type(top_p_decay) == float): + return "top_p_decay must be a positive float less than or equal to 1.0" + if top_p == 0.0: + return "top_p_decay cannot be set without top_p" + if not (0 <= top_p_decay <= 1.0): + return "top_p_decay must be less than or equal to 1.0" +- ++ else: ++ raw_req['top_p_decay'] = 0.0 ++ + top_p_bound = 0.0 +- if "top_p_bound" in request.get_json(): +- top_p_bound = request.get_json()["top_p_bound"] ++ if "top_p_bound" in raw_req: ++ top_p_bound = raw_req["top_p_bound"] + if not (type(top_p_bound) == float): + return "top_p_bound must be a positive float less than or equal to top_p" + if top_p == 0.0: + return "top_p_bound cannot be set without top_p" + if not (0.0 < top_p_bound <= top_p): + return "top_p_bound must be greater than 0 and less than top_p" +- +- add_BOS = False +- if "add_BOS" in request.get_json(): +- add_BOS = request.get_json()["add_BOS"] +- if not isinstance(add_BOS, bool): ++ else: ++ raw_req['top_p_bound'] = 0.0 ++ ++ if "temperature" in raw_req: ++ temperature = raw_req["temperature"] ++ if not (type(temperature) == int or type(temperature) == float): ++ return "temperature must be a positive number less than or equal to 100.0" ++ if not (0.0 <= temperature <= 100.0): ++ return "temperature must be a positive number less than or equal to 100.0" ++ else: ++ raw_req['temperature'] = 0.0 ++ ++ if raw_req['temperature'] == 0.0: ++ raw_req['top_k'] = 1 ++ raw_req['top_p'] = 0 ++ ++ if "add_BOS" in raw_req: ++ if not isinstance(raw_req["add_BOS"], bool): + return "add_BOS must be a boolean value" ++ else: ++ raw_req['add_BOS'] = False + +- if any([len(prompt) == 0 for prompt in prompts]) and not add_BOS: ++ if any([len(prompt) == 0 for prompt in raw_req['prompts']]) and not raw_req["add_BOS"]: + return "Empty prompts require add_BOS=true" + +- stop_on_double_eol = False +- if "stop_on_double_eol" in request.get_json(): +- stop_on_double_eol = request.get_json()["stop_on_double_eol"] +- if not isinstance(stop_on_double_eol, bool): +- return "stop_on_double_eol must be a boolean value" +- +- stop_on_eol = False +- if "stop_on_eol" in request.get_json(): +- stop_on_eol = request.get_json()["stop_on_eol"] +- if not isinstance(stop_on_eol, bool): +- return "stop_on_eol must be a boolean value" +- +- prevent_newline_after_colon = False +- if "prevent_newline_after_colon" in request.get_json(): +- prevent_newline_after_colon = request.get_json()["prevent_newline_after_colon"] +- if not isinstance(prevent_newline_after_colon, bool): ++ if "stop_sequences" in raw_req: ++ if not isinstance(raw_req["stop_sequences"], list): ++ return "stop_sequences must be a str list" ++ for seq in raw_req['stop_sequences']: ++ if not isinstance(seq, str): ++ return "stop_sequences must be a str list" ++ else: ++ raw_req["stop_sequences"] = None ++ ++ if "prevent_newline_after_colon" in raw_req: ++ if not isinstance(raw_req["prevent_newline_after_colon"], bool): + return "prevent_newline_after_colon must be a boolean value" ++ else: ++ raw_req['prevent_newline_after_colon'] = False + +- random_seed = -1 +- if "random_seed" in request.get_json(): +- random_seed = request.get_json()["random_seed"] ++ if "random_seed" in raw_req: ++ random_seed = raw_req["random_seed"] + if not isinstance(random_seed, int): + return "random_seed must be integer" + if random_seed < 0: + return "random_seed must be a positive integer" ++ else: ++ raw_req['random_seed'] = 1234 + + no_log = False +- if "no_log" in request.get_json(): +- no_log = request.get_json()["no_log"] ++ if "no_log" in raw_req: ++ no_log = raw_req["no_log"] + if not isinstance(no_log, bool): + return "no_log must be a boolean value" + + beam_width = None +- if "beam_width" in request.get_json(): +- beam_width = request.get_json()["beam_width"] ++ if "beam_width" in raw_req: ++ beam_width = raw_req["beam_width"] + if not isinstance(beam_width, int): + return "beam_width must be integer" + if beam_width < 1: + return "beam_width must be an integer > 1" +- if len(prompts) > 1: ++ if len(raw_req['prompts']) > 1: + return "When doing beam_search, batch size must be 1" + +- stop_token=50256 +- if "stop_token" in request.get_json(): +- stop_token = request.get_json()["stop_token"] +- if not isinstance(stop_token, int): +- return "stop_token must be an integer" +- +- length_penalty = 1 +- if "length_penalty" in request.get_json(): +- length_penalty = request.get_json()["length_penalty"] ++ if "length_penalty" in raw_req: ++ length_penalty = raw_req["length_penalty"] + if not isinstance(length_penalty, float): + return "length_penalty must be a float" +- +- with lock: # Need to get lock to keep multiple threads from hitting code +- +- if not no_log: +- print("request IP: " + str(request.remote_addr)) +- print(json.dumps(request.get_json()),flush=True) +- print("start time: ", datetime.datetime.now()) +- +- try: +- if beam_width is not None: +- MegatronGenerate.send_do_beam_search() # Tell other ranks we're doing beam_search +- response, response_seg, response_scores = \ +- beam_search_and_post_process( +- self.model, +- prompts=prompts, +- tokens_to_generate=tokens_to_generate, +- beam_size = beam_width, +- add_BOS=add_BOS, +- stop_token=stop_token, +- num_return_gen=beam_width, # Returning whole beam +- length_penalty=length_penalty, +- prevent_newline_after_colon=prevent_newline_after_colon +- ) +- +- return jsonify({"text": response, +- "segments": response_seg, +- "scores": response_scores}) +- else: +- MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate +- response, response_seg, response_logprobs, _ = \ +- generate_and_post_process( +- self.model, +- prompts=prompts, +- tokens_to_generate=tokens_to_generate, +- return_output_log_probs=logprobs, +- top_k_sampling=top_k, +- top_p_sampling=top_p, +- top_p_decay=top_p_decay, +- top_p_bound=top_p_bound, +- temperature=temperature, +- add_BOS=add_BOS, +- use_eod_token_for_early_termination=True, +- stop_on_double_eol=stop_on_double_eol, +- stop_on_eol=stop_on_eol, +- prevent_newline_after_colon=prevent_newline_after_colon, +- random_seed=random_seed) +- +- return jsonify({"text": response, +- "segments": response_seg, +- "logprobs": response_logprobs}) +- +- except ValueError as ve: +- return ve.args[0] +- print("end time: ", datetime.datetime.now()) +- ++ else: ++ raw_req['length_penalty'] = 1 ++ ++ if not no_log: ++ print("request IP: " + str(request.remote_addr)) ++ print(json.dumps(raw_req),flush=True) ++ print("start time: ", datetime.datetime.now()) ++ try: ++ result = self.loop.run_until_complete(self.generate(raw_req)) ++ return jsonify(result) ++ except ValueError as ve: ++ return ve.args[0] ++ + + class MegatronServer(object): + def __init__(self, model): +@@ -237,5 +246,5 @@ class MegatronServer(object): + api = Api(self.app) + api.add_resource(MegatronGenerate, '/api', resource_class_args=[model]) + +- def run(self, url): +- self.app.run(url, threaded=True, debug=False) ++ def run(self, host, port): ++ self.app.run(host=host, port=port, threaded=True, debug=False) +diff --git a/megatron/timers.py b/megatron/timers.py +index a9478fa..8399ad9 100644 +--- a/megatron/timers.py ++++ b/megatron/timers.py +@@ -287,7 +287,7 @@ class Timers: + print(output_string, flush=True) + + +- def write(self, names, writer, iteration, normalizer=1.0, ++ def write(self, names, writer, wandb, iteration, normalizer=1.0, + reset=False, barrier=False): + """Write timers to a tensorboard writer + Note that we only report maximum time across ranks to tensorboard. +@@ -302,3 +302,9 @@ class Timers: + for name in name_to_min_max_time: + _, max_time = name_to_min_max_time[name] + writer.add_scalar(name + '-time', max_time, iteration) ++ if wandb is not None: ++ wandb_log_dic = {} ++ for name in name_to_min_max_time: ++ _, max_time = name_to_min_max_time[name] ++ wandb_log_dic[f'timer/{name}'] = max_time ++ wandb.log(wandb_log_dic, iteration) +diff --git a/megatron/training.py b/megatron/training.py +index b821ae7..e0dfc54 100644 +--- a/megatron/training.py ++++ b/megatron/training.py +@@ -15,6 +15,7 @@ from megatron import get_args + from megatron import get_signal_handler + from megatron import get_timers + from megatron import get_tensorboard_writer ++from megatron import get_wandb + from megatron import get_current_global_batch_size + from megatron import get_num_microbatches + from megatron import is_last_rank +@@ -171,14 +172,14 @@ def pretrain(train_valid_test_dataset_provider, + evaluate_and_print_results(prefix, forward_step_func, + valid_data_iterator, model, + iteration, process_non_loss_data_func, config, +- verbose=True, write_to_tensorboard=not args.skip_train) ++ verbose=True, write_to_tracker=not args.skip_train) + + if args.do_test: + prefix = f'iteration {iteration} on test set' + evaluate_and_print_results(prefix, forward_step_func, + test_data_iterator, model, + iteration, process_non_loss_data_func, config, +- verbose=True, write_to_tensorboard=not args.skip_train) ++ verbose=True, write_to_tracker=not args.skip_train) + + + def update_train_iters(args): +@@ -504,6 +505,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, + args = get_args() + timers = get_timers() + writer = get_tensorboard_writer() ++ wandb = get_wandb() + + # Advanced, skipped, and Nan iterations. + advanced_iters_key = 'advanced iterations' +@@ -571,68 +573,97 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, + + # Tensorboard values. + # Timer requires all the ranks to call. +- if args.log_timers_to_tensorboard and \ +- (iteration % args.tensorboard_log_interval == 0): +- timers.write(timers_to_log, writer, iteration, ++ if args.log_timers_to_tracker and \ ++ (iteration % args.tracker_log_interval == 0): ++ timers.write(timers_to_log, writer, wandb, iteration, + normalizer=total_iterations) +- if writer and (iteration % args.tensorboard_log_interval == 0): +- if args.log_learning_rate_to_tensorboard: +- writer.add_scalar('learning-rate', learning_rate, iteration) +- writer.add_scalar('learning-rate vs samples', learning_rate, +- args.consumed_train_samples) +- if args.log_batch_size_to_tensorboard: +- writer.add_scalar('batch-size', batch_size, iteration) +- writer.add_scalar('batch-size vs samples', batch_size, +- args.consumed_train_samples) +- for key in loss_dict: +- writer.add_scalar(key , loss_dict[key], iteration) +- writer.add_scalar(key + ' vs samples', loss_dict[key], +- args.consumed_train_samples) +- if args.log_loss_scale_to_tensorboard: +- writer.add_scalar('loss-scale', loss_scale, iteration) +- writer.add_scalar('loss-scale vs samples', loss_scale, +- args.consumed_train_samples) +- if args.log_world_size_to_tensorboard: +- writer.add_scalar('world-size', args.world_size, iteration) +- writer.add_scalar('world-size vs samples', args.world_size, +- args.consumed_train_samples) +- if grad_norm is not None: +- writer.add_scalar('grad-norm', grad_norm, iteration) +- writer.add_scalar('grad-norm vs samples', grad_norm, +- args.consumed_train_samples) +- if num_zeros_in_grad is not None: +- writer.add_scalar('num-zeros', num_zeros_in_grad, iteration) +- writer.add_scalar('num-zeros vs samples', num_zeros_in_grad, +- args.consumed_train_samples) +- if params_norm is not None: +- writer.add_scalar('params-norm', params_norm, iteration) +- writer.add_scalar('params-norm vs samples', params_norm, +- args.consumed_train_samples) +- if args.log_memory_to_tensorboard: +- mem_stats = torch.cuda.memory_stats() +- writer.add_scalar( +- "mem-reserved-bytes", +- mem_stats["reserved_bytes.all.current"], +- iteration, +- ) +- writer.add_scalar( +- "mem-allocated-bytes", +- mem_stats["allocated_bytes.all.current"], +- iteration, +- ) +- writer.add_scalar( +- "mem-allocated-count", +- mem_stats["allocation.all.current"], +- iteration, +- ) ++ if iteration % args.tracker_log_interval == 0: ++ if writer: ++ if args.log_learning_rate_to_tracker: ++ writer.add_scalar('learning-rate', learning_rate, iteration) ++ writer.add_scalar('learning-rate vs samples', learning_rate, ++ args.consumed_train_samples) ++ if args.log_batch_size_to_tracker: ++ writer.add_scalar('batch-size', batch_size, iteration) ++ writer.add_scalar('batch-size vs samples', batch_size, ++ args.consumed_train_samples) ++ for key in loss_dict: ++ writer.add_scalar(key , loss_dict[key], iteration) ++ writer.add_scalar(key + ' vs samples', loss_dict[key], ++ args.consumed_train_samples) ++ if args.log_loss_scale_to_tracker: ++ writer.add_scalar('loss-scale', loss_scale, iteration) ++ writer.add_scalar('loss-scale vs samples', loss_scale, ++ args.consumed_train_samples) ++ if args.log_world_size_to_tracker: ++ writer.add_scalar('world-size', args.world_size, iteration) ++ writer.add_scalar('world-size vs samples', args.world_size, ++ args.consumed_train_samples) ++ if grad_norm is not None: ++ writer.add_scalar('grad-norm', grad_norm, iteration) ++ writer.add_scalar('grad-norm vs samples', grad_norm, ++ args.consumed_train_samples) ++ if num_zeros_in_grad is not None: ++ writer.add_scalar('num-zeros', num_zeros_in_grad, iteration) ++ writer.add_scalar('num-zeros vs samples', num_zeros_in_grad, ++ args.consumed_train_samples) ++ if params_norm is not None: ++ writer.add_scalar('params-norm', params_norm, iteration) ++ writer.add_scalar('params-norm vs samples', params_norm, ++ args.consumed_train_samples) ++ if args.log_memory_to_tracker: ++ writer.add_scalar( ++ "mem-reserved-bytes", ++ mem_stats["reserved_bytes.all.current"], ++ iteration, ++ ) ++ writer.add_scalar( ++ "mem-allocated-bytes", ++ mem_stats["allocated_bytes.all.current"], ++ iteration, ++ ) ++ writer.add_scalar( ++ "mem-allocated-count", ++ mem_stats["allocation.all.current"], ++ iteration, ++ ) ++ if wandb: ++ wandb_log_dic = {} ++ if args.log_learning_rate_to_tracker: ++ wandb_log_dic['train/learning_rate'] = learning_rate ++ if args.log_batch_size_to_tracker: ++ wandb_log_dic['train/global_batch_size'] = batch_size ++ for key in loss_dict: ++ wandb_log_dic[f'train/{key}'] = loss_dict[key] ++ if args.log_loss_scale_to_tracker: ++ wandb_log_dic['train/loss_scale'] = loss_scale ++ if args.log_world_size_to_tracker: ++ wandb_log_dic['train/world_size'] = args.world_size ++ if grad_norm is not None: ++ wandb_log_dic['train/grad_norm'] = grad_norm ++ if num_zeros_in_grad is not None: ++ wandb_log_dic['train/num_zeros_in_grad'] = num_zeros_in_grad ++ if params_norm is not None: ++ wandb_log_dic['train/params_norm'] = params_norm ++ if args.log_memory_to_tracker: ++ mem_stats = torch.cuda.memory_stats() ++ wandb_log_dic['train/mem-reserved-bytes'] = mem_stats["reserved_bytes.all.current"] ++ wandb_log_dic['train/mem-allocated-bytes'] = mem_stats["allocated_bytes.all.current"] ++ wandb_log_dic['train/mem-allocated-count'] = mem_stats["allocation.all.current"] ++ wandb.log(wandb_log_dic, iteration) + + if iteration % args.log_interval == 0: + elapsed_time = timers('interval-time').elapsed(barrier=True) + elapsed_time_per_iteration = elapsed_time / total_iterations +- if writer: +- if args.log_timers_to_tensorboard: ++ samples_per_second = batch_size / elapsed_time_per_iteration ++ tokens_per_second = samples_per_second * args.seq_length ++ wandb_log_dic = {} ++ if args.log_timers_to_tracker: ++ if writer: + writer.add_scalar('iteration-time', + elapsed_time_per_iteration, iteration) ++ if wandb: ++ wandb_log_dic['timer/time_per_iteration'] = elapsed_time_per_iteration + log_string = ' iteration {:8d}/{:8d} |'.format( + iteration, args.train_iters) + log_string += ' consumed samples: {:12d} |'.format( +@@ -641,6 +672,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, + elapsed_time_per_iteration * 1000.0) + log_string += ' learning rate: {:.3E} |'.format(learning_rate) + log_string += ' global batch size: {:5d} |'.format(batch_size) ++ log_string += ' samples per second: {:.3f} |'.format(samples_per_second) ++ log_string += ' tokens per second: {:.3f} |'.format(tokens_per_second) + for key in total_loss_dict: + if key not in [advanced_iters_key, skipped_iters_key, + nan_iters_key]: +@@ -651,7 +684,11 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, + total_loss_dict[key] = torch.cuda.FloatTensor([0.0]) + log_string += ' loss scale: {:.1f} |'.format(loss_scale) + if grad_norm is not None: +- log_string += ' grad norm: {:.3f} |'.format(grad_norm) ++ if isinstance(grad_norm, dict): ++ log_string += ' total grad norm: {:.3f} |'.format(grad_norm['total_grad_norm']) ++ log_string += ' embedding grad norm: {:.3f} |'.format(grad_norm['embed_grad_norm']) ++ else: ++ log_string += ' grad norm: {:.3f} |'.format(grad_norm) + if num_zeros_in_grad is not None: + log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad) + if params_norm is not None: +@@ -660,6 +697,17 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, + total_loss_dict[skipped_iters_key]) + log_string += ' number of nan iterations: {:3d} |'.format( + total_loss_dict[nan_iters_key]) ++ if writer: ++ writer.add_scalar('samples_per_second', samples_per_second, iteration) ++ writer.add_scalar('tokens_per_second', tokens_per_second, iteration) ++ writer.add_scalar('skipped_iterations', total_loss_dict[skipped_iters_key], iteration) ++ writer.add_scalar('nan_iterations', total_loss_dict[nan_iters_key], iteration) ++ if wandb: ++ wandb_log_dic['train/samples_per_second'] = samples_per_second ++ wandb_log_dic['train/tokens_per_second'] = samples_per_second * args.seq_length ++ wandb_log_dic['train/skipped_iterations'] = total_loss_dict[skipped_iters_key] ++ wandb_log_dic['train/nan_iterations'] = total_loss_dict[nan_iters_key] ++ wandb.log(wandb_log_dic, iteration) + total_loss_dict[advanced_iters_key] = 0 + total_loss_dict[skipped_iters_key] = 0 + total_loss_dict[nan_iters_key] = 0 +@@ -889,10 +937,10 @@ def evaluate(forward_step_func, + def evaluate_and_print_results(prefix, forward_step_func, + data_iterator, model, + iteration, process_non_loss_data_func, config, +- verbose=False, write_to_tensorboard=True): ++ verbose=False, write_to_tracker=True): + """Helper function to evaluate and dump results on screen.""" + args = get_args() +- if write_to_tensorboard: ++ if write_to_tracker: + writer = get_tensorboard_writer() + else: + writer = None +@@ -912,7 +960,7 @@ def evaluate_and_print_results(prefix, forward_step_func, + writer.add_scalar('{} validation vs samples'.format(key), + total_loss_dict[key].item(), + args.consumed_train_samples) +- if args.log_validation_ppl_to_tensorboard: ++ if args.log_validation_ppl_to_tracker: + writer.add_scalar('{} validation ppl'.format(key), ppl, + iteration) + writer.add_scalar('{} validation ppl vs samples'.format(key), +diff --git a/tools/inference.py b/tools/inference.py +new file mode 100644 +index 0000000..2d83fc6 +--- /dev/null ++++ b/tools/inference.py +@@ -0,0 +1,207 @@ ++ ++"""Inference tools.""" ++from abc import ABC, abstractmethod ++import time ++import jsonlines ++import torch ++import os ++import sys ++sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ++ os.path.pardir))) ++from megatron import get_args, print_rank_0 ++from megatron.core import mpu ++from megatron.checkpointing import load_checkpoint ++from megatron.model import GPTModel ++from megatron.initialize import initialize_megatron ++from megatron.training import get_model ++from megatron.text_generation import generate_and_post_process ++ ++ ++class JsonlineReader(): ++ def __init__(self, input): ++ self.input = jsonlines.open(input, 'r') ++ ++ def __iter__(self): ++ return self ++ ++ def __next__(self): ++ try: ++ line = self.input.read(skip_empty=True, allow_none=False) ++ except EOFError: ++ raise StopIteration ++ return line ++ ++ ++class JsonlineWriter(): ++ def __init__(self, output): ++ self.output = jsonlines.open(output, 'w', flush=True) ++ ++ def write(self, response): ++ self.output.write(response) ++ ++ ++class AbstractFormatter(ABC): ++ ++ @abstractmethod ++ def format_input(self, request): ++ raise NotImplementedError( ++ f'FORMAT_INPUT is not provided for {self.name}') ++ ++ @abstractmethod ++ def format_output(self, request, response): ++ raise NotImplementedError( ++ f'FORMAT_OUTPUT is not provided for {self.name}') ++ ++ ++class Formatter(AbstractFormatter): ++ """Default formatter implementation""" ++ ++ def format_input(self, request): ++ if not isinstance(request['text'], list): ++ request['text'] = [request['text']] ++ return request ++ ++ def format_output(self, request, response): ++ return { ++ 'prompt': request['text'], ++ 'text': response['text'], ++ 'segments': response['segments'], ++ 'logprobs': response['logprobs'] ++ } ++ ++ ++class GPTEvalFormatter(Formatter): ++ """Formatter for FastChat llm-judge""" ++ ++ def format_input(self, request): ++ if not isinstance(request['text'], list): ++ request['text'] = [request['text']] ++ request['text'] = [f"Question:{text}\\n\\nAnswer:" for text in request['text']] ++ return request ++ ++ def format_output(self, request, response): ++ return { ++ 'question_id': request['question_id'], ++ 'text': response['text'][0], ++ 'model_id': response['model_name'], ++ 'metadata': {} ++ } ++ ++ ++def load_formatter(name): ++ if name == None: ++ return Formatter() ++ elif name == 'gpt_eval': ++ return GPTEvalFormatter() ++ else: ++ raise NotImplementedError(f"Formatter for {name} is not implemented") ++ ++ ++RUN_SIG = 0 ++STOP_SIG = 1 ++ ++def run_infer(model, prompts, writer, formatter): ++ start_time = time.time() ++ state = torch.cuda.LongTensor([RUN_SIG]) ++ torch.distributed.broadcast(state, 0) ++ request = [prompt['text'][0] for prompt in prompts] ++ print(f'request: {request}') ++ texts, segments, logprobs, _ = generate_and_post_process( ++ model, ++ prompts=request, ++ tokens_to_generate=args.tokens_to_generate, ++ echo_prompts=args.echo_prompts, ++ return_output_log_probs=args.log_probs, ++ top_k_sampling=args.top_k, ++ top_p_sampling=args.top_p, ++ temperature=args.temperature ++ ) ++ end_time = time.time() ++ print(f'response: {texts}') ++ print(f'inference time: {end_time - start_time}') ++ for i, prompt in enumerate(prompts): ++ result = formatter.format_output( ++ request=prompt, ++ response={ ++ 'text': [texts[i]], ++ 'segments': [segments[i]] if segments is not None else None, ++ 'logprobs': [logprobs[i]] if logprobs is not None else None, ++ 'model_name': args.model_name ++ }) ++ writer.write(result) ++ ++def infer(model, args): ++ if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0: ++ reader = JsonlineReader(args.input) ++ writer = JsonlineWriter(args.output) ++ formatter = load_formatter(args.formatter) ++ prompts = [] ++ for prompt in reader: ++ prompts.append(formatter.format_input(prompt)) ++ if len(prompts) >= args.batch: ++ run_infer(model, prompts, writer, formatter) ++ prompts.clear() ++ if len(prompts) > 0: ++ run_infer(model, prompts, writer, formatter) ++ state = torch.cuda.LongTensor([STOP_SIG]) ++ torch.distributed.broadcast(state, 0) ++ else: ++ while True: ++ state = torch.cuda.LongTensor(1) ++ torch.distributed.broadcast(state, 0) ++ if state[0].item() == RUN_SIG: ++ generate_and_post_process(model) ++ else: ++ break ++ print(f"rank {torch.distributed.get_rank()} finish inference") ++ ++ ++def model_provider(pre_process=True, post_process=True): ++ """Build the model.""" ++ ++ print_rank_0('building GPT model ...') ++ model = GPTModel(num_tokentypes=0, parallel_output=False, ++ pre_process=pre_process, post_process=post_process) ++ ++ return model ++ ++ ++def add_inference_args(parser): ++ group = parser.add_argument_group(title='inference') ++ group.add_argument('--input', type=str, required=True) ++ group.add_argument('--output', type=str, required=True) ++ group.add_argument('--formatter', type=str, default=None) ++ group.add_argument('--tokens-to-generate', type=int, default=512) ++ group.add_argument('--top-k', type=int, default=0) ++ group.add_argument('--top-p', type=float, default=0) ++ group.add_argument('--temperature', type=float, default=1.0) ++ group.add_argument('--log-probs', type=bool, default=False) ++ group.add_argument('--echo-prompts', type=bool, default=False) ++ group.add_argument('--batch', type=int, default=1) ++ group.add_argument('--model-name', type=str, default='my_llm') ++ return parser ++ ++ ++def check_args(args): ++ if args.temperature == 0.0: ++ args.top_p = 0.0 ++ args.top_k = 1 ++ assert args.temperature >= 0.0 and args.temperature <= 100.0, 'temperature must be a positive number less than or equal to 100.0' ++ ++ ++if __name__ == '__main__': ++ initialize_megatron(extra_args_provider=add_inference_args, ++ args_defaults={ ++ 'no_load_rng': True, ++ 'no_load_optim': True, ++ 'use_checkpoint_args': True ++ }) ++ args = get_args() ++ check_args(args) ++ # todo: support interleaved pipeline schedule ++ model = get_model(model_provider, wrap_with_ddp=False) ++ if args.load is not None: ++ _ = load_checkpoint(model, None, None) ++ assert len(model) == 1, "Load checkpoint failed" ++ model = model[0] ++ infer(model, args) +diff --git a/tools/run_text_generation_server.py b/tools/run_text_generation_server.py +index 5278915..f9757c8 100644 +--- a/tools/run_text_generation_server.py ++++ b/tools/run_text_generation_server.py +@@ -5,7 +5,6 @@ import os + import sys + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), + os.path.pardir))) +-import socket + from megatron import get_args + from megatron import print_rank_0 + from megatron.core import mpu +@@ -31,15 +30,8 @@ def model_provider(pre_process=True, post_process=True): + + def add_text_generate_args(parser): + group = parser.add_argument_group(title='text generation') +- +- group.add_argument("--temperature", type=float, default=1.0, +- help='Sampling temperature.') +- group.add_argument("--top_p", type=float, default=0.0, +- help='Top p sampling.') +- group.add_argument("--top_k", type=int, default=0, +- help='Top k sampling.') +- group.add_argument("--out-seq-length", type=int, default=1024, +- help='Size of the output generated text.') ++ group.add_argument("--port", type=int, default=5000, ++ help='Text generation server port.') + return parser + + +@@ -66,7 +58,7 @@ if __name__ == "__main__": + model = model[0] + if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0: + server = MegatronServer(model) +- server.run("0.0.0.0") ++ server.run("0.0.0.0", args.port) + + while True: + choice = torch.cuda.LongTensor(1) diff --git a/thirdparty/setup_helm.sh b/thirdparty/setup_helm.sh new file mode 100755 index 000000000..2d2650466 --- /dev/null +++ b/thirdparty/setup_helm.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +export THIRD_PARTY_DIR=$(cd $(dirname $0); pwd) +export HELM_DIR=${THIRD_PARTY_DIR}/helm + +# install conda +conda &> /dev/null +if [ $? -ne 0 ]; then + echo "> setup conda ..." + CONDA_DIR=${HOME}/miniconda3 + wget https://repo.anaconda.com/miniconda/Miniconda3-py38_23.1.0-1-Linux-x86_64.sh + bash Miniconda3-py38_23.1.0-1-Linux-x86_64.sh -b -p $CONDA_DIR + export PATH=$CONDA_DIR/bin:$PATH +fi + +# setup helm +echo "> setup helm ..." +git clone https://github.com/stanford-crfm/helm.git +cd $HELM_DIR +git reset 33ca6e62 --hard +git apply ${THIRD_PARTY_DIR}/patch/helm.diff +conda create -n crfm-helm python=3.8 pip -y +eval "$(conda shell.bash hook)" +conda activate crfm-helm +pip install -e . \ No newline at end of file diff --git a/thirdparty/setup_megatron.sh b/thirdparty/setup_megatron.sh new file mode 100755 index 000000000..78b4edc1b --- /dev/null +++ b/thirdparty/setup_megatron.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +export THIRD_PARTY_DIR=$(cd $(dirname $0); pwd) +export MEGATRON_DIR=${THIRD_PARTY_DIR}/Megatron-LM + + +# setup megatron +echo "> setup Megatron-LM ..." +git clone https://github.com/NVIDIA/Megatron-LM.git +cd $MEGATRON_DIR +git reset 040eac9 --hard +git apply ${THIRD_PARTY_DIR}/patch/megatron.diff +pip install flash-attn flask flask_restful jsonlines asyncio wandb sentencepiece diff --git a/tools/.DS_Store b/tools/.DS_Store new file mode 100644 index 000000000..1dc7bbd41 Binary files /dev/null and b/tools/.DS_Store differ diff --git a/tools/analyze_data.py b/tools/analyze_data.py new file mode 100644 index 000000000..5d8db6e54 --- /dev/null +++ b/tools/analyze_data.py @@ -0,0 +1,13 @@ +from loguru import logger + +from data_juicer.core import Analyser + + +@logger.catch +def main(): + analyser = Analyser() + analyser.run() + + +if __name__ == '__main__': + main() diff --git a/tools/converter/batch_convert.sh b/tools/converter/batch_convert.sh new file mode 100644 index 000000000..62cfbedbc --- /dev/null +++ b/tools/converter/batch_convert.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +set -e + +WORKDIR=/home/data/gaodawei.gdw/stanford_alpaca + +MODEL_TO_CONVERT=( +) + +PATH_TO_SAVE=( +) + +for i in "${!MODEL_TO_CONVERT[@]}"; do + path_model=${MODEL_TO_CONVERT[i]} + path_save=${PATH_TO_SAVE[i]} + + echo $i ":" $path_model "to" $path_save + + python ${WORKDIR}/convert/convert_gpt_to_transformers.py \ + --load_path ${path_model} \ + --save_path ${path_save} \ + --max_shard_size "10GB" \ + --tokenizer_name "decapoda-research/llama-7b-hf" \ + --print-checkpoint-structure +done diff --git a/tools/converter/convert_gpt_to_transformers.py b/tools/converter/convert_gpt_to_transformers.py new file mode 100644 index 000000000..07308aeed --- /dev/null +++ b/tools/converter/convert_gpt_to_transformers.py @@ -0,0 +1,605 @@ +# Some code here has been modified from: +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/configuration_llama.py +# -------------------------------------------------------- + +# Data-Juicer adopts Apache 2.0 license, the original license of this file +# is as follows: +# -------------------------------------------------------- +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import sys +import os +import re + +import types + +import torch + +from transformers import AutoTokenizer, LlamaConfig +from modeling_megatron_llama import MegatronLlamaConfig +from transformers.modeling_utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, shard_checkpoint + + +def add_checkpointing_args(parser): + parser.add_argument("--megatron-path", type=str, default=None, help="Base directory of Megatron repository") + parser.add_argument( + "--load_path", + type=str, + required=True, + help="Path to the checkpoint to convert.", + ) + parser.add_argument( + "--save_path", + type=str, + required=True, + help="Path to the converted checkpoint.", + ) + parser.add_argument("--print-checkpoint-structure", action="store_true") + return parser + + +def add_transformers_checkpoint_args(parser): + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help=( + "The name of the pre-trained tokenizer to save. " + "If not None, the tokenizer will be saved. " + "Only used when converting a Megatron checkpoint to a Transformers checkpoint." + ), + ) + parser.add_argument( + "--max_shard_size", + type=str, + default="10GB", + help=( + "The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size " + "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`). " + "Only used when converting a Megatron checkpoint to a Transformers checkpoint." + ), + ) + + return parser + + +# The simple map of names for "automated" rules. +megatron_to_transformers = { + "attention.dense": ".self_attn.o_proj.", + "self_attention.dense": ".self_attn.o_proj.", + # TODO: one to two vectors + "mlp.dense_h_to_4h": ".mlp.{}_proj.", + "mlp.dense_4h_to_h": ".mlp.down_proj.", +} +transformers_to_megatron = {v[1:-1]: k for k, v in megatron_to_transformers.items()} + +tensor_parallel_params = [ + # megatron-lm layers to merge across tp ranks + "self_attention.query_key_value.weight", + "self_attention.query_key_value.bias", + "self_attention.dense.weight", + "mlp.dense_h_to_4h.weight", + "mlp.dense_h_to_4h.bias", + "mlp.dense_4h_to_h.weight", + # deprecated + "attention.query_key_value.weight", + "attention.query_key_value.bias", + "attention.dense.weight", + # transformers layers to split across tp ranks + "attn.c_attn.weight", + "attn.c_attn.bias", + "attn.c_proj.weight", + "mlp.c_fc.weight", + "mlp.c_fc.bias", + "mlp.c_proj.weight", +] + + +def recursive_print(name, val, spaces=0): + """ + Recursively print the structure of a checkpoint. This function is taken from `convert_megatron_gpt2_checkpoint.py` + + Args: + name (str): the name of the current tensor parameter + val (Tuple(int)): the shape of the current tensor parameter + spaces (int): the number of spaces to print before the output for a nested structure + """ + # Format the message. + if name is None: + msg = None + else: + fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}" + msg = fmt.format(name) + + # Print and recurse (if needed). + if isinstance(val, dict): + if msg is not None: + print(msg) + for k in val.keys(): + recursive_print(k, val[k], spaces + 2) + elif isinstance(val, torch.Tensor): + print(msg, ":", val.size()) + else: + print(msg, ":", val) + + +def megatron_to_transformers_fix_query_key_value_ordering( + param, checkpoint_version, num_splits, num_heads, hidden_size +): + """ + Permutes layout of param tensor to [num_splits * num_heads * hidden_size, :] for compatibility with later versions + of NVIDIA Megatron-LM. The inverse operation is performed inside Megatron-LM to read checkpoints: + https://github.com/NVIDIA/Megatron-LM/blob/v2.4/megatron/checkpointing.py#L209 If param is the weight tensor of the + self-attention block, the returned tensor will have to be transposed one more time to be read by HuggingFace GPT2. + This function is taken from `convert_megatron_gpt2_checkpoint.py` + + Args: + param (torch.Tensor): the tensor to permute + checkpoint_version (int): the version of the checkpoint. + num_splits (int): the number of projections, usually 3 for (Query, Key, Value) + num_heads (int): the number of attention heads + hidden_size (int): the hidden size per head + """ + + input_shape = param.size() + if checkpoint_version == 1.0: + # version 1.0 stores [num_heads * hidden_size * num_splits, :] + saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:] + param = param.view(*saved_shape) + param = param.transpose(0, 2) + param = param.transpose(1, 2).contiguous() + elif checkpoint_version >= 2.0: + # other versions store [num_heads * num_splits * hidden_size, :] + saved_shape = (num_heads, num_splits, hidden_size) + input_shape[1:] + param = param.view(*saved_shape) + param = param.transpose(0, 1).contiguous() + param = param.view(*input_shape) + return param + + +def transformers_to_megatron_fix_query_key_value_ordering( + param, checkpoint_version, num_splits, num_heads, hidden_size +): + """ + Permutes layout of param tensor to the one compatible with respective NVIDIA Megatron-LM chekpoint versions. Input + is [num_splits * num_heads * hidden_size, :] and output is [num_heads * hidden_size * num_splits, :] for version + 1.0 and [num_heads * num_splits * hidden_size, :] for version 2.0 and later. If param is the weight tensor of the + self-attention block, the param needs to be already transposed before calling this function. + + Args: + param (torch.Tensor): the tensor to permute + checkpoint_version (int): the version of the checkpoint. + num_splits (int): the number of projections, usually 3 for (Query, Key, Value) + num_heads (int): the number of attention heads + hidden_size (int): the hidden size per head + """ + + # Input is [num_splits * num_heads * hidden_size, :] + input_shape = param.size() + if checkpoint_version == 1.0: + # version 1.0 stores [num_heads * hidden_size * num_splits, :] + current_shape = (num_splits, num_heads, hidden_size) + input_shape[1:] + param = param.view(*current_shape) + param = param.transpose(0, 2) + param = param.transpose(1, 2).contiguous() + elif checkpoint_version >= 2.0: + # other versions store [num_heads * num_splits * hidden_size, :] + current_shape = (num_splits, num_heads, hidden_size) + input_shape[1:] + param = param.view(*current_shape) + param = param.transpose(0, 1).contiguous() + param = param.view(*input_shape) + return param + + +def merge_transformers_sharded_states(path, num_checkpoints): + """ + Merge sharded checkpoints from transformers into a single checkpoint. + + Args: + path (str): the path to the sharded checkpoints + num_checkpoints (int): the number of checkpoints to merge + """ + state_dict = {} + for i in range(1, num_checkpoints + 1): + checkpoint_path = os.path.join(path, f"pytorch_model-{i:05d}-of-{num_checkpoints:05d}.bin") + current_chunk = torch.load(checkpoint_path, map_location="cpu") + state_dict.update(current_chunk) + return state_dict + + +def get_megatron_sharded_states(args, tp_size, pp_size, pp_rank): + """ + Get sharded checkpoints from NVIDIA Megatron-LM checkpoint based on the provided tensor parallel size, pipeline + parallel size and pipeline parallel rank. + + Args: + args (argparse.Namespace): the arguments to the script + tp_size (int): the tensor parallel size + pp_size (int): the pipeline parallel size + pp_rank (int): the pipeline parallel rank + """ + tp_state_dicts = [] + for i in range(tp_size): + sub_dir_name = f"mp_rank_{i:02d}" if pp_size == 1 else f"mp_rank_{i:02d}_{pp_rank:03d}" + checkpoint_name = os.listdir(os.path.join(args.load_path, sub_dir_name))[0] + checkpoint_path = os.path.join(args.load_path, sub_dir_name, checkpoint_name) + state_dict = torch.load(checkpoint_path, map_location="cpu") + tp_state_dicts.append(state_dict) + return tp_state_dicts + + +def get_element_from_dict_by_path(d, path): + """ + Get element from dictionary by path. If element is not present, recursively add empty dictionaries. + + Args: + d (dict): the dictionary to get the element from + path (list): the path to the element which is delimited by "." + """ + path = path.split(".") + for k in path: + if k not in d: + d[k] = {} + d = d[k] + return d + + +def convert_checkpoint_from_megatron_to_transformers(args): + """ + Convert NVIDIA Megatron-LM checkpoint to HuggingFace Transformers checkpoint. This handles Megatron checkpoints + with different tensor parallelism and pipeline parallelism sizes. It saves the converted checkpoint into shards + using HuggingFace Transformers checkpoint sharding functionality. This greatly extends the functionality of + `convert_megatron_gpt2_checkpoint.py` + + Args: + args (argparse.Namespace): the arguments to the script + """ + # Load Megatron-LM checkpoint arguments from the state dict + sub_dirs = os.listdir(args.load_path) + possible_sub_dirs = ["mp_rank_00", "mp_rank_00_000"] + for sub_dir in possible_sub_dirs: + if sub_dir in sub_dirs: + rank0_checkpoint_name = os.listdir(os.path.join(args.load_path, sub_dir))[0] + rank0_checkpoint_path = os.path.join(args.load_path, sub_dir, rank0_checkpoint_name) + break + print(f"Loading Megatron-LM checkpoint arguments from: {rank0_checkpoint_path}") + state_dict = torch.load(rank0_checkpoint_path, map_location="cpu") + megatron_args = state_dict.get("args", None) + if megatron_args is None: + raise ValueError( + "Megatron-LM checkpoint does not contain arguments. This utility only supports Megatron-LM checkpoints" + " containing all the megatron arguments. This is because it loads all config related to model" + " architecture, the tensor and pipeline model parallel size from the checkpoint insead of user having to" + " manually specify all the details. Please save Megatron-LM checkpoint along with all the megatron" + " arguments to use this utility." + ) + + # Create Transformers GPT2 config from Megatron-LM arguments + if megatron_args is not None: + # dawei: use swish as activation function + if megatron_args.swiglu: + activation_function = "silu" + elif megatron_args.bias_gelu_fusion: + activation_function = "gelu_fast" + elif megatron_args.openai_gelu: + activation_function = "gelu_new" + else: + activation_function = "gelu" + else: + # in the very early days this used to be "gelu_new" + activation_function = "gelu_new" + vocab_size = ( + megatron_args.padded_vocab_size + if getattr(megatron_args, "orig_vocab_size", None) is None + else megatron_args.orig_vocab_size + ) + print("vocab size:", vocab_size) + + config = MegatronLlamaConfig( + # dawei: from megatron-lm + vocab_size=vocab_size, + hidden_size=megatron_args.hidden_size, + intermediate_size=megatron_args.ffn_hidden_size, # 10880 + num_hidden_layers=megatron_args.num_layers, + num_attention_heads=megatron_args.num_attention_heads, + hidden_act=activation_function, + max_position_embeddings=megatron_args.max_position_embeddings, + rms_norm_eps=megatron_args.layernorm_epsilon, + + # dawei: from official config of llama + max_sequence_length=2048, + initializer_range=0.02, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + architectures=["MegatronLlamaForCausalLM"], + + use_bias=True, + ) + + output_state_dict = {} + + checkpoint_version = state_dict.get("checkpoint_version", 0.0) + tp_size = megatron_args.tensor_model_parallel_size + pp_size = megatron_args.pipeline_model_parallel_size + dtype = torch.float32 + # The regex to extract layer names. + layer_re = re.compile(r"layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") + + # Convert. + print("Converting") + + # Embeddings + print("Converting embeddings") + tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, 0) + + # Convert and store the position embeddings. + position_embeddings = get_element_from_dict_by_path( + tp_state_dicts[0], "model.language_model.embedding.position_embeddings.weight" + ) + output_state_dict["model.embed_position.weight"] = position_embeddings.to(dtype) + + # Convert and store the word embeddings. + word_embeddings = torch.cat( + [ + get_element_from_dict_by_path( + tp_state_dicts[tp_rank], "model.language_model.embedding.word_embeddings.weight" + ) + for tp_rank in range(tp_size) + ], + dim=0, + ) + word_embeddings = word_embeddings[:vocab_size].to(dtype) + output_state_dict["model.embed_tokens.weight"] = word_embeddings + + # Transformer Layers + print("Converting transformer layers") + # The number of heads. + heads = config.num_attention_heads + # The hidden_size per head. + hidden_size_per_head = config.hidden_size // config.num_attention_heads + n_positions = config.max_position_embeddings + num_layers = config.num_hidden_layers // pp_size + + for pp_rank in range(pp_size): + if pp_size > 0: + print(f"Converting pipeline parallel rank {pp_rank}") + tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, pp_rank) + + # The transformer. + path = ( + "model.language_model.transformer" + if "transformer" in get_element_from_dict_by_path(tp_state_dicts[0], "model.language_model").keys() + else "model.language_model.encoder" + ) + # Extract the layers. + for key, val in get_element_from_dict_by_path(tp_state_dicts[0], path).items(): + # Match the name. + m = layer_re.match(key) + # Stop if that's not a layer + if m is None: + break + + # The index of the layer. + layer_idx = int(m.group(1)) + pp_rank * num_layers + # The name of the operation. + # dawei: input_layernorm, self_attention, mlp, post_attention_layernorm + op_name = m.group(2) + # Is it a weight or a bias? + weight_or_bias = m.group(3) + + # The name of the layer. + layer_name = f"model.layers.{layer_idx}" + + if op_name + "." + weight_or_bias not in tensor_parallel_params: + # dawei: input_layernorm.weight, input_layernorm.bias, self_attention.dense.bias, + # dawei: self_attention_layernorm.weight, self_attention_layernorm.bias, mlp.dense_4h_to_h.bias + # dawei: post_attention_layernorm.weight, post_attention_layernorm.bias + params = val.to(dtype) + else: + # dawei: self_attention.query_key_value.weight, self_attention_query_value.bias, self_attention.dense.weight, + # mlp.dense_h_to_4h.weight, mlp.dense_h_to_4h.bias, + # mlp.dense_4h_to_h.weight + dim = 1 if op_name in ["self_attention.dense", "mlp.dense_4h_to_h", "attention.dense"] else 0 + # dawei: maybe only stored in the first chunk + + # dawei: fix bug in swiglu and dense_h_to_4h.weight + + if op_name == "mlp.dense_h_to_4h" and weight_or_bias == "weight": + params_list = [val] + [ + get_element_from_dict_by_path(tp_state_dicts[tp_rank], f"{path}")[key] + for tp_rank in range(1, tp_size) + ] + ws, vs = list(), list() + for p in params_list: + w, v = torch.chunk(p, 2, dim=0) + ws.append(w) + vs.append(v) + params = torch.cat(ws + vs, dim=dim).to(dtype) + + else: + params = torch.cat( + [val] + + [ + get_element_from_dict_by_path(tp_state_dicts[tp_rank], f"{path}")[key] + for tp_rank in range(1, tp_size) + ], + dim=dim, + ).to(dtype) + + # For layernorm(s), simply store the layer norm. + # dawei: ignore the bias for layernorm + if op_name.endswith("layernorm"): + # dawei: input_layernorm & post_attention_layernorm + if weight_or_bias == "weight": + # dawei: skip bias + ln_name = "input_layernorm" if op_name.startswith("input") else "post_attention_layernorm" + output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = params + + # Transpose the QKV matrix. + elif ( + op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value" + ) and weight_or_bias == "weight": + # dawei: (gpt2) self_attention.query_key_value.weight + out_val = megatron_to_transformers_fix_query_key_value_ordering( + params, + checkpoint_version, + 3, + heads, + hidden_size_per_head, + ) + # Megatron stores (3*D) x D but transformers-GPT2 expects D x 3*D. + + # dawei: (3*D) x D + out_val = out_val.contiguous() + + # dawei: split into 3 weight + # (3*D) x D ==> D x D, still [out_dim, in_dim] + q, k, v = torch.chunk(out_val, 3, dim=0) + # Store. + output_state_dict[layer_name + ".self_attn.q_proj.weight"] = q + output_state_dict[layer_name + ".self_attn.k_proj.weight"] = k + output_state_dict[layer_name + ".self_attn.v_proj.weight"] = v + + # Transpose the bias. + elif ( + op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value" + ) and weight_or_bias == "bias": + # dawei: (gpt2) self_attention.query_key_value.bias + out_val = megatron_to_transformers_fix_query_key_value_ordering( + params, checkpoint_version, 3, heads, hidden_size_per_head + ) + # dawei: split in to 3 bias + q_b, k_b, v_b = torch.chunk(out_val, 3, dim=0) + + # Store. No change of shape. + output_state_dict[layer_name + ".self_attn.q_proj.bias"] = q_b + output_state_dict[layer_name + ".self_attn.k_proj.bias"] = k_b + output_state_dict[layer_name + ".self_attn.v_proj.bias"] = v_b + + elif ( + op_name == "mlp.dense_h_to_4h" and weight_or_bias == "weight" + ): + # dawei: mlp.dense_h_to_4h.weight + out_name = megatron_to_transformers[op_name] + gate, up = torch.chunk(params, 2, dim=0) + output_state_dict[layer_name + out_name.format("gate") + "weight"] = gate + output_state_dict[layer_name + out_name.format("up") + "weight"] = up + + # Transpose the weights. + elif weight_or_bias == "weight": + # dawei: self_attention.dense.weight, mlp.dense_4h_to_h.weight + out_name = megatron_to_transformers[op_name] + output_state_dict[layer_name + out_name + "weight"] = params + + elif ( + op_name == "mlp.dense_h_to_4h" and weight_or_bias == "bias" + ): + # dawei: mlp.dense_h_to_4h.bias + out_name = megatron_to_transformers[op_name] + gate_b, up_b = torch.chunk(params, 2, dim=0) + output_state_dict[layer_name + out_name.format("gate") + "bias"] = gate_b + output_state_dict[layer_name + out_name.format("up") + "bias"] = up_b + + # Copy the bias. + elif weight_or_bias == "bias": + # dawei: (gpt2) self_attention.query_key_value.bias + out_name = megatron_to_transformers[op_name] + output_state_dict[layer_name + out_name + "bias"] = params + + if config.num_hidden_layers != (layer_idx + 1): + raise ValueError(f"Expected {config.num_hidden_layers} layers but found {layer_idx + 1}") + + # The final layernorm. + print("Converting final layernorm") + params = get_element_from_dict_by_path(tp_state_dicts[0], str(path)) + output_state_dict["model.norm.weight"] = params["final_layernorm.weight"].to(dtype) + + # For LM head, transformers' wants the matrix to weight embeddings. + print("Converting LM head") + output_state_dict["lm_head.weight"] = word_embeddings.to(dtype) + + # It should be done! + print("Conversion from Megatron-LM to Transformers is done!") + + # Print the structure of converted state dict. + if args.print_checkpoint_structure: + recursive_print(None, output_state_dict) + + # Add tokenizer class info to config + # see https://github.com/huggingface/transformers/issues/13906) + + print("Tokenizer_name: ", args.tokenizer_name) + if args.tokenizer_name is None: + tokenizer_name = "gpt2" + else: + tokenizer_name = args.tokenizer_name + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False) + tokenizer_class = type(tokenizer).__name__ + config.tokenizer_class = tokenizer_class + + # Store the config to file. + print("Saving config") + config.save_pretrained(args.save_path) + + # Save tokenizer based on args + if args.tokenizer_name is not None: + print(f"Adding {tokenizer_class} tokenizer files") + tokenizer.save_pretrained(args.save_path) + + # Store the state_dict to file. + max_shard_size = int(args.max_shard_size) if args.max_shard_size.isdigit() else args.max_shard_size + shards, index = shard_checkpoint(output_state_dict, max_shard_size=max_shard_size) + + # Save the model + for shard_file, shard in shards.items(): + torch.save(shard, os.path.join(args.save_path, shard_file)) + + if index is None: + print(f"Model weights saved in {os.path.join(args.save_path, WEIGHTS_NAME)}") + else: + save_index_file = os.path.join(args.save_path, WEIGHTS_INDEX_NAME) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + print( + f"The model is bigger than the maximum size per checkpoint ({args.max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + +def main(): + parser = argparse.ArgumentParser() + parser = add_checkpointing_args(parser) + parser = add_transformers_checkpoint_args(parser) + args = parser.parse_args() + + convert_checkpoint_from_megatron_to_transformers(args) + + +if __name__ == "__main__": + main() diff --git a/tools/converter/modeling_megatron_llama.py b/tools/converter/modeling_megatron_llama.py new file mode 100644 index 000000000..f35e6665c --- /dev/null +++ b/tools/converter/modeling_megatron_llama.py @@ -0,0 +1,951 @@ +# Some codes are adapted from https://pypi.org/project/transformers + +# Data-Juicer adopts Apache 2.0 license, the original license of this file +# is as follows +# -------------------------------------------------------- + +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from transformers.configuration_utils import PretrainedConfig + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MegatronLlamaConfig" + + +LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +class MegatronLlamaConfig(PretrainedConfig): + model_type = "megatron-llama" + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + use_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.use_bias = use_bias + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +############################################################################################ + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class LlamaRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + use_bias: bool + ): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=use_bias) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=use_bias) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=use_bias) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MegatronLlamaConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.max_position_embeddings + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.use_bias) + self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: MegatronLlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LlamaAttention(config=config) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + use_bias=config.use_bias, + ) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MegatronLlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = MegatronLlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LlamaModel): + module.gradient_checkpointing = value + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: MegatronLlamaConfig + """ + + def __init__(self, config: MegatronLlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # TODO: position embeddings, should be removed if rotary position embedding + self.embed_position = nn.Embedding(config.max_sequence_length, config.hidden_size) + + # word embeddings + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # TODO: remove in the future + position_embeds = self.embed_position(position_ids) + inputs_embeds += position_embeds + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class MegatronLlamaForCausalLM(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + # used for model + config.use_bias = True + + self.model = LlamaModel(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = MegatronLlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/tools/evaluator/.DS_Store b/tools/evaluator/.DS_Store new file mode 100644 index 000000000..da8df02d7 Binary files /dev/null and b/tools/evaluator/.DS_Store differ diff --git a/tools/evaluator/README.md b/tools/evaluator/README.md new file mode 100644 index 000000000..2bacf9d3d --- /dev/null +++ b/tools/evaluator/README.md @@ -0,0 +1,86 @@ +# Auto Evaluation Toolkit + +Automatically evaluate your model and monitor changes of metrics during the training process. + +## Preparation + +1. Multiple GPU machines (at least 2, one for evaluation, others for training). + +2. Mount a shared file system (e.g., NAS) to the same path (e.g., `/mnt/shared`) on the above machines. + +3. Install data-juicer in the shared file system (e.g., `/mnt/shared/code/data-juicer`). + +4. Install third-party dependencies (Megatron-LM and HELM) accoroding to `thirdparty/README.md` on each machine. + +5. Prepare your dataset and tokenizer, preprocess your dataset with Megatron-LM into mmap format (see README of Megatron-LM for more details) in the shared file system (e.g., `/mnt/shared/dataset`). + +6. Run Megatron-LM on training machines and save the checkpoint in the shared file system (e.g., `/mnt/shared/checkpoints`). + +## Usage + +Use `evaluator.py` to automatically evaluate your models with HELM and OpenAI API. + +```shell +python tools/evaluator.py \ + --config \ + --begin-iteration \ + [--end-iteration ] \ + [--iteration-interval ] \ + [--check-interval ] \ + [--model-type ] \ + [--eval-type ] \ +``` + +- `config`: a yaml file containing various settings required to run the evaluation (see [Configuration](#configuration) for details) +- `begin_iteration`: iteration of the first checkpoint to be evaluated +- `end_iteration`: iteration of the last checkpoint to be evaluated. If not set, continuously monitor the training process and evaluate the generated checkpoints. +- `iteration_interval`: iteration interval between two checkpoints, default is 1000 iterations +- `check_interval`: time interval between checks, default is 30 minutes +- `model_type`: type of your model, support `megatron` and `huggingface` for now + - `megatron`: evaluate Megatron-LM checkpoints (default) + - `huggingface`: evaluate HuggingFace model, only support gpt eval type +- `eval-type`: type of the evaluation to run, support `helm` and `gpt` for now + - `helm`: evaluate your model with HELM (default), you can change the benchmarks to run by modifying the helm specific template file + - `gpt`: evaluate your model with OpenAI API, more details can be found in `gpt_eval/README.md` + +> e.g., +> `python evaluator.py --config --begin-iteration 2000 --iteration-interval 1000 --check-interval 10` +> will use HELM to evaluate a Megatron-LM checkpoint every 1000 iterations starting from iteration 2000, and check whether there is a new checkpoint meets the condition every 10 minutes + +After running the `evaluator.py`, you can use `recorder/wandb_writer.py` to visualize the evaluation results, more details can be found in `recorder/README.md`. + +## Configuration + +The format of `config_file` is as follows: + +```yaml +auto_eval: + project_name: # your project name + model_name: # your model name + cache_dir: # path of cache dir + megatron: + process_num: # number of process to run megatron + megatron_home: # root dir of Megatron-LM + checkpoint_path: # path of checkpoint dir + tokenizer_type: # support gpt2 or sentencepiece for now + vocab_path: # configuration for gpt2 tokenizer type, path to vocab file + merge_path: # configuration for gpt2 tokenizer type, path to merge file + tokenizer_path: # configuration for sentencepiece tokenizer type, path to model file + max_tokens: # max tokens to generate in inference + token_per_iteration: # billions tokens per iteraion + helm: + helm_spec_template_path: # path of helm spec template file, default is tools/evaluator/config/helm_spec_template.conf + helm_output_path: # path of helm output dir + helm_env_name: # helm conda env name + gpt_evaluation: + # openai config + openai_api_key: # your api key + openai_organization: # your organization + # files config + question_file: # default is tools/evaluator/gpt_eval/config/question.jsonl + baseline_file: # default is tools/evaluator/gpt_eval/answer/openai/gpt-3.5-turbo.jsonl + prompt_file: # default is tools/evaluator/gpt_eval/config/prompt.jsonl + reviewer_file: # default is tools/evaluator/gpt_eval/config/reviewer.jsonl + answer_file: # path to generated answer file + result_file: # path to generated review file +``` diff --git a/tools/evaluator/README_ZH.md b/tools/evaluator/README_ZH.md new file mode 100644 index 000000000..f86090416 --- /dev/null +++ b/tools/evaluator/README_ZH.md @@ -0,0 +1,87 @@ +# Auto Evaluation Toolkit + +在训练过程中自动评测您的模型并持续监控指标的变化。 + +## 准备工作 + +1. 多台GPU机器(至少2台,一台用于运行评测,其他机器用于训练模型)。 + +2. 将共享文件系统(例如NAS)挂载到上述机器上的相同路径(例如`/mnt/shared`)。 + +3. 在共享文件系统中安装 data-juicer(例如 `/mnt/shared/code/data-juicer`)。 + +4. 根据 `thirdparty/README_ZH.md` 在每台机器上安装第三方依赖项(Megatron-LM 和 HELM)。 + +5. 准备数据集和 tokenizer,在共享文件系统(例如`/mnt/shared/dataset`)中使用 Megatron-LM 提供的预处理工具将数据集预处理为 mmap 格式(更多详细信息,请参阅 Megatron-LM 的 README)。 + +6. 在训练机器上运行 Megatron-LM 并将检查点保存在共享文件系统中(例如`/mnt/shared/checkpoints`)。 + +## 用法 + +通过 `evaluator.py` 来使用 HELM 或 OpenAI API 自动评估您的模型。 + +```shell +python tools/evaluator.py \ + --config \ + --begin-iteration \ + [--end-iteration ] \ + [--iteration-interval ] \ + [--check-interval ] \ + [--model-type ] \ + [--eval-type ] \ +``` + +- `config`: 包含运行评估所需的各种设置的 yaml 文件(详细信息请参阅[配置](#配置)) +- `begin_iteration`: 首个需要评估的检查点的 iteration +- `end_iteration`: 最后一个需要评估的检查点的 iteration。如果没有设置,该进程将持续监控训练过程中产生的检查点。 +- `iteration_interval`: 两次评测之间的 iteration 间隔,默认为 1000。 +- `check_interval`: 两次检查是否有满足条件检查点的时间间隔,默认为 30 分钟。 +- `model_type`: 被评测的模型类型,当前支持 `megatron` 和 `huggingface` + - `megatron`: 即 Megatron-LM 检查点,默认为此项 + - `huggingface`: 即 HuggingFace 模型 +- `eval-type`: type of the evaluation to run, support `helm` and `gpt` for now + - `helm`: 使用 HELM 评测,默认为此项,当前仅支持评测 Megatron-LM 模型。 + - `gpt`: 使用 OpenAI API 评测,更多细节请见 `gpt_eval/README_ZH.md` + +> 例: +> `python evaluator.py --config --begin-iteration 2000 --iteration-interval 1000 --check-interval 10` +> 将会使用 HELM 从 Megatron-LM 训练到 2000 iteration 开始每隔 1000 iterations 评测一个检查点,并会每隔 30 分钟检测一次是否有新的检查点生成 + +在运行 `evaluator.py` 之后, 可以使用 `recorder/wandb_writer.py` 将评测结果记录到 wandb 并可视化展示,更多细节请参考 `recorder/README_ZH.md`。 + +## 配置 + +`config_file` 文件格式如下: + +```yaml +auto_eval: + project_name: # 项目名称 + model_name: # 模型名称 + cache_dir: # 缓存目录路径 + megatron: + process_num: # 运行 megatron-lm 所需的进程数 + megatron_home: # Megatron-LM 代码根目录 + checkpoint_path: # 检查点保存根目录 + tokenizer_type: # 目前支持 gpt2 或 sentencepiece + vocab_path: # 针对 gpt2 tokenizer 的配置项, vocab 文件的路径 + merge_path: # 针对 gpt2 tokenizer 的配置项, merge 文件的路径 + tokenizer_path: # 针对 sentencepiece tokenizer 的配置项, model 文件的路径 + max_tokens: # 在执行生成任务时最大生成的 token 数量 + token_per_iteration: # 训练时每次迭代所使用的 token 数量(单位:B) + helm: + helm_spec_template_path: # helm 评测模版文件, 默认为 tools/evaluator/config/helm_spec_template.conf,可通过修改此文件来调整运行的评测 + helm_output_path: # helm 输出目录路径 + helm_env_name: # helm 的 conda 环境名 + gpt_evaluation: + # openai config + openai_api_key: + openai_organization: + # files config + question_file: # 默认为 tools/evaluator/gpt_eval/config/question.jsonl + baseline_file: # 默认为 tools/evaluator/gpt_eval/answer/openai/gpt-3.5-turbo.jsonl + prompt_file: # 默认为 tools/evaluator/gpt_eval/config/prompt.jsonl + reviewer_file: # 默认为 tools/evaluator/gpt_eval/config/reviewer.jsonl + answer_file: # 生成的回答文件的路径 + result_file: # 生成的评价文件的路径 + +``` diff --git a/tools/evaluator/config/evaluator_example.yaml b/tools/evaluator/config/evaluator_example.yaml new file mode 100644 index 000000000..ee69c0f3f --- /dev/null +++ b/tools/evaluator/config/evaluator_example.yaml @@ -0,0 +1,33 @@ +auto_eval: + project_name: + model_name: + cache_dir: + megatron: + process_num: + megatron_home: + checkpoint_path: + tokenizer_type: + vocab_path: + merge_path: + max_tokens: + token_per_iteration: + # tokenizer_path: + # log_path: + helm: + helm_spec_template_path: + helm_output_path: + helm_env_name: + gpt_evaluation: + # openai config + openai_api_key: + openai_organization: + # files config + question_file: ./tools/eval/gpt_eval/config/question.jsonl + answer_file: + baseline_file: ./tools/eval/gpt_eval/answer/openai/chatgpt.jsonl + prompt_file: ./tools/eval/gpt_eval/config/prompt.jsonl + reviewer_file: ./tools/eval/gpt_eval/config/reviewer.jsonl + result_file: + wandb: + project: + base_url: \ No newline at end of file diff --git a/tools/evaluator/config/helm_spec_template.conf b/tools/evaluator/config/helm_spec_template.conf new file mode 100644 index 000000000..d21560a75 --- /dev/null +++ b/tools/evaluator/config/helm_spec_template.conf @@ -0,0 +1,107 @@ +# Don't modify this file !!! + +entries: [ + {description: "mmlu:model=,subject=abstract_algebra,data_augmentation=canonical", priority: 2} + {description: "mmlu:model=,subject=anatomy,data_augmentation=canonical", priority: 3} + {description: "mmlu:model=,subject=college_chemistry,data_augmentation=canonical", priority: 2} + {description: "mmlu:model=,subject=computer_security,data_augmentation=canonical", priority: 2} + {description: "mmlu:model=,subject=econometrics,data_augmentation=canonical", priority: 2} + {description: "mmlu:model=,subject=global_facts,data_augmentation=canonical", priority: 3} + {description: "mmlu:model=,subject=jurisprudence,data_augmentation=canonical", priority: 3} + {description: "mmlu:model=,subject=philosophy,data_augmentation=canonical", priority: 3} + {description: "mmlu:model=,subject=professional_medicine,data_augmentation=canonical", priority: 3} + {description: "mmlu:model=,subject=us_foreign_policy,data_augmentation=canonical", priority: 2} + {description: "mmlu:model=,subject=astronomy,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=business_ethics,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=clinical_knowledge,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=college_biology,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=college_computer_science,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=college_mathematics,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=college_medicine,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=college_physics,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=conceptual_physics,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=electrical_engineering,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=elementary_mathematics,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=formal_logic,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=high_school_biology,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=high_school_chemistry,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=high_school_computer_science,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=high_school_european_history,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=high_school_geography,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=high_school_government_and_politics,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=high_school_macroeconomics,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=high_school_mathematics,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=high_school_microeconomics,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=high_school_physics,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=high_school_psychology,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=high_school_statistics,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=high_school_us_history,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=high_school_world_history,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=human_aging,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=human_sexuality,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=international_law,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=logical_fallacies,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=machine_learning,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=management,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=marketing,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=medical_genetics,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=miscellaneous,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=moral_disputes,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=moral_scenarios,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=nutrition,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=prehistory,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=professional_accounting,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=professional_law,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=professional_psychology,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=public_relations,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=security_studies,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=sociology,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=virology,data_augmentation=canonical", priority: 4} + {description: "mmlu:model=,subject=world_religions,data_augmentation=canonical", priority: 4} + + {description: "imdb:model=,data_augmentation=canonical", priority: 1} + + {description: "raft:subset=ade_corpus_v2,model=,data_augmentation=canonical", priority: 2} + {description: "raft:subset=banking_77,model=,data_augmentation=canonical", priority: 2} + {description: "raft:subset=neurips_impact_statement_risks,model=,data_augmentation=canonical", priority: 2} + {description: "raft:subset=one_stop_english,model=,data_augmentation=canonical", priority: 2} + {description: "raft:subset=overruling,model=,data_augmentation=canonical", priority: 2} + {description: "raft:subset=semiconductor_org_types,model=,data_augmentation=canonical", priority: 2} + {description: "raft:subset=tweet_eval_hate,model=,data_augmentation=canonical", priority: 2} + {description: "raft:subset=twitter_complaints,model=,data_augmentation=canonical", priority: 2} + {description: "raft:subset=systematic_review_inclusion,model=,data_augmentation=canonical", priority: 2} + {description: "raft:subset=tai_safety_research,model=,data_augmentation=canonical", priority: 2} + {description: "raft:subset=terms_of_service,model=,data_augmentation=canonical", priority: 2} + + {description: "summarization_cnndm:model=,temperature=0.3,device=cpu", priority: 1} + + {description: "truthful_qa:model=,task=mc_single,data_augmentation=canonical", priority: 1} + + {description: "boolq:model=,data_augmentation=canonical", priority: 1} + + {description: "narrative_qa:model=,data_augmentation=canonical", priority: 2} + + {description: "natural_qa:model=,mode=openbook_longans,data_augmentation=canonical", priority: 1} + + {description: "natural_qa:model=,mode=closedbook,data_augmentation=canonical", priority: 1} + + {description: "quac:model=,data_augmentation=canonical", priority: 1} + + {description: "commonsense:model=,dataset=hellaswag,method=multiple_choice_separate_original,data_augmentation=canonical", priority: 1} + {description: "commonsense:model=,dataset=openbookqa,method=multiple_choice_separate_calibrated,data_augmentation=canonical", priority: 2} + + {description: "msmarco:model=,data_augmentation=canonical,track=regular,valid_topk=30", priority: 2} + {description: "msmarco:model=,data_augmentation=canonical,track=trec,valid_topk=30", priority: 1} + + {description: "summarization_xsum_sampled:model=,temperature=0.3,device=cpu", priority: 1} + + {description: "civil_comments:model=,demographic=all,data_augmentation=canonical", priority: 1} + {description: "civil_comments:model=,demographic=male,data_augmentation=canonical", priority: 2} + {description: "civil_comments:model=,demographic=female,data_augmentation=canonical", priority: 2} + {description: "civil_comments:model=,demographic=LGBTQ,data_augmentation=canonical", priority: 2} + {description: "civil_comments:model=,demographic=christian,data_augmentation=canonical", priority: 2} + {description: "civil_comments:model=,demographic=muslim,data_augmentation=canonical", priority: 2} + {description: "civil_comments:model=,demographic=other_religions,data_augmentation=canonical", priority: 2} + {description: "civil_comments:model=,demographic=black,data_augmentation=canonical", priority: 2} + {description: "civil_comments:model=,demographic=white,data_augmentation=canonical", priority: 2} +] \ No newline at end of file diff --git a/tools/evaluator/evaluator.py b/tools/evaluator/evaluator.py new file mode 100644 index 000000000..5efc443be --- /dev/null +++ b/tools/evaluator/evaluator.py @@ -0,0 +1,261 @@ +import argparse +import yaml +import os +import subprocess +import time +import shutil + +from gpt_eval.gpt_evaluator import GPTEvaluator +from recorder.wandb_writer import HelmWriter + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, required=True) + parser.add_argument( + '--model-type', choices=['megatron', 'huggingface'], default='megatron') + parser.add_argument('--eval-type', choices=['helm', 'gpt'], default='helm') + parser.add_argument('--iteration-interval', type=int, default=1000) + parser.add_argument('--begin-iteration', type=int, default=None) + parser.add_argument('--end-iteration', type=int, default=None) + parser.add_argument('--check-iterval', type=int, default=30) + return parser.parse_args() + + +def check_args(args): + if args.begin_iteration == None: + print( + f"--begin-iteration is not provided, use the value of --iteration-interval ({args.iteration_interval}).") + args.begin_iteration = args.iteration_interval + if args.end_iteration == None: + print(f"--end-iteration is not provided, evaluator will monitor the traning process continuously.") + args.end_iteration = float('inf') + + +class Evaluator(): + + def __init__(self, args): + with open(args.config, 'r', encoding='utf-8') as f: + self.config = yaml.safe_load(f)['auto_eval'] + self.eval_type = args.eval_type + self.iteration_interval = args.iteration_interval + self.begin_iteration = args.begin_iteration + self.end_iteration = args.end_iteration + self.check_iterval = args.check_iterval + self.load_config() + + def load_config(self): + self.project_name = self.config['project_name'] + self.model_name = self.config['model_name'] + self.full_name = f'{self.project_name}-{self.model_name}' + # load cache dir + self.cur_dir = os.path.abspath(os.getcwd()) + self.cache_dir = self.config['cache_dir'] if 'cache_dir' in self.config else os.path.join( + self.cur_dir, 'cache') + if not os.path.exists(self.cache_dir): + os.makedirs(self.cache_dir) + # load megatron config + if 'megatron' in self.config: + os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' + os.environ['OMP_NUM_THREADS'] = '4' + self.megatron_process_num = self.config['megatron']['process_num'] + self.megatron_checkpoint_path = self.config['megatron']['checkpoint_path'] + # for different tokenizer + if self.config['megatron']['tokenizer_type'] == 'sentencepiece': + self.tokenizer_type = 'sentencepiece' + self.vocab_path = None + self.merge_path = None + self.tokenizer_path = self.config['megatron']['tokenizer_path'] + elif self.config['megatron']['tokenizer_type'] == 'gpt2': + self.tokenizer_type = 'gpt2' + self.vocab_path = self.config['megatron']['vocab_path'] + self.merge_path = self.config['megatron']['merge_path'] + self.tokenizer_path = None + else: + raise NotImplementedError( + f"tokenizer type: {self.config['megatron']['tokenizer_type']} is not supported") + self.megatron_log_path = os.path.join( + self.cache_dir, 'megatron.log') + if 'log_path' in self.config['megatron']: + self.megatron_log_path = self.config['megatron']['log_path'] + self.megatron_server_port = 5000 + if 'port' in self.config['megatron']: + self.megatron_server_port = self.config['megatron']['port'] + self.megatron_home = self.cur_dir + if 'megatron_home' in self.config['megatron']: + self.megatron_home = self.config['megatron']['megatron_home'] + self.max_tokens = 512 + if 'max_tokens' in self.config['megatron']: + self.max_tokens = self.config['megatron']['max_tokens'] + self.megatron_token_per_iteration = 0 + if 'token_per_iteration' in self.config['megatron']: + self.megatron_token_per_iteration = self.config['megatron']['token_per_iteration'] + # load helm config + if 'helm' in self.config: + self.helm_spec_template_path = self.config['helm']['helm_spec_template_path'] + self.helm_output_path = self.config['helm']['helm_output_path'] + self.helm_spec_path = os.path.join( + self.cache_dir, 'helm_spec.conf') + self.helm_cache_path = os.path.join(self.cache_dir, 'helm_cache') + self.helm_suite_name = self.full_name + self.helm_conda_env = self.config['helm']['helm_env_name'] if 'helm_env_name' in self.config['helm'] else 'crfm-helm' + self.helm_eval_instances = self.config['helm'][ + 'eval_instances'] if 'eval_instances' in self.config['helm'] else 100 + self.helm_benchmarks = self.config['helm']['benchmarks'] if 'benchmarks' in self.config['helm'] else None + self.helm_mymodel_config = os.path.join( + self.cache_dir, 'helm_config.yaml') + with open(self.helm_mymodel_config, 'w', encoding='utf-8') as f: + mymodel_config = { + 'port': self.megatron_server_port, + 'tokenizer': { + 'type': self.tokenizer_type, + 'vocab_path': self.vocab_path, + 'merge_path': self.merge_path, + 'tokenizer_path': self.tokenizer_path + } + } + yaml.dump(mymodel_config, f) + if self.eval_type == 'gpt': + self.gpt_question_file = self.config['gpt_evaluation']['question_file'] + self.gpt_answer_file = self.config['gpt_evaluation']['answer_file'] + if 'wandb' in self.config: + self.wandb_base_url = self.config['wandb']['base_url'] if 'base_url' in self.config['wandb'] else None + self.wandb_project = self.config['wandb']['project'] if 'project' in self.config['wandb'] else self.project_name + + def _set_megatron_tokenizer(self, args): + if self.tokenizer_type == 'gpt2': + args.append('GPT2BPETokenizer') + args.append('--vocab-file') + args.append(self.vocab_path) + args.append('--merge-file') + args.append(self.merge_path) + elif self.tokenizer_type == 'sentencepiece': + args.append('SentencePieceTokenizer') + args.append('--tokenizer-model') + args.append(self.tokenizer_path) + + def run_megatron_server(self, iteration): + while not self.megatron_checkpoint_exists(iteration): + print(f'Wait for megatron checkpoint {iteration}') + time.sleep(self.check_iterval * 60) + # setup megatron server + print( + f'Start megatron text generation server for checkpoint iter_{iteration}') + args = ['torchrun', '--master_addr', '127.0.0.1', '--master_port', '5950', '--nproc_per_node', str(self.megatron_process_num), '--nnodes', '1', '--node_rank', '0', os.path.join(self.megatron_home, 'tools/run_text_generation_server.py'), '--port', + str(self.megatron_server_port), '--use-checkpoint-args', '--load', self.megatron_checkpoint_path, + '--load-iteration', str(iteration), '--tokenizer-type'] + self._set_megatron_tokenizer(args) + logfile = open(self.megatron_log_path, 'w') + os.chdir(self.megatron_home) + process = subprocess.Popen(args, stdout=logfile, stderr=logfile) + os.chdir(self.cur_dir) + return { + 'process': process, + 'logfile': logfile + } + + def stop_megatron_server(self, process, logfile): + process.terminate() + logfile.close() + print(f'Stop megatron text generation server') + + def run_megatron_inference(self, iteration): + while not self.megatron_checkpoint_exists(iteration): + time.sleep(self.check_iterval * 60) + print(f'Wait for megatron checkpoint {iteration}') + print(f'Start megatron inference for checkpoint iter_{iteration}') + args = ['torchrun', '--master_addr', '127.0.0.1', '--master_port', '5950', '--nproc_per_node', '1', '--nnodes', + str(self.megatron_process_num), '--node_rank', '0', 'tools/inference.py', '--use-checkpoint-args', + '--formatter', 'gpt_eval', '--tokens-to-generate', str( + self.max_tokens), '--input', self.gpt_question_file, + '--output', self.gpt_answer_file, '--load', self.megatron_checkpoint_path, '--load-iteration', + str(iteration), '--model-name', f'{self.full_name}/{iteration}', '--tokenizer-type'] + self._set_megatron_tokenizer(args) + logfile = open(self.megatron_log_path, 'w') + os.chdir(self.megatron_home) + subprocess.run(args) + os.chdir(self.cur_dir) + logfile.close() + return {} + + def megatron_checkpoint_exists(self, iteration): + with open(os.path.join(self.megatron_checkpoint_path, 'latest_checkpointed_iteration.txt'), 'r') as f: + latest_checkpoint_iter = int(f.readline()) + if iteration > latest_checkpoint_iter: + return False + checkpoint_path = os.path.join( + self.megatron_checkpoint_path, 'iter_{:07d}'.format(iteration)) + return os.path.exists(checkpoint_path) + + def replace_pattern(self, input_file, output_file, pattern, s): + with open(input_file, 'r', encoding='utf-8') as input, open(output_file, 'w', encoding='utf-8') as output: + lines = input.readlines() + for i in range(len(lines)): + lines[i] = lines[i].replace(pattern, s) + output.writelines(lines) + + def run_helm_eval(self, iteration): + print(f'Start helm evaluation for checkpoint iter_{iteration}') + if os.path.exists(self.helm_cache_path): + shutil.rmtree(self.helm_cache_path) + self.replace_pattern(self.helm_spec_template_path, self.helm_spec_path, + '', f'mymodel/{self.full_name}/{iteration}') + helm_run_args = ['conda', 'run', '-n', self.helm_conda_env, '--no-capture-output', 'helm-run', '-n', '4', '-m', str(self.helm_eval_instances), + '--conf-paths', self.helm_spec_path, '--my-config-path', self.helm_mymodel_config, + '--local-path', self.helm_cache_path, + '--suite', self.helm_suite_name, '-o', self.helm_output_path] + subprocess.check_call(helm_run_args) + print(f'run helm summarize for checkpoint iter_{iteration}') + helm_summarize_args = ['conda', 'run', '-n', self.helm_conda_env, '--no-capture-output', + 'helm-summarize', '--suite', self.helm_suite_name, '-o', self.helm_output_path] + subprocess.check_call(helm_summarize_args) + print(f'Finish helm evaluation for checkpoint iter_{iteration}') + + def run_gpt_eval(self, iteration): + GPTEvaluator(self.config['gpt_evaluation']).run() + + def write_wandb(self): + if self.eval_type == 'helm': + helm_config = { + 'model_name': self.full_name, + 'source': 'helm', + 'helm_output_dir': self.helm_output_path, + 'helm_suite_name': self.helm_suite_name, + 'token_per_iteration': self.megatron_token_per_iteration + } + if self.helm_benchmarks is not None: + helm_config['benchmarks'] = self.helm_benchmarks + HelmWriter(project_name=self.wandb_project, + base_url=self.wandb_base_url, helm_config=helm_config) + + def evaluate(self, start_gen_func, start_eval_func, stop_gen_func, stop_eval_func): + cur_iter = self.begin_iteration + while cur_iter <= self.end_iteration: + states = start_gen_func(cur_iter) + start_eval_func(cur_iter) + stop_eval_func() + stop_gen_func(**states) + cur_iter += self.iteration_interval + + def dummy_stop(self, args=None): + return + + def run(self): + if self.eval_type == 'helm': + start_gen_func = self.run_megatron_server + start_eval_func = self.run_helm_eval + stop_gen_func = self.stop_megatron_server + stop_eval_func = self.dummy_stop + elif self.eval_type == 'gpt': + start_gen_func = self.run_megatron_inference + start_eval_func = self.run_gpt_eval + stop_gen_func = self.dummy_stop + stop_eval_func = self.dummy_stop + self.evaluate(start_gen_func, start_eval_func, + stop_gen_func, stop_eval_func) + + +if __name__ == '__main__': + args = parse_args() + check_args(args) + Evaluator(args).run() diff --git a/tools/evaluator/gpt_eval/README.md b/tools/evaluator/gpt_eval/README.md new file mode 100644 index 000000000..5661b0dd3 --- /dev/null +++ b/tools/evaluator/gpt_eval/README.md @@ -0,0 +1,74 @@ +# GPT EVAL -- Evaluate your model with OpenAI API + +## Quick Start + +1. Prepare your model and the baseline model + - your model: Huggingface and Megatron-LM format models are supported, other models will be supported in future releases + - baseline model: Huggingface, Megatron-LM or OpenAI model + > Evaluating Megatron-LM models requires a customized Megatron-LM which is provided in `thirdparty` + +2. generate answers using `answer_generator.py` for both your model and the baseline model + 1. prepare the benchmark dataset: the toolkit has provided Vicuna Bench(`./config/question.jsonl`), and you can create custom dataset to generate answers. The custom datasets must be a single file in jsonl format, and each json object in it contains 3 attributes: + - question_id: int type + - text: the specific content of the question, string type + - category: the type of the question, string type + + 2. build the config file (`config.yaml`): the format of the file is as follows: + ```yaml + answer_generation: + model_name: + question_file: # path of the benchmark dataset file + answer_file: # path of the answer file generated by the model + batch_size: # batch size when generating answers + max_tokens: # maximum token size for each generated answer + temperature: + # Choose one of the following configurations according to your model type + # Config for huggingface + huggingface: + model_path: # path of your model + tokenizer_path: # path of your tokenizer + # Config for megatron-lm + megatron: + megatron_home: # root dir of Megatron-LM code + process_num: # number of processes to run megatron + checkpoint_path: # megatron checkpoint dir path + tokenizer_type: # only support 'gpt2' and 'sentencepiece' for now + vocab_path: # path to the vocab file for gpt2 tokenizer + merge_path: # path to the merge file for gpt2 tokenizer + tokenizer_path: # path to the tokenizer model for sentencepiece tokenizer + iteration: # iteration of the checkpoint to load + # Config for openai + openai: + openai_organization: + openai_api_key: + model: # the type of model,e.g., gpt-3.5-turbo + max_retry: # the maxium number of retries when api access fails + ``` + 3. run the script + ```shell + python answer_generator.py --config + ``` + +3. get OpenAI API evaluation result via `gpt_evaluator.py` + 1. prepare dependencies: make sure the following files are ready + - question_file: the benchmark dataset file in previous step + - answer_file: the answer file of your model in previous step + - baseline_file: the answer file of the baseline model in previous step + - prompt_file: a file contains multiple prompt templates, the toolkit has provided a sample file (`config/prompt.json`) + - reviewer_file: a file contains multiple reviewer templates (including the model type and other parameters used in the OpenAI api request),the toolkit has provided a sample file (`config/reviewer.json`) + 2. build the config file (`config.yaml`): the format of the file is as follows: + ```yaml + gpt_evaluation: + openai_organization: + openai_api_key: + question_file: + answer_file: + baseline_file: + prompt_file: + reviewer_file: + result_file: # path of the evaulation result + ``` + 3. run the script + ```shell + python gpt_evaluator.py --config + ``` \ No newline at end of file diff --git a/tools/evaluator/gpt_eval/README_ZH.md b/tools/evaluator/gpt_eval/README_ZH.md new file mode 100644 index 000000000..b2ebaedf4 --- /dev/null +++ b/tools/evaluator/gpt_eval/README_ZH.md @@ -0,0 +1,82 @@ +# GPT EVAL —— 使用 OpenAI API 评测大模型 + +## 快速上手 + +1. 准备待评测的模型以及对照模型 + - 待评测模型:当前支持 Huggingface 以及 Megatron-LM 格式,后续会陆续支持加载其他常见模型检查点格式 + - 对照模型:Huggingface, Megatron-LM 模型或 OpenAI 提供的模型 + > 评测 Megatron-LM 模型需要使用 `thirdparty` 文件夹提供的定制化 Megatron-LM + +2. 使用 `answer_generator.py` 在评测数据集上分别生成待评测模型及对照模型的回答 + 1. 准备数据集:工具包内已经提供了 Vicuna 评测数据集 (`./config/question.jsonl`),同时支持用户使用自定义数据集生成回答,自定义数据集要求为单个 jsonl 文件,其中每个 json 对象包含以下3个域: + - question_id: int 类型,用于标识该问题 + - text: string 类型,问题的具体内容 + - category: string 类型,该问题的类型 + + 2. 编写配置文件: 运行脚本需要的 yaml 文件格式如下 + + ```yaml + answer_generation: + model_name: + question_file: # 评测数据文件路径 + answer_file: # 模型生成回答文件路径 + batch_size: # 生成回答时的 batch size + max_tokens: # 生成回答的最大 token 数量 + temperature: + # 以下配置根据模型来源选择一种即可 + # huggingface 配置 + huggingface: + model_path: # 文件路径或 huggingface model path + tokenizer_path: # 文件路径或 huggingface model path + # megatron-lm 配置 + megatron: + megatron_home: # Megatron-LM 代码根目录 + process_num: # 运行 megatron-lm 所需的进程数 + checkpoint_path: # megatron checkpoint 文件夹路径 + tokenizer_type: # 目前仅支持 'gpt2' 和 'sentencepiece' + vocab_path: # gpt2 tokenizer 的 vocab 文件路径 + merge_path: # gpt2 tokenizer 的 merge 文件路径 + tokenizer_path: # sentencepiece tokenizer 的 model 文件路径 + iteration: # 待加载 checkpoint 的 iteration + # openai 配置 + openai: + openai_organization: + openai_api_key: + model: # 评测模型的类型,例如 gpt-3.5-turbo + max_retry: # api 访问失败时最大重试次数 + ``` + + 3. 运行脚本 + + ```shell + python answer_generator.py --config + ``` + + > 工具包内已经提供了 gpt-3.5-turbo 在 Vicuna 评测集上的回答文件 + +3. 通过 `gpt_evaluator.py` 调用 OpenAI API 获得评价结果 + 1. 准备评测依赖项:运行脚本前需准备如下文件 + - question_file: 即上一步中的评测数据文件 + - answer_file: 即上一步得到的待评测模型的回答文件 + - baseline_file: 即上一步得到的对照模型的回答文件 + - prompt_file: prompt 模板文件,工具包内已提供一份样本 (`config/prompt.json`) + - reviewer_file: reviewer 模板文件(包括评测时使用的模型类型和其他参数),工具包内已提供一份样本 (`config/reviewer.json`) + 2. 编写配置文件:运行脚本所需的 yaml 文件格式如下 + + ```yaml + gpt_evaluation: + openai_organization: + openai_api_key: + question_file: + answer_file: + baseline_file: + prompt_file: + reviewer_file: + result_file: # 评价结果输出文件路径 + ``` + + 3. 运行脚本 + + ```shell + python gpt_evaluator.py --config + ``` diff --git a/tools/evaluator/gpt_eval/__init__.py b/tools/evaluator/gpt_eval/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tools/evaluator/gpt_eval/answer/openai/gpt-3.5-turbo.jsonl b/tools/evaluator/gpt_eval/answer/openai/gpt-3.5-turbo.jsonl new file mode 100644 index 000000000..08b228ef3 --- /dev/null +++ b/tools/evaluator/gpt_eval/answer/openai/gpt-3.5-turbo.jsonl @@ -0,0 +1,10 @@ +{"question_id": 1, "model_id": "openai-gpt-3.5-turbo", "text": "Here are some tips to improve your time management skills:\n\n1. Set clear goals: Start by setting specific and achievable goals for yourself. This will help you prioritize your tasks and stay focused on what's important.\n\n2. Prioritize tasks: Make a to-do list and prioritize tasks based on their importance and urgency. Focus on completing high-priority tasks first.\n\n3. Break tasks into smaller, manageable chunks: Large tasks can be overwhelming, so break them down into smaller, more manageable tasks. This will make it easier to tackle them and track your progress.\n\n4. Use a time management system: Find a time management system that works for you, such as using a planner, calendar, or digital tools like task management apps. This will help you stay organized and keep track of your tasks and deadlines.\n\n5. Avoid procrastination: Procrastination can waste a lot of time and lead to unnecessary stress. Try using techniques like the Pomodoro Technique (working in short bursts with breaks in between) or setting specific deadlines for yourself to help overcome procrastination.\n\n6. Delegate tasks: If possible, delegate tasks that can be done by others. This will free up your time to focus on more important or complex tasks.\n\n7. Learn to say no: It's important to recognize your limits and be selective about taking on additional tasks. Learn to say no when you already have a full plate to avoid overcommitting yourself.\n\n8. Minimize distractions: Identify and minimize distractions that can disrupt your focus and productivity. This may include turning off notifications on your phone, closing unnecessary tabs on your computer, or finding a quiet workspace.\n\n9. Practice time blocking: Time blocking is the practice of scheduling specific time slots for different tasks or activities. This helps you allocate dedicated time for each task and prevents multitasking.\n\n10. Take regular breaks: Taking regular breaks can help refresh your mind and prevent burnout. Use your breaks to relax, stretch, or engage in activities that help you recharge.\n\nRemember, improving time management skills takes practice and consistency. Experiment with different techniques and find what works best for you."} +{"question_id": 2, "model_id": "openai-gpt-3.5-turbo", "text": "There are several effective ways to deal with stress. Here are some strategies you can try:\n\n1. Exercise regularly: Physical activity helps to reduce stress hormones and release endorphins, which elevate mood and promote relaxation.\n\n2. Practice deep breathing or meditation: Deep breathing exercises and meditation can help calm your mind and reduce stress. These techniques focus on slow, deep breaths and can be done anywhere.\n\n3. Get enough sleep: Adequate sleep is essential for managing stress. Make sure you establish a regular sleep routine and create a relaxing bedtime routine to promote better sleep.\n\n4. Prioritize and manage your time: Organize your tasks, set realistic goals, and prioritize your responsibilities. This will help you feel more in control and reduce stress caused by overwhelm.\n\n5. Take breaks and practice self-care: Give yourself regular breaks throughout the day to relax and recharge. Engage in activities that you enjoy, such as reading, listening to music, or spending time in nature.\n\n6. Seek social support: Reach out to family, friends, or support groups. Sharing your feelings and concerns with others can help you gain perspective and receive emotional support.\n\n7. Practice stress-reducing techniques: Engage in activities that help you relax and reduce stress, such as taking a warm bath, practicing yoga, or engaging in hobbies you enjoy.\n\n8. Maintain a healthy lifestyle: Eat a balanced diet, limit caffeine and alcohol intake, and avoid smoking. These healthy choices can support your body's ability to cope with stress.\n\n9. Establish boundaries: Learn to say no to extra commitments and set boundaries in your personal and professional life. This will help prevent overwhelm and reduce stress.\n\n10. Seek professional help if needed: If stress becomes overwhelming and starts affecting your daily life, consider seeking support from a mental health professional who can provide guidance and strategies to manage stress effectively.\n\nRemember, it's essential to find what works best for you as everyone's response to stress is different. Experiment with various techniques and incorporate them into your routine to find what helps you the most."} +{"question_id": 3, "model_id": "openai-gpt-3.5-turbo", "text": "Python and JavaScript are both popular programming languages, but they have some key differences:\n\n1. Syntax: Python has a clean and easy-to-read syntax, emphasizing readability and simplicity. JavaScript has a C-style syntax, which can be more complex and sometimes harder to read.\n\n2. Purpose: Python is a general-purpose language used for various tasks, including web development, data analysis, machine learning, and scripting. JavaScript, on the other hand, is primarily used for web development, both on the client-side (running in a web browser) and the server-side (running on a server).\n\n3. Paradigm: Python is a multi-paradigm language that supports procedural, object-oriented, and functional programming. JavaScript is also multi-paradigm but is primarily known as a prototype-based, object-oriented language. It also supports functional programming to some extent.\n\n4. Typing: Python is a dynamically typed language, which means that variable types are determined at runtime. JavaScript is also dynamically typed, but it has more flexibility with its loosely typed nature, allowing variables to change types during runtime.\n\n5. Libraries and frameworks: Both Python and JavaScript have a wide range of libraries and frameworks available. Python has extensive libraries for scientific computing, data analysis, and machine learning (e.g., NumPy, Pandas, TensorFlow). JavaScript, being mainly used for web development, has many frameworks and libraries for front-end development (e.g., React, Angular, Vue.js) and back-end development (e.g., Node.js, Express.js).\n\n6. Execution environment: Python programs are typically run using a Python interpreter, which converts the code into machine-readable instructions. JavaScript, on the other hand, is primarily executed by web browsers to run client-side code. However, with the introduction of Node.js, JavaScript can now also be executed on the server-side.\n\nIt's important to note that while there are differences, both Python and JavaScript are powerful languages with their own strengths and use cases. Choosing between them often depends on the specific requirements of a project."} +{"question_id": 4, "model_id": "openai-gpt-3.5-turbo", "text": "To increase your productivity while working from home, consider implementing the following strategies:\n\n1. Establish a dedicated workspace: Set up a designated area for work that is separate from your leisure space. This helps create a mental boundary between work and personal life.\n\n2. Stick to a routine: Create a daily schedule that includes specific work hours, breaks, and time for other activities. Following a routine can help maintain focus and ensure you accomplish tasks efficiently.\n\n3. Dress for success: Although tempting, avoid staying in your pajamas all day. Dressing up as you would for a regular workday can put you in a professional mindset and boost productivity.\n\n4. Minimize distractions: Identify and eliminate distractions in your surroundings. Turn off notifications on your phone, close unnecessary tabs on your computer, and let your family or housemates know when you need uninterrupted work time.\n\n5. Set clear goals and prioritize tasks: Start each day by setting clear goals and prioritizing tasks. Break down larger projects into smaller, manageable tasks to maintain motivation and track progress.\n\n6. Utilize productivity tools: Take advantage of various productivity tools and apps available. Tools like project management software, time trackers, and to-do list apps can help you stay organized and focused.\n\n7. Take regular breaks: Incorporate short breaks into your schedule to rest and recharge. Use this time to stretch, take a walk, or engage in other activities that help clear your mind and reduce stress.\n\n8. Practice time management techniques: Techniques like the Pomodoro Technique, where you work for a designated period (e.g., 25 minutes) followed by a short break (e.g., 5 minutes), can enhance focus and productivity.\n\n9. Stay connected with colleagues: Regularly communicate and collaborate with your colleagues through virtual meetings, messaging platforms, or email. Staying connected can help maintain a sense of accountability and foster collaboration.\n\n10. Prioritize self-care: Taking care of your physical and mental well-being is crucial for productivity. Make sure to get enough sleep, exercise regularly, eat healthily, and practice relaxation techniques such as meditation or deep breathing.\n\nRemember, finding what works best for you may require some experimentation and adjustments. By implementing these strategies and being mindful of your work habits, you can increase your productivity while working from home."} +{"question_id": 5, "model_id": "openai-gpt-3.5-turbo", "text": "Of course! Quantum computing is a rapidly evolving field that combines principles from physics, computer science, and mathematics to perform computations using quantum bits or qubits.\n\nIn classical computing, information is stored in bits, which can represent either a 0 or a 1. However, in quantum computing, qubits can exist in a superposition, meaning they can represent both 0 and 1 simultaneously. This allows quantum computers to process and manipulate multiple possibilities simultaneously, leading to the potential for much faster and more efficient computations.\n\nAnother important concept in quantum computing is entanglement. Entanglement allows two or more qubits to become linked in such a way that the state of one qubit is dependent on the state of the others. This property enables quantum computers to perform certain types of calculations more efficiently than classical computers.\n\nQuantum algorithms, such as Shor's algorithm for factoring large numbers or Grover's algorithm for database search, take advantage of these unique quantum properties to solve specific problems faster than classical algorithms.\n\nHowever, building practical quantum computers is still a significant challenge due to issues such as maintaining qubit stability, reducing errors caused by decoherence, and scaling up the number of qubits. Researchers and engineers are actively working on developing and improving quantum hardware, software, and error correction techniques to overcome these obstacles.\n\nWhile quantum computing is still in its early stages, it holds great promise for solving complex problems in fields like cryptography, optimization, drug discovery, and materials science. As the technology continues to advance, it has the potential to revolutionize various industries and scientific domains."} +{"question_id": 6, "model_id": "openai-gpt-3.5-turbo", "text": "There are several key differences between plant-based and animal-based protein sources:\n\n1. Source: Plant-based protein sources come from plants, such as legumes, grains, nuts, and seeds, while animal-based protein sources come from animal products like meat, poultry, fish, eggs, and dairy.\n\n2. Amino Acid Profile: Animal-based proteins are considered complete proteins as they contain all essential amino acids required by the body. On the other hand, plant-based proteins may be incomplete as they often lack one or more essential amino acids. However, by combining different plant protein sources (e.g., beans and rice), one can obtain a complete amino acid profile.\n\n3. Digestibility: Animal-based proteins are generally more easily digested and absorbed by the body compared to plant-based proteins. This is due to the differences in their protein structures and the presence of certain enzymes needed for digestion.\n\n4. Nutrient Content: Animal-based proteins often contain higher amounts of certain nutrients like vitamin B12, iron, zinc, and omega-3 fatty acids. Plant-based proteins, on the other hand, are typically richer in dietary fiber, antioxidants, and phytochemicals.\n\n5. Health Impact: Research suggests that a diet high in plant-based proteins may be associated with lower risks of chronic diseases, such as heart disease, diabetes, and certain cancers. Animal-based proteins, when consumed in excess or in processed forms, may be linked to an increased risk of these conditions.\n\n6. Environmental Impact: Plant-based protein sources tend to have a lower carbon footprint and require fewer resources (land, water, and energy) compared to animal-based sources. Choosing plant-based proteins can contribute to reducing greenhouse gas emissions and environmental degradation.\n\nIt's important to note that both plant-based and animal-based protein sources can be part of a healthy diet. The choice between them depends on individual dietary preferences, health considerations, and ethical or environmental concerns."} +{"question_id": 7, "model_id": "openai-gpt-3.5-turbo", "text": "Developing critical thinking skills involves practicing and honing your ability to analyze, evaluate, and interpret information. Here are some strategies to help you improve your critical thinking skills:\n\n1. Question assumptions: Challenge your own beliefs and assumptions, as well as those of others. Ask yourself why you hold certain beliefs, and consider alternative perspectives.\n\n2. Seek different perspectives: Expose yourself to diverse viewpoints and opinions. Engage in discussions with people who have different backgrounds and experiences than you. This helps broaden your understanding and prevents bias.\n\n3. Evaluate sources: When encountering information, whether it's from books, articles, or online sources, critically evaluate their credibility, reliability, and relevance. Look for evidence, check for biases, and consider the author's credentials.\n\n4. Analyze arguments: Break down arguments into their premises and conclusions. Identify any logical fallacies or weaknesses in reasoning. Strengthening your ability to identify flawed arguments will help you construct stronger arguments of your own.\n\n5. Practice active listening: Pay attention to what others are saying and try to understand their perspectives before formulating your own response. Actively listen by asking clarifying questions and seeking further information.\n\n6. Practice reflection: Regularly reflect on your own thoughts, decisions, and actions. Analyze the reasoning behind them and consider alternative approaches or solutions.\n\n7. Solve problems: Engage in problem-solving activities that require critical thinking. This could involve puzzles, brain teasers, or real-life situations. Practice breaking down complex problems into smaller, manageable parts.\n\n8. Develop intellectual curiosity: Cultivate a thirst for knowledge and a desire to learn. Explore topics beyond your comfort zone and engage in continuous learning.\n\n9. Practice clear communication: Develop your ability to clearly articulate your thoughts and ideas. This helps you organize your thinking and enables others to understand and challenge your arguments.\n\n10. Reflect on past experiences: Analyze past situations where critical thinking was required. Identify what worked well and what could have been improved. This reflection will help you refine your critical thinking skills for future scenarios.\n\nRemember, developing critical thinking skills is a continuous process that requires practice and an open mind."} +{"question_id": 8, "model_id": "openai-gpt-3.5-turbo", "text": "The education sector faces several major challenges today. Some of the key challenges include:\n\n1. Access to quality education: Many students, especially those in low-income areas or developing countries, lack access to quality education due to factors like poverty, limited resources, and inadequate infrastructure.\n\n2. Educational inequality: Disparities in educational opportunities exist based on factors such as socioeconomic status, gender, race, and location. This inequality leads to unequal learning outcomes and perpetuates social and economic disparities.\n\n3. Technology integration: While technology has the potential to enhance learning experiences, its integration into education systems comes with challenges. Ensuring equitable access to technology, promoting digital literacy among educators and students, and effectively using technology to enhance learning outcomes are ongoing challenges.\n\n4. Teacher shortage and quality: Many regions face a shortage of well-qualified and motivated teachers. Attracting and retaining high-quality educators, providing professional development opportunities, and addressing teacher burnout are critical challenges.\n\n5. Education funding: Adequate funding is essential for providing quality education. However, many education systems struggle with limited financial resources, hindering efforts to improve infrastructure, teacher salaries, and learning resources.\n\n6. Curriculum relevance: The rapidly changing world requires educational systems to adapt their curricula to meet the evolving needs of students and the job market. Ensuring that curricula are relevant, inclusive, and effectively prepare students for the future is a significant challenge.\n\n7. Mental health and well-being: The increasing prevalence of mental health issues among students poses a challenge for the education sector. Schools need to prioritize student well-being, provide support services, and create a positive and inclusive learning environment.\n\nAddressing these challenges requires collaborative efforts among governments, policymakers, educators, communities, and relevant stakeholders to ensure that all students have access to quality education and equal opportunities for success."} +{"question_id": 9, "model_id": "openai-gpt-3.5-turbo", "text": "There are several primary factors that influence consumer behavior. These include:\n\n1. Personal factors: These include age, gender, occupation, income level, education, lifestyle, and personality traits. These factors can greatly influence a consumer's preferences, needs, and purchasing decisions.\n\n2. Psychological factors: Psychological factors such as perception, motivation, learning, beliefs, attitudes, and emotions also play a crucial role in consumer behavior. For example, a consumer's perception of a product, their motivation to buy it, and their emotional response to the brand can influence their decision-making process.\n\n3. Social factors: Social influences, such as family, friends, peers, and social class, can significantly impact consumer behavior. Consumers often seek opinions and recommendations from their social network before making a purchase, and they may also be influenced by societal norms and cultural values.\n\n4. Cultural factors: Culture, subculture, and cultural values can strongly influence consumer behavior. Different cultures have varying beliefs, attitudes, and preferences that affect how consumers perceive and respond to marketing messages and products.\n\n5. Situational factors: The immediate context in which a consumer finds themselves, such as the physical environment, time constraints, and situational factors, can also impact their buying decisions. For example, a consumer may be more likely to make impulse purchases when they encounter a sale or promotional offer.\n\nIt is essential to understand and consider these factors when developing marketing strategies and targeting specific consumer segments."} +{"question_id": 10, "model_id": "openai-gpt-3.5-turbo", "text": "There are several effective strategies for conflict resolution in the workplace. Here are a few:\n\n1. Communication: Encourage open and honest communication between the parties involved. Create a safe space where individuals feel comfortable expressing their concerns and actively listen to each other's perspectives.\n\n2. Active listening: Practice active listening by giving full attention to what each person is saying without interruption or judgment. This helps to understand the underlying issues and emotions involved in the conflict.\n\n3. Collaboration: Encourage the parties to find a solution together through collaboration. Facilitate discussions where they can brainstorm ideas, find common ground, and work towards a mutually beneficial resolution.\n\n4. Mediation: In some cases, it may be helpful to involve a neutral third party as a mediator. A skilled mediator can help facilitate the conversation, ensure fairness, and guide the parties towards a resolution.\n\n5. Focus on interests, not positions: Encourage the parties to identify their underlying interests rather than focusing solely on their positions. By understanding each other's needs and motivations, it becomes easier to find creative solutions that address everyone's concerns.\n\n6. Respect and empathy: Encourage respect and empathy towards each other's perspectives and emotions. This helps to build trust and create an environment conducive to resolving conflicts.\n\n7. Seek feedback and follow-up: After a resolution has been reached, encourage the parties involved to provide feedback on the effectiveness of the process. Use this feedback to learn and improve conflict resolution strategies for future situations.\n\nRemember, every conflict is unique, and the most effective strategy may vary depending on the situation. It's essential to remain flexible and adapt your approach accordingly."} diff --git a/tools/evaluator/gpt_eval/answer_generator.py b/tools/evaluator/gpt_eval/answer_generator.py new file mode 100644 index 000000000..a8e8e6f09 --- /dev/null +++ b/tools/evaluator/gpt_eval/answer_generator.py @@ -0,0 +1,237 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer +import subprocess +import yaml +import jsonlines +import argparse +import openai +import time +import json +import os +import requests + +from abc import ABC, abstractmethod +from tqdm import tqdm + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, required=True) + return parser.parse_args() + + +def format_question(question): + return f'Question: {question}\n\nAnswer:' + + +class AbstractGenerator(ABC): + + @abstractmethod + def generate(self, texts, max_tokens, temperature): + raise NotImplementedError(f'GENERATE is not implemented') + + def close(self): + # do nothing + return + + +class HuggingfaceGenerator(AbstractGenerator): + + def __init__(self, config): + self.model = AutoModelForCausalLM.from_pretrained( + config['model_path']).cuda() + self.tokenizer = AutoTokenizer.from_pretrained( + config['tokenizer_path'], padding_side='left') + self.tokenizer.pad_token = self.tokenizer.eos_token + + def generate(self, texts, max_tokens, temperature): + texts = [format_question(text) for text in texts] + inputs = self.tokenizer( + texts, return_tensors='pt', padding=True).to(self.model.device) + outputs = self.model.generate( + **inputs, max_new_tokens=max_tokens, do_sample=True, temperature=temperature) + return [self.tokenizer.decode( + output[inputs.input_ids.shape[1]:], skip_special_tokens=True) for output in outputs] + + +class OpenAIGenerator(AbstractGenerator): + + def __init__(self, config): + openai.organization = config['openai_organization'] + openai.api_key = config['openai_api_key'] + self.model = config['model'] + if 'max_retry' in config: + self.max_retry = config['max_retry'] + else: + self.max_retry = 5 + if 'retry_wait' in config: + self.retry_wait = config['retry_wait'] + else: + self.retry_wait = 5 + + def generate(self, texts, max_tokens, temperature): + outputs = [] + for text in texts: + output = "" + for _ in range(self.max_retry): + try: + response = openai.ChatCompletion.create( + model=self.model, + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": text, + }, + ], + temperature=temperature, + max_tokens=max_tokens, + ) + output = response["choices"][0]["message"]["content"] + break + except Exception as e: + print(e) + time.sleep(self.retry_wait) + if len(output) == 0: + print(f"Failed to answer [{text}]") + outputs.append(output) + return outputs + + +class MegatronGenerator(AbstractGenerator): + + def __init__(self, config): + os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' + os.environ['OMP_NUM_THREADS'] = '4' + self.cur_dir = os.path.abspath(os.getcwd()) + self.process_num = config['process_num'] + self.checkpoint_path = config['checkpoint_path'] + self.load_iteration = config['iteration'] + # for different tokenizer + if config['tokenizer_type'] == 'sentencepiece': + self.tokenizer_type = 'sentencepiece' + self.vocab_path = None + self.merge_path = None + self.tokenizer_path = config['tokenizer_path'] + elif config['tokenizer_type'] == 'gpt2': + self.tokenizer_type = 'gpt2' + self.vocab_path = config['vocab_path'] + self.merge_path = config['merge_path'] + self.tokenizer_path = None + else: + raise NotImplementedError("Unsupported tokenizer type") + self.megatron_home = self.cur_dir + if 'megatron_home' in config: + self.megatron_home = config['megatron_home'] + print(f"Megatron-LM home: {self.megatron_home}") + self.server_port = config['port'] if 'port' in config else 5000 + self.handle = self._run_megatron_server() + self.url = f'http://localhost:{self.server_port}/api' + self.header = { + 'Content-Type': 'application/json; charset=UTF-8', + } + print('Start Megatron text generation server') + time.sleep(30) + + def _set_megatron_tokenizer(self, args): + if self.tokenizer_type == 'gpt2': + args.append('GPT2BPETokenizer') + args.append('--vocab-file') + args.append(self.vocab_path) + args.append('--merge-file') + args.append(self.merge_path) + elif self.tokenizer_type == 'sentencepiece': + args.append('SentencepieceTokenizer') + args.append('--tokenizer-model') + args.append(self.tokenizer_path) + + def _run_megatron_server(self): + args = ['torchrun', '--master_addr', '127.0.0.1', '--master_port', '5950', '--nproc_per_node', '1', '--nnodes', str(self.process_num), '--node_rank', '0', 'tools/run_text_generation_server.py', '--port', str( + self.server_port), '--use-checkpoint-args', '--load', self.checkpoint_path, '--load-iteration', str(self.load_iteration), '--tokenizer-type'] + self._set_megatron_tokenizer(args) + os.chdir(self.megatron_home) + process = subprocess.Popen( + args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + os.chdir(self.cur_dir) + return process + + def _request(self, prompts, max_tokens, temperature): + for _ in range(5): + try: + response = requests.put(self.url, headers=self.header, data=json.dumps({ + 'prompts': prompts, + 'tokens_to_generate': max_tokens, + 'temperature': temperature, + 'echo_prompts': False + })).json() + except Exception as e: + response = { + 'message': e + } + if 'text' not in response: + print(f'Error in megatron response: {response}, retry in 10s') + time.sleep(10) + else: + break + return response['text'] + + def generate(self, texts, max_tokens, temperature): + texts = [format_question(text) for text in texts] + return self._request(texts, max_tokens, temperature) + + def close(self): + self.handle.terminate() + + +class TextGenerator(): + def __init__(self, args): + with open(args.config, 'r') as f: + config = yaml.safe_load(f)['answer_generation'] + self.questions = [q for q in jsonlines.open( + config['question_file'], 'r')] + if not os.path.exists(os.path.dirname(config['answer_file'])): + os.makedirs(os.path.dirname(config['answer_file'])) + self.answer_writer = jsonlines.open( + config['answer_file'], 'w', flush=True) + self.batch_size = config['batch_size'] + self.max_tokens = config['max_tokens'] + self.temperature = config['temperature'] + self.model_name = config['model_name'] + if 'huggingface' in config: + self.generator = HuggingfaceGenerator(config['huggingface']) + elif 'openai' in config: + self.generator = OpenAIGenerator(config['openai']) + elif 'megatron' in config: + self.generator = MegatronGenerator(config['megatron']) + else: + raise NotImplementedError("Generator not found") + + def generate(self, questions): + texts = [question['text'] for question in questions] + answer_texts = self.generator.generate( + texts, self.max_tokens, self.temperature) + for (question, answer_text) in zip(questions, answer_texts): + self.answer_writer.write({ + 'question_id': question['question_id'], + 'model_id': self.model_name, + 'text': answer_text, + }) + + def run(self): + questions = [] + for question in tqdm(self.questions): + questions.append(question) + if len(questions) % self.batch_size == 0: + self.generate(questions) + questions.clear() + if len(questions) > 0: + self.generate(questions) + self.generator.close() + self.answer_writer.close() + + +if __name__ == '__main__': + args = parse_args() + TextGenerator(args).run() diff --git a/tools/evaluator/gpt_eval/config/config.yaml b/tools/evaluator/gpt_eval/config/config.yaml new file mode 100644 index 000000000..08e06d264 --- /dev/null +++ b/tools/evaluator/gpt_eval/config/config.yaml @@ -0,0 +1,36 @@ +answer_generation: + model_name: my_model + question_file: ./config/question.jsonl + answer_file: ./answer/myorg/mymodel.jsonl + batch_size: 4 + max_tokens: 512 + temperature: 0.7 + # config for huggingface + huggingface: + model_path: myorg/mymodel + tokenizer_path: myorg/mymodel + # # config for openai + # openai: + # openai_organization: + # openai_api_key: + # model: + # max_retry: + # # config for megatron-lm + # megatron: + # process_num: + # checkpoint_path: + # tokenizer_type: + # vocab_path: + # merge_path: + # iteration: +gpt_evaluation: + # openai config + openai_organization: + openai_api_key: + # files config + question_file: ./config/question.jsonl + answer_file: ./answer/myorg/mymodel.jsonl + baseline_file: ./answer/openai/gpt-3.5-turbo.jsonl + prompt_file: ./config/prompt.jsonl + reviewer_file: ./config/reviewer.jsonl + result_file: ./review/myorg/mymodel-gpt3.5-turbo.jsonl \ No newline at end of file diff --git a/tools/evaluator/gpt_eval/config/prompt.jsonl b/tools/evaluator/gpt_eval/config/prompt.jsonl new file mode 100644 index 000000000..f88859d17 --- /dev/null +++ b/tools/evaluator/gpt_eval/config/prompt.jsonl @@ -0,0 +1,3 @@ +{"category": "general", "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[The Start of Assistant 1's Answer]\n{answer_1}\n\n[The End of Assistant 1's Answer]\n\n[The Start of Assistant 2's Answer]\n{answer_2}\n\n[The End of Assistant 2's Answer]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment.\nThen, output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively.\n\nOutput format:\nEvaluation evidence: \n,"}, "description": "Prompt for general questions"} +{"category": "coding", "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[The Start of Assistant 1's Answer]\n{answer_1}\n\n[The End of Assistant 1's Answer]\n\n[The Start of Assistant 2's Answer]\n{answer_2}\n\n[The End of Assistant 2's Answer]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. Please first provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment.\nThen, output a single line containing only two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively.\n\nOutput format:\nEvaluation evidence: \n,"}, "description": "Prompt for coding questions"} +{"category": "math", "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[The Start of Assistant 1's Answer]\n{answer_1}\n\n[The End of Assistant 1's Answer]\n\n[The Start of Assistant 2's Answer]\n{answer_2}\n\n[The End of Assistant 2's Answer]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question displayed above.\nFirst, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please first provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment.\nThen, output a single line containing only two values indicating the scores for Assistant 1 and 2, ranging from 1 to 10, respectively.\n\nOutput format:\nEvaluation evidence: \n,"}, "description": "Prompt for math questions"} diff --git a/tools/evaluator/gpt_eval/config/question.jsonl b/tools/evaluator/gpt_eval/config/question.jsonl new file mode 100644 index 000000000..c946b8f79 --- /dev/null +++ b/tools/evaluator/gpt_eval/config/question.jsonl @@ -0,0 +1,80 @@ +{"question_id": 1, "text": "How can I improve my time management skills?", "category": "generic"} +{"question_id": 2, "text": "What are the most effective ways to deal with stress?", "category": "generic"} +{"question_id": 3, "text": "What are the main differences between Python and JavaScript programming languages?", "category": "generic"} +{"question_id": 4, "text": "How can I increase my productivity while working from home?", "category": "generic"} +{"question_id": 5, "text": "Can you explain the basics of quantum computing?", "category": "generic"} +{"question_id": 6, "text": "What are the differences between plant-based and animal-based protein sources?", "category": "generic"} +{"question_id": 7, "text": "How can I develop my critical thinking skills?", "category": "generic"} +{"question_id": 8, "text": "What are the major challenges faced by the education sector today?", "category": "generic"} +{"question_id": 9, "text": "What are the primary factors that influence consumer behavior?", "category": "generic"} +{"question_id": 10, "text": "What are the most effective strategies for conflict resolution in the workplace?", "category": "generic"} +{"question_id": 11, "text": "What are some potential implications of using a single-use plastic bottle versus a reusable bottle on both the environment and human health?", "category": "knowledge"} +{"question_id": 12, "text": "What factors would you consider when designing an inclusive and accessible public transportation system?", "category": "knowledge"} +{"question_id": 13, "text": "How can governments utilize fiscal and monetary policies to combat economic recessions?", "category": "knowledge"} +{"question_id": 14, "text": "How do language and cultural barriers affect the way people communicate and form relationships in multicultural societies?", "category": "knowledge"} +{"question_id": 15, "text": "Describe a scenario where artificial intelligence could be used to improve the quality and efficiency of healthcare delivery.", "category": "knowledge"} +{"question_id": 16, "text": "Explain the process of gene editing using CRISPR-Cas9 technology, and discuss its potential applications and ethical implications.", "category": "knowledge"} +{"question_id": 17, "text": "How do vaccinations work to protect individuals and communities from infectious diseases, and what is herd immunity?", "category": "knowledge"} +{"question_id": 18, "text": "How do social media platforms influence the way people consume and share news, and what are the potential implications for the spread of misinformation?", "category": "knowledge"} +{"question_id": 19, "text": "How do cultural, social, and economic factors influence people's food choices, and how can this knowledge be used to promote healthier diets?", "category": "knowledge"} +{"question_id": 20, "text": "Explain the process of natural selection and how it contributes to the evolution and adaptation of species.", "category": "knowledge"} +{"question_id": 21, "text": "How would you introduce yourself as a medieval knight at a royal banquet?", "category": "roleplay"} +{"question_id": 22, "text": "As a pirate captain, what would you say to your crew to motivate them to search for hidden treasure?", "category": "roleplay"} +{"question_id": 23, "text": "If you were a Shakespearean character, how would you declare your love for someone in a soliloquy?", "category": "roleplay"} +{"question_id": 24, "text": "As a superhero, how would you explain your origin story to a curious child?", "category": "roleplay"} +{"question_id": 25, "text": "Imagine you are a time traveler from the year 3000. What technological advancements would you tell people about?", "category": "roleplay"} +{"question_id": 26, "text": "As a sports commentator, describe the winning play in the final seconds of a championship game.", "category": "roleplay"} +{"question_id": 27, "text": "Pretend to be a world-famous chef. How would you describe your signature dish to a panel of judges?", "category": "roleplay"} +{"question_id": 28, "text": "You are a mountain climber reaching the summit of Mount Everest. Describe your emotions and the view from the top.", "category": "roleplay"} +{"question_id": 29, "text": "As a space colonist on Mars, describe your daily life and the challenges you face living on another planet.", "category": "roleplay"} +{"question_id": 30, "text": "Pretend to be a character in a post-apocalyptic world. Describe how you survive and the allies you encounter.", "category": "roleplay"} +{"question_id": 31, "text": "How can you determine if a restaurant is popular among locals or mainly attracts tourists, and why might this information be useful?", "category": "common-sense"} +{"question_id": 32, "text": "What are some subtle clues that suggest someone is pretending to understand a topic or conversation when they are actually confused or uninformed?", "category": "common-sense"} +{"question_id": 33, "text": "Why might someone choose to use a paper map or ask for directions instead of relying on a GPS device or smartphone app?", "category": "common-sense"} +{"question_id": 34, "text": "How can you determine if a person is genuinely interested in a conversation or simply being polite?", "category": "common-sense"} +{"question_id": 35, "text": "Why might someone prefer to shop at a small, locally-owned business instead of a large chain store, even if the prices are higher?", "category": "common-sense"} +{"question_id": 36, "text": "How can you assess the credibility of a source of information, such as a news article or blog post, without relying solely on the reputation of the author or publisher?", "category": "common-sense"} +{"question_id": 37, "text": "Why do some people enjoy the sensation of being scared, such as by watching horror movies or going on roller coasters, while others avoid these experiences?", "category": "common-sense"} +{"question_id": 38, "text": "How can observing the behavior of other people in a social situation provide clues about cultural norms and expectations?", "category": "common-sense"} +{"question_id": 39, "text": "Do we have a moral obligation to explore space, or should we focus on solving Earth's problems first?", "category": "common-sense"} +{"question_id": 40, "text": "In a world where automation is becoming increasingly prevalent, is it more important to prioritize job creation or technological progress?", "category": "common-sense"} +{"question_id": 41, "text": "How many times does the average human blink in a lifetime? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"} +{"question_id": 42, "text": "How many atoms are in a grain of salt? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"} +{"question_id": 43, "text": "How many lightning strikes occur on Earth each day? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"} +{"question_id": 44, "text": "How many balloons would it take to lift a house like in the movie \"Up\"? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"} +{"question_id": 45, "text": "How many text messages are sent globally in a minute? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"} +{"question_id": 46, "text": "How many words are spoken daily on Earth? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"} +{"question_id": 47, "text": "How many snowflakes fall during a typical winter? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"} +{"question_id": 48, "text": "How many pages are in all the books ever written? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"} +{"question_id": 49, "text": "How many times has the Earth orbited the Sun since the beginning of life? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"} +{"question_id": 50, "text": "How many songs have been recorded throughout history? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"} +{"question_id": 51, "text": "What if the Internet had been invented during the Renaissance period?", "category": "counterfactual"} +{"question_id": 52, "text": "What if the Aztecs had successfully repelled the Spanish conquistadors?", "category": "counterfactual"} +{"question_id": 53, "text": "What if the Black Death had not occurred in the 14th century?", "category": "counterfactual"} +{"question_id": 54, "text": "What if Isaac Newton had focused on biology instead of physics?", "category": "counterfactual"} +{"question_id": 55, "text": "What if the Beatles had never formed as a band?", "category": "counterfactual"} +{"question_id": 56, "text": "What if Alan Turing had not cracked the Enigma code during World War II?", "category": "counterfactual"} +{"question_id": 57, "text": "What if the Suez Canal had never been constructed?", "category": "counterfactual"} +{"question_id": 58, "text": "What if the Maya civilization had never mysteriously collapsed?", "category": "counterfactual"} +{"question_id": 59, "text": "What if Christopher Columbus had not discovered the Americas?", "category": "counterfactual"} +{"question_id": 60, "text": "What if Vincent van Gogh had been a successful artist during his lifetime?", "category": "counterfactual"} +{"question_id": 61, "text": "Develop a C++ program that reads a text file line by line and counts the number of occurrences of a specific word in the file.", "category": "coding"} +{"question_id": 62, "text": "Implement a Python function to find the longest common subsequence of two input strings using dynamic programming.", "category": "coding"} +{"question_id": 63, "text": "Implement a regular expression in Python to validate an email address.", "category": "coding"} +{"question_id": 64, "text": "Write a program to find the nth Fibonacci number using dynamic programming.", "category": "coding"} +{"question_id": 65, "text": "Implement a binary search algorithm to find a specific element in a sorted array.", "category": "coding"} +{"question_id": 66, "text": "Implement a queue data structure using two stacks in Python.", "category": "coding"} +{"question_id": 67, "text": "Implement a program to find the common elements in two arrays without using any extra data structures.", "category": "coding"} +{"question_id": 68, "text": "Given that f(x) = 5x^3 - 2x + 3, find the value of f(2).", "category": "math"} +{"question_id": 69, "text": "Solve for x in the equation 3x + 10 = 5(x - 2).", "category": "math"} +{"question_id": 70, "text": "If the endpoints of a line segment are (2, -2) and (10, 4), what is the length of the segment?", "category": "math"} +{"question_id": 71, "text": "Can you help me write a formal email to a potential business partner proposing a joint venture?", "category": "writing"} +{"question_id": 72, "text": "Can you help me write a resignation letter to my current employer, while leaving on good terms and expressing gratitude for the opportunities provided?", "category": "writing"} +{"question_id": 73, "text": "Use an appropriate format to structure a formal letter of recommendation for a student applying to a prestigious graduate program in computer science.", "category": "writing"} +{"question_id": 74, "text": "Write a compelling product launch announcement email to inform our customers of our new software solution.", "category": "writing"} +{"question_id": 75, "text": "Draft an apology email to a customer who experienced a delay in their order, and provide reassurance that the issue has been resolved.", "category": "writing"} +{"question_id": 76, "text": "Write a script for a YouTube video exploring the history and cultural significance of jazz.", "category": "writing"} +{"question_id": 77, "text": "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.", "category": "writing"} +{"question_id": 78, "text": "Write a captivating movie review for a recently released science fiction film, discussing its plot, characters, and special effects.", "category": "writing"} +{"question_id": 79, "text": "Structure a podcast script for an episode discussing the influence of streaming platforms on the music industry.", "category": "writing"} +{"question_id": 80, "text": "Write a symphony concert review, discussing the orchestra's performance and overall audience experience.", "category": "writing"} diff --git a/tools/evaluator/gpt_eval/config/reviewer.jsonl b/tools/evaluator/gpt_eval/config/reviewer.jsonl new file mode 100644 index 000000000..bc9b6fefb --- /dev/null +++ b/tools/evaluator/gpt_eval/config/reviewer.jsonl @@ -0,0 +1,3 @@ +{"category": "general", "metadata": {"temperature": 0.2, "max_tokens": 1024, "model": "gpt-3.5-turbo"}} +{"category": "coding", "metadata": {"temperature": 0.2, "max_tokens": 1024, "model": "gpt-3.5-turbo"}} +{"category": "math", "metadata": {"temperature": 0.2, "max_tokens": 1024, "model": "gpt-3.5-turbo"}} \ No newline at end of file diff --git a/tools/evaluator/gpt_eval/gpt_evaluator.py b/tools/evaluator/gpt_eval/gpt_evaluator.py new file mode 100644 index 000000000..1561829e9 --- /dev/null +++ b/tools/evaluator/gpt_eval/gpt_evaluator.py @@ -0,0 +1,175 @@ +# Some code here has been modified from: +# https://github.com/lm-sys/FastChat +# -------------------------------------------------------- + +import jsonlines +import openai +import logging +import time +import argparse +import yaml +import os +from multiprocessing import Pool +from tqdm import tqdm + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, required=True, + help="Config file path") + parser.add_argument('--worker-num', type=int, default=4, + help="Number of workers for OpenAI API") + parser.add_argument("--max-retry", type=int, default=5, + help='Retry times for OpenAI API') + parser.add_argument("--debug", action='store_true', + help='Run without calling OpenAI API') + return parser.parse_args() + + +class GPTEvaluator(): + + def __init__(self, config): + openai.organization = config['openai_organization'] + openai.api_key = config['openai_api_key'] + self.questions = [q for q in jsonlines.open( + config['question_file'], 'r')] + self.answers = [a for a in jsonlines.open( + config['answer_file'], 'r')] + self.baseline = [b for b in jsonlines.open( + config['baseline_file'], 'r')] + self.prompt_templates = { + p['category']: p for p in jsonlines.open(config['prompt_file'], 'r')} + self.reviewers = { + z['category']: z for z in jsonlines.open(config['reviewer_file'], 'r')} + if not os.path.exists(os.path.dirname(config['result_file'])): + os.makedirs(os.path.dirname(config['result_file'])) + self.result_writer = jsonlines.open( + config['result_file'], 'w', flush=True) + self.worker_num = config['worker_num'] if 'worker_num' in config else 4 + self.max_retry = config['max_retry'] if 'max_retry' in config else 5 + self.debug = config['debug'] if 'debug' in config else False + + def generate_prompt(self, question, answer, baseline, prompts): + if question['category'] in self.reviewers.keys(): + reviewer = self.reviewers[question['category']] + prompt_json = prompts[question['category']] + else: + reviewer = self.reviewers['general'] + prompt_json = prompts['general'] + sys_prompt = prompt_json["system_prompt"] + prompt_template = prompt_json["prompt_template"] + defaults = prompt_json["defaults"] + prompt1 = prompt_template.format( + question=question['text'], answer_1=answer['text'], answer_2=baseline['text'], **defaults + ) + prompt2 = prompt_template.format( + question=question['text'], answer_1=baseline['text'], answer_2=answer['text'], **defaults + ) + return sys_prompt, prompt1, prompt2, reviewer + + def parse_score(self, review): + review = review.strip('\n') + score_pair = review.split("\n")[-1] + score_pair.strip() + sp = score_pair.split(",") + try: + if len(sp) == 2: + return [float(sp[0]), float(sp[1])] + else: + logger.error( + f"Invalid score pair." + ) + return [0, 0] + except Exception as e: + logger.error("Invalid answer") + return [0, 0] + + def run(self): + results = [] + requests = [] + question_num = len(self.questions) + for i in range(question_num): + sys_prompt, prompt1, prompt2, reviewer = self.generate_prompt( + self.questions[i], self.answers[i], self.baseline[i], self.prompt_templates) + results.append({ + 'question_id': self.questions[i]['question_id'], + 'metadata': reviewer['metadata'], + 'model1': self.answers[i]['model_id'], + 'model2': self.baseline[i]['model_id'] + }) + pool = Pool(processes=self.worker_num) + requests.append({ + 'sys_prompt': sys_prompt, 'user_prompt': prompt1, 'temperature': reviewer['metadata']['temperature'], 'max_tokens': reviewer['metadata']['max_tokens'], 'model': reviewer['metadata']['model'], 'debug': self.debug, 'retry': self.max_retry}) + requests.append({ + 'sys_prompt': sys_prompt, 'user_prompt': prompt2, 'temperature': reviewer['metadata']['temperature'], 'max_tokens': reviewer['metadata']['max_tokens'], 'model': reviewer['metadata']['model'], 'debug': self.debug, 'retry': self.max_retry}) + reviews = pool.map(eval, requests) + target_score = 0.0 + baseline_score = 0.0 + cnt = 0 + for i, review in enumerate(tqdm(reviews)): + scores = self.parse_score(review) + idx = i // 2 + if i % 2 == 0: + results[idx]['review1'] = review + results[idx]['score1'] = scores + target_score += scores[0] + baseline_score += scores[1] + else: + results[idx]['review2'] = review + results[idx]['score2'] = scores + target_score += scores[1] + baseline_score += scores[0] + self.result_writer.write(results[idx]) + cnt += 1 + target_avg_score = target_score / cnt / 2 + baseline_avg_score = baseline_score / cnt / 2 + print("-------------------------") + print(f"> {results[0]['model1']}: {target_avg_score}") + print(f"> {results[0]['model2']}: {baseline_avg_score}") + print("-------------------------") + self.result_writer.write({ + f"{results[0]['model1']}": target_avg_score, + f"{results[0]['model2']}": baseline_avg_score + }) + self.result_writer.close() + + +def eval(request): + if request['debug']: + logger.info(f"Fake response {request['user_prompt']}") + return "Fake response\n10,9\n" + for _ in range(request['retry']): + try: + response = openai.ChatCompletion.create( + model=request['model'], + messages=[ + {"role": "system", "content": request['sys_prompt']}, + { + "role": "user", + "content": request['user_prompt'], + }, + ], + temperature=request['temperature'], + max_tokens=request['max_tokens'], + ) + content = response["choices"][0]["message"]["content"] + logger.info(content) + return content + except Exception as e: + logger.error(e) + time.sleep(5) + logger.error(f"Failed after {request['retry']} retries.") + return "error" + + +if __name__ == "__main__": + args = parse_args() + config = yaml.safe_load(open(args.config, 'r', encoding='utf-8'))['gpt_evaluation'] + config['worker_num'] = args.worker_num + config['max_retry'] = args.max_retry + config['debug'] = args.debug + evaluator = GPTEvaluator(config) + evaluator.run() diff --git a/tools/evaluator/recorder/README.md b/tools/evaluator/recorder/README.md new file mode 100644 index 000000000..5654a313e --- /dev/null +++ b/tools/evaluator/recorder/README.md @@ -0,0 +1,112 @@ +# Evaluation Results Recorder + +Record your evaluation results to [WandB](https://wandb.ai/) with `wandb_writer.py`. + +With `wandb_writer.py`, you can: + +- visualize the changes of evaluation metrics of your model during the training process +![Metrics](../../../docs/imgs/eval-02.png "change of metrics") +- make a leaderboard to compare the metrics of different models +![Leaderboard](../../../docs/imgs/eval-01.png "Leaderboard") + +## Usage + +```shell +python wandb_writer.py --config [--print-only] +``` + +- `config_file`: path to the configuration file (see [Configuration](#configuration) for details) +- `--print-only`: only print the result to command line, do not write to wandb + +## Configuration + +We provided three example files in `config` folder for three different cases. + +The general format is as follows: + +```yaml +project: # your wandb project name +base_url: # your wandb instance url +# other specific configuration items +``` + +### Parse from HELM output + +The following configuration is used to parse evaluation results from HELM output folder and record them to wandb. + +```yaml +# general configurations +# ... + +evals: # evaluations to record + - eval_type: helm # only support helm for now + model_name: # your model name + source: helm # use helm to parse from helm output directory + helm_output_dir: + helm_suite_name: + token_per_iteration: + benchmarks: # benchmark metrics to be recorded, and below are some examples + - name: mmlu + metrics: + - EM + - name: boolq + metrics: + - EM + - name: narrative_qa + metrics: + - F1 + - name: hellaswag + metrics: + - EM + - ... +``` + +> We use 16 core metrics of HELM as the default benchmarks if the `benchmarks` field is not provided, the 16 metrics are as follows: +> +> `mmlu.EM, raft.EM, imdb.EM, truthful_qa.EM, summarization_cnndm.ROUGE-2, summarization_xsum.ROUGE-2, boolq.EM, msmarco_trec.NDCG@10, msmarco_regular.RR@10, narrative_qa.F1, natural_qa_closedbook.F1, natural_qa_openbook_longans.F1, civil_comments.EM, hellaswag.EM, openbookqa.EM` + +### Parse from configuration file + +The scores of metrics can be given directly in the configuration file, the following is an example. + +```yaml +# general configurations +# ... + +evals: # evaluations to record + - eval_type: helm + model_name: llama-7B # your model name + source: file # use file to parse from configuration + token_num: 1000 + eval_result: # evaluation results to be recorded + mmlu: + EM: 0.345 + boolq: + EM: 0.751 + narrative_qa: + F1: 0.524 + hellaswag: + EM: 0.747 + ... +``` + +### Make leaderboard + +The following configuration is used to make a leaderboard. + +```yaml +# general configurations +# ... +leaderboard: True +leaderboard_metrics: # metrics required for the leaderboard + - mmlu.EM + - boolq.EM + - quac.F1 + - hellaswag.EM + - ... +excluded_models: # models that do not participate in the leaderboard + - + - ... +``` + +> We use 16 core metrics of HELM as the default leaderboard metrics if the `leaderboard_metrics` field is not provided, the 16 metrics are as same as the default benchmark metrics. \ No newline at end of file diff --git a/tools/evaluator/recorder/README_ZH.md b/tools/evaluator/recorder/README_ZH.md new file mode 100644 index 000000000..1cdbe7cb6 --- /dev/null +++ b/tools/evaluator/recorder/README_ZH.md @@ -0,0 +1,113 @@ +# Evaluation Results Recorder + +使用 `wandb_writer.py` 将评测结果记录到 [WandB](https://wandb.ai/) 并可视化展示。 + +`wandb_writer.py` 能够: + +- 可视化模型在训练过程中各项评测指标的变化 +![Metrics](../../../docs/imgs/eval-02.png "指标变化") +- 制作排行榜来比较不同模型的各项评测指标 +![Leaderboard](../../../docs/imgs/eval-01.png "排行榜") + +## 用法 + +```shell +python wandb_writer.py --config [--print-only] +``` + +- `config_file`: yaml 配置文件路径(配置项细节请见[配置](#配置)) +- `--print-only`: 仅将结果打印到命令行,不执行写wandb操作,用于调试 + +## 配置 + +我们在 `config` 文件夹中为三种不同的情况提供了三个示例文件,其中通用的配置项格式如下: + +```yaml +project: # wandb 项目名 +base_url: # wandb 实例 url +# other specific configuration items +``` + +其他配置项根据实际需要填写。 + +### 从 HELM 输出目录中提取评测结果 + +如下配置项用于从 HELM 的输出目录中提取评测结果并记录到 wandb 中。 + +```yaml +# general configurations +# ... + +evals: # evaluations to record + - eval_type: helm # 目前仅支持 helm + model_name: # 模型名字 + source: helm # helm 或 file,这里使用 helm 来从 helm 输出目录提取评测结果 + helm_output_dir: + helm_suite_name: + token_per_iteration: + benchmarks: # 需要记录到 wandb 的评测指标,如下是一些样例 + - name: mmlu + metrics: + - EM + - name: boolq + metrics: + - EM + - name: narrative_qa + metrics: + - F1 + - name: hellaswag + metrics: + - EM + - ... +``` + +> 本工具使用 HELM 的 16 组核心指标作为默认评测指标。如果配置中没有提供 benchmarks 域,则会自动使用如下16个评测指标: +> +> `mmlu.EM, raft.EM, imdb.EM, truthful_qa.EM, summarization_cnndm.ROUGE-2, summarization_xsum.ROUGE-2, boolq.EM, msmarco_trec.NDCG@10, msmarco_regular.RR@10, narrative_qa.F1, natural_qa_closedbook.F1, natural_qa_openbook_longans.F1, civil_comments.EM, hellaswag.EM, openbookqa.EM` + +### 从配置文件中读取评测结果 + +评测结果可以直接写在配置文件中,该选项主要用于快速向 wandb 记录已有的评测结果。 + +```yaml +# general configurations +# ... + +evals: # evaluations to record + - eval_type: helm + model_name: llama-7B # 模型名字 + source: file # helm 或 file,这里使用 file 来直接从配置文件提取评测结果 + token_num: 1000 # 需要提供该模型训练时使用的 token 数量(单位:B) + eval_result: # 需要被记录的评测结果,如下为一些样例 + mmlu: + EM: 0.345 + boolq: + EM: 0.751 + narrative_qa: + F1: 0.524 + hellaswag: + EM: 0.747 + ... +``` + +### 构建排行榜 + +如下配置用于对在同一个 wandb 项目中的数据生成排行榜。 + +```yaml +# general configurations +# ... +leaderboard: True +leaderboard_metrics: # 排行榜中需要统计的指标(仅有包含全部指标评测结果的模型才会进入榜单) + - mmlu.EM + - boolq.EM + - quac.F1 + - hellaswag.EM + - ... +excluded_models: # 不参与排行榜的模型名称 + - + - ... +``` + +> 工具使用 HELM 的 16 组核心指标作为默认的排行榜指标。如果没有提供 `leaderboard_metrics` 域,则会自动使用如下16个评测指标: +> `mmlu.EM, raft.EM, imdb.EM, truthful_qa.EM, summarization_cnndm.ROUGE-2, summarization_xsum.ROUGE-2, boolq.EM, msmarco_trec.NDCG@10, msmarco_regular.RR@10, narrative_qa.F1, natural_qa_closedbook.F1, natural_qa_openbook_longans.F1, civil_comments.EM, hellaswag.EM, openbookqa.EM` diff --git a/tools/evaluator/recorder/__init__.py b/tools/evaluator/recorder/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tools/evaluator/recorder/config/leaderboard_example.yaml b/tools/evaluator/recorder/config/leaderboard_example.yaml new file mode 100644 index 000000000..1a954facb --- /dev/null +++ b/tools/evaluator/recorder/config/leaderboard_example.yaml @@ -0,0 +1,13 @@ +project: +base_url: +leaderboard: True +leaderboard_metrics: + - mmlu.EM + - boolq.EM + - quac.F1 + - raft.EM + - hellaswag.EM + - ... +excluded_runs: + - + - ... \ No newline at end of file diff --git a/tools/evaluator/recorder/config/llama_example.yaml b/tools/evaluator/recorder/config/llama_example.yaml new file mode 100644 index 000000000..646a2001c --- /dev/null +++ b/tools/evaluator/recorder/config/llama_example.yaml @@ -0,0 +1,40 @@ +project: +base_url: +evals: + - eval_type: helm + model_name: llama-7b + source: file + token_num: 1000 + eval_result: + mmlu: + EM: 0.345 + raft: + EM: 0.583 + imdb: + EM: 0.933 + truthful_qa: + EM: 0.297 + summarization_cnndm: + ROUGE-2: 0.149 + summarization_xsum: + ROUGE-2: 0.127 + boolq: + EM: 0.751 + msmarco_trec: + NDCG@10: 0.482 + msmarco_regular: + RR@10: 0.252 + narrative_qa: + F1: 0.524 + natural_qa_closedbook: + F1: 0.297 + natural_qa_openbook_longans: + F1: 0.580 + quac: + F1: 0.332 + civil_comments: + EM: 0.578 + hellaswag: + EM: 0.747 + openbookqa: + EM: 0.574 \ No newline at end of file diff --git a/tools/evaluator/recorder/config/mymodel_example.yaml b/tools/evaluator/recorder/config/mymodel_example.yaml new file mode 100644 index 000000000..f0231294a --- /dev/null +++ b/tools/evaluator/recorder/config/mymodel_example.yaml @@ -0,0 +1,26 @@ +project: +base_url: +evals: + - eval_type: helm + model_name: + source: helm + helm_output_dir: + helm_suite_name: + token_per_iteration: + benchmarks: + - name: mmlu + metrics: + - EM + - name: boolq + metrics: + - EM + - name: quac + metrics: + - F1 + - name: raft + metrics: + - EM + - name: hellaswag + metrics: + - EM + - ... \ No newline at end of file diff --git a/tools/evaluator/recorder/wandb_writer.py b/tools/evaluator/recorder/wandb_writer.py new file mode 100644 index 000000000..188a428b9 --- /dev/null +++ b/tools/evaluator/recorder/wandb_writer.py @@ -0,0 +1,306 @@ +import wandb +import argparse +import json +import yaml +import os + + +def get_args(): + parser = argparse.ArgumentParser( + description="write evaluation result into wandb", + allow_abbrev=False + ) + parser.add_argument('--config', type=str, required=True) + parser.add_argument('--summary-only', action='store_true') + parser.add_argument('--print-only', action='store_true') + return parser.parse_args() + + +class Writer(): + + def __init__(self, + project_name, + base_url=None, + print_only=False, + summary_only=False + ) -> None: + self.project = project_name + self.base_url = base_url + self.print_only = print_only + self.summary_only = summary_only + + +DEFAULT_HELM_BENCHMARKS = [ + {"name": "mmlu", "metrics": ["EM"]}, + {"name": "raft", "metrics": ["EM"]}, + {"name": "imdb", "metrics": ["EM"]}, + {"name": "truthful_qa", "metrics": ["EM"]}, + {"name": "summarization_cnndm", "metrics": ["ROUGE-2"]}, + {"name": "summarization_xsum", "metrics": ["ROUGE-2"]}, + {"name": "boolq", "metrics": ["EM"]}, + {"name": "msmarco_trec", "metrics": ["NDCG@10"]}, + {"name": "msmarco_regular", "metrics": ["RR@10"]}, + {"name": "narrative_qa", "metrics": ["F1"]}, + {"name": "natural_qa_closedbook", "metrics": ["F1"]}, + {"name": "natural_qa_openbook_longans", "metrics": ["F1"]}, + {"name": "quac", "metrics": ["F1"]}, + {"name": "civil_comments", "metrics": ["EM"]}, + {"name": "hellaswag", "metrics": ["EM"]}, + {"name": "openbookqa", "metrics": ["EM"]} +] + +DEFAULT_HELM_METRICS = [ + "mmlu.EM", "raft.EM", "imdb.EM", "truthful_qa.EM", "summarization_cnndm.ROUGE-2", "summarization_xsum.ROUGE-2", "boolq.EM", "msmarco_trec.NDCG@10", "msmarco_regular.RR@10", + "narrative_qa.F1", "natural_qa_closedbook.F1", "natural_qa_openbook_longans.F1", "civil_comments.EM", "hellaswag.EM", "openbookqa.EM" +] + + +class HelmWriter(Writer): + + def __init__(self, + project_name, + helm_config=None, + leaderboard=False, + base_url=None, + print_only=False, + summary_only=False) -> None: + super().__init__(project_name, base_url, print_only, summary_only) + self.conf = helm_config + self.leaderboard = leaderboard + if self.leaderboard: + self.leaderboard_metrics = self.conf['leaderboard_metrics'] if 'leaderboard_metrics' in self.conf else DEFAULT_HELM_METRICS + self.excluded_models = self.conf['excluded_models'] if 'excluded_models' in self.conf else [ + ] + return + self.parse_from_helm = False + self.parse_from_file = False + self.source = self.conf['source'] if 'source' in self.conf else 'helm' + # parse from helm output dir + if self.source == 'helm': + self.helm_root = self.conf['helm_output_dir'] + self.suite_name = self.conf['helm_suite_name'] + if 'benchmarks' in self.conf: + self.scenarios = self.conf['benchmarks'] + else: + self.scenarios = DEFAULT_HELM_BENCHMARKS + self.parse_from_helm = True + # parse from config file + elif self.source == 'file': + self.eval_result = self.conf['eval_result'] + self.parse_from_file = True + self.default_iteration = 0 + self.model = None + if 'model_name' in self.conf: + self.model = self.conf['model_name'] + if 'default_iteration' in self.conf: + self.default_iteration = self.conf['default_iteration'] + + def make_leaderboard(self): + api = wandb.Api(overrides={ + 'base_url': self.base_url + }) + runs = api.runs(path=f'{self.project}', filters={'tags': 'summary'}) + result = {} + token_num = {} + token_per_iteration = {} + for run in runs: + if run.group == 'leaderboard' or run.group in self.excluded_models: + continue + print(run.id) + run_name = run.group + history = run.scan_history( + keys=['_step'] + self.leaderboard_metrics, page_size=2000, min_step=0) + if 'token_num' in run.config: + token_num[run_name] = run.config['token_num'] + if 'token_per_iteration' in run.config: + token_per_iteration[run_name] = run.config['token_per_iteration'] + for step in history: + for metric_name, score in step.items(): + if metric_name in ['_step', 'average']: + continue + if metric_name not in result: + result[metric_name] = {} + if run_name not in result[metric_name]: + result[metric_name][run_name] = {} + result[metric_name][run_name][step['_step']] = score + sum_scores = {} + for metric_scores in result.values(): + self.cal_score(metric_scores) + for run_name, iters in metric_scores.items(): + for iter, score in iters.items(): + if run_name not in sum_scores: + sum_scores[run_name] = {} + if iter not in sum_scores[run_name]: + sum_scores[run_name][iter] = score + else: + sum_scores[run_name][iter] += score + if self.print_only: + print(sum_scores) + else: + run = wandb.init( + project=self.project, + group='leaderboard', + name='leaderboard', + save_code=False, + id=f'{self.project}-leaderboard', + tags=['leaderboard'], + reinit=True) + data = [] + for name, iters in sum_scores.items(): + for iter, score in iters.items(): + if name in token_num: + data.append([name, token_num[name], score]) + elif name in token_per_iteration: + data.append( + [name, iter * token_per_iteration[name], score]) + else: + data.append([name, None, score]) + table = wandb.Table(data=data, columns=[ + 'model', 'token_num', 'score']) + wandb.log( + {'benchmark_score': wandb.plot.bar(table, 'model', 'score')}) + run.finish() + + def cal_score(self, scores): + max_score = 0.0 + min_score = 1.0 + for subject, iters in scores.items(): + max_score = max(max(iters.values()), max_score) + min_score = min(min(iters.values()), min_score) + for subject, iters in scores.items(): + for iter, score in iters.items(): + scores[subject][iter] = ( + score - min_score) / (max_score - min_score) + + def write(self): + if self.leaderboard: + self.make_leaderboard() + elif self.parse_from_helm: + self.parse_scenarios() + elif self.parse_from_file: + self.write_wandb('summary', { + self.default_iteration: self.eval_result + }, 'summary') + else: + print('do nothing, please check your config file') + + def parse_scenarios(self): + summary = {} + for scenario in self.scenarios: + try: + result = self.parse_scenario( + scenario['name'], scenario['metrics'], self.model) + if not self.summary_only: + self.write_wandb(scenario['name'], result, 'detail') + self.make_summary(scenario['name'], result, summary) + except Exception as e: + print(f"Fail to parse {scenario['name']}: {e}") + self.write_wandb('summary', summary, 'summary') + + def make_summary(self, scenario_name, eval_result, summary): + print(f"summarize for {scenario_name}") + for iteration, scenarios in eval_result.items(): + if iteration not in summary: + summary[iteration] = dict() + if scenario_name not in summary[iteration]: + summary[iteration][scenario_name] = dict() + for _, metrics in scenarios[scenario_name].items(): + summary[iteration][scenario_name] = metrics + break + + def make_average(self, summary): + for iteration, scenarios in summary.items(): + score = 0.0 + count = 0 + for _, metrics in scenarios.items(): + for _, value in metrics.items(): + score += value + count += 1 + break + summary[iteration]['average'] = score / count + + def parse_scenario(self, scenario_name, scenario_metrics, model=None): + evaluate_result = {} + with open(os.path.join(self.helm_root, 'runs', self.suite_name, 'groups', f'{scenario_name}.json')) as f: + print(f"parsing {scenario_name}.json") + subjects = json.load(f) + for subject in subjects: + print(f" parsing {subject['title']}") + record_column_idx = {} + for i, column in enumerate(subject['header']): + if column['value'] in scenario_metrics: + record_column_idx[column['value']] = i + for row in subject['rows']: + iteration = self.default_iteration + try: + iteration = int(row[0]['value'].split('_')[-1]) + except Exception: + pass + try: + iteration = int(row[0]['value'].split('/')[-1]) + except Exception: + pass + if iteration not in evaluate_result: + evaluate_result[iteration] = dict() + if scenario_name not in evaluate_result[iteration]: + evaluate_result[iteration][scenario_name] = dict() + evaluate_result[iteration][scenario_name][subject['title'].split(',')[ + 0]] = dict() + for metric, i in record_column_idx.items(): + evaluate_result[iteration][scenario_name][subject['title'].split(',')[ + 0]][metric] = row[i]['value'] + return evaluate_result + + def write_wandb(self, name, result, tag): + if self.print_only: + print(result) + return + config = {} + if 'token_num' in self.conf: + config['token_num'] = self.conf['token_num'] + if 'token_per_iteration' in self.conf: + config['token_per_iteration'] = self.conf['token_per_iteration'] + run = wandb.init( + project=self.project, + group=self.model, + name=name, + save_code=False, + id=f'{self.project}-{self.model}-{name}', + tags=['evalate', tag], + config=config, + reinit=True) + print(f"write {name} to wandb") + for iteration in sorted(result.keys()): + print(f" write iteration {iteration} to wandb") + wandb.log(result[iteration], int(iteration)) + run.finish() + + +def main(): + args = get_args() + config = yaml.safe_load(open(args.config, 'r', encoding='utf-8')) + eval_configs = config['evals'] if 'evals' in config else [] + for eval in eval_configs: + if eval['eval_type'] == 'helm': + HelmWriter( + project_name=config['project'], + base_url=config['base_url'], + print_only=args.print_only, + summary_only=args.summary_only, + helm_config=eval + ).write() + else: + raise NotImplementedError( + f"Unsupported type for eval type {eval['eval_type']}") + if 'leaderboard' in config and config['leaderboard'] == True: + HelmWriter( + project_name=config['project'], + base_url=config['base_url'], + leaderboard=True, + helm_config=config, + print_only=args.print_only + ).write() + + +if __name__ == "__main__": + main() diff --git a/tools/postprocess/README.md b/tools/postprocess/README.md new file mode 100644 index 000000000..a2decf093 --- /dev/null +++ b/tools/postprocess/README.md @@ -0,0 +1,32 @@ +# Postprocess tools + +This folder contains some postprocess scripts for additional processing of your processed dataset using data-juicer. + +## Usage + +### Mix multiple datasets with optional weights + +Use `data_mixture.py` to mix multiple datasets. + +This script will randomly select samples from every dataset and mix theses samples and export to a new_dataset. + + +```shell +python tools/postprocess/data_mixture.py \ + --data_path \ + --export_path \ + --export_shard_size \ + --num_proc + +# get help +python tools/postprocess/data_mixture.py --help +``` + +- `data_path`: a dataset file or a list of dataset files or a list of both them, optional weights, if not set, 1.0 as default. +- `export_path`: a dataset file name for exporting mixed dataset, support `json` / `jsonl` / `parquet`. +- `export_shard_size`: dataset file size in Byte. If not set, mixed dataset will be exported into only one file. +- `num_proc`: process num to load and export datasets. + +- e.g., `python tools/postprocess/data_mixture.py --data_path ds.jsonl ds_dir ds_file.json` + +**Note:** All datasets must have the same meta field, so we can use `datasets` to align their features. diff --git a/tools/postprocess/README_ZH.md b/tools/postprocess/README_ZH.md new file mode 100644 index 000000000..585dd4b97 --- /dev/null +++ b/tools/postprocess/README_ZH.md @@ -0,0 +1,31 @@ +# Postprocess tools + +此文件夹包含一些后处理脚本,用于对 data-juicer 处理后的数据集进行进一步处理。 + +## 用法 + +### 将多个数据集以可选的权重混合 + +使用 `data_mixture.py` 将多个数据集混合。 + +该脚本将从每个数据集中随机选择样本并混合这些样本并导出到新的数据集。 + +```shell +python tools/postprocess/data_mixture.py \ + --data_path \ + --export_path \ + --export_shard_size \ + --num_proc + +# get help +python tools/postprocess/data_mixture.py --help +``` + +- `data_path`: 数据集文件或数据集文件列表或两者的列表。可附加可选权重,权重未设置时默认值为 1.0。 +- `export_path`: 用于导出混合数据集的数据集文件名,支持 `json` / `jsonl` / `parquet` 格式。 +- `export_shard_size`: 数据集文件大小(以字节为单位)。 如果未设置,混合数据集将仅导出到一个文件中。 +- `num_proc`: 加载以及导出数据集使用的进程数量 + +- 例,`python tools/postprocess/data_mixture.py --data_path ds.jsonl ds_dir ds_file.json` + +**注意事项:** 所有数据集必须具有相同的元字段,从而可以使用 `datasets` 来对齐它们的特征。 diff --git a/tools/postprocess/data_mixture.py b/tools/postprocess/data_mixture.py new file mode 100644 index 000000000..144b89761 --- /dev/null +++ b/tools/postprocess/data_mixture.py @@ -0,0 +1,72 @@ +import argparse + +from data_juicer.core.exporter import Exporter +from data_juicer.format import load_formatter + + +def parse_args(): + """Parse all arguments.""" + parser = argparse.ArgumentParser( + description='Mix multiple datasets Arguments') + parser.add_argument('--data_path', + nargs='*', + default=None, + help='Path to datasets. Accepted format:' + '1) a single data path, 2) multiple datasets in the' + 'form: dataset1-weight dataset1-path dataset2-weight ' + 'dataset2-path ...') + + parser.add_argument('--export_path', + default='mixed.jsonl', + help='Path to save the mixed dataset. ' + 'Supported suffixes include ' + '["jsonl", "json", "parquet"]') + + parser.add_argument('--export_shard_size', + type=int, + default=0, + help='Shard size of exported dataset in Byte. In ' + 'default, it\'s 0, which means export the whole ' + 'dataset into only one file. If it\'s set a ' + 'positive number, the exported dataset will be ' + 'split into several dataset shards, and the max ' + 'size of each shard won\'t larger than the ' + 'export_shard_size') + + parser.add_argument('--num_proc', + type=int, + default=4, + help='Number of processes to process dataset.') + + args = parser.parse_args() + + return args + + +def run_mixture(): + """ + Mix multiple datasets into one dataset. + Randomly select samples from every dataset and mix theses + samples, then export to a new mixed dataset + + `data_path` with optional weight(1.0 as default), + e.g. + 1) a single data path + 2) multiple datasets in the format: dataset1-path + dataset1-file dataset3-path ...' + + """ + args = parse_args() + data_path = ' '.join(args.data_path) + formatter = load_formatter(data_path) + dataset = formatter.load_dataset(args.num_proc, args) + exporter = Exporter( + export_path = args.export_path, + export_shard_size = args.export_shard_size, + num_proc = args.num_proc, + export_stats = False) + exporter.export(dataset) + + +if __name__ == '__main__': + run_mixture() diff --git a/tools/preprocess/README.md b/tools/preprocess/README.md new file mode 100644 index 000000000..198719fb6 --- /dev/null +++ b/tools/preprocess/README.md @@ -0,0 +1,144 @@ +# Preprocess Tools + +This folder contains some preprocess scripts for additional processing of your dataset before using data-juicer. + +## Usage + +### Split datasets to sub-datasets by language + +This tool will split raw dataset to different sub-datasets by language information. + + +```shell +python tools/preprocess/dataset_split_by_language.py \ + --src_dir \ + --target_dir \ + --suffixes \ + --text_key \ + --num_proc + +# get help +python tools/preprocess/dataset_split_by_language.py --help +``` +- `src_dir`: you just need to set this argument to the path which stores your datasets. +- `target_dir`: result directory to store the converted jsonl files. +- `text_key`: key name of field that stores sample text. Default: text +- `suffixes`: the suffix of files that will be read. Default: None +- `num_proc` (optional): number of process workers. Default it's 1. + +### Convert raw arxiv data to jsonl + +This tool is used to convert the raw arxiv data downloaded from S3 into the jsonl format which is friendly to data-juicer. + + + +```shell +python tools/preprocess/raw_arxiv_to_jsonl.py \ + --arxiv_src_dir \ + --target_dir \ + --temp_dir \ + --num_proc + +# get help +python tools/preprocess/raw_arxiv_to_jsonl.py --help +``` +- `arxiv_src_dir`: if you download raw arxiv data as Redpajama did, you will get a directory src which includes thousands of tar files whose filenames are like `arXiv_src_yymm_xxx.tar`. You just need to set this argument to the path of this dir. +- `target_dir`: result directory to store the converted jsonl files. +- `temp_dir`: directory to store intermediate files, and they will be removed once the conversion ends. Default it's `./tmp` +- `num_proc` (optional): number of process workers. Default it's 1. + +**Note:** + +* For downloading process, please refer to [here](https://github.com/togethercomputer/RedPajama-Data/tree/main/data_prep/arxiv). + +* Before you downloading, converting or processing, you might make sure that your drive space is large enough to store the raw data (over 3TB), converted data (over 3TB), at least processed data (about 500-600GB), and even more cache data during processing. + +### Convert raw stack_exchange data to jsonl + +Use `raw_stackexchange_to_jsonl.py` to convert raw stack_exchange data. + +This tool is used for converting the raw Stack Exchange data downloaded from from [Archive](https://archive.org/download/stackexchange) to several jsonl files. + + + +```shell +python tools/preprocess/raw_arxiv_stackexchange_to_jsonl.py \ + --src_dir \ + --target_dir \ + --topk \ + --num_proc \ + +# get help +python tools/preprocess/raw_stackexchange_to_jsonl.py --help +``` +- `src_dir`: if you download raw Stack Exchange data as Redpajama did, you will get a directory src which includes hundreds of 7z files whose filenames are like `*.*.com.7z `. You need to unzip these files and rename the POSTs.xml to the corresponding compressed package name and place it in that dir. For more details, please refer to [here](https://github.com/togethercomputer/RedPajama-Data/tree/main/data_prep/stack_exchange). +- `target_dir`: result directory to store the converted jsonl files. +- `topk` (optional): select the topk sites with the most content. Default it's 28. +- `num_proc` (optional): number of process workers. Default it's 1. + +**Note:** Before you downloading, converting or processing, you might make sure that your drive space is large enough to store the raw data (over 100GB), converted data (over 100GB) + +### Convert raw Alpaca-CoT data to jsonl + +Use `raw_alpaca_cot_merge_add_meta.py` to convert raw Alpaca-CoT data. + +This tool is used for converting the raw Alpaca-Cot data downloaded from [HuggingFace](https://huggingface.co/QingyiSi/Alpaca-CoT) to jsonl files. + + + +```shell +python tools/preprocess/raw_alpaca_cot_merge_add_meta.py \ + --src_dir \ + --target_dir \ + --num_proc + +# get help +python tools/preprocess/raw_alpaca_cot_merge_add_meta.py --help +``` +- `src_dir`: you just need to set this argument to the path which stores Alpaca_CoT data. +- `target_dir`: result directory to store the converted jsonl files. +- `num_proc` (optional): number of process workers. Default it's 1. + +### reformat csv or tsv file + +This tool is used for reformat csv or tsv files which may have Nan values in some field to several jsonl files. + + + +```shell +python tools/preprocess/reformat_csv_nan_value.py \ + --src_dir \ + --target_dir \ + --suffixes \ + --is_tsv \ + --keep_default_na \ + --num_proc + +# get help +python tools/preprocess/reformat_csv_nan_value.py --help +``` +- `src_dir`: you just need to set this argument to the path which stores filenames are like "*.csv" or "*.tsv". +- `target_dir`: result directory to store the converted jsonl files. +- `suffixes`: what kind of suffixes you want to process, multi-suffixes args like "--suffixes '.tsv', '.csv' " +- `is_tsv`: if true, sep will be set to '\t', otherwize ',' as default. +- `keep_default_na`: if False, strings will be parsed as NaN, otherwise only the default NaN values are used for parsing. +- `num_proc` (optional): number of process workers. Default it's 1. + +### reformat jsonl file + +This tool is used for reformat jsonl files which may have Nan values in some field. + + + +```shell +python tools/preprocess/reformat_jsonl_nan_value.py \ + --src_dir \ + --target_dir \ + --num_proc + +# get help +python tools/preprocess/reformat_jsonl_nan_value.py --help +``` +- `src_dir`: you just need to set this argument to the path which stores filenames are like "*.jsonl". +- `target_dir`: result directory to store the converted jsonl files. +- `num_proc` (optional): number of process workers. Default it's 1. diff --git a/tools/preprocess/README_ZH.md b/tools/preprocess/README_ZH.md new file mode 100644 index 000000000..b9e47ba84 --- /dev/null +++ b/tools/preprocess/README_ZH.md @@ -0,0 +1,139 @@ +# 预处理工具 + +此文件夹包含一些预处理脚本,用于在使用 data-juicer 之前对数据集进行处理。 + +## 用法 + +### 按语言将数据集拆分为子数据集 + +该工具将根据语言信息将原始数据集拆分为不同的子数据集。 + +```shell +python tools/preprocess/dataset_split_by_language.py \ + --src_dir \ + --target_dir \ + --suffixes \ + --text_key \ + --num_proc + +# get help +python tools/preprocess/dataset_split_by_language.py --help +``` + +- `src_dir`: 将此参数设置为存储数据集的路径即可。 +- `target_dir`: 用于存储转换后的 jsonl 文件的结果目录。 +- `text_key`: 存储示例文本的字段的 key,默认为 text。 +- `suffixes`: 待读取文件的后缀名,默认为 None。 +- `num_proc` (可选): worker 进程数量,默认为 1。 + +### 将原始 arxiv 数据转换为 jsonl + +该工具用于将从 S3 下载的原始 arxiv 数据转换为对 data-juicer 友好的 jsonl 格式。 + +```shell +python tools/preprocess/raw_arxiv_to_jsonl.py \ + --arxiv_src_dir \ + --target_dir \ + --temp_dir \ + --num_proc + +# get help +python tools/preprocess/raw_arxiv_to_jsonl.py --help +``` + +- `arxiv_src_dir`: 如果你像 Redpajama 一样下载原始 arxiv 数据,你将得到一个目录 src,其中包含数千个 tar 文件,其文件名类似于 `arXiv_src_yymm_xxx.tar`。 您只需将此参数设置为该目录的路径即可。 +- `target_dir`: 用于存储转换后的 jsonl 文件的结果目录。 +- `temp_dir`: 用于存储临时文件的目录,该目录将在转化结束时自动被删除,默认为 `./tmp`。 +- `num_proc` (可选): worker 进程数量,默认为 1。 + +**注意事项:** + +* 下载过程请参考[这里](https://github.com/togethercomputer/RedPajama-Data/tree/main/data_prep/arxiv)。 + +* 在下载、转换或处理之前,您需要确保您的硬盘空间足够大,可以存储原始数据(超过 3TB)、转换后的数据(超过 3TB)、最小处理后的数据(大约 500-600GB),以及处理期间的缓存数据。 + +### 将原始 stack_exchange 数据转换为 jsonl + +使用 `raw_stackexchange_to_jsonl.py` 来转化原始 stack_exchange 数据. + +该工具用于将从 [Archive](https://archive.org/download/stackexchange) 下载的原始 Stack Exchange 数据转化为多个 jsonl 文件. + +```shell +python tools/preprocess/raw_arxiv_stackexchange_to_jsonl.py \ + --src_dir \ + --target_dir \ + --topk \ + --num_proc \ + +# get help +python tools/preprocess/raw_stackexchange_to_jsonl.py --help +``` + +- `src_dir`: 如果像 Redpajama 一样下载原始 Stack Exchange 数据,你将得到一个目录 src,其中包含数百个 7z 文件,其文件名类似于 `*.*.com.7z`。 您需要解压这些文件并将 POSTs.xml 重命名为相应的压缩包名称并将其放在该目录中。更多详情请参考[这里](https://github.com/togethercomputer/RedPajama-Data/tree/main/data_prep/stack_exchange)。 +- `target_dir`: 用于存储转换后的 jsonl 文件的结果目录。 +- `topk` (可选): 选择内容最多的 k 个站点,默认为 28. +- `num_proc` (可选): worker 进程数量,默认为 1。 + +**注意事项:** 在下载、转换或处理之前,您需要确保您的硬盘空间足够大,可以存储原始数据(超过 100GB)、转换后的数据(超过 100GB) + +### 将原始 Alpaca-CoT 数据转换为 jsonl + +使用 `raw_alpaca_cot_merge_add_meta.py` 来转化原始 Alpaca-CoT 数据. + +该工具用于将从 [HuggingFace]( https://huggingface.co/QingyiSi/Alpaca-CoT) 下载的原始 Alpaca-Cot 数据转化为 jsonl 文件. + +```shell +python tools/preprocess/raw_alpaca_cot_merge_add_meta.py \ + --src_dir \ + --target_dir \ + --num_proc \ + +# get help +python tools/preprocess/raw_alpaca_cot_merge_add_meta.py --help +``` + +- `src_dir`: 将此参数设置为存储Alpaca-CoT数据集的路径。 +- `target_dir`: 用于存储转换后的 jsonl 文件的结果目录。 +- `num_proc` (可选): worker 进程数量,默认为 1。 + +### 重新格式化 csv 或者 tsv 文件 + +此工具用于将某些字段中可能具有 NaN 值的 csv 或 tsv 文件格式化为多个 jsonl 文件。 + +```shell +python tools/preprocess/reformat_csv_nan_value.py \ + --src_dir \ + --target_dir \ + --suffixes \ + --is_tsv \ + --keep_default_na \ + --num_proc + +# get help +python tools/preprocess/reformat_csv_nan_value.py --help +``` + +- `src_dir`: 将此参数设置为存储数据集的路径,例如 `*.csv` 或 `*.tsv` 即可。 +- `target_dir`: 用于存储转换后的 jsonl 文件的结果目录。 +- `suffixes`: 待读取文件的后缀名,可指定多个,例如 `--suffixes '.tsv', '.csv'` +- `is_tsv`: 如果为 true,则分隔符将设置为 `\t`,否则默认设置为 `,`。 +- `keep_default_na`: 如果为 False,字符串将被解析为 NaN,否则仅使用默认的 NaN 值进行解析。 +- `num_proc` (可选): worker 进程数量,默认为 1。 + +### 重新格式化 jsonl 文件 + +该工具用于重新格式化某些字段中可能包含 Nan 值的 jsonl 文件。 + +```shell +python tools/preprocess/reformat_jsonl_nan_value.py \ + --src_dir \ + --target_dir \ + --num_proc + +# get help +python tools/preprocess/reformat_jsonl_nan_value.py --help +``` + +- `src_dir`: 将此参数设置为存储数据集的路径,例如 `*.jsonl`. +- `target_dir`: 用于存储转换后的 jsonl 文件的结果目录。 +- `num_proc` (可选): worker 进程数量,默认为 1。 diff --git a/tools/preprocess/dataset_split_by_language.py b/tools/preprocess/dataset_split_by_language.py new file mode 100644 index 000000000..ab9fd89e6 --- /dev/null +++ b/tools/preprocess/dataset_split_by_language.py @@ -0,0 +1,91 @@ +# This tool is used to split datasets to sub-datasets +# by fast-text lanuage model. + +import os + +import fire +import pandas as pd +from jsonargparse import Namespace +from loguru import logger + +from data_juicer.format import load_formatter +from data_juicer.ops.filter.language_id_score_filter import \ + LanguageIDScoreFilter + + +def keep_by_lang(sample, lang): + """ + Keep samples with the specified language. + :param sample: a sample in dataset + :param lang: the specified language + :return: True to keep, False to discard + """ + if sample['stats']['lang'] == lang: + return True + return False + + +def main(src_dir, + target_dir, + text_keys_to_load=None, + text_key_to_process='text', + suffixes=[], + num_proc=1): + """ + Load dataset from the source directory, then apply language identification + using the operation filter called `LanguageIDScoreFilter`, + finally, split the dataset by language and save it. + :param src_dir: path thats store dataset directory + :param target_dir: path to store subset files(`jsonl` format) + :param text_key: key name of field that stores sample text, default "text: + :param suffixes: files with suffixes to be loaded, default None + :param num_proc: number of processes to process dataset, default 1. + """ + if text_keys_to_load is None: + text_keys_to_load = ['text'] + # check if the source directory exists. + if not os.path.exists(src_dir): + raise ValueError('The raw source data directory does not exist,' + ' Please check and retry.') + if not os.path.exists(target_dir): + os.makedirs(target_dir, exist_ok=True) + + # Note: + # key name of `"keys_to_load"` in sample will be rename to "text" + formatter = load_formatter(src_dir, + keys_to_load=text_keys_to_load, + suffixes=suffixes) + tmp_cfg = Namespace({'text_key_to_process': text_key_to_process}) + dataset = formatter.load_dataset(num_proc, tmp_cfg) + + op = LanguageIDScoreFilter(text_key=tmp_cfg['text_key_to_process']) + + if 'stats' not in dataset.features: + # TODO: + # this is a temp solution, + # only add stats when calling filter op + dataset = dataset.add_column(name='stats', + column=[{}] * dataset.num_rows) + + # identify language + dataset = dataset.map(op.compute_stats, num_proc=num_proc) + + langs = pd.DataFrame(dataset['stats'])['lang'] + unique_langs = list(set(langs)) + + logger.info(f'There are {len(dataset)} in dataset') + logger.info(f'Languages in dataset are {unique_langs}') + + # split and save subset of dataset by language + for lang in unique_langs: + ds = dataset.filter(keep_by_lang, + num_proc=num_proc, + fn_kwargs=dict(lang=lang)) + + logger.info(f'There are {len(ds)} with language [{lang}]') + jsonl_fp = os.path.join(target_dir, lang + '.jsonl') + ds.to_json(jsonl_fp, force_ascii=False) + + +if __name__ == '__main__': + fire.Fire(main) \ No newline at end of file diff --git a/tools/preprocess/raw_alpaca_cot_merge_add_meta.py b/tools/preprocess/raw_alpaca_cot_merge_add_meta.py new file mode 100644 index 000000000..5da7dbcd9 --- /dev/null +++ b/tools/preprocess/raw_alpaca_cot_merge_add_meta.py @@ -0,0 +1,507 @@ +# This tool is used for converting the raw Alpaca-Cot data downloaded +# from Huggingface (ref: https://huggingface.co/QingyiSi/Alpaca-CoT) +# to several jsonl files. + +import os +import pathlib +from multiprocessing import Pool + +import fire +from datasets import load_dataset +from loguru import logger + +meta_dict = { + 'Chain-of-Thought': { # sub directory + 'Task': 'MT', # Alpaca-Cot original Task + 'Gen': 'HG', # Alpaca-Cot original Gen + 'Lang': 'EN/CN', # Alpaca-Cot original Language + 'Dataset': 'Chain-of-Thought', # sub directory + 'Multi-round Dialog': + False, # whether is Multi-round Dialog data, added by Data-Juicer + 'IFT': True, # whether is IFT data, added by Data-Juicer + 'SFT': False, # whether is SFT data, added by Data-Juicer + 'Preference': + False, # whether is Preference data, added by Data-Juicer + }, + 'GPT4all': { + 'Task': 'MT', + 'Gen': 'COL', + 'Lang': 'EN', + 'Dataset': 'GPT4all', + 'Multi-round Dialog': False, + 'IFT': True, + 'SFT': True, + 'Preference': False, + }, + 'GPTeacher': { + 'Task': 'MT', + 'Gen': 'SI', + 'Lang': 'EN', + 'Dataset': 'GPTeacher', + 'Multi-round Dialog': False, + 'IFT': False, + 'SFT': True, + 'Preference': False, + }, + 'Guanaco': { + 'Task': 'MT', + 'Gen': 'SI', + 'Lang': 'ML', + 'Dataset': 'Guanaco', + 'Multi-round Dialog': False, + 'IFT': False, + 'SFT': True, + 'Preference': False, + }, + 'HC3': { + 'Task': 'TS', + 'Gen': 'MIX', + 'Lang': 'EN/CN', + 'Dataset': 'HC3', + 'Multi-round Dialog': False, + 'IFT': False, + 'SFT': True, + 'Preference': True, + }, + 'alpaca': { + 'Task': 'MT', + 'Gen': 'SI', + 'Lang': 'EN', + 'Dataset': 'alpaca', + 'Multi-round Dialog': False, + 'IFT': False, + 'SFT': True, + 'Preference': False, + }, + 'Natural-Instructions': { + 'Task': 'MT', + 'Gen': 'COL', + 'Lang': 'ML', + 'Dataset': 'Natural-Instructions', + 'Multi-round Dialog': False, + 'IFT': True, + 'SFT': False, + 'Preference': False, + }, + 'belle_cn': { + 'Task': 'TS/MT', + 'Gen': 'SI', + 'Lang': 'CN', + 'Dataset': 'belle_cn', + 'Multi-round Dialog': False, + 'IFT': False, + 'SFT': True, + 'Preference': False, + }, + 'instinwild': { + 'Task': 'MT', + 'Gen': 'SI', + 'Lang': 'EN/CN', + 'Dataset': 'instinwild', + 'Multi-round Dialog': False, + 'IFT': False, + 'SFT': True, + 'Preference': False, + }, + 'prosocial-dialog': { + 'Task': 'TS', + 'Gen': 'MIX', + 'Lang': 'EN', + 'Dataset': 'prosocial-dialog', + 'Multi-round Dialog': False, + 'IFT': False, + 'SFT': True, + 'Preference': False, + }, + 'finance': { + 'Task': 'TS', + 'Gen': 'COL', + 'Lang': 'EN', + 'Dataset': 'finance', + 'Multi-round Dialog': False, + 'IFT': False, + 'SFT': True, + 'Preference': False, + }, + 'xP3': { + 'Task': 'MT', + 'Gen': 'COL', + 'Lang': 'ML', + 'Dataset': 'xP3', + 'Multi-round Dialog': False, + 'IFT': True, + 'SFT': False, + 'Preference': False, + }, + 'firefly': { + 'Task': 'MT', + 'Gen': 'COL', + 'Lang': 'CN', + 'Dataset': 'firefly', + 'Multi-round Dialog': False, + 'IFT': True, + 'SFT': False, + 'Preference': False, + }, + 'instruct': { + 'Task': 'MT', + 'Gen': 'COL', + 'Lang': 'EN', + 'Dataset': 'instruct', + 'Multi-round Dialog': False, + 'IFT': False, + 'SFT': True, + 'Preference': False, + }, + 'CodeAlpaca': { + 'Task': 'TS', + 'Gen': 'SI', + 'Lang': 'EN', + 'Dataset': 'CodeAlpaca', + 'Multi-round Dialog': False, + 'IFT': True, + 'SFT': False, + 'Preference': False, + }, + 'alpacaGPT4': { + 'Task': 'MT', + 'Gen': 'SI', + 'Lang': 'EN/CN', + 'Dataset': 'alpacaGPT4', + 'Multi-round Dialog': False, + 'IFT': False, + 'SFT': True, + 'Preference': True, + }, + 'webGPT': { + 'Task': 'TS', + 'Gen': 'MIX', + 'Lang': 'EN', + 'Dataset': 'webGPT', + 'Multi-round Dialog': False, + 'IFT': True, + 'SFT': False, + 'Preference': True, + }, + 'dolly': { + 'Task': 'TS', + 'Gen': 'HG', + 'Lang': 'EN', + 'Dataset': 'dolly', + 'Multi-round Dialog': False, + 'IFT': False, + 'SFT': True, + 'Preference': False, + }, + 'baize': { + 'Task': 'MT', + 'Gen': 'COL', + 'Lang': 'EN', + 'Dataset': 'baize', + 'Multi-round Dialog': False, + 'IFT': False, + 'SFT': True, + 'Preference': False, + }, + 'hh-rlhf': { + 'Task': 'TS', + 'Gen': 'MIX', + 'Lang': 'EN', + 'Dataset': 'hh-rlhf', + 'Multi-round Dialog': True, + 'IFT': False, + 'SFT': True, + 'Preference': True, + }, + 'OIG': { + 'Task': 'MT', + 'Gen': 'COL', + 'Lang': 'EN', + 'Dataset': 'OIG', + 'Multi-round Dialog': False, + 'IFT': False, + 'SFT': True, + 'Preference': False, + }, + 'GAOKAO': { + 'Task': 'MT', + 'Gen': 'COL', + 'Lang': 'CN', + 'Dataset': 'GAOKAO', + 'Multi-round Dialog': False, + 'IFT': True, + 'SFT': False, + 'Preference': False, + }, + 'camel': { + 'Task': 'MT', + 'Gen': 'SI', + 'Lang': 'EN', + 'Dataset': 'camel', + 'Multi-round Dialog': False, + 'IFT': True, + 'SFT': False, + 'Preference': False, + }, + 'FLAN-Muffin': { + 'Task': 'MT', + 'Gen': 'COL', + 'Lang': 'EN', + 'Dataset': 'FLAN-Muffin', + 'Multi-round Dialog': False, + 'IFT': True, + 'SFT': False, + 'Preference': False, + }, + 'COIG': { + 'Task': 'MT', + 'Gen': 'COL', + 'Lang': 'CN', + 'Dataset': 'COIG', + 'Multi-round Dialog': False, + 'IFT': False, + 'SFT': True, + 'Preference': False, + }, + 'gpt4tools': { + 'Task': 'MT', + 'Gen': 'SI', + 'Lang': 'EN', + 'Dataset': 'gpt4tools', + 'Multi-round Dialog': False, + 'IFT': True, + 'SFT': False, + 'Preference': False, + }, + 'ShareGPT': { + 'Task': 'MT', + 'Gen': 'MIX', + 'Lang': 'EN', + 'Dataset': 'ShareGPT', + 'Multi-round Dialog': True, + 'IFT': False, + 'SFT': True, + 'Preference': False, + }, + 'Auto-CoT': { + 'Task': 'MT', + 'Gen': 'COL', + 'Lang': 'EN', + 'Dataset': 'Auto-CoT', + 'Multi-round Dialog': False, + 'IFT': True, + 'SFT': False, + 'Preference': False, + }, + 'MOSS': { + 'Task': 'TS', + 'Gen': 'SI', + 'Lang': 'EN/CN', + 'Dataset': 'MOSS', + 'Multi-round Dialog': False, + 'IFT': False, + 'SFT': True, + 'Preference': False, + }, + 'ultrachat': { + 'Task': 'TS', + 'Gen': 'SI', + 'Lang': 'EN', + 'Dataset': 'ultrachat', + 'Multi-round Dialog': False, + 'IFT': False, + 'SFT': True, + 'Preference': False, + }, + 'Chinese-medical': { + 'Task': 'TS', + 'Gen': 'COL', + 'Lang': 'CN', + 'Dataset': 'Chinese-medical', + 'Multi-round Dialog': False, + 'IFT': False, + 'SFT': True, + 'Preference': False, + }, + 'CSL': { + 'Task': 'MT', + 'Gen': 'COL', + 'Lang': 'CN', + 'Dataset': 'CSL', + 'Multi-round Dialog': False, + 'IFT': True, + 'SFT': False, + 'Preference': False, + }, + 'pCLUE': { + 'Task': 'MT', + 'Gen': 'COL', + 'Lang': 'CN', + 'Dataset': 'pCLUE', + 'Multi-round Dialog': False, + 'IFT': True, + 'SFT': False, + 'Preference': False, + }, + 'news_commentary': { + 'Task': 'TS', + 'Gen': 'COL', + 'Lang': 'CN', + 'Dataset': 'news_commentary', + 'Multi-round Dialog': False, + 'IFT': True, + 'SFT': False, + 'Preference': False, + }, + 'StackExchange': { + 'Task': 'MT', + 'Gen': 'COL', + 'Lang': 'EN', + 'Dataset': 'StackExchange', + 'Multi-round Dialog': False, + 'IFT': False, + 'SFT': True, + 'Preference': True, + }, + "ConvAI2": { + "Task": "TS", + "Gen": "HG", + "Lang": "EN", + "Dataset": "ConvAI2", + "Multi-round Dialog": False, + "IFT": False, + "SFT": True, + "Preference": False, + }, + "FastChat": { + "Task": "MT", + "Gen": "SI", + "Lang": "EN", + "Dataset": "FastChat", + "Multi-round Dialog": False, + "IFT": False, + "SFT": True, + "Preference": False, + }, + 'Tabular-LLM-Data': { + 'Task': 'MT', + 'Gen': 'COL', + 'Lang': 'EN/CN', + "Dataset": "Tabular-LLM-Data", + "Multi-round Dialog": False, + "IFT": True, + "SFT": False, + "Preference": False, + }, + 'ThoughtSource': { + 'Task': 'MT', + 'Gen': 'COL', + 'Lang': 'EN', + "Dataset": "ThoughtSource", + "Multi-round Dialog": False, + "IFT": True, + "SFT": False, + "Preference": False, + } +} + + +def merge_and_add_meta(filename, target_dir): + """ + Merge `instruction`/`input`/`output` to `text` for process, + and add meta info. + :param filename: input dataset file + :param target_dir: path to save updated dataset + """ + + ds = load_dataset('json', data_files=[filename], split='train') + + if 'instruction' in ds.features and \ + 'input' in ds.features and \ + 'output' in ds.features: + for column_name in ds.column_names: + if column_name not in ['instruction', 'input', 'output']: + ds = ds.remove_columns(column_name) + else: + logger.warning(f'Can not find ["instruction", "input", "output"] in \ + {filename}, do nothing.') + return + + meta = None + for key in meta_dict.keys(): + if key in filename: + meta = meta_dict[key] + + if meta is None: + logger.warning(f'Can not find meta in {filename}, do nothing.') + return + + def _merge_and_add_meta(sample, path, meta): + """ + Merge `instruction`/`input`/`output` to `text` for process, + and add meta info. + :param sample: a dict sample in dataset + :param path: sample in which file + :param meta: meta added to sample + :return: updated sample + """ + sample['text'] = ' '.join( + [sample['instruction'], sample['input'], sample['output']]) + sample['meta'] = meta + sample['meta']['origin_path'] = path + return sample + + path = ''.join(['Alpaca-CoT', filename.split('Alpaca-CoT')[1]]) + ds = ds.map(_merge_and_add_meta, + num_proc=48, + fn_kwargs={ + 'path': path, + 'meta': meta + }) + + if len(ds) > 0: + out_file = ''.join([target_dir, filename.split('Alpaca-CoT')[1]]) + out_file = out_file.replace('.json', '.jsonl') + dir_name = os.path.dirname(out_file) + os.makedirs(dir_name, exist_ok=True) + ds.to_json(out_file, force_ascii=False) + + +def fp_iter(src_dir): + """ + Find all tar files in the source directory. + :param src_dir: path to source dataset directory + :return: iterator over json files + """ + for fp in pathlib.Path(src_dir).rglob('*.json'): + yield fp + + +def main(src_dir, target_dir, num_proc=4): + """ + Load dataset from the source directory, then apply language identification + using the operation filter called `LanguageIDScoreFilter`, + finally, split the dataset by language and save it. + :param src_dir: path thats store dataset directory + :param target_dir: path to store subset files(`jsonl` format) + :param num_proc: number of processes to process dataset, default 1. + """ + + # check if the source directory exists. + if not os.path.exists(src_dir): + raise ValueError('The raw source data directory does not exist,' + ' Please check and retry.') + if not os.path.exists(target_dir): + os.makedirs(target_dir, exist_ok=True) + + pool = Pool(num_proc) + + for fp in fp_iter(src_dir): + pool.apply_async(merge_and_add_meta, args=(str(fp), target_dir)) + + pool.close() + pool.join() + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/tools/preprocess/raw_arxiv_to_jsonl.py b/tools/preprocess/raw_arxiv_to_jsonl.py new file mode 100644 index 000000000..d531cdbf5 --- /dev/null +++ b/tools/preprocess/raw_arxiv_to_jsonl.py @@ -0,0 +1,146 @@ +# Part of the code here has been modified from: +# https://github.com/togethercomputer/RedPajama-Data/blob/main/data_prep/arxiv/arxiv_cleaner.py +# -------------------------------------------------------- +# +# This tool is used for converting the raw arxiv data downloaded from S3 +# (ref: https://info.arxiv.org/help/bulk_data_s3.html) to several jsonl files. +# +# For downloading process, please refer to: +# https://github.com/togethercomputer/RedPajama-Data/tree/main/data_prep/arxiv +# +# Notice: before you downloading, converting or processing, you might make sure +# that your drive space is large enough to store the raw data (over 3TB), +# converted data (over 3TB), at least processed data (about 500-600GB), and +# even more cache data during processing. + +import gzip +import os +import pathlib +import tarfile +import tempfile +from multiprocessing import Pool + +import fire +import jsonlines as jl +from loguru import logger + + +@logger.catch +def tex_proj_loader(file_or_dir_path: pathlib.Path): + """ + Load the tex files from a tar file or a gzip file. + :param file_or_dir_path: path to tar file or the gzip file + :return: a list of content in tex files + """ + + files_and_content = [] + try: + # if it is a directory, open it as a tarfile + with tarfile.open(file_or_dir_path) as sub_tf: + for member in sub_tf.getmembers(): + if member.name.endswith('.tex'): + file_content = sub_tf.extractfile(member).read() + try: + file_content = file_content.decode('utf-8') + except UnicodeDecodeError: + logger.error(f'UnicodeDecodeError: {file_or_dir_path}') + return None + files_and_content.append(file_content) + except tarfile.ReadError: + # otherwise we try opening it as a gzip file + try: + with gzip.open(file_or_dir_path, 'rb') as gz: + file_content = gz.read() + except Exception as e: + # all fails, we skip this file + logger.error(f'{e}: {file_or_dir_path}') + return None + + try: + file_content = file_content.decode('utf-8') + except UnicodeDecodeError: + logger.error(f'UnicodeDecodeError: {file_or_dir_path}') + return None + files_and_content.append(file_content) + except Exception as e: + logger.error(f'{e}: {file_or_dir_path}') + return None + + return files_and_content + + +@logger.catch +def convert_tar_to_jsonl(tar_fp, jsonl_fp, tmp_dir): + """ + Extract the contents of tex files from tar file, convert and + save to jsonl file + :param tar_fp: path to tar file + :param jsonl_fp: path to save jsonl file + :param tmp_dir: a temporary directory to save extracted files + """ + failed = 0 + success = 0 + with tempfile.TemporaryDirectory(dir=tmp_dir, prefix=tar_fp.name) as td: + with jl.open(jsonl_fp, mode='w') as writer: + with tarfile.open(tar_fp) as tf: + tf.extractall(members=tf.getmembers(), path=td) + for proj_dir_or_file in pathlib.Path(td).rglob('*.gz'): + data = tex_proj_loader(proj_dir_or_file) + if data is None: + failed += 1 + continue + success += 1 + writer.write_all([{'text': txt} for txt in data]) + + logger.info(f'{jsonl_fp} done. Fail: {failed}, success: {success}') + + +def tar_fp_iter(src_dir): + """ + Find all tar files in the source directory. + :param src_dir: path to source dataset directory + :return: iterator over tar files + """ + for tar_fp in pathlib.Path(src_dir).glob('*.tar'): + yield tar_fp + + +def main(arxiv_src_dir, target_dir, work_dir='./tmp/', num_proc=1): + """ + :param arxiv_src_dir: if you download raw arxiv data as Redpajama did, + you will get a directory src which includes thousands of tar + files whose filenames are like "arXiv_src_yymm_xxx.tar". You + just need to set this argument to the path of this dir. + :param target_dir: result directory to store the converted jsonl files. + :param work_dir: directory to store intermediate files, and they will + be removed once the conversion ends. Default it's "./tmp" + :param num_proc: number of process workers. Default it's 1. + """ + # check if the source directory exists. + if not os.path.exists(arxiv_src_dir): + raise ValueError('The raw arxiv source data directory does not exist,' + ' Please check and retry.') + if not os.path.exists(target_dir): + os.makedirs(target_dir, exist_ok=True) + + if not os.path.exists(work_dir): + os.makedirs(work_dir, exist_ok=True) + + # convert in multiprocess + pool = Pool(num_proc) + for tar_fp in tar_fp_iter(arxiv_src_dir): + logger.info(f'Start to process {tar_fp}') + jsonl_fp = os.path.join(target_dir, + tar_fp.name.replace('.tar', '.jsonl')) + pool.apply_async(convert_tar_to_jsonl, + args=( + tar_fp, + jsonl_fp, + work_dir, + )) + pool.close() + pool.join() + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/tools/preprocess/raw_stackexchange_to_jsonl.py b/tools/preprocess/raw_stackexchange_to_jsonl.py new file mode 100644 index 000000000..a9f267211 --- /dev/null +++ b/tools/preprocess/raw_stackexchange_to_jsonl.py @@ -0,0 +1,242 @@ +# Part of the code here has been modified from: +# https://github.com/togethercomputer/RedPajama-Data/tree/main/data_prep/stack_exchange +# -------------------------------------------------------- +# +# This tool is used for converting the raw Stack Exchange data downloaded from +# from Archive (ref: https://archive.org/download/stackexchange) to several +# jsonl files. +# +# For downloading process, please refer to: +# https://github.com/togethercomputer/RedPajama-Data/tree/main/data_prep/stack_exchange +# +# Notice: before you downloading, converting or processing, you might make sure +# that your drive space is large enough to store the raw data (over 100GB), +# converted data (over 100GB) + +import json +import os +import xml.etree.ElementTree as ET +from multiprocessing import Pool + +import fire +from loguru import logger +from tqdm import tqdm + + +@logger.catch +def get_sites_count(path, topk=28): + """ + Take top-K sites(`.xml`) by its size of content + :param path: path to stack_exchage data + :param topk: number of top-k sites + :return + 1) a dict stores pair of site and its size of content + 2) a list of topk sites + """ + + logger.info('Got counts for all sites.') + sites = os.listdir(path) + sites = [x for x in sites if x.endswith('.xml')] + counts = {} + for site in tqdm(sites): + if site == '.DS_Store': + continue + # read xml file and count contents + with open(os.path.join(path, site), 'r') as f: + # read # lines + count = sum(1 for line in f) + counts[site] = count - 3 # subtract the header + # sort the counts + counts = { + k: v + for k, v in sorted( + counts.items(), key=lambda item: item[1], reverse=True) + } + # take first 28 + sites = list(counts.keys())[:topk] + return counts, sites + + +@logger.catch +def get_parents(site, counts): + """ + Find all answers's parent id, and groups by parent id + :param site: site(xml) name + :param counts: a dict stores pair of site and its size of content + :return: a dict stores pair of parent question id and list of answer id + """ + parents = {} + with open(site, 'r') as f: + for i, line in enumerate(tqdm(f, + total=counts[os.path.basename(site)])): + # first 2 lines are header + # e.g., counts = 2: total=5 lines, 2,3 are data + # last line is footer + if i > 1 and i <= counts[os.path.basename(site)] + 1: + root = ET.fromstring(line) + if 'ParentId' in root.attrib: + # this is an answer + if root.attrib['ParentId'] not in parents: + parents[root.attrib['ParentId']] = [] + parents[root.attrib['ParentId']].append({ + 'id': + root.attrib['Id'], + 'text': + root.attrib['Body'], + 'score': + root.attrib['Score'] + }) + logger.info((f'Got {len(parents)} questions for {site}.')) + return parents + + +@logger.catch +def get_qapairs(site, counts, parents): + """ + Find and group all matched pairs of question and answer in site file + :param site: site(.xml) name + :param counts: a dict stores pair of site and its size of content + :param parents: a dict stores pair of parent question id and + list of answer id + :return: a list of qa pairs + """ + qa_pairs = [] + with open(site, 'r') as f: + for i, line in enumerate(tqdm(f, + total=counts[os.path.basename(site)])): + # first 2 lines are header + # e.g., counts = 2: total=5 lines, 2,3 are data + # last line is footer + if i > 1 and i <= counts[os.path.basename(site)] + 1: + root = ET.fromstring(line) + if 'ParentId' not in root.attrib: + post_id = root.attrib['Id'] + if post_id in parents: + # this is a question + qa_pairs.append({ + 'question': { + 'id': post_id, + 'text': f"{root.attrib['Title']} \ + {root.attrib['Body']}", + 'score': root.attrib['Score'] + }, + 'answers': parents[post_id] + }) + else: + if 'Title' in root.attrib: + # if there's a title => then a valid question + body = root.attrib[ + 'Body'] if 'Body' in root.attrib else '' + score = root.attrib[ + 'Score'] if 'Score' in root.attrib else 0 + qa_pairs.append({ + 'question': { + 'id': post_id, + 'text': f"{root.attrib['Title']} {body}", + 'score': score + }, + }) + logger.info((f'Got {len(qa_pairs)} qa_pairs for {site}.')) + return qa_pairs + + +@logger.catch +def process_qa_pair(pair, site_name, site_count): + """ + Sort answers by their score for question in qa pair sample, + add meta info to sample + :param pair: input qa pair sample + :param site_name: site name of qa pair + :param site_count: content size of site + :return: a dict of qa pair, including ["text", "meta"] + """ + # sort answers by score + if 'answers' in pair: + pair['answers'] = sorted(pair['answers'], + key=lambda x: x['score'], + reverse=True) + answers = '\nA: '.join([x['text'] for x in pair['answers']]) + text = f"Q: {pair['question']['text']}\nA: {answers}" + else: + text = f"Q: {pair['question']['text']}" + return { + 'text': text, + 'meta': { + 'site_count': site_count, + 'url': f"https://{site_name}/questions/{pair['question']['id']}", + 'timestamp': '2023-03-29', + 'source': 'stackexchange', + 'question_score': pair['question']['score'], + } + } + + +@logger.catch +def process_site(site, counts, src_dir, target_dir, num_proc=24): + """ + Convert one raw Stack Exchange site data to jsonl file. + 1) find all answers's parent id and groups by parent id + 2) find matched pair of question and answers + 3) sort answers by their score for each question + :param site: site name endwith `".xml"` + :param counts: dict stores pair of site name and its size + :parma src_dir: dir path of site + :param target_dir: path to save jsonl file + :param num_proc: number of process workers. Default it's 24. + """ + logger.info(f'Processing {site}...') + logger.info(f'|{site}|{counts[site]}|') + site_path = os.path.join(src_dir, site) + parents = get_parents(site_path, counts) + qa_pairs = get_qapairs(site_path, counts, parents) + + site_name = site.removesuffix('.xml') + if 'stackoverflow_part' in site_name: + site_name = 'stackoverflow.com' + + site_name_list = [site_name] * len(qa_pairs) + counts_list = [counts[site]] * len(qa_pairs) + tasks = [*zip(qa_pairs, site_name_list, counts_list)] + with Pool(num_proc) as p: + results = p.starmap(process_qa_pair, iterable=tasks) + logger.info(f"Writing {len(results)} results to \ + {os.path.join(target_dir, site_name+'.jsonl')}") + + with open(os.path.join(target_dir, site_name + '.jsonl'), 'w') as f: + for result in results: + f.write(json.dumps(result) + '\n') + + +@logger.catch +def main(src_dir, target_dir, topk=28, num_proc=1): + """ + Convert the raw Stack Exchange data downloaded from from Archive + (ref: https://archive.org/download/stackexchange) to several + jsonl files. + :param src_dir: if you download raw Stack Exchange data as Redpajama did, + you will get a directory src which includes hundreds of 7z files + whose filenames are like "*.*.com.7z ". You need to unzip these + files and rename the POSTs.xml to the corresponding compressed + package name and place it in that dir. + :param target_dir: result directory to store the converted jsonl files. + :param topk: select the topk sites with the most content. + Default it's 28. + :param num_proc: number of process workers. Default it's 1. + """ + # check if the source directory exists + if not os.path.exists(src_dir): + raise ValueError( + 'The raw stack_exchange source data directory does not exist,' + ' Please check and retry.') + if not os.path.exists(target_dir): + os.makedirs(target_dir, exist_ok=True) + + # select topk sites from src dir by most contents number + counts, sites = get_sites_count(src_dir, topk=topk) + for site in sites: + logger.info(f'Start to process {site}') + process_site(site, counts, src_dir, target_dir, num_proc) + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/tools/preprocess/reformat_csv_nan_value.py b/tools/preprocess/reformat_csv_nan_value.py new file mode 100644 index 000000000..a455af9f9 --- /dev/null +++ b/tools/preprocess/reformat_csv_nan_value.py @@ -0,0 +1,85 @@ +# This tool is used for reformat csv or tsv files which may contain Nan values +# in some field to several jsonl files. + +import os +import pathlib +from multiprocessing import Pool + +import fire +from datasets import Dataset + + +def reformat_nan_value(fp, jsonl_fp, keep_default_na, kwargs): + """ + Reformat a csv/tsv file with kwargs. + :param fp: a csv/tsv file + :param jsonl_fp: path to save jsonl file + :param keep_default_na: if False, no string will be parsed as NaN, + otherwise only the default NaN values are used for parsing. + :param kwargs: for tsv file, kwargs["sep'} is `\t` + :return: iterator over files, + """ + ds = Dataset.from_csv(fp, keep_default_na=keep_default_na, **kwargs) + ds.to_json(jsonl_fp, force_ascii=False) + pass + + +def fp_iter(src_dir, suffix): + """ + Find all files endswith the specified suffix in the source directory. + :param src_dir: path to source dataset directory + :return: iterator over files, + """ + for fp in pathlib.Path(src_dir).glob(f'*{suffix}'): + yield fp + + +def main(src_dir, + target_dir, + suffixes=['.csv'], + is_tsv=False, + keep_default_na=False, + num_proc=1, + **kwargs): + """ + Reformat csv or tsv files that may contain Nan values using HuggingFace + to load with extra args, e.g. set `keep_default_na` to False + :param src_dir: path thats stores filenames are like "*.csv" or "*.tsv". + :param target_dir: path to store the converted jsonl files. + :param suffixes: files with suffixes to be to process, multi-suffixes args + like `--suffixes "'.tsv', '.csv'" + :param is_tsv: if True, sep will be set to '\t'. Default ','. + :param keep_default_na: if False, no strings will be parsed as NaN, + otherwise only the default NaN values are used for parsing. + :param num_proc: number of process workers, Default 1. + :param kwargs: optional extra args for Dataset loading csv/tsv + """ + # check if the source directory exists + if not os.path.exists(src_dir): + raise ValueError('The raw source data directory does not exist,' + ' Please check and retry.') + if not os.path.exists(target_dir): + os.makedirs(target_dir, exist_ok=True) + + if kwargs is None: + kwargs = {} + + if is_tsv: + kwargs['sep'] = '\t' + + if isinstance(suffixes, str): + suffixes = [suffixes] + + pool = Pool(num_proc) + for suffix in suffixes: + for fp in fp_iter(src_dir, suffix): + jsonl_fp = os.path.join(target_dir, + fp.name.replace(suffix, '.jsonl')) + pool.apply_async(reformat_nan_value, + args=(str(fp), jsonl_fp, keep_default_na, kwargs)) + pool.close() + pool.join() + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/tools/preprocess/reformat_jsonl_nan_value.py b/tools/preprocess/reformat_jsonl_nan_value.py new file mode 100644 index 000000000..b5c1a7013 --- /dev/null +++ b/tools/preprocess/reformat_jsonl_nan_value.py @@ -0,0 +1,97 @@ +# This tool is used for reformat jsonl files which may have Nan values +# in some field. + +import os +import pathlib +from multiprocessing import Pool + +import fire +import jsonlines +import pandas as pd +from datasets import Dataset + + +def check_dict_non_nan(obj): + """ + Check if all fields in the dict object are non-Nan + :papram: a dict object + :return: True if all fields in the dict object are non-Nan, + else False + """ + no_nan = True + for key, value in obj.items(): + if isinstance(value, dict): + no_nan = no_nan & check_dict_non_nan(value) + elif pd.isna(value) or pd.isnull(value): + return False + return no_nan + + +def get_non_nan_features(src_dir): + """ + Get the first object feature which does not contain Nan value. + :param src_dir: path which stores jsonl files. + :return: reference feature of dataset. + """ + for fp in fp_iter(src_dir): + with jsonlines.open(fp, 'r') as reader: + for obj in reader: + if check_dict_non_nan(obj): + ds = Dataset.from_list([obj]) + return ds.features + return None + + +def reformat_jsonl(fp, jsonl_fp, features): + """ + Reformat a jsonl file with reference features + :param fp: input jsonl file + :param jsonl_fp: formated jsonl file + :param features: reference feature to use for dataset. + """ + with jsonlines.open(fp, 'r') as reader: + objs = [obj for obj in reader] + ds = Dataset.from_list(objs, features=features) + ds.to_json(jsonl_fp, force_ascii=False) + + +def fp_iter(src_dir): + """ + Find all jsonl files in the source directory. + :param src_dir: path to source dataset directory + :return: iterator over jsonl files + """ + for fp in pathlib.Path(src_dir).glob('*.jsonl'): + yield fp + + +def main(src_dir, target_dir, num_proc=1): + """ + Reformat the jsonl files which may contain Nan values. Traverse jsonl + files to find the first object that does not contain Nan as a + reference feature type, then set it for loading all jsonl files. + :param src_dir: path thats stores jsonl files. + :param target_dir: path to store the converted jsonl files. + :param num_proc: number of process workers. Default it's 1. + """ + + # check if the source directory exists + if not os.path.exists(src_dir): + raise ValueError('The raw source data directory does not exist,' + ' Please check and retry.') + if not os.path.exists(target_dir): + os.makedirs(target_dir, exist_ok=True) + + pool = Pool(num_proc) + features = get_non_nan_features(src_dir) + for fp in fp_iter(src_dir): + print(fp) + jsonl_fp = os.path.join(target_dir, fp.name) + pool.apply_async(reformat_jsonl, args=(str(fp), jsonl_fp, features)) + + pool.close() + pool.join() + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/tools/process_data.py b/tools/process_data.py new file mode 100644 index 000000000..4780736c7 --- /dev/null +++ b/tools/process_data.py @@ -0,0 +1,13 @@ +from loguru import logger + +from data_juicer.core import Executor + + +@logger.catch +def main(): + executor = Executor() + executor.run() + + +if __name__ == '__main__': + main() diff --git a/tools/quality_classifier/README.md b/tools/quality_classifier/README.md new file mode 100644 index 000000000..e8a55e4cc --- /dev/null +++ b/tools/quality_classifier/README.md @@ -0,0 +1,139 @@ +# Quality Classifier Toolkit + +Help you reproduce and apply quality classifier to your web datasets similar to GPT-3 quality classifier. + +The whole toolkit is based on PySpark. And the basic structure of quality classifiers here consists of: +- tokenizer: the [standard Tokenizer](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.Tokenizer.html#tokenizer) of PySpark or [sentencepiece](https://github.com/google/sentencepiece) model +- feature extractor: [HashingTF](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.HashingTF.html#hashingtf) +- classifier: [LogisticRegression](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.classification.LogisticRegression.html#logisticregression) + +## Usage + +### Predict with existing classifiers + +Use `predict.py` to predict a document score of "quality" and a label for each sample to indicate whether this sample should be kept according to the score. + +```shell +# predict doc_score for a dataset +python predict.py \ + \ + \ + [--model ] \ + [--tokenizer ] \ + [--keep_method ] \ + [--text_key ] \ + [--overall_stats] + +# print the usage message +python predict.py --help +``` + +- `dataset_path`: the input dataset path. The suffix of the path should be one of the `[json, jsonl, parquet]`. +- `result_path`: the path to store the dataset with prediction results. The suffix of the path should be one of the `[json, jsonl, parquet]`. +- `model_path`: (Optional. Default: "gpt3") the path to the model used to predict. You can use one of the models we provide `[gpt3, chinese, code]`. Or you can use the model trained by yourself using the `train.py` script. +- `tokenizer`: (Optional. Default: None) the tokenizer to tokenize texts to be classified. If it's None, the [standard Tokenizer](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.Tokenizer.html#tokenizer) of PySpark will be used. Besides, you can use one of the tokenizers we provide `[zh.sp.model, code.sp.model]`. Or you can set it to a path to your own [sentencepiece](https://github.com/google/sentencepiece) model. +- `keep_method`: (Optional. Default: "gpt3") the method used to decide whether a sample should be kept according to the doc_score. Should be one of `[gpt3, label]`. +- `text_key`: (Optional. Default: "text") the field name to store texts to be classified in the input dataset. +- `overall_stats`: (Optional. Default: False) whether to generate an overall stats report of document scores. + +### Train your own quality classifier + +Use `train.py` to train your own quality classifier for your datasets. + +```shell +# train a quality classifier for your own dataset +python train.py \ + ] \ + ] \ + [--output_model_path ] \ + [--num_training_samples ] \ + [--train_test_split_ratio ] \ + [--tokenizer ] \ + [--evaluation ] \ + [--text_key ] + +# print the usage message +python train.py --help +``` + +- `positive_datasets`: the paths to the positive datasets. It could be a string for a single dataset, e.g. `'pos.parquet'`, or a list of strings for multiple datasets, e.g. `'["pos1.parquet", "pos2.parquet"]'`. +- `negative_datasets`: the paths to the negative datasets. Similar to `positive_datasets`. +- `output_model_path`: (Optional. Default: "my_quality_model") the path to store the trained classifier. +- `num_training_samples`: (Optional. Default: 0) number of samples used to train the model for pos/neg datasets respectively. Default 0 means using all samples to train. +- `train_test_split_ratio`: (Optional. Default: 0.8) ratio to split training set, and the rest of samples will be test set used to evaluate. +- `tokenizer`: (Optional. Default: None) the tokenizer to tokenize texts to be classified. If it's None, the [standard Tokenizer](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.Tokenizer.html#tokenizer) of PySpark will be used. Besides, you can use one of the tokenizers we provide `[zh.sp.model, code.sp.model]`. Or you can set it to a path to your own [sentencepiece](https://github.com/google/sentencepiece) model. +- `evaluation`: (Optional, Default: True) whether to evaluate the trained classifier using the test set after training. +- `text_key`: (Optional. Default: "text") the field name to store texts to be classified in the input dataset. + +### Evaluate a quality classifier + +Use `eval.py` to evaluate a quality classifier to report Precision, Recall, and F1 metrics. + +```shell +# evaluate a quality classifier on your own dataset +python eval.py \ + [--positive_datasets ] \ + [--negative_datasets ] \ + [--model ] \ + [--tokenizer ] \ + [--text_key ] + +# print the usage message +python eval.py --help +``` + +- `positive_datasets`: (Optional. Default: None) the paths to the positive datasets. It could be a string for a single dataset, e.g. `'pos.parquet'`, or a list of strings for multiple datasets, e.g. `'["pos1.parquet", "pos2.parquet"]'`. +- `negative_datasets`: (Optional. Default: None) the paths to the negative datasets. Similar to `positive_datasets`. +- `model_path`: (Optional. Default: "my_quality_model") the path to the model to be evaluated. You can evaluate one of the models we provide `[gpt3, chinese, code]`. Or you can evaluate the model trained by yourself using the `train.py` script. +- `tokenizer`: (Optional. Default: None) the tokenizer to tokenize texts to be classified. If it's None, the [standard Tokenizer](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.Tokenizer.html#tokenizer) of PySpark will be used. Besides, you can use one of the tokenizers we provide `[zh.sp.model, code.sp.model]`. Or you can set it to a path to your own [sentencepiece](https://github.com/google/sentencepiece) model. +- `text_key`: (Optional. Default: "text") the field name to store texts to be classified in the input dataset. + +## Model Zoo + +We provide 3 models we trained before: `gpt3`, `chinese`, `code`. Each model has its tokenizer and keep method. Tokenizers "xx.sp.model" are trained on the training data using [sentencepiece](https://github.com/google/sentencepiece). + +| model | tokenizer | keep method | positive datasets | negative datasets | +|-----------|--------------------|------------------|----------------------------------------------------|------------------------------------------| +| `gpt3` | standard Tokenizer | pareto | Wikipedia-en & books1 & OpenWebText2 | CommonCrawl | +| `chinese` | zh.sp.model | label | Wikipedia-zh & Wudao | Samples in Chinese from CommonCrawl | +| `code` | code.sp.model | label | Samples with max_stars_count >= 1372 from TheStack | Random samples from the rest of TheStack | + +- `gpt3`: GPT-3 quality classifier reproduced by us. +- `chinese`: A Chinese quality classifier trained by the same pipeline as `gpt3`, but with different tokenizer and training data. +- `code`: (Experimental) A code quality classifier trained by the same pipeline as `gpt3`, but with different tokenizer and training data. We only keep "programming" and "markup" language types of samples for training. +- Experiments of these classifiers on corresponding test sets are shown in the table below: + +| model | Precision | Recall | F1 | +|-----------|------------|--------|--------| +| `gpt3` | 96.82% | 98.14% | 97.47% | +| `chinese` | 98.00% | 99.30% | 98.64% | +| `code` | 71.23% | 54.21% | 61.56% | + +- Keep ratios of `gpt3` and `chiense` classifiers on CommonCrawl are shown in the table below: + +| model | keep ratio @ label | keep ratio @ pareto | +|--------------------------------------|---------------------|---------------------| +| GPT-3 quality classifier (estimated) | - | ~1.3% | +| `gpt3` | 3.22% | 1.41% | +| `chinese` | 1.81% | - | + +## More about Quality Classifier + +### Method + +The quality classifiers here mainly refer to the GPT-3 quality classifier mentioned in the Appendix A of GPT-3 paper: + +> In order to improve the quality of Common Crawl, we developed an automatic filtering method to remove low quality documents. Using the original WebText as a proxy for high-quality documents, we trained a classifier to distinguish these from raw Common Crawl. We then used this classifier to re-sample Common Crawl by prioritizing documents which were predicted by the classifier to be higher quality. The classifier is trained using logistic regression classifier with features from Spark’s standard tokenizer and HashingTF 10. For the positive examples, we used a collection of curated datasets such as WebText, Wikiedia, and our web books corpus as the positive examples, and for the negative examples, we used unfiltered Common Crawl. We used this classifier to score Common Crawl documents. We kept each document in our dataset iff +> +> np.random.pareto(α) > 1 − document_score +> +> We chose α = 9 in order to take mostly documents the classifier scored highly, but still include some documents that were out of distribution. α was chosen to match the distribution of scores from our classifier on WebText. We found this re-weighting increased quality as measured by loss on a range of out-of-distribution generative text samples. + +### Tokenizers + +- Standard Tokenizer in Spark: split texts by whitespaces. +- zh/code.sp.model: trained using sentencepiece with BPE. + +### Keep Methods +- label: `doc_score > 0.5` +- pareto: `doc_score > 1 - np.random.pareto(α), α = 9` diff --git a/tools/quality_classifier/README_ZH.md b/tools/quality_classifier/README_ZH.md new file mode 100644 index 000000000..607ddff00 --- /dev/null +++ b/tools/quality_classifier/README_ZH.md @@ -0,0 +1,141 @@ +# Quality Classifier Toolkit + +帮助您复现类似于 GPT-3 质量分类器并将其应用到您的 Web 数据集。 + +整个工具包基于PySpark,分类器的基本模块包括: + +- tokenizer: PySpark 的 [standard Tokenizer](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.Tokenizer.html#tokenizer) 或 [sentencepiece](https://github.com/google/sentencepiece) 模型 +- feature extractor: [HashingTF](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.HashingTF.html#hashingtf) +- classifier: [LogisticRegression](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.classification.LogisticRegression.html#logisticregression) + +## 用法 + +### 使用现有的分类器进行预测 + +使用 `predict.py` 来预测一个文档的“质量”分数,并为每个样本添加一个标签,以根据分数判断是否应该保留该样本。 + +```shell +# 预测数据集的 doc_score +python predict.py \ + \ + \ + [--model ] \ + [--tokenizer ] \ + [--keep_method ] \ + [--text_key ] \ + [--overall_stats] + +# 打印帮助信息 +python predict.py --help +``` + +- `dataset_path`: 输入数据集路径。要求路径的后缀为 `[json, jsonl, parquet]` 之一。 +- `result_path`: 存储带有预测结果的数据集的路径。要求路径的后缀为`[json, jsonl, parquet]`之一。 +- `model_path`: (可选,默认为 `gpt3`) 用于预测的模型的路径。您可以使用我们提供的模型之一`[gpt3, chinese,code]`。或者您可以使用`train.py`脚本使用自己训练的模型。 +- `tokenizer`: (可选,默认为 None) 用于标记要分类的文本的标记器。 如果为 None,则将使用 PySpark 的 [standard Tokenizer](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.Tokenizer.html#tokenizer)。 此外,您可以使用我们提供的标记器`[zh.sp.model, code.sp.model]`之一。您也可以将其设置为您自己的 [sentencepiece](https://github.com/google/sentencepiece) 模型的路径。 +- `keep_method`: (可选,默认为 `gpt3`) 根据 doc_score 决定是否保留样本的方法。应为 `[gpt3, label]` 之一。 +- `text_key`: (可选,默认为 `text`) 用于存储输入数据集中需要被分类的文本的字段名称。 +- `overall_stats`: (可选,默认为 False) 是否生成文档分数的汇总统计报告。 + +### 训练自己的质量分类器 + +使用`train.py`在您的数据集上训练您自己的质量分类器。 + +```shell +# 为自己的数据集训练质量分类器 +python train.py \ + ] \ + ] \ + [--output_model_path ] \ + [--num_training_samples ] \ + [--train_test_split_ratio ] \ + [--tokenizer ] \ + [--evaluation ] \ + [--text_key ] + +# 打印帮助信息 +python train.py --help +``` + +- `positive_datasets`: 正样本数据集的路径。可以是单个数据集的字符串,例如 `'pos.parquet'`,或多个数据集的字符串列表,例如 `'["pos1.parquet", "pos2.parquet"]'`。 +- `negative_datasets`: 负样本数据集的路径,配置方法与 `positive_datasets` 相似。 +- `output_model_path`: (可选,默认值为 `my_quality_model`) 存储训练好的分类器的路径。 +- `num_training_samples`: (可选,默认值为 0) 分别用于训练 正/负样本数据集模型的样本数量。 默认0表示使用所有样本进行训练。 +- `train_test_split_ratio`: (可选,默认值为0.8) 分割训练集的比率,其余样本将作为测试集用于评估。 +- `tokenizer`: (可选,默认值为None) 用于对要分类的文本进行标记的标记生成器。如果为 None,则将使用 PySpark 的[标准 Tokenizer](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.Tokenizer.html#tokenizer) 此外,您可以使用我们提供的标记器`[zh.sp.model,code.sp.model]`之一。也可以将其设置为您自己的 [sentencepiece](https://github.com/google/sentencepiece) 模型的路径。 +- `evaluation`: (可选,默认值为 True) 是否在训练后使用测试集评估训练好的分类器。 +- `text_key`: (可选,默认值为 `text`) 用于存储输入数据集中需要被分类的文本的字段名称。 + +### 评估质量分类器 + +使用`eval.py`以报告精度、召回率和 F1 指标来评估质量分类器。 + +```shell +# 在自己的数据集上评估质量分类器 +python eval.py \ + [--positive_datasets ] \ + [--negative_datasets ] \ + [--model ] \ + [--tokenizer ] \ + [--text_key ] + +# 打印帮助信息 +python eval.py --help +``` + +- `positive_datasets`: (Optional. Default: None) the paths to the positive datasets. It could be a string for a single dataset, e.g. `'pos.parquet'`, or a list of strings for multiple datasets, e.g. `'["pos1.parquet", "pos2.parquet"]'`. +- `negative_datasets`: (Optional. Default: None) the paths to the negative datasets. Similar to `positive_datasets`. +- `model_path`: (Optional. Default: "my_quality_model") the path to the model to be evaluated. You can evaluate one of the models we provide `[gpt3, chinese, code]`. Or you can evaluate the model trained by yourself using the `train.py` script. +- `tokenizer`: (Optional. Default: None) the tokenizer to tokenize texts to be classified. If it's None, the [standard Tokenizer](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.Tokenizer.html#tokenizer) of PySpark will be used. Besides, you can use one of the tokenizers we provide `[zh.sp.model, code.sp.model]`. Or you can set it to a path to your own [sentencepiece](https://github.com/google/sentencepiece) model. +- `text_key`: (Optional. Default: "text") the field name to store texts to be classified in the input dataset. + +## Model Zoo + +我们提供了已训练好的三个模型:`gpt3`,`chinese`,`code`。每个模型都有其 tokenizer 和 keep method。其中Tokenizer `xx.sp.model` 使用 [sentencepiece](https://github.com/google/sentencepiece) 的训练数据进行训练。 + +| model | tokenizer | keep method | positive datasets | negative datasets | +|-----------|--------------------|------------------|----------------------------------------------------|------------------------------------------| +| `gpt3` | standard Tokenizer | pareto | Wikipedia-en & books1 & OpenWebText2 | CommonCrawl | +| `chinese` | zh.sp.model | label | Wikipedia-zh & Wudao | Samples in Chinese from CommonCrawl | +| `code` | code.sp.model | label | Samples with max_stars_count >= 1372 from TheStack | Random samples from the rest of TheStack | + +- `gpt3`: 我们复现的 GPT-3质量分类器。 +- `chinese`: 通过与`gpt3`相同的流程训练的中文质量分类器,但使用不同的标记器和训练数据。 +- `code`: (Experimental) 通过与`gpt3`相同的流程进行训练,但使用不同的标记器和训练数据得到的代码质量分类器。我们只保留 “programming” 和 “markup” 语言类型的样本进行训练。 +- 这些分类器在相应测试集上的实验如下表所示: + +| model | Precision | Recall | F1 | +|-----------|------------|--------|--------| +| `gpt3` | 96.82% | 98.14% | 97.47% | +| `chinese` | 98.00% | 99.30% | 98.64% | +| `code` | 71.23% | 54.21% | 61.56% | + +- Common Crawl 上 `gpt3`和 `chinese` 分类器的 keep ratio 如下表所示: + +| model | keep ratio @ label | keep ratio @ pareto | +|--------------------------------------|---------------------|---------------------| +| GPT-3 quality classifier (estimated) | - | ~1.3% | +| `gpt3` | 3.22% | 1.41% | +| `chinese` | 1.81% | - | + +## 有关质量分类器的更多信息 + +### 方法 + +这里的质量分类器主要参考GPT-3论文附录A中提到的GPT-3质量分类器: + +> In order to improve the quality of Common Crawl, we developed an automatic filtering method to remove low quality documents. Using the original WebText as a proxy for high-quality documents, we trained a classifier to distinguish these from raw Common Crawl. We then used this classifier to re-sample Common Crawl by prioritizing documents which were predicted by the classifier to be higher quality. The classifier is trained using logistic regression classifier with features from Spark’s standard tokenizer and HashingTF 10. For the positive examples, we used a collection of curated datasets such as WebText, Wikiedia, and our web books corpus as the positive examples, and for the negative examples, we used unfiltered Common Crawl. We used this classifier to score Common Crawl documents. We kept each document in our dataset iff +> +> np.random.pareto(α) > 1 − document_score +> +> We chose α = 9 in order to take mostly documents the classifier scored highly, but still include some documents that were out of distribution. α was chosen to match the distribution of scores from our classifier on WebText. We found this re-weighting increased quality as measured by loss on a range of out-of-distribution generative text samples. + +### Tokenizers + +- Spark 中的标准 Tokenizer: 根据空白字符分割文本. +- zh/code.sp.model: 使用 sentencepiece BPE 训练得到。 + +### Keep Methods + +- label: `doc_score > 0.5` +- pareto: `doc_score > 1 - np.random.pareto(α), α = 9` diff --git a/tools/quality_classifier/eval.py b/tools/quality_classifier/eval.py new file mode 100644 index 000000000..80e92ff67 --- /dev/null +++ b/tools/quality_classifier/eval.py @@ -0,0 +1,99 @@ +# This tool is used for evaluating a quality classifier on your own datasets +# based on PySpark. +# +# We provide several trained models for you. Please refer to the comments at +# the beginning of predict tool for more details. +# +# This tool needs several arguments: +# - positive_datasets: the paths to the positive datasets. It could be a +# string for a single dataset, e.g. 'pos.parquet', or a list of strings +# for several datasets, e.g. '["pos1.parquet", "pos2.parquet"]'. +# - negative_datasets: the paths to the negative datasets. It could be a +# string for a single dataset, e.g. 'neg.parquet', or a list of strings +# for several datasets, e.g. '["neg1.parquet", "neg2.parquet"]'. +# - model: quality classifier name to apply. It's "gpt3" in default. You can +# use one of ["gpt3", "chinese", "code"] we provided, or you can set it +# to the path to your own model trained using the train.py tool. +# - tokenizer: what tokenizer to use to tokenize texts. It's None in default, +# which means using the standard Tokenizer of PySpark. You can use one of +# ["zh.sp.model", "code.sp.model"] we provided, or you can set it to the +# path to your own sentencepiece model. +# - text_key: the field key name to hold texts to be classified. It's "text" +# in default. + +import fire +from loguru import logger + +from qc_utils import eval, init_spark, load_datasets + + +@logger.catch +def main(positive_datasets=None, + negative_datasets=None, + model='my_quality_model', + tokenizer=None, + text_key='text'): + """ + Evaluate a trained quality classifier using specific positive/negative + datasets + :param positive_datasets: the paths to the positive datasets. It could be a + string for a single dataset, e.g. 'pos.parquet', or a list of strings + for multiple datasets, e.g. '["pos1.parquet", "pos2.parquet"]' + :param negative_datasets: the paths to the negative datasets. It could be a + string for a single dataset, e.g. 'neg.parquet', or a list of strings + for multiple datasets, e.g. '["neg1.parquet", "neg2.parquet"]' + :param model: quality classifier name to apply. It's "my_quality_model" in + default. You can use one of ["gpt3", "chinese", "code"] we provided, or + you can set it to the path to your own model trained using the train.py + tool + :param tokenizer: what tokenizer to use to tokenize texts. It's None in + default, which means using the standard Tokenizer of PySpark. You can + use one of ["zh.sp.model", "code.sp.model"] we provided, or you can set + it to the path to your own sentencepiece model + :param text_key: the field key name to hold texts to be classified. It's + "text" in default + :return: + """ + # convert a single dataset to a dataset list + if positive_datasets is None: + positive_datasets = [] + if negative_datasets is None: + negative_datasets = [] + if isinstance(positive_datasets, str): + positive_datasets = [positive_datasets] + if isinstance(negative_datasets, str): + negative_datasets = [negative_datasets] + + # initialize a spark session + spark = init_spark() + + # load positive and negative datasets + pos = load_datasets(spark, + positive_datasets, + text_key=text_key, + label=1, + only_text=True) + neg = load_datasets(spark, + negative_datasets, + text_key=text_key, + label=0, + only_text=True) + + # merge positive and negative datasets + if pos is not None and neg is not None: + ds = pos.unionAll(neg) + elif pos is not None: + ds = pos + elif neg is not None: + ds = neg + else: + logger.error('Empty dataset.') + exit(0) + + # start evaluation + logger.info(f'Number of samples: {ds.count()}') + eval(model, ds, tokenizer) + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/tools/quality_classifier/predict.py b/tools/quality_classifier/predict.py new file mode 100644 index 000000000..80fda65fb --- /dev/null +++ b/tools/quality_classifier/predict.py @@ -0,0 +1,125 @@ +# This tool is used for predicting a document score for text samples using +# quality classifier models we provided, including: +# - gpt3: A GPT3 quality classifier reproduced from scratch by us based on +# PySpark. It's trained over CC as negative samples and Wikipedia-en, +# Books, OpenWebText as positive samples. +# - chinese: A quality classifier for Chinese. It's trained over Chinese +# texts sampled from CC as negative samples and Wudao, Wikipedia-zh as +# positive samples. +# - code: A quality classifier for codes. It's trained over code samples that +# have stars >= 1372 as positive samples and random samples from left +# data as negative samples. Stars count 1372 splits a nearly 700w subset +# with most stars. +# All these 3 classifiers are trained using the same training pipeline as GPT3 +# based on PySpark but with different tokenizers and keeping methods: +# - gpt3: standard Tokenizer from spark & GPT3 keeping method based on pareto +# - chinese: sentencepiece tokenizer for Chinese & label +# - code: sentencepiece tokenizer for code & label +# +# This tool needs several arguments: +# - dataset_path: the path to the dataset you want to predict doc_scores for. +# - result_path: the path to store the predicted result dataset. +# - model: quality classifier name to apply. It's "gpt3" in default. You can +# use one of ["gpt3", "chinese", "code"] we provided, or you can set it +# to the path to your own model trained using the train.py tool. +# - tokenizer: what tokenizer to use to tokenize texts. It's None in default, +# which means using the standard Tokenizer of PySpark. You can use one of +# ["zh.sp.model", "code.sp.model"] we provided, or you can set it to the +# path to your own sentencepiece model. +# - keep_method: the method to label should_keep field for each sample. It's +# "gpt3" in default. Should be one of ["gpt3", "label"]. +# - text_key: the field key name to hold texts to be classified. It's "text" +# in default. +# - overall_stats: whether to output an overall stats report on predicted +# document scores. It's False in default. +# +# Recommended arguments for provided trained models: +# - gpt3: +# - model: gpt3 +# - tokenizer: None +# - keep_method: gpt3 +# - chinese: +# - model: chinese +# - tokenizer: zh.sp.model +# - keep_method: label +# - code: +# - model: code +# - tokenizer: code.sp.model +# - keep_method: label +# +# Notice: +# 1. The configs of SparkSession in function init_spark can be modified to be +# more suitable for your own machine. See function init_spark in +# qc_utils.py. +# 2. Random factors are involved in "gpt3" model. So you might get different +# should_keep label in different running processes. But you should get +# same doc_score predictions in different running processes. + +import os + +import fire +from loguru import logger + +from qc_utils import (export_result, init_spark, load_dataset, predict, + prepare_model) + + +@logger.catch +def main(dataset_path, + result_path, + model='gpt3', + tokenizer=None, + keep_method='gpt3', + text_key='text', + overall_stats=False): + """ + Use specific quality classifier to predict document scores on your dataset + :param dataset_path: the path to the dataset you want to predict for + :param result_path: the path to store the predicted result dataset + :param model: quality classifier name to apply. It's "gpt3" in default. You + can use one of ["gpt3", "chinese", "code"] we provided, or you can set + it to the path to your own model trained using the train.py tool + :param tokenizer: what tokenizer to use to tokenize texts. It's None in + default, which means using the standard Tokenizer of PySpark. You can + use one of ["zh.sp.model", "code.sp.model"] we provided, or you can set + it to the path to your own sentencepiece model + :param keep_method: the method to label should_keep field for each sample. + It's "gpt3" in default. Should be one of ["gpt3", "label"] + :param text_key: the field key name to hold texts to be classified. It's + "text" in default + :param overall_stats: whether to output an overall stats report on + predicted document scores. It's False in default + :return: + """ + # set default tokenizers for default models + if model == 'chinese': + tokenizer = 'zh.sp.model' + keep_method = 'label' + if model == 'code': + tokenizer = 'code.sp.model' + keep_method = 'label' + if model == 'gpt3': + tokenizer = None + keep_method = 'gpt3' + + # initialize a spark session + spark = init_spark() + # load the quality classifier model + model = prepare_model(model_name=model) + # load dataset + ds = load_dataset(spark, dataset_path, text_key=text_key) + # start to predict + pred = predict(model, ds, tokenizer=tokenizer, keep_method=keep_method) + # export prediction result to specific path + export_result(pred, result_path) + + # generate overall statistics on doc scores + if overall_stats: + overall = pred.select('doc_score').toPandas().describe(include='all') + # export to result report file + overall.to_csv(os.path.join(result_path, 'overall.csv')) + overall.to_markdown(os.path.join(result_path, 'overall.md')) + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/tools/quality_classifier/qc_utils.py b/tools/quality_classifier/qc_utils.py new file mode 100644 index 000000000..f7baed5c6 --- /dev/null +++ b/tools/quality_classifier/qc_utils.py @@ -0,0 +1,305 @@ +import os +import zipfile + +import numpy as np +import sentencepiece as spm +import wget +from loguru import logger +from pyspark.ml import Pipeline, PipelineModel +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.feature import HashingTF, Tokenizer +from pyspark.sql import SparkSession +from pyspark.sql.functions import col, rand, udf +from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StringType + +from data_juicer.utils.cache_utils import DATA_JUICER_MODELS_CACHE +from data_juicer.utils.model_utils import (MODEL_LINKS, + prepare_sentencepiece_model) + + +def init_spark(): + """ + Initialize a spark session. You can set parameters such as memory, number + of partitions, timeout and so on here + :return: A spark session instance. + """ + spark = (SparkSession.builder.config('spark.driver.memory', '64g').config( + 'spark.executor.memory', + '64g').config('spark.sql.shuffle.partitions', '300').config( + 'spark.sql.execution.arrow.pyspark.enabled', + 'true').config('spark.executor.memoryOverhead', '20000').config( + 'spark.network.timeout', + '10000s').config('spark.executor.heartbeatInterval', + '3600s').getOrCreate()) + logger.info('Spark initialization done.') + return spark + + +def prepare_model(model_name, model_path=DATA_JUICER_MODELS_CACHE): + """ + Prepare the specific model from model cache path or the remote oss + :param model_name: name of the quality classifier model + :param model_path: the path to store the model to be loaded + :return: a loaded PipelineModel + """ + udm = False + if model_name not in ['gpt3', 'chinese', 'code']: + # use user-specific mdoel + real_model_path = model_name + udm = True + else: + # use prepared models we provided + model_name = '%s_quality_model' % model_name + real_model_path = os.path.join(model_path, model_name) + logger.info(f'Preparing scorer model in [{real_model_path}]...') + if os.path.exists(real_model_path) and os.path.isdir(real_model_path): + return PipelineModel.load(real_model_path) + if udm: + logger.error(f'Customized model [{real_model_path}] cannot be loaded.') + exit(0) + # No specific models in local file systems. Download them from remote. + os.makedirs(model_path, exist_ok=True) + wget.download(os.path.join(MODEL_LINKS, f'{model_name}.zip'), + os.path.join(model_path, f'{model_name}.zip')) + # extract the compressed model file into a model directory + with zipfile.ZipFile(os.path.join(model_path, f'{model_name}.zip')) as zp: + zp.extractall(os.path.join(model_path)) + return PipelineModel.load(real_model_path) + + +def load_dataset(spark, ds_path, text_key='text', only_text=False): + """ + Load a single dataset using PySpark. Only support 'json', 'jsonl', or + 'parquet' files for now + :param spark: spark session + :param ds_path: dataset path + :param text_key: the name of the column that stores the contents of texts + :param only_text: whether to load texts only and drop other columns. + :return: a data frame + """ + # load dataset using different methods according to the suffix + logger.info(f'Loading dataset from [{ds_path}]...') + if ds_path.endswith('.json') or ds_path.endswith('.jsonl'): + df = spark.read.json(ds_path) + elif ds_path.endswith('.parquet'): + df = spark.read.parquet(ds_path) + else: + raise NotImplementedError('Dataset type is not supported for now. ' + 'Suffix of dataset file should be one of ' + '[.json, .jsonl, .parquet]') + # rename the column that stores texts to "text" if necessary + if text_key != 'text': + df = df.withColumnRenamed(text_key, 'text') + # whether to keep "text" column only + if only_text: + return df.select('text') + else: + return df + + +def load_datasets(spark, + ds_paths, + text_key='text', + label=None, + only_text=True): + """ + Load a list of datasets. Only support 'json', 'jsonl', or 'parquet' files + for now + :param spark: spark session + :param ds_paths: a list of datasets to be loaded. + :param text_key: the name of the column that stores the contents of texts + :param label: the label set to these datasets. Used in training pipeline + :param only_text: whether to load texts only and drop other columns. + :return: a data frame + """ + if len(ds_paths) == 0: + logger.warning('No dataset path provided.') + return None + # load each dataset in order and union them all + base_ds = load_dataset(spark, ds_paths[0], text_key, only_text) + for i in range(1, len(ds_paths)): + base_ds = base_ds.unionAll( + load_dataset(spark, ds_paths[i], text_key, only_text)) + if label is not None: + # add labels for training pipeline + return base_ds.selectExpr('text', '%d as label' % label) + else: + return base_ds + + +def shuffle(df): + """ + Shuffle a data frame + :param df: input data frame + :return: shuffled data frame + """ + temp_df = df.withColumn('rand', rand(seed=42)) + df_rnd = temp_df.orderBy(temp_df.rand) + return df_rnd.drop(df_rnd.rand) + + +def export_result(ds, res_path): + """ + Export a dataset to specified path. Only support 'json', 'jsonl', or + 'parquet' export formats for now + :param ds: the dataset to be exported + :param res_path: the path to store the exported dataset + :return: + """ + logger.info(f'Exporting predicted result to [{res_path}]') + if res_path.endswith('.json') or res_path.endswith('.jsonl'): + ds.write.mode('overwrite').format('json').save(res_path) + elif res_path.endswith('.parquet'): + ds.write.mode('overwrite').format('parquet').save(res_path) + else: + ds.write.mode('overwrite').save(res_path) + + +def get_keep_method_udf(keep_method): + """ + Given the name of keep method, return a PySpark user-defined function of + this kind of keep method. Only support 'gpt3' or 'label' for now + :param keep_method: name of keep method + :return: a PySpark udf of specified keep method + """ + if keep_method == 'label': + return udf(lambda score: int(score > 0.5), IntegerType()) + elif keep_method == 'gpt3': + pareto = 9 + return udf(lambda score: int(score > 1 - np.random.pareto(pareto)), + IntegerType()) + else: + raise NotImplementedError(f'Keep method [{keep_method}] is not ' + f'implemented for now.') + + +def tokenize_dataset(ds, tokenizer): + """ + Tokenize the texts in input dataset using specified tokenizer + :param ds: dataset to be tokenized + :param tokenizer: tokenizer used to tokenize texts + :return: a dataset with an extra column "words" that stores the tokenized + texts + """ + if os.path.exists(tokenizer): + # if it's a local model + tkn = spm.SentencePieceProcessor() + tkn.load(tokenizer) + else: + # else, try to load it from our remote model list + tkn = prepare_sentencepiece_model(tokenizer, ()) + # create a PySpark udf to tokenize the dataset + tokenizer_udf = udf(lambda text: tkn.encode_as_pieces(text), + ArrayType(StringType())) + logger.info('Tokenize texts using specific tokenizer...') + return ds.withColumn('words', tokenizer_udf(col('text'))) + + +def train(output_model_path, ds, tokenizer=None): + """ + Train a quality classifier with training dataset and export the trained + model to a specified path + :param output_model_path: the path to store the trained model + :param ds: training dataset + :param tokenizer: specified sentencepiece tokenizer. It's None in default, + which means using the standard Tokenizer in PySpark + :return: + """ + logger.info('Preparing training quality classifier model...') + if tokenizer: + # tokenizer is not standard Tokenizer in PySpark, need to apply it + # explicitly + ds = tokenize_dataset(ds, tokenizer) + + # model + hashingTF = HashingTF(inputCol='words', outputCol='features') + lr = LogisticRegression() + if tokenizer is None: + # using standard Tokenizer in PySpark + std_tokenizer = Tokenizer(inputCol='text', outputCol='words') + pipeline = Pipeline(stages=[std_tokenizer, hashingTF, lr]) + else: + # using extra sentencepiece tokenizer, which will not included in the + # final PipelineModel + pipeline = Pipeline(stages=[hashingTF, lr]) + + logger.info('Start training...') + model = pipeline.fit(ds) + + logger.info('Trained model saving...') + model.write().overwrite().save(output_model_path) + + +def eval(model_path, ds, tokenizer=None): + """ + Evaluate a quality classifier model on specified dataset + :param model_path: the path to the model to be evaluated + :param ds: evaluation dataset + :param tokenizer: specified sentencepiece tokenizer. It's None in default, + which means using the standard Tokenizer in PySpark + :return: + """ + logger.info('Preparing to evaluate...') + if tokenizer: + # tokenizer is not standard Tokenizer in PySpark, need to apply it + # explicitly + ds = tokenize_dataset(ds, tokenizer) + + logger.info('Start evaluation...') + model = prepare_model(model_path) + pred = model.transform(ds) + # get positive and negative samples + P = pred.filter('label = 1') + N = pred.filter('label = 0') + # get TP, FP, TN, FN samples + TP = P.filter('prediction = 1').count() + 1 + FP = N.filter('prediction = 1').count() + 1 + TN = N.filter('prediction = 0').count() + 1 + FN = P.filter('prediction = 0').count() + 1 + # compute precision, recall and F1 metrics + precision = 1.0 * TP / (TP + FP) + recall = 1.0 * TP / P.count() + F1 = 2.0 * precision * recall / (precision + recall) + logger.info(f'TP: {TP}, FN: {FN}') + logger.info(f'FP: {FP}, TN: {TN}') + logger.info(f'P: {precision}, R: {recall}, F1: {F1}') + + +def predict(model, ds, tokenizer=None, keep_method='label'): + """ + Predict document scores for a dataset using a trained quality classifier + model + :param model: the model used to predict + :param ds: the dataset to be predicted + :param tokenizer: specified sentencepiece tokenizer. It's None in default, + which means using the standard Tokenizer in PySpark + :param keep_method: name of keep method to label the "should_keep" column + :return: + """ + logger.info('Start scoring dataset...') + if tokenizer: + # tokenizer is not standard Tokenizer in PySpark, need to apply it + # explicitly + ds = tokenize_dataset(ds, tokenizer) + + prediction = model.transform(ds) + + # A UDF to extract doc scores from probability vectors + def extract_prob(v): + try: + return float(v[1]) + except ValueError: + return None + + # extract the predicted probability as the doc_score + extract_prob_udf = udf(extract_prob, DoubleType()) + doc_score = prediction.withColumn('doc_score', + extract_prob_udf(col('probability'))) + + # A UDF to get the bool value indicating whether this sample should be kept + should_keep_label_udf = get_keep_method_udf(keep_method) + should_keep = doc_score.withColumn('should_keep', + should_keep_label_udf(col('doc_score'))) + # drop extra useless columns + return should_keep.drop('words', 'features', 'rawPrediction', + 'probability', 'prediction') diff --git a/tools/quality_classifier/train.py b/tools/quality_classifier/train.py new file mode 100644 index 000000000..3774a9539 --- /dev/null +++ b/tools/quality_classifier/train.py @@ -0,0 +1,117 @@ +# This tool is used for training a quality classifier for your own datasets +# based on PySpark. +# +# After training, this tool will generate a classifier model in a specific +# directory. You can use it to evaluate or predict on other datasets using eval +# and predict tools. +# +# This tool needs several arguments: +# - positive_datasets: the paths to the positive datasets. It could be a +# string for a single dataset, e.g. 'pos.parquet', or a list of strings +# for several datasets, e.g. '["pos1.parquet", "pos2.parquet"]'. +# - negative_datasets: the paths to the negative datasets. It could be a +# string for a single dataset, e.g. 'neg.parquet', or a list of strings +# for several datasets, e.g. '["neg1.parquet", "neg2.parquet"]'. +# - output_model_path: the path to store the trained quality classifier. It's +# "my_quality_model" in default. +# - num_training_samples: number of samples used to train the model. It's 0 +# in default, which means using all samples in datasets to train. +# - train_test_split_ratio: ratio to split train and test set. It's 0.8 in +# default. +# - tokenizer: what tokenizer to use to tokenize texts. It's None in default, +# which means using the standard Tokenizer of PySpark. You can use one of +# ["zh.sp.model", "code.sp.model"] we provided, or you can set it to the +# path to your own sentencepiece model. +# - evaluation: whether to evaluate the model after training using test set. +# It's True in default. +# - text_key: the field key name to hold texts to be classified. It's "text" +# in default. + +import fire +from loguru import logger + +from qc_utils import eval, init_spark, load_datasets, shuffle, train + + +@logger.catch +def main(positive_datasets, + negative_datasets, + output_model_path='my_quality_model', + num_training_samples=0, + train_test_split_ratio=0.8, + tokenizer=None, + evaluation=True, + text_key='text'): + """ + Train a quality classifier using your own pos/neg datasets + :param positive_datasets: the paths to the positive datasets. It could be a + string for a single dataset, e.g. 'pos.parquet', or a list of strings + for several datasets, e.g. '["pos1.parquet", "pos2.parquet"]' + :param negative_datasets: the paths to the negative datasets. It could be a + string for a single dataset, e.g. 'neg.parquet', or a list of strings + for several datasets, e.g. '["neg1.parquet", "neg2.parquet"]' + :param output_model_path: the path to store the trained quality classifier. + It's "my_quality_model" in default + :param num_training_samples: number of samples used to train the model. + It's 0 in default, which means using all samples in datasets to train + :param train_test_split_ratio: ratio to split train and test set. It's 0.8 + in default + :param tokenizer: what tokenizer to use to tokenize texts. It's None in + default, which means using the standard Tokenizer of PySpark. You can + use one of ["zh.sp.model", "code.sp.model"] we provided, or you can set + it to the path to your own sentencepiece model + :param evaluation: whether to evaluate the model after training using test + set. It's True in default + :param text_key: the field key name to hold texts to be classified. It's + "text" in default + :return: + """ + # convert a single dataset to a dataset list + if isinstance(positive_datasets, str): + positive_datasets = [positive_datasets] + if isinstance(negative_datasets, str): + negative_datasets = [negative_datasets] + + # initialize a spark session + spark = init_spark() + + # load positive and negative datasets + pos = load_datasets(spark, + positive_datasets, + text_key=text_key, + label=1, + only_text=True) + neg = load_datasets(spark, + negative_datasets, + text_key=text_key, + label=0, + only_text=True) + + if pos is None or neg is None: + logger.error('Empty dataset in positive/negative dataset list...') + exit(1) + + # sample a part of positive/negative samples to train + if num_training_samples > 0: + logger.info(f'Only use {num_training_samples} pairs samples to train.') + pos = shuffle(pos).limit(num_training_samples) + neg = shuffle(neg).limit(num_training_samples) + + # merge pos and neg samples + ds = pos.unionAll(neg) + # split the merged dataset into training and test set + train_set, test_set = ds.randomSplit( + [train_test_split_ratio, 1.0 - train_test_split_ratio], seed=42) + logger.info(f'Number of training samples: {train_set.count()}, ' + f'test samples: {test_set.count()}') + + # start the ML pipeline to train the classifier + train(output_model_path, train_set, tokenizer) + + # evaluate the trained model on test set + if evaluation: + eval(output_model_path, test_set, tokenizer) + + +if __name__ == '__main__': + fire.Fire(main)