diff --git a/README.md b/README.md
index c99e5f15..017eae52 100644
--- a/README.md
+++ b/README.md
@@ -39,30 +39,30 @@ python -m spacy download en
## fastNLP教程
-中文[文档](https://fastnlp.readthedocs.io/)、[教程](https://fastnlp.readthedocs.io/zh/latest/user/tutorials.html)
+中文[文档](http://www.fastnlp.top/docs/fastNLP/)、 [教程](http://www.fastnlp.top/docs/fastNLP/user/quickstart.html)
### 快速入门
-- [0. 快速入门](https://fastnlp.readthedocs.io/zh/latest/user/quickstart.html)
+- [Quick-1. 文本分类](http://www.fastnlp.top/docs/fastNLP/tutorials/%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB.html)
+- [Quick-2. 序列标注](http://www.fastnlp.top/docs/fastNLP/tutorials/%E5%BA%8F%E5%88%97%E6%A0%87%E6%B3%A8.html)
### 详细使用教程
-- [1. 使用DataSet预处理文本](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_1_data_preprocess.html)
-- [2. 使用Vocabulary转换文本与index](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_2_vocabulary.html)
-- [3. 使用Embedding模块将文本转成向量](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_3_embedding.html)
-- [4. 使用Loader和Pipe加载并处理数据集](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_4_load_dataset.html)
-- [5. 动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_5_loss_optimizer.html)
-- [6. 动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_6_datasetiter.html)
-- [7. 使用Metric快速评测你的模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_7_metrics.html)
-- [8. 使用Modules和Models快速搭建自定义模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_8_modules_models.html)
-- [9. 快速实现序列标注模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_9_seq_labeling.html)
-- [10. 使用Callback自定义你的训练过程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_10_callback.html)
+- [1. 使用DataSet预处理文本](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_1_data_preprocess.html)
+- [2. 使用Vocabulary转换文本与index](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_2_vocabulary.html)
+- [3. 使用Embedding模块将文本转成向量](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_3_embedding.html)
+- [4. 使用Loader和Pipe加载并处理数据集](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_4_load_dataset.html)
+- [5. 动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_5_loss_optimizer.html)
+- [6. 动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_6_datasetiter.html)
+- [7. 使用Metric快速评测你的模型](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_7_metrics.html)
+- [8. 使用Modules和Models快速搭建自定义模型](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_8_modules_models.html)
+- [9. 使用Callback自定义你的训练过程](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_9_callback.html)
### 扩展教程
-- [Extend-1. BertEmbedding的各种用法](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_1_bert_embedding.html)
-- [Extend-2. 分布式训练简介](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_2_dist.html)
-- [Extend-3. 使用fitlog 辅助 fastNLP 进行科研](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_3_fitlog.html)
+- [Extend-1. BertEmbedding的各种用法](http://www.fastnlp.top/docs/fastNLP/tutorials/extend_1_bert_embedding.html)
+- [Extend-2. 分布式训练简介](http://www.fastnlp.top/docs/fastNLP/tutorials/extend_2_dist.html)
+- [Extend-3. 使用fitlog 辅助 fastNLP 进行科研](http://www.fastnlp.top/docs/fastNLP/tutorials/extend_3_fitlog.html)
## 内置组件
diff --git a/docs/requirements.txt b/docs/requirements.txt
index c7d94486..cfa9c93a 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,8 +1,4 @@
-numpy>=1.14.2
-http://download.pytorch.org/whl/cpu/torch-0.4.1-cp36-cp36m-linux_x86_64.whl
-torchvision>=0.1.8
-sphinx-rtd-theme==0.4.1
-tensorboardX>=1.4
-tqdm>=4.28.1
-ipython>=6.4.0
-ipython-genutils>=0.2.0
\ No newline at end of file
+sphinx==3.2.1
+docutils==0.16
+sphinx-rtd-theme==0.5.0
+readthedocs-sphinx-search==0.1.0rc3
\ No newline at end of file
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 4db6dea6..ff77a6fc 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -4,6 +4,10 @@ fastNLP 中文文档
`fastNLP `_ 是一款轻量级的自然语言处理(NLP)工具包。你既可以用它来快速地完成一个NLP任务,
也可以用它在研究中快速构建更复杂的模型。
+.. hint::
+
+ 如果你是从 readthedocs 访问的该文档,请跳转到我们的 `最新网站 `_
+
fastNLP具有如下的特性:
- 统一的Tabular式数据容器,简化数据预处理过程;
@@ -41,7 +45,7 @@ API 文档
fitlog文档
----------
-您可以 `点此 `_ 查看fitlog的文档。
+您可以 `点此 `_ 查看fitlog的文档。
fitlog 是由我们团队开发的日志记录+代码管理的工具。
索引与搜索
diff --git a/docs/source/tutorials/extend_3_fitlog.rst b/docs/source/tutorials/extend_3_fitlog.rst
index 0fa24143..152e18fe 100644
--- a/docs/source/tutorials/extend_3_fitlog.rst
+++ b/docs/source/tutorials/extend_3_fitlog.rst
@@ -4,7 +4,7 @@
本文介绍结合使用 fastNLP 和 fitlog 进行科研的方法。
-首先,我们需要安装 `fitlog `_ 。你需要确认你的电脑中没有其它名为 `fitlog` 的命令。
+首先,我们需要安装 `fitlog `_ 。你需要确认你的电脑中没有其它名为 `fitlog` 的命令。
我们从命令行中进入到一个文件夹,现在我们要在文件夹中创建我们的 fastNLP 项目。你可以在命令行输入 `fitlog init test1` ,
然后你会看到如下提示::
@@ -15,7 +15,7 @@
Fitlog project test1 is initialized.
这表明你已经创建成功了项目文件夹,并且在项目文件夹中已经初始化了 Git。如果你不想初始化 Git,
-可以参考文档 `命令行工具 `_
+可以参考文档 `命令行工具 `_
现在我们进入你创建的项目文件夹 test1 中,可以看到有一个名为 logs 的文件夹,后面我们将会在里面存放你的实验记录。
同时也有一个名为 main.py 的文件,这是我们推荐你使用的训练入口文件。文件的内容如下::
@@ -37,7 +37,7 @@
fitlog.finish() # finish the logging
我们推荐你保留除注释外的四行代码,它们有助于你的实验,
-他们的具体用处参见文档 `用户 API `_
+他们的具体用处参见文档 `用户 API `_
我们假定你要进行前两个教程中的实验,并已经把数据复制到了项目根目录下的 tutorial_sample_dataset.csv 文件中。
现在我们编写如下的训练代码,使用 :class:`~fastNLP.core.callback.FitlogCallback` 进行实验记录保存::
diff --git "a/docs/source/tutorials/\346\226\207\346\234\254\345\210\206\347\261\273.rst" "b/docs/source/tutorials/\346\226\207\346\234\254\345\210\206\347\261\273.rst"
index 73686916..30f6cf4f 100644
--- "a/docs/source/tutorials/\346\226\207\346\234\254\345\210\206\347\261\273.rst"
+++ "b/docs/source/tutorials/\346\226\207\346\234\254\345\210\206\347\261\273.rst"
@@ -291,7 +291,7 @@ fastNLP提供了Trainer对象来组织训练过程,包括完成loss计算(所
PS: 使用Bert进行文本分类
-~~~~~~~~~~~~~~~~~~~~
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: python
@@ -368,7 +368,7 @@ PS: 使用Bert进行文本分类
PS: 基于词进行文本分类
-~~~~~~~~~~~~~~~~~~~~
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
由于汉字中没有显示的字与字的边界,一般需要通过分词器先将句子进行分词操作。
下面的例子演示了如何不基于fastNLP已有的数据读取、预处理代码进行文本分类。
diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py
index 9c9e505b..45a488d9 100644
--- a/fastNLP/core/dataset.py
+++ b/fastNLP/core/dataset.py
@@ -53,7 +53,7 @@
from fastNLP import DataSet
from fastNLP import Instance
instances = []
- winstances.append(Instance(sentence="This is the first instance",
+ instances.append(Instance(sentence="This is the first instance",
ords=['this', 'is', 'the', 'first', 'instance', '.'],
seq_len=6))
instances.append(Instance(sentence="Second instance .",
diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py
index 55ffd9cf..cb05f82d 100644
--- a/fastNLP/core/tester.py
+++ b/fastNLP/core/tester.py
@@ -148,7 +148,7 @@ def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=No
self._predict_func = self._model.predict
self._predict_func_wrapper = self._model.predict
else:
- if _model_contains_inner_module(model):
+ if _model_contains_inner_module(self._model):
self._predict_func_wrapper = self._model.forward
self._predict_func = self._model.module.forward
else:
diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py
index c7893f38..bbf3de1e 100644
--- a/fastNLP/io/file_utils.py
+++ b/fastNLP/io/file_utils.py
@@ -103,6 +103,11 @@
"yelp-review-polarity": "yelp_review_polarity.tar.gz",
"sst-2": "SST-2.zip",
"sst": "SST.zip",
+ 'mr': 'mr.zip',
+ "R8": "R8.zip",
+ "R52": "R52.zip",
+ "20ng": "20ng.zip",
+ "ohsumed": "ohsumed.zip",
# Classification, Chinese
"chn-senti-corp": "chn_senti_corp.zip",
diff --git a/fastNLP/io/pipe/__init__.py b/fastNLP/io/pipe/__init__.py
index 94784515..35965ca3 100644
--- a/fastNLP/io/pipe/__init__.py
+++ b/fastNLP/io/pipe/__init__.py
@@ -23,15 +23,15 @@
"ChnSentiCorpPipe",
"THUCNewsPipe",
"WeiboSenti100kPipe",
- "MRPipe", "R52Pipe", "R8Pipe", "OhsumedPipe", "NG20Loader",
-
+ "MRPipe", "R52Pipe", "R8Pipe", "OhsumedPipe", "NG20Pipe",
+
"Conll2003NERPipe",
"OntoNotesNERPipe",
"MsraNERPipe",
"WeiboNERPipe",
"PeopleDailyPipe",
"Conll2003Pipe",
-
+
"MatchingBertPipe",
"RTEBertPipe",
"SNLIBertPipe",
@@ -53,14 +53,20 @@
"RenamePipe",
"GranularizePipe",
"MachingTruncatePipe",
-
+
"CoReferencePipe",
- "CMRC2018BertPipe"
+ "CMRC2018BertPipe",
+
+ "R52PmiGraphPipe",
+ "R8PmiGraphPipe",
+ "OhsumedPmiGraphPipe",
+ "NG20PmiGraphPipe",
+ "MRPmiGraphPipe"
]
from .classification import CLSBasePipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \
- WeiboSenti100kPipe, AGsNewsPipe, DBPediaPipe, MRPipe, R8Pipe, R52Pipe, OhsumedPipe, NG20Loader
+ WeiboSenti100kPipe, AGsNewsPipe, DBPediaPipe, MRPipe, R8Pipe, R52Pipe, OhsumedPipe, NG20Pipe
from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe
from .conll import Conll2003Pipe
from .coreference import CoReferencePipe
@@ -70,3 +76,5 @@
LCQMCPipe, BQCorpusPipe, LCQMCBertPipe, RenamePipe, GranularizePipe, MachingTruncatePipe
from .pipe import Pipe
from .qa import CMRC2018BertPipe
+
+from .construct_graph import MRPmiGraphPipe, R8PmiGraphPipe, R52PmiGraphPipe, NG20PmiGraphPipe, OhsumedPmiGraphPipe
diff --git a/fastNLP/io/pipe/construct_graph.py b/fastNLP/io/pipe/construct_graph.py
new file mode 100644
index 00000000..d597da9d
--- /dev/null
+++ b/fastNLP/io/pipe/construct_graph.py
@@ -0,0 +1,268 @@
+
+__all__ =[
+ 'MRPmiGraphPipe',
+ 'R8PmiGraphPipe',
+ 'R52PmiGraphPipe',
+ 'OhsumedPmiGraphPipe',
+ 'NG20PmiGraphPipe'
+]
+try:
+ import networkx as nx
+ from sklearn.feature_extraction.text import CountVectorizer
+ from sklearn.feature_extraction.text import TfidfTransformer
+ from sklearn.pipeline import Pipeline
+except:
+ pass
+from collections import defaultdict
+import itertools
+import math
+from tqdm import tqdm
+import numpy as np
+
+from ..data_bundle import DataBundle
+from ...core.const import Const
+from ..loader.classification import MRLoader, OhsumedLoader, R52Loader, R8Loader, NG20Loader
+
+
+def _get_windows(content_lst: list, window_size:int):
+ r"""
+ 滑动窗口处理文本,获取词频和共现词语的词频
+ :param content_lst:
+ :param window_size:
+ :return: 词频,共现词频,窗口化后文本段的数量
+ """
+ word_window_freq = defaultdict(int) # w(i) 单词在窗口单位内出现的次数
+ word_pair_count = defaultdict(int) # w(i, j)
+ windows_len = 0
+ for words in tqdm(content_lst, desc="Split by window"):
+ windows = list()
+
+ if isinstance(words, str):
+ words = words.split()
+ length = len(words)
+
+ if length <= window_size:
+ windows.append(words)
+ else:
+ for j in range(length - window_size + 1):
+ window = words[j: j + window_size]
+ windows.append(list(set(window)))
+
+ for window in windows:
+ for word in window:
+ word_window_freq[word] += 1
+
+ for word_pair in itertools.combinations(window, 2):
+ word_pair_count[word_pair] += 1
+
+ windows_len += len(windows)
+ return word_window_freq, word_pair_count, windows_len
+
+def _cal_pmi(W_ij, W, word_freq_i, word_freq_j):
+ r"""
+ params: w_ij:为词语i,j的共现词频
+ w:文本数量
+ word_freq_i: 词语i的词频
+ word_freq_j: 词语j的词频
+ return: 词语i,j的tfidf值
+ """
+ p_i = word_freq_i / W
+ p_j = word_freq_j / W
+ p_i_j = W_ij / W
+ pmi = math.log(p_i_j / (p_i * p_j))
+
+ return pmi
+
+def _count_pmi(windows_len, word_pair_count, word_window_freq, threshold):
+ r"""
+ params: windows_len: 文本段数量
+ word_pair_count: 词共现频率字典
+ word_window_freq: 词频率字典
+ threshold: 阈值
+ return 词语pmi的list列表,其中元素为[word1, word2, pmi]
+ """
+ word_pmi_lst = list()
+ for word_pair, W_i_j in tqdm(word_pair_count.items(), desc="Calculate pmi between words"):
+ word_freq_1 = word_window_freq[word_pair[0]]
+ word_freq_2 = word_window_freq[word_pair[1]]
+
+ pmi = _cal_pmi(W_i_j, windows_len, word_freq_1, word_freq_2)
+ if pmi <= threshold:
+ continue
+ word_pmi_lst.append([word_pair[0], word_pair[1], pmi])
+ return word_pmi_lst
+
+class GraphBuilderBase:
+ def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
+ self.graph = nx.Graph()
+ self.word2id = dict()
+ self.graph_type = graph_type
+ self.window_size = widow_size
+ self.doc_node_num = 0
+ self.tr_doc_index = None
+ self.te_doc_index = None
+ self.dev_doc_index = None
+ self.doc = None
+ self.threshold = threshold
+
+ def _get_doc_edge(self, data_bundle: DataBundle):
+ r'''
+ 对输入的DataBundle进行处理,然后生成文档-单词的tfidf值
+ :param: data_bundle中的文本若为英文,形式为[ 'This is the first document.'],若为中文则为['他 喜欢 吃 苹果']
+ : return 返回带有具有tfidf边文档-单词稀疏矩阵
+ '''
+ tr_doc = list(data_bundle.get_dataset("train").get_field(Const.RAW_WORD))
+ val_doc = list(data_bundle.get_dataset("dev").get_field(Const.RAW_WORD))
+ te_doc = list(data_bundle.get_dataset("test").get_field(Const.RAW_WORD))
+ doc = tr_doc + val_doc + te_doc
+ self.doc = doc
+ self.tr_doc_index = [ind for ind in range(len(tr_doc))]
+ self.dev_doc_index = [ind+len(tr_doc) for ind in range(len(val_doc))]
+ self.te_doc_index = [ind+len(tr_doc)+len(val_doc) for ind in range(len(te_doc))]
+ text_tfidf = Pipeline([('count', CountVectorizer(token_pattern=r'\S+', min_df=1, max_df=1.0)),
+ ('tfidf', TfidfTransformer(norm=None, use_idf=True, smooth_idf=False, sublinear_tf=False))])
+
+ tfidf_vec = text_tfidf.fit_transform(doc)
+ self.doc_node_num = tfidf_vec.shape[0]
+ vocab_lst = text_tfidf['count'].get_feature_names()
+ for ind, word in enumerate(vocab_lst):
+ self.word2id[word] = ind
+ for ind, row in enumerate(tfidf_vec):
+ for col_index, value in zip(row.indices, row.data):
+ self.graph.add_edge(ind, self.doc_node_num+col_index, weight=value)
+ return nx.to_scipy_sparse_matrix(self.graph)
+
+ def _get_word_edge(self):
+ word_window_freq, word_pair_count, windows_len = _get_windows(self.doc, self.window_size)
+ pmi_edge_lst = _count_pmi(windows_len, word_pair_count, word_window_freq, self.threshold)
+ for edge_item in pmi_edge_lst:
+ word_indx1 = self.doc_node_num + self.word2id[edge_item[0]]
+ word_indx2 = self.doc_node_num + self.word2id[edge_item[1]]
+ if word_indx1 == word_indx2:
+ continue
+ self.graph.add_edge(word_indx1, word_indx2, weight=edge_item[2])
+
+ def build_graph(self, data_bundle: DataBundle):
+ r"""
+ 对输入的DataBundle进行处理,然后返回该scipy_sparse_matrix类型的邻接矩阵。
+
+ :param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象
+ :return:
+ """
+ raise NotImplementedError
+
+ def build_graph_from_file(self, path: str):
+ r"""
+ 传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()`
+
+ :param paths:
+ :return: scipy_sparse_matrix
+ """
+ raise NotImplementedError
+
+
+class MRPmiGraphPipe(GraphBuilderBase):
+
+ def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
+ super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)
+
+ def build_graph(self, data_bundle: DataBundle):
+ r'''
+ params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
+ return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
+ '''
+ self._get_doc_edge(data_bundle)
+ self._get_word_edge()
+ return nx.to_scipy_sparse_matrix(self.graph,
+ nodelist=list(range(self.graph.number_of_nodes())),
+ weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index)
+
+ def build_graph_from_file(self, path: str):
+ data_bundle = MRLoader().load(path)
+ return self.build_graph(data_bundle)
+
+class R8PmiGraphPipe(GraphBuilderBase):
+
+ def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
+ super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)
+
+ def build_graph(self, data_bundle: DataBundle):
+ r'''
+ params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
+ return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
+ '''
+ self._get_doc_edge(data_bundle)
+ self._get_word_edge()
+ return nx.to_scipy_sparse_matrix(self.graph,
+ nodelist=list(range(self.graph.number_of_nodes())),
+ weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index)
+
+ def build_graph_from_file(self, path: str):
+ data_bundle = R8Loader().load(path)
+ return self.build_graph(data_bundle)
+
+class R52PmiGraphPipe(GraphBuilderBase):
+
+ def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
+ super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)
+
+ def build_graph(self, data_bundle: DataBundle):
+ r'''
+ params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
+ return 返回csr类型的稀疏矩阵;训练集,验证集,测试集,在图中的index.
+ '''
+ self._get_doc_edge(data_bundle)
+ self._get_word_edge()
+ return nx.to_scipy_sparse_matrix(self.graph,
+ nodelist=list(range(self.graph.number_of_nodes())),
+ weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index)
+
+ def build_graph_from_file(self, path: str):
+ data_bundle = R52Loader().load(path)
+ return self.build_graph(data_bundle)
+
+class OhsumedPmiGraphPipe(GraphBuilderBase):
+
+ def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
+ super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)
+
+ def build_graph(self, data_bundle: DataBundle):
+ r'''
+ params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
+ return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
+ '''
+ self._get_doc_edge(data_bundle)
+ self._get_word_edge()
+ return nx.to_scipy_sparse_matrix(self.graph,
+ nodelist=list(range(self.graph.number_of_nodes())),
+ weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index)
+
+ def build_graph_from_file(self, path: str):
+ data_bundle = OhsumedLoader().load(path)
+ return self.build_graph(data_bundle)
+
+
+class NG20PmiGraphPipe(GraphBuilderBase):
+
+ def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
+ super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)
+
+ def build_graph(self, data_bundle: DataBundle):
+ r'''
+ params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
+ return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
+ '''
+ self._get_doc_edge(data_bundle)
+ self._get_word_edge()
+ return nx.to_scipy_sparse_matrix(self.graph,
+ nodelist=list(range(self.graph.number_of_nodes())),
+ weight='weight', dtype=np.float32, format='csr'), (
+ self.tr_doc_index, self.dev_doc_index, self.te_doc_index)
+
+ def build_graph_from_file(self, path: str):
+ r'''
+ param: path->数据集的路径.
+ return: 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
+ '''
+ data_bundle = NG20Loader().load(path)
+ return self.build_graph(data_bundle)
diff --git a/fastNLP/models/biaffine_parser.py b/fastNLP/models/biaffine_parser.py
index dff4809c..cd874e7c 100644
--- a/fastNLP/models/biaffine_parser.py
+++ b/fastNLP/models/biaffine_parser.py
@@ -376,7 +376,7 @@ def forward(self, words1, words2, seq_len, target1=None):
if self.encoder_name.endswith('lstm'):
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True)
x = x[sort_idx]
- x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True)
+ x = nn.utils.rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=True)
feat, _ = self.encoder(x) # -> [N,L,C]
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True)
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
diff --git a/fastNLP/modules/encoder/_elmo.py b/fastNLP/modules/encoder/_elmo.py
index 13843f83..7a2cf4bc 100644
--- a/fastNLP/modules/encoder/_elmo.py
+++ b/fastNLP/modules/encoder/_elmo.py
@@ -251,7 +251,7 @@ def __init__(self, config):
def forward(self, inputs, seq_len):
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True)
inputs = inputs[sort_idx]
- inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=self.batch_first)
+ inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens.cpu(), batch_first=self.batch_first)
output, hx = self.encoder(inputs, None) # -> [N,L,C]
output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=self.batch_first)
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
@@ -316,7 +316,7 @@ def forward(self, inputs, seq_len):
max_len = inputs.size(1)
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True)
inputs = inputs[sort_idx]
- inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=True)
+ inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens.cpu(), batch_first=True)
output, _ = self._lstm_forward(inputs, None)
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
output = output[:, unsort_idx]
diff --git a/readthedocs.yml b/readthedocs.yml
index e6d5bafd..5ff803a0 100644
--- a/readthedocs.yml
+++ b/readthedocs.yml
@@ -7,10 +7,11 @@ build:
image: latest
python:
- version: 3.6
+ version: 3.8
install:
+ - requirements: docs/requirements.txt
- method: setuptools
path: .
formats:
- - htmlzip
\ No newline at end of file
+ - htmlzip
diff --git a/setup.py b/setup.py
index d4a71c33..a3f47009 100644
--- a/setup.py
+++ b/setup.py
@@ -23,7 +23,7 @@
long_description_content_type='text/markdown',
license='Apache License',
author='Fudan FastNLP Team',
- python_requires='>=3.6',
+ python_requires='>=3.7',
packages=pkgs,
install_requires=reqs.strip().split('\n'),
)