forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathraw_data_processor.py
230 lines (210 loc) · 9.12 KB
/
raw_data_processor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library for processing crawled content and generating tfrecords."""
import collections
import json
import multiprocessing
import os
import urllib.parse
import tensorflow as tf
from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib
class RawDataProcessor(object):
"""Data converter for story examples."""
def __init__(self,
vocab: str,
do_lower_case: bool,
len_title: int = 15,
len_passage: int = 200,
max_num_articles: int = 5,
include_article_title_in_passage: bool = False,
include_text_snippet_in_example: bool = False):
"""Constructs a RawDataProcessor.
Args:
vocab: Filepath of the BERT vocabulary.
do_lower_case: Whether the vocabulary is uncased or not.
len_title: Maximum number of tokens in story headline.
len_passage: Maximum number of tokens in article passage.
max_num_articles: Maximum number of articles in a story.
include_article_title_in_passage: Whether to include article title in
article passage.
include_text_snippet_in_example: Whether to include text snippet (headline
and article content) in generated tensorflow Examples, for debug usage.
If include_article_title_in_passage=True, title and body will be
separated by [SEP].
"""
self.articles = dict()
self.tokenizer = tokenization.FullTokenizer(
vocab, do_lower_case=do_lower_case, split_on_punc=False)
self.len_title = len_title
self.len_passage = len_passage
self.max_num_articles = max_num_articles
self.include_article_title_in_passage = include_article_title_in_passage
self.include_text_snippet_in_example = include_text_snippet_in_example
# ex_index=5 deactivates printing inside convert_single_example.
self.ex_index = 5
# Parameters used in InputExample, not used in NHNet.
self.label = 0
self.guid = 0
self.num_generated_examples = 0
def read_crawled_articles(self, folder_path):
"""Reads crawled articles under folder_path."""
for path, _, files in os.walk(folder_path):
for name in files:
if not name.endswith(".json"):
continue
url, article = self._get_article_content_from_json(
os.path.join(path, name))
if not article.text_a:
continue
self.articles[RawDataProcessor.normalize_url(url)] = article
if len(self.articles) % 5000 == 0:
print("Number of articles loaded: %d\r" % len(self.articles), end="")
print()
return len(self.articles)
def generate_examples(self, input_file, output_files):
"""Loads story from input json file and exports examples in output_files."""
writers = []
story_partition = []
for output_file in output_files:
writers.append(tf.io.TFRecordWriter(output_file))
story_partition.append(list())
with tf.io.gfile.GFile(input_file, "r") as story_json_file:
stories = json.load(story_json_file)
writer_index = 0
for story in stories:
articles = []
for url in story["urls"]:
normalized_url = RawDataProcessor.normalize_url(url)
if normalized_url in self.articles:
articles.append(self.articles[normalized_url])
if not articles:
continue
story_partition[writer_index].append((story["label"], articles))
writer_index = (writer_index + 1) % len(writers)
lock = multiprocessing.Lock()
pool = multiprocessing.pool.ThreadPool(len(writers))
data = [(story_partition[i], writers[i], lock) for i in range(len(writers))]
pool.map(self._write_story_partition, data)
return len(stories), self.num_generated_examples
@classmethod
def normalize_url(cls, url):
"""Normalize url for better matching."""
url = urllib.parse.unquote(
urllib.parse.urlsplit(url)._replace(query=None).geturl())
output, part = [], None
for part in url.split("//"):
if part == "http:" or part == "https:":
continue
else:
output.append(part)
return "//".join(output)
def _get_article_content_from_json(self, file_path):
"""Returns (url, InputExample) keeping content extracted from file_path."""
with tf.io.gfile.GFile(file_path, "r") as article_json_file:
article = json.load(article_json_file)
if self.include_article_title_in_passage:
return article["url"], classifier_data_lib.InputExample(
guid=self.guid,
text_a=article["title"],
text_b=article["maintext"],
label=self.label)
else:
return article["url"], classifier_data_lib.InputExample(
guid=self.guid, text_a=article["maintext"], label=self.label)
def _write_story_partition(self, data):
"""Writes stories in a partition into file."""
for (story_headline, articles) in data[0]:
story_example = tf.train.Example(
features=tf.train.Features(
feature=self._get_single_story_features(story_headline,
articles)))
data[1].write(story_example.SerializeToString())
data[2].acquire()
try:
self.num_generated_examples += 1
if self.num_generated_examples % 1000 == 0:
print(
"Number of stories written: %d\r" % self.num_generated_examples,
end="")
finally:
data[2].release()
def _get_single_story_features(self, story_headline, articles):
"""Converts a list of articles to a tensorflow Example."""
def get_text_snippet(article):
if article.text_b:
return " [SEP] ".join([article.text_a, article.text_b])
else:
return article.text_a
story_features = collections.OrderedDict()
story_headline_feature = classifier_data_lib.convert_single_example(
ex_index=self.ex_index,
example=classifier_data_lib.InputExample(
guid=self.guid, text_a=story_headline, label=self.label),
label_list=[self.label],
max_seq_length=self.len_title,
tokenizer=self.tokenizer)
if self.include_text_snippet_in_example:
story_headline_feature.label_id = story_headline
self._add_feature_with_suffix(
feature=story_headline_feature,
suffix="a",
story_features=story_features)
for (article_index, article) in enumerate(articles):
if article_index == self.max_num_articles:
break
article_feature = classifier_data_lib.convert_single_example(
ex_index=self.ex_index,
example=article,
label_list=[self.label],
max_seq_length=self.len_passage,
tokenizer=self.tokenizer)
if self.include_text_snippet_in_example:
article_feature.label_id = get_text_snippet(article)
suffix = chr(ord("b") + article_index)
self._add_feature_with_suffix(
feature=article_feature, suffix=suffix, story_features=story_features)
# Adds empty features as placeholder.
for article_index in range(len(articles), self.max_num_articles):
suffix = chr(ord("b") + article_index)
empty_article = classifier_data_lib.InputExample(
guid=self.guid, text_a="", label=self.label)
empty_feature = classifier_data_lib.convert_single_example(
ex_index=self.ex_index,
example=empty_article,
label_list=[self.label],
max_seq_length=self.len_passage,
tokenizer=self.tokenizer)
if self.include_text_snippet_in_example:
empty_feature.label_id = ""
self._add_feature_with_suffix(
feature=empty_feature, suffix=suffix, story_features=story_features)
return story_features
def _add_feature_with_suffix(self, feature, suffix, story_features):
"""Appends suffix to feature names and fills in the corresponding values."""
def _create_int_feature(values):
return tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
def _create_string_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
story_features["input_ids_%c" % suffix] = _create_int_feature(
feature.input_ids)
story_features["input_mask_%c" % suffix] = _create_int_feature(
feature.input_mask)
story_features["segment_ids_%c" % suffix] = _create_int_feature(
feature.segment_ids)
if self.include_text_snippet_in_example:
story_features["text_snippet_%c" % suffix] = _create_string_feature(
bytes(feature.label_id.encode()))