Skip to content

Commit

Permalink
update LSTM
Browse files Browse the repository at this point in the history
  • Loading branch information
sunlanchang committed Jun 18, 2020
1 parent 2edb478 commit bda64b0
Showing 1 changed file with 142 additions and 103 deletions.
245 changes: 142 additions & 103 deletions LSTM_age_gender_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def get_train(feature_name, vocab_size, len_feature):
# 第一个输入
print('获取 creative_id 特征')
X1_train, tokenizer = get_train(
'creative_id', NUM_creative_id+1, LEN_creative_id) # +1为了UNK的creative_id
'creative_id', NUM_creative_id+1, +1, LEN_creative_id) # +1为了UNK的creative_id
creative_id_emb = get_embedding('creative_id', tokenizer)

DATA['X1_train'] = X1_train[:train_examples]
Expand Down Expand Up @@ -216,7 +216,7 @@ def get_tail_concat_model(DATA, predict_age=True, predict_gender=False):
input_length=LEN_creative_id,
mask_zero=True)(input_creative_id)
for _ in range(args.num_lstm):
x1 = Bidirectional(LSTM(128, return_sequences=True))(x1)
x1 = Bidirectional(LSTM(256, return_sequences=True))(x1)
x1 = layers.GlobalMaxPooling1D()(x1)

# second input
Expand All @@ -228,7 +228,7 @@ def get_tail_concat_model(DATA, predict_age=True, predict_gender=False):
input_length=LEN_ad_id,
mask_zero=True)(input_ad_id)
for _ in range(args.num_lstm):
x2 = Bidirectional(LSTM(128, return_sequences=True))(x2)
x2 = Bidirectional(LSTM(256, return_sequences=True))(x2)
x2 = layers.GlobalMaxPooling1D()(x2)

# third input
Expand All @@ -240,7 +240,7 @@ def get_tail_concat_model(DATA, predict_age=True, predict_gender=False):
input_length=LEN_product_id,
mask_zero=True)(input_product_id)
for _ in range(args.num_lstm):
x3 = Bidirectional(LSTM(128, return_sequences=True))(x3)
x3 = Bidirectional(LSTM(256, return_sequences=True))(x3)
x3 = layers.GlobalMaxPooling1D()(x3)

# third input
Expand All @@ -252,7 +252,7 @@ def get_tail_concat_model(DATA, predict_age=True, predict_gender=False):
input_length=LEN_advertiser_id,
mask_zero=True)(input_advertiser_id)
for _ in range(args.num_lstm):
x4 = Bidirectional(LSTM(128, return_sequences=True))(x4)
x4 = Bidirectional(LSTM(256, return_sequences=True))(x4)
x4 = layers.GlobalMaxPooling1D()(x4)

# third input
Expand All @@ -264,7 +264,7 @@ def get_tail_concat_model(DATA, predict_age=True, predict_gender=False):
input_length=LEN_industry,
mask_zero=True)(input_industry)
for _ in range(args.num_lstm):
x5 = Bidirectional(LSTM(128, return_sequences=True))(x5)
x5 = Bidirectional(LSTM(256, return_sequences=True))(x5)
x5 = layers.GlobalMaxPooling1D()(x5)

# third input
Expand All @@ -276,12 +276,12 @@ def get_tail_concat_model(DATA, predict_age=True, predict_gender=False):
input_length=LEN_product_category,
mask_zero=True)(input_product_category)
for _ in range(args.num_lstm):
x6 = Bidirectional(LSTM(128, return_sequences=True))(x6)
x6 = Bidirectional(LSTM(256, return_sequences=True))(x6)
x6 = layers.GlobalMaxPooling1D()(x6)

x = layers.Concatenate(axis=2)([x1, x2, x3, x4, x5, x6])
x = layers.Concatenate(axis=1)([x1, x2, x3, x4, x5, x6])
# x = layers.GlobalMaxPooling1D()(x)

