Skip to content

Commit

Permalink
update predict transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
sunlanchang committed Jun 21, 2020
1 parent a5a8c3f commit 09a4a70
Showing 1 changed file with 2 additions and 13 deletions.
15 changes: 2 additions & 13 deletions Transformer_keras_6_input_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,10 +447,6 @@ def save_npy(datas, name):
y_pred_gender = y_pred_gender.flatten()
y_pred_gender += 1

ans_gender = pd.DataFrame({'predicted_gender': y_pred_gender})
ans_gender.to_csv(
'C:/Users/yrqun/Desktop/TMP/trans/tmp/transformer_gender.csv', header=True, columns=['predicted_gender'], index=False)

y_pred_age = model_age.predict(
{
'creative_id': DATA['X1_test'],
Expand All @@ -466,20 +462,13 @@ def save_npy(datas, name):
y_pred_age = y_pred_age.flatten()
y_pred_age += 1

ans_age = pd.DataFrame({'predicted_age': y_pred_age})
ans_age.to_csv(
'C:/Users/yrqun/Desktop/TMP/trans/tmp/transformer_age.csv', header=True, columns=['predicted_age'], index=False)

# slc
user_id_test = pd.read_csv(
'data/test/clicklog_ad.csv').sort_values(['user_id'], ascending=(True,)).user_id.unique()
ans = pd.DataFrame({'user_id': user_id_test})

gender = pd.read_csv(
'C:/Users/yrqun/Desktop/TMP/trans/tmp/transformer_gender.csv')
age = pd.read_csv(
'C:/Users/yrqun/Desktop/TMP/trans/tmp/transformer_age.csv')
ans['predicted_gender'] = gender.predicted_gender
ans['predicted_age'] = age.predicted_age
ans['predicted_gender'] = y_pred_gender
ans['predicted_age'] = y_pred_age
ans.to_csv('C:/Users/yrqun/Desktop/TMP/trans/tmp/submission.csv', header=True, index=False,
columns=['user_id', 'predicted_age', 'predicted_gender'])

0 comments on commit 09a4a70

Please sign in to comment.