Skip to content

Commit

Permalink
update predict transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
sunlanchang committed Jun 21, 2020
1 parent c981ad2 commit a5a8c3f
Showing 1 changed file with 12 additions and 24 deletions.
36 changes: 12 additions & 24 deletions Transformer_keras_6_input_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def get_gender_model(DATA):
input_creative_id = Input(shape=(max_seq_len,), name='creative_id')
x1 = Embedding(input_dim=NUM_creative_id+1,
output_dim=256,
weights=[DATA['creative_id_emb']],
# weights=[DATA['creative_id_emb']],
trainable=args.not_train_embedding,
# trainable=False,
input_length=150,
Expand All @@ -116,7 +116,7 @@ def get_gender_model(DATA):
input_ad_id = Input(shape=(max_seq_len,), name='ad_id')
x2 = Embedding(input_dim=NUM_ad_id+1,
output_dim=256,
weights=[DATA['ad_id_emb']],
# weights=[DATA['ad_id_emb']],
trainable=args.not_train_embedding,
# trainable=False,
input_length=150,
Expand All @@ -125,7 +125,7 @@ def get_gender_model(DATA):
input_product_id = Input(shape=(max_seq_len,), name='product_id')
x3 = Embedding(input_dim=NUM_product_id+1,
output_dim=32,
weights=[DATA['product_id_emb']],
# weights=[DATA['product_id_emb']],
trainable=args.not_train_embedding,
# trainable=False,
input_length=150,
Expand All @@ -134,7 +134,7 @@ def get_gender_model(DATA):
input_advertiser_id = Input(shape=(max_seq_len,), name='advertiser_id')
x4 = Embedding(input_dim=NUM_advertiser_id+1,
output_dim=64,
weights=[DATA['advertiser_id_emb']],
# weights=[DATA['advertiser_id_emb']],
trainable=args.not_train_embedding,
# trainable=False,
input_length=150,
Expand All @@ -143,7 +143,7 @@ def get_gender_model(DATA):
input_industry = Input(shape=(max_seq_len,), name='industry')
x5 = Embedding(input_dim=NUM_industry+1,
output_dim=16,
weights=[DATA['industry_emb']],
# weights=[DATA['industry_emb']],
trainable=True,
# trainable=False,
input_length=150,
Expand All @@ -153,7 +153,7 @@ def get_gender_model(DATA):
shape=(max_seq_len,), name='product_category')
x6 = Embedding(input_dim=NUM_product_category+1,
output_dim=8,
weights=[DATA['product_category_emb']],
# weights=[DATA['product_category_emb']],
trainable=True,
# trainable=False,
input_length=150,
Expand Down Expand Up @@ -214,7 +214,7 @@ def get_age_model(DATA):
input_creative_id = Input(shape=(max_seq_len,), name='creative_id')
x1 = Embedding(input_dim=NUM_creative_id+1,
output_dim=256,
weights=[DATA['creative_id_emb']],
# weights=[DATA['creative_id_emb']],
trainable=args.not_train_embedding,
# trainable=False,
input_length=150,
Expand All @@ -225,7 +225,7 @@ def get_age_model(DATA):
input_ad_id = Input(shape=(max_seq_len,), name='ad_id')
x2 = Embedding(input_dim=NUM_ad_id+1,
output_dim=256,
weights=[DATA['ad_id_emb']],
# weights=[DATA['ad_id_emb']],
trainable=args.not_train_embedding,
# trainable=False,
input_length=150,
Expand All @@ -234,7 +234,7 @@ def get_age_model(DATA):
input_product_id = Input(shape=(max_seq_len,), name='product_id')
x3 = Embedding(input_dim=NUM_product_id+1,
output_dim=32,
weights=[DATA['product_id_emb']],
# weights=[DATA['product_id_emb']],
trainable=args.not_train_embedding,
# trainable=False,
input_length=150,
Expand All @@ -243,7 +243,7 @@ def get_age_model(DATA):
input_advertiser_id = Input(shape=(max_seq_len,), name='advertiser_id')
x4 = Embedding(input_dim=NUM_advertiser_id+1,
output_dim=64,
weights=[DATA['advertiser_id_emb']],
# weights=[DATA['advertiser_id_emb']],
trainable=args.not_train_embedding,
# trainable=False,
input_length=150,
Expand All @@ -252,7 +252,7 @@ def get_age_model(DATA):
input_industry = Input(shape=(max_seq_len,), name='industry')
x5 = Embedding(input_dim=NUM_industry+1,
output_dim=16,
weights=[DATA['industry_emb']],
# weights=[DATA['industry_emb']],
trainable=True,
# trainable=False,
input_length=150,
Expand All @@ -262,7 +262,7 @@ def get_age_model(DATA):
shape=(max_seq_len,), name='product_category')
x6 = Embedding(input_dim=NUM_product_category+1,
output_dim=8,
weights=[DATA['product_category_emb']],
# weights=[DATA['product_category_emb']],
trainable=True,
# trainable=False,
input_length=150,
Expand Down Expand Up @@ -421,18 +421,6 @@ def save_npy(datas, name):
DATA['X_test6'] = np.load(
'C:/Users/yrqun/Desktop/TMP/trans/tmp/test_5.npy', allow_pickle=True)

DATA['creative_id_emb'] = np.load(
'C:/Users/yrqun/Desktop/TMP/trans/tmp/embeddings_0.npy', allow_pickle=True)
DATA['ad_id_emb'] = np.load(
'C:/Users/yrqun/Desktop/TMP/trans/tmp/embeddings_1.npy', allow_pickle=True)
DATA['product_id_emb'] = np.load(
'C:/Users/yrqun/Desktop/TMP/trans/tmp/embeddings_2.npy', allow_pickle=True)
DATA['advertiser_id_emb'] = np.load(
'C:/Users/yrqun/Desktop/TMP/trans/tmp/embeddings_3.npy', allow_pickle=True)
DATA['industry_emb'] = np.load(
'C:/Users/yrqun/Desktop/TMP/trans/tmp/embeddings_4.npy', allow_pickle=True)
DATA['product_category_emb'] = np.load(
'C:/Users/yrqun/Desktop/TMP/trans/tmp/embeddings_5.npy', allow_pickle=True)

# %%
model_gender = get_gender_model(DATA)
Expand Down

0 comments on commit a5a8c3f

Please sign in to comment.