if predict_age and predict_gender:
output_gender = Dense(2, activation='softmax', name='gender')(x)
output_age = Dense(10, activation='softmax', name='age')(x)
Expand Down Expand Up @@ -489,109 +489,148 @@ def scheduler(epoch):
checkpoint = ModelCheckpoint("tmp/lstm_tail_concat_epoch_{epoch:02d}.hdf5", monitor='val_loss', verbose=1,
save_best_only=False, mode='auto', period=1)
# %%
try:
train_examples = args.train_examples
val_examples = args.val_examples
mail('start train lstm')
if args.head_concat:
model.fit(
{
'creative_id': DATA['X1_train'][:train_examples],
'ad_id': DATA['X2_train'][:train_examples],
'product_id': DATA['X3_train'][:train_examples],
'advertiser_id': DATA['X4_train'][:train_examples],
'industry': DATA['X5_train'][:train_examples],
'product_category': DATA['X6_train'][:train_examples]
},
{
'gender': DATA['Y_gender_train'][:train_examples],
'age': DATA['Y_age_train'][:train_examples],
},
validation_data=(
{
'creative_id': DATA['X1_val'][:val_examples],
'ad_id': DATA['X2_val'][:val_examples],
'product_id': DATA['X3_val'][:val_examples],
'advertiser_id': DATA['X4_val'][:val_examples],
'industry': DATA['X5_val'][:val_examples],
'product_category': DATA['X6_val'][:val_examples]
},
{
'gender': DATA['Y_gender_val'][:val_examples],
'age': DATA['Y_age_val'][:val_examples],
},
),
epochs=args.epoch,
batch_size=args.batch_size,
callbacks=[checkpoint],
)
elif args.tail_concat:
model.fit(
{
'creative_id': DATA['X1_train'][:train_examples],
'ad_id': DATA['X2_train'][:train_examples],
'product_id': DATA['X3_train'][:train_examples],
'advertiser_id': DATA['X4_train'][:train_examples],
'industry': DATA['X5_train'][:train_examples],
'product_category': DATA['X6_train'][:train_examples]
},
{
'gender': DATA['Y_gender_train'][:train_examples],
'age': DATA['Y_age_train'][:train_examples],
},
validation_data=(
{
'creative_id': DATA['X1_val'][:val_examples],
'ad_id': DATA['X2_val'][:val_examples],
'product_id': DATA['X3_val'][:val_examples],
'advertiser_id': DATA['X4_val'][:val_examples],
'industry': DATA['X5_val'][:val_examples],
'product_category': DATA['X6_val'][:val_examples]
},
{
'gender': DATA['Y_gender_val'][:val_examples],
'age': DATA['Y_age_val'][:val_examples],
},
),
epochs=args.epoch,
batch_size=args.batch_size,
callbacks=[checkpoint],
)

mail('train lstm done!!!')
except Exception as e:
e = str(e)
mail('train lstm failed!!! ' + e)
# try:
# train_examples = args.train_examples
# val_examples = args.val_examples
# mail('start train lstm')
# if args.head_concat:
# model.fit(
# {
# 'creative_id': DATA['X1_train'][:train_examples],
# 'ad_id': DATA['X2_train'][:train_examples],
# 'product_id': DATA['X3_train'][:train_examples],
# 'advertiser_id': DATA['X4_train'][:train_examples],
# 'industry': DATA['X5_train'][:train_examples],
# 'product_category': DATA['X6_train'][:train_examples]
# },
# {
# 'gender': DATA['Y_gender_train'][:train_examples],
# 'age': DATA['Y_age_train'][:train_examples],
# },
# validation_data=(
# {
# 'creative_id': DATA['X1_val'][:val_examples],
# 'ad_id': DATA['X2_val'][:val_examples],
# 'product_id': DATA['X3_val'][:val_examples],
# 'advertiser_id': DATA['X4_val'][:val_examples],
# 'industry': DATA['X5_val'][:val_examples],
# 'product_category': DATA['X6_val'][:val_examples]
# },
# {
# 'gender': DATA['Y_gender_val'][:val_examples],
# 'age': DATA['Y_age_val'][:val_examples],
# },
# ),
# epochs=args.epoch,
# batch_size=args.batch_size,
# callbacks=[checkpoint],
# )
# elif args.tail_concat:
# model.fit(
# {
# 'creative_id': DATA['X1_train'][:train_examples],
# 'ad_id': DATA['X2_train'][:train_examples],
# 'product_id': DATA['X3_train'][:train_examples],
# 'advertiser_id': DATA['X4_train'][:train_examples],
# 'industry': DATA['X5_train'][:train_examples],
# 'product_category': DATA['X6_train'][:train_examples]
# },
# {
# 'gender': DATA['Y_gender_train'][:train_examples],
# 'age': DATA['Y_age_train'][:train_examples],
# },
# validation_data=(
# {
# 'creative_id': DATA['X1_val'][:val_examples],
# 'ad_id': DATA['X2_val'][:val_examples],
# 'product_id': DATA['X3_val'][:val_examples],
# 'advertiser_id': DATA['X4_val'][:val_examples],
# 'industry': DATA['X5_val'][:val_examples],
# 'product_category': DATA['X6_val'][:val_examples]
# },
# {
# 'gender': DATA['Y_gender_val'][:val_examples],
# 'age': DATA['Y_age_val'][:val_examples],
# },
# ),
# epochs=args.epoch,
# batch_size=args.batch_size,
# callbacks=[checkpoint],
# )

