Skip to content

Commit

Permalink
update lstm
Browse files Browse the repository at this point in the history
  • Loading branch information
sunlanchang committed Jun 20, 2020
1 parent 4d72f2a commit f6bf458
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
11 changes: 8 additions & 3 deletions LSTM_age_gender.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,17 @@ def get_embedding(feature_name, tokenizer):
wv = KeyedVectors.load(path, mmap='r')
feature_tokens = list(wv.vocab.keys())
embedding_dim = 128
embedding_matrix = np.random.randn(
len(feature_tokens)+1, embedding_dim)
# embedding_matrix = np.random.randn(
# len(feature_tokens)+1, embedding_dim, dtype=float)
embedding_matrix = np.zeros(
(len(feature_tokens)+1, embedding_dim), dtype=float)
# 得到索引
for word, i in tokenizer.word_index.items():
embedding_vector = wv[word]
if embedding_vector is not None:
embedding_matrix[i] = embedding_vector
else:
print(str(word)+' 没有找到')
return embedding_matrix

# 从序列文件提取array格式数据
Expand Down Expand Up @@ -303,7 +307,8 @@ def get_tail_concat_model(DATA, predict_age=True, predict_gender=False):
output_y)

model.compile(loss='categorical_crossentropy',
optimizer='adam', metrics=['accuracy'])
optimizer=optimizers.Adam(1e-3),
metrics=['accuracy'])

model.summary()
return model
Expand Down
23 changes: 11 additions & 12 deletions word2vec_creative_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,30 @@
import pickle
from mymail import mail
# %%
df_train = pd.read_csv(
df = pd.read_csv(
'data/click_log_ad.csv')
# df_test = pd.read_csv('data/test/clicklog_ad_user_test.csv')
columns = ['user_id', 'creative_id', 'time']
frame = [df_train[columns], df_test[columns]]
df_train_test = pd.concat(frame, ignore_index=True)
df_train_test_sorted = df_train_test.sort_values(
["user_id", "time"], ascending=(True, True))
columns = ['user_id', 'creative_id']
# frame = [df_train[columns], df_test[columns]]
# df_train_test = pd.concat(frame, ignore_index=True)
# df_train_test_sorted = df_train_test.sort_values(
# ["user_id", "time"], ascending=(True, True))
# %%
with open('word2vec/df_train_test_sorted.pkl', 'wb') as f:
pickle.dump(df_train_test_sorted, f)
# %%
with open('word2vec/df_train_test_sorted.pkl', 'rb') as f:
df_train_test_sorted = pickle.load(f)
# %%
userid_creative_ids = df_train_test_sorted.groupby(
'user_id')['creative_id'].apply(list).reset_index(name='creative_ids')
userid_creative_id = df.groupby(
'user_id')['creative_id'].apply(list).reset_index(name='creative_id')
# %%
with open('word2vec/userid_creative_ids.txt', 'w')as f:
for ids in userid_creative_ids.creative_ids:
with open('word2vec_new/creative_id.txt', 'w')as f:
for ids in userid_creative_id.creative_ids:
ids = [str(e) for e in ids]
line = ' '.join(ids)
f.write(line+'\n')
# %%
sentences = LineSentence('word2vec/userid_creative_ids.txt')
sentences = LineSentence('word2vec_new/creative_id.txt')
dimension_embedding = 128
model = Word2Vec(sentences, size=dimension_embedding,
window=10, min_count=1, workers=-1, iter=10, sg=1)
Expand Down

0 comments on commit f6bf458

Please sign in to comment.