Skip to content

Commit

Permalink
update LSTM
Browse files Browse the repository at this point in the history
  • Loading branch information
sunlanchang committed Jun 15, 2020
1 parent 9c92d33 commit 01d8b2f
Show file tree
Hide file tree
Showing 4 changed files with 519 additions and 151 deletions.
90 changes: 29 additions & 61 deletions LSTM_age_multi_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# %%
# 统计creative_id序列的长度,只需要统计一次
# f = open('word2vec/userid_creative_ids.txt')
# LEN_creative_id = -1
# for line in f:
Expand All @@ -27,10 +28,10 @@
# %%
parser = argparse.ArgumentParser()
parser.add_argument('--load_from_npy', action='store_true',
help='从npy文件加载数据',
help='从npy文件加载训练数据,不用每次训练都重新生成array文件',
default=False)
parser.add_argument('--not_train_embedding', action='store_false',
help='从npy文件加载数据',
help='不训练embedding文件,一般来说加上这个参数效果不太好',
default=True)

parser.add_argument('--epoch', type=int,
Expand All @@ -40,12 +41,12 @@
help='batch size大小',
default=256)
parser.add_argument('--examples', type=int,
help='训练数据,默认为训练集,不包含验证集',
help='训练数据,默认为训练集,不包含验证集,调试时候可以设置1000',
default=810000)


parser.add_argument('--num_lstm', type=int,
help='LSTM head个数',
help='LSTM层数个数,目前结果3层比5层好用,1层还在做实验中...',
default=1)

args = parser.parse_args()
Expand All @@ -57,6 +58,22 @@

def get_train_val():

# 提取词向量文件
def get_embedding(feature_name):
path = f"word2vec/wordvectors_{feature_name}.kv"
wv = KeyedVectors.load(path, mmap='r')
feature_tokens = list(wv.vocab.keys())
embedding_dim = 128
embedding_matrix = np.random.randn(
len(feature_tokens)+1, embedding_dim)
for feature in feature_tokens:
embedding_vector = wv[feature]
if embedding_vector is not None:
index = tokenizer.texts_to_sequences([feature])[0][0]
embedding_matrix[index] = embedding_vector
return embedding_matrix

# 第一个输入
# 获取 creative_id 特征
# f = open('tmp/userid_creative_ids.txt')
f = open('word2vec/userid_creative_ids.txt')
Expand All @@ -72,23 +89,9 @@ def get_train_val():
X1_train = pad_sequences(
sequences, maxlen=LEN_creative_id, padding='post')

# 获取creative_id embedding
def get_creative_id_emb():
path = "word2vec/wordvectors_creative_id.kv"
wv = KeyedVectors.load(path, mmap='r')
creative_id_tokens = list(wv.vocab.keys())
embedding_dim = 128
embedding_matrix = np.random.randn(
len(creative_id_tokens)+1, embedding_dim)
for creative_id in creative_id_tokens:
embedding_vector = wv[creative_id]
if embedding_vector is not None:
index = tokenizer.texts_to_sequences([creative_id])[0][0]
embedding_matrix[index] = embedding_vector
return embedding_matrix

creative_id_emb = get_creative_id_emb()
creative_id_emb = get_embedding(feature_name='creative_id')

# 第二个输入
# 获取 ad_id 特征
f = open('word2vec/userid_ad_ids.txt')
tokenizer = Tokenizer(num_words=NUM_ad_id)
Expand All @@ -103,22 +106,9 @@ def get_creative_id_emb():
X2_train = pad_sequences(
sequences, maxlen=LEN_ad_id, padding='post')

def get_ad_id_emb():
path = "word2vec/wordvectors_ad_id.kv"
wv = KeyedVectors.load(path, mmap='r')
ad_id_tokens = list(wv.vocab.keys())
embedding_dim = 128
embedding_matrix = np.random.randn(
len(ad_id_tokens)+1, embedding_dim)
for ad_id in ad_id_tokens:
embedding_vector = wv[ad_id]
if embedding_vector is not None:
index = tokenizer.texts_to_sequences([ad_id])[0][0]
embedding_matrix[index] = embedding_vector
return embedding_matrix

ad_id_emb = get_ad_id_emb()
ad_id_emb = get_embedding(feature_name='ad_id')

# 第三个输入
# 获取 product_id 特征
# f = open('tmp/userid_product_ids.txt')
f = open('word2vec/userid_product_ids.txt')
Expand All @@ -134,24 +124,10 @@ def get_ad_id_emb():
X3_train = pad_sequences(
sequences, maxlen=LEN_product_id, padding='post')

# 获取product_id embedding
def get_product_id_emb():
path = "word2vec/wordvectors_product_id.kv"
wv = KeyedVectors.load(path, mmap='r')
product_id_tokens = list(wv.vocab.keys())
embedding_dim = 128
embedding_matrix = np.random.randn(
len(product_id_tokens)+1, embedding_dim)
for product_id in product_id_tokens:
embedding_vector = wv[product_id]
if embedding_vector is not None:
index = tokenizer.texts_to_sequences([product_id])[0][0]
embedding_matrix[index] = embedding_vector
return embedding_matrix

product_id_emb = get_product_id_emb()
product_id_emb = get_embedding(feature_name='product_id')

# 获得age标签
# 构造输出的训练标签
# 获得age、gender标签
user_train = pd.read_csv(
'data/train_preliminary/user.csv').sort_values(['user_id'], ascending=(True,))
Y_gender = user_train['gender'].values
Expand Down Expand Up @@ -268,15 +244,6 @@ def save_data(datas):
# %%
checkpoint = ModelCheckpoint("tmp/age_epoch_{epoch:02d}.hdf5", monitor='val_loss', verbose=1,
save_best_only=False, mode='auto', period=1)
# %%
# model.fit(
# {'creative_id': x1_train, 'ad_id': x2_train},
# y_train,
# validation_data=([x1_val, x2_val], y_val),
# epochs=5,
# batch_size=256,
# callbacks=[checkpoint],
# )

# %%
try:
Expand All @@ -298,6 +265,7 @@ def save_data(datas):


# %%
# 后续为预测过程,暂时注释掉不使用但是不要删除
# model.load_weights('tmp\gender_epoch_01.hdf5')


Expand Down
Loading

0 comments on commit 01d8b2f

Please sign in to comment.