# mail('train lstm done!!!')
# except Exception as e:
# e = str(e)
# mail('train lstm failed!!! ' + e)

# %%
# 后续为预测过程,暂时注释掉不使用但是不要删除
# model.load_weights('tmp\gender_epoch_01.hdf5')

# model.load_weights('tmp/model_age_0.443.hdf5')
model.load_weights('tmp/model_gender_0.929.hdf5')
print('load model sucessful')

# # %%
# if debug:
# sequences = tokenizer.texts_to_sequences(
# creative_id_seq[900000:])
# else:
# sequences = tokenizer.texts_to_sequences(
# creative_id_seq[900000:])

# X_test = pad_sequences(sequences, maxlen=LEN_creative_id)
# # %%
# y_pred = model.predict(X_test, batch_size=4096)

# y_pred = np.where(y_pred > 0.5, 1, 0)
# y_pred = y_pred.flatten()

# # %%
# y_pred = y_pred+1
# # %%
# res = pd.DataFrame({'predicted_gender': y_pred})
# res.to_csv(
# 'data/ans/lstm_gender.csv', header=True, columns=['predicted_gender'], index=False)
def get_test(feature_name, vocab_size, len_feature):
print(f'processing {feature_name}')
f = open(f'word2vec/userid_{feature_name}s.txt')
tokenizer = Tokenizer(num_words=vocab_size)
tokenizer.fit_on_texts(f)
f.close()

feature_seq = []
with open(f'word2vec/userid_{feature_name}s.txt') as f:
for text in f:
feature_seq.append(text.strip())

sequences = tokenizer.texts_to_sequences(feature_seq[900000:])
X_test = pad_sequences(
sequences, maxlen=len_feature, padding='post')
return X_test


X1_test = get_test('creative_id', NUM_creative_id+1, 100)
X2_test = get_test('ad_id', NUM_ad_id+1, 100)
X3_test = get_test('product_id', NUM_product_id+1, 100)
X4_test = get_test('advertiser_id', NUM_advertiser_id+1, 100)
X5_test = get_test('industry', NUM_industry+1, 100)
X6_test = get_test('product_category', NUM_product_category+1, 100)


# # %%
# mail('predict lstm gender done')
y_pred = model.predict([X1_test,
X2_test,
X3_test,
X4_test,
X5_test,
X6_test, ], batch_size=1024)

y_pred = np.argmax(y_pred, axis=1)
y_pred = y_pred.flatten()
y_pred = y_pred+1
# %%
# %%
ans = pd.DataFrame({'predicted_gender': y_pred})
ans.to_csv(
'data/ans/lstm_gender.csv', header=True, columns=['predicted_gender'], index=False)
# %%
mail('predict lstm gender done')

# %%


def merge_gender_age_csv():
user_id_test = pd.read_csv(
'data/test/clicklog_ad_user_test.csv').sort_values(['user_id'], ascending=(True,)).user_id.unique()
ans = pd.DataFrame({'user_id': user_id_test})

gender = pd.read_csv('data/ans/lstm_gender.csv')
age = pd.read_csv('data/ans/lstm_age.csv')
ans['predicted_gender'] = gender.predicted_gender
ans['predicted_age'] = age.predicted_age
ans.to_csv('data/ans/LSTM.csv', header=True, index=False,
columns=['user_id', 'predicted_age', 'predicted_gender'])


merge_gender_age_csv()

0 comments on commit bda64b0

Please sign in to comment.