From bda64b084a25c14fca08f5cebde4c00745795131 Mon Sep 17 00:00:00 2001 From: sunlanchang Date: Fri, 19 Jun 2020 00:32:05 +0800 Subject: [PATCH] update LSTM --- LSTM_age_gender_old.py | 245 ++++++++++++++++++++++++----------------- 1 file changed, 142 insertions(+), 103 deletions(-) diff --git a/LSTM_age_gender_old.py b/LSTM_age_gender_old.py index 81b0d68..14864ab 100644 --- a/LSTM_age_gender_old.py +++ b/LSTM_age_gender_old.py @@ -130,7 +130,7 @@ def get_train(feature_name, vocab_size, len_feature): # 第一个输入 print('获取 creative_id 特征') X1_train, tokenizer = get_train( - 'creative_id', NUM_creative_id+1, LEN_creative_id) # +1为了UNK的creative_id + 'creative_id', NUM_creative_id+1, +1, LEN_creative_id) # +1为了UNK的creative_id creative_id_emb = get_embedding('creative_id', tokenizer) DATA['X1_train'] = X1_train[:train_examples] @@ -216,7 +216,7 @@ def get_tail_concat_model(DATA, predict_age=True, predict_gender=False): input_length=LEN_creative_id, mask_zero=True)(input_creative_id) for _ in range(args.num_lstm): - x1 = Bidirectional(LSTM(128, return_sequences=True))(x1) + x1 = Bidirectional(LSTM(256, return_sequences=True))(x1) x1 = layers.GlobalMaxPooling1D()(x1) # second input @@ -228,7 +228,7 @@ def get_tail_concat_model(DATA, predict_age=True, predict_gender=False): input_length=LEN_ad_id, mask_zero=True)(input_ad_id) for _ in range(args.num_lstm): - x2 = Bidirectional(LSTM(128, return_sequences=True))(x2) + x2 = Bidirectional(LSTM(256, return_sequences=True))(x2) x2 = layers.GlobalMaxPooling1D()(x2) # third input @@ -240,7 +240,7 @@ def get_tail_concat_model(DATA, predict_age=True, predict_gender=False): input_length=LEN_product_id, mask_zero=True)(input_product_id) for _ in range(args.num_lstm): - x3 = Bidirectional(LSTM(128, return_sequences=True))(x3) + x3 = Bidirectional(LSTM(256, return_sequences=True))(x3) x3 = layers.GlobalMaxPooling1D()(x3) # third input @@ -252,7 +252,7 @@ def get_tail_concat_model(DATA, predict_age=True, predict_gender=False): input_length=LEN_advertiser_id, mask_zero=True)(input_advertiser_id) for _ in range(args.num_lstm): - x4 = Bidirectional(LSTM(128, return_sequences=True))(x4) + x4 = Bidirectional(LSTM(256, return_sequences=True))(x4) x4 = layers.GlobalMaxPooling1D()(x4) # third input @@ -264,7 +264,7 @@ def get_tail_concat_model(DATA, predict_age=True, predict_gender=False): input_length=LEN_industry, mask_zero=True)(input_industry) for _ in range(args.num_lstm): - x5 = Bidirectional(LSTM(128, return_sequences=True))(x5) + x5 = Bidirectional(LSTM(256, return_sequences=True))(x5) x5 = layers.GlobalMaxPooling1D()(x5) # third input @@ -276,12 +276,12 @@ def get_tail_concat_model(DATA, predict_age=True, predict_gender=False): input_length=LEN_product_category, mask_zero=True)(input_product_category) for _ in range(args.num_lstm): - x6 = Bidirectional(LSTM(128, return_sequences=True))(x6) + x6 = Bidirectional(LSTM(256, return_sequences=True))(x6) x6 = layers.GlobalMaxPooling1D()(x6) - x = layers.Concatenate(axis=2)([x1, x2, x3, x4, x5, x6]) + x = layers.Concatenate(axis=1)([x1, x2, x3, x4, x5, x6]) # x = layers.GlobalMaxPooling1D()(x) - + if predict_age and predict_gender: output_gender = Dense(2, activation='softmax', name='gender')(x) output_age = Dense(10, activation='softmax', name='age')(x) @@ -489,109 +489,148 @@ def scheduler(epoch): checkpoint = ModelCheckpoint("tmp/lstm_tail_concat_epoch_{epoch:02d}.hdf5", monitor='val_loss', verbose=1, save_best_only=False, mode='auto', period=1) # %% -try: - train_examples = args.train_examples - val_examples = args.val_examples - mail('start train lstm') - if args.head_concat: - model.fit( - { - 'creative_id': DATA['X1_train'][:train_examples], - 'ad_id': DATA['X2_train'][:train_examples], - 'product_id': DATA['X3_train'][:train_examples], - 'advertiser_id': DATA['X4_train'][:train_examples], - 'industry': DATA['X5_train'][:train_examples], - 'product_category': DATA['X6_train'][:train_examples] - }, - { - 'gender': DATA['Y_gender_train'][:train_examples], - 'age': DATA['Y_age_train'][:train_examples], - }, - validation_data=( - { - 'creative_id': DATA['X1_val'][:val_examples], - 'ad_id': DATA['X2_val'][:val_examples], - 'product_id': DATA['X3_val'][:val_examples], - 'advertiser_id': DATA['X4_val'][:val_examples], - 'industry': DATA['X5_val'][:val_examples], - 'product_category': DATA['X6_val'][:val_examples] - }, - { - 'gender': DATA['Y_gender_val'][:val_examples], - 'age': DATA['Y_age_val'][:val_examples], - }, - ), - epochs=args.epoch, - batch_size=args.batch_size, - callbacks=[checkpoint], - ) - elif args.tail_concat: - model.fit( - { - 'creative_id': DATA['X1_train'][:train_examples], - 'ad_id': DATA['X2_train'][:train_examples], - 'product_id': DATA['X3_train'][:train_examples], - 'advertiser_id': DATA['X4_train'][:train_examples], - 'industry': DATA['X5_train'][:train_examples], - 'product_category': DATA['X6_train'][:train_examples] - }, - { - 'gender': DATA['Y_gender_train'][:train_examples], - 'age': DATA['Y_age_train'][:train_examples], - }, - validation_data=( - { - 'creative_id': DATA['X1_val'][:val_examples], - 'ad_id': DATA['X2_val'][:val_examples], - 'product_id': DATA['X3_val'][:val_examples], - 'advertiser_id': DATA['X4_val'][:val_examples], - 'industry': DATA['X5_val'][:val_examples], - 'product_category': DATA['X6_val'][:val_examples] - }, - { - 'gender': DATA['Y_gender_val'][:val_examples], - 'age': DATA['Y_age_val'][:val_examples], - }, - ), - epochs=args.epoch, - batch_size=args.batch_size, - callbacks=[checkpoint], - ) - - mail('train lstm done!!!') -except Exception as e: - e = str(e) - mail('train lstm failed!!! ' + e) +# try: +# train_examples = args.train_examples +# val_examples = args.val_examples +# mail('start train lstm') +# if args.head_concat: +# model.fit( +# { +# 'creative_id': DATA['X1_train'][:train_examples], +# 'ad_id': DATA['X2_train'][:train_examples], +# 'product_id': DATA['X3_train'][:train_examples], +# 'advertiser_id': DATA['X4_train'][:train_examples], +# 'industry': DATA['X5_train'][:train_examples], +# 'product_category': DATA['X6_train'][:train_examples] +# }, +# { +# 'gender': DATA['Y_gender_train'][:train_examples], +# 'age': DATA['Y_age_train'][:train_examples], +# }, +# validation_data=( +# { +# 'creative_id': DATA['X1_val'][:val_examples], +# 'ad_id': DATA['X2_val'][:val_examples], +# 'product_id': DATA['X3_val'][:val_examples], +# 'advertiser_id': DATA['X4_val'][:val_examples], +# 'industry': DATA['X5_val'][:val_examples], +# 'product_category': DATA['X6_val'][:val_examples] +# }, +# { +# 'gender': DATA['Y_gender_val'][:val_examples], +# 'age': DATA['Y_age_val'][:val_examples], +# }, +# ), +# epochs=args.epoch, +# batch_size=args.batch_size, +# callbacks=[checkpoint], +# ) +# elif args.tail_concat: +# model.fit( +# { +# 'creative_id': DATA['X1_train'][:train_examples], +# 'ad_id': DATA['X2_train'][:train_examples], +# 'product_id': DATA['X3_train'][:train_examples], +# 'advertiser_id': DATA['X4_train'][:train_examples], +# 'industry': DATA['X5_train'][:train_examples], +# 'product_category': DATA['X6_train'][:train_examples] +# }, +# { +# 'gender': DATA['Y_gender_train'][:train_examples], +# 'age': DATA['Y_age_train'][:train_examples], +# }, +# validation_data=( +# { +# 'creative_id': DATA['X1_val'][:val_examples], +# 'ad_id': DATA['X2_val'][:val_examples], +# 'product_id': DATA['X3_val'][:val_examples], +# 'advertiser_id': DATA['X4_val'][:val_examples], +# 'industry': DATA['X5_val'][:val_examples], +# 'product_category': DATA['X6_val'][:val_examples] +# }, +# { +# 'gender': DATA['Y_gender_val'][:val_examples], +# 'age': DATA['Y_age_val'][:val_examples], +# }, +# ), +# epochs=args.epoch, +# batch_size=args.batch_size, +# callbacks=[checkpoint], +# ) + +# mail('train lstm done!!!') +# except Exception as e: +# e = str(e) +# mail('train lstm failed!!! ' + e) # %% # 后续为预测过程,暂时注释掉不使用但是不要删除 -# model.load_weights('tmp\gender_epoch_01.hdf5') - +# model.load_weights('tmp/model_age_0.443.hdf5') +model.load_weights('tmp/model_gender_0.929.hdf5') +print('load model sucessful') # # %% -# if debug: -# sequences = tokenizer.texts_to_sequences( -# creative_id_seq[900000:]) -# else: -# sequences = tokenizer.texts_to_sequences( -# creative_id_seq[900000:]) - -# X_test = pad_sequences(sequences, maxlen=LEN_creative_id) -# # %% -# y_pred = model.predict(X_test, batch_size=4096) -# y_pred = np.where(y_pred > 0.5, 1, 0) -# y_pred = y_pred.flatten() -# # %% -# y_pred = y_pred+1 -# # %% -# res = pd.DataFrame({'predicted_gender': y_pred}) -# res.to_csv( -# 'data/ans/lstm_gender.csv', header=True, columns=['predicted_gender'], index=False) +def get_test(feature_name, vocab_size, len_feature): + print(f'processing {feature_name}') + f = open(f'word2vec/userid_{feature_name}s.txt') + tokenizer = Tokenizer(num_words=vocab_size) + tokenizer.fit_on_texts(f) + f.close() + + feature_seq = [] + with open(f'word2vec/userid_{feature_name}s.txt') as f: + for text in f: + feature_seq.append(text.strip()) + + sequences = tokenizer.texts_to_sequences(feature_seq[900000:]) + X_test = pad_sequences( + sequences, maxlen=len_feature, padding='post') + return X_test + + +X1_test = get_test('creative_id', NUM_creative_id+1, 100) +X2_test = get_test('ad_id', NUM_ad_id+1, 100) +X3_test = get_test('product_id', NUM_product_id+1, 100) +X4_test = get_test('advertiser_id', NUM_advertiser_id+1, 100) +X5_test = get_test('industry', NUM_industry+1, 100) +X6_test = get_test('product_category', NUM_product_category+1, 100) # # %% -# mail('predict lstm gender done') +y_pred = model.predict([X1_test, + X2_test, + X3_test, + X4_test, + X5_test, + X6_test, ], batch_size=1024) + +y_pred = np.argmax(y_pred, axis=1) +y_pred = y_pred.flatten() +y_pred = y_pred+1 +# %% +# %% +ans = pd.DataFrame({'predicted_gender': y_pred}) +ans.to_csv( + 'data/ans/lstm_gender.csv', header=True, columns=['predicted_gender'], index=False) +# %% +mail('predict lstm gender done') # %% + + +def merge_gender_age_csv(): + user_id_test = pd.read_csv( + 'data/test/clicklog_ad_user_test.csv').sort_values(['user_id'], ascending=(True,)).user_id.unique() + ans = pd.DataFrame({'user_id': user_id_test}) + + gender = pd.read_csv('data/ans/lstm_gender.csv') + age = pd.read_csv('data/ans/lstm_age.csv') + ans['predicted_gender'] = gender.predicted_gender + ans['predicted_age'] = age.predicted_age + ans.to_csv('data/ans/LSTM.csv', header=True, index=False, + columns=['user_id', 'predicted_age', 'predicted_gender']) + + +merge_gender_age_csv()