diff --git a/Transformer_keras.py b/Transformer_keras.py index 18eb341..d9bac3e 100644 --- a/Transformer_keras.py +++ b/Transformer_keras.py @@ -290,27 +290,27 @@ def get_model_head_concat(DATA): # first input input_creative_id = Input(shape=(None,), name='creative_id') x1 = TokenAndPositionEmbedding( - maxlen, NUM_creative_id, embed_dim, DATA['creative_id_emb'])(input_creative_id) + maxlen, NUM_creative_id+1, embed_dim, DATA['creative_id_emb'])(input_creative_id) input_ad_id = Input(shape=(None,), name='ad_id') x2 = TokenAndPositionEmbedding( - maxlen, NUM_ad_id, embed_dim, DATA['ad_id_emb'])(input_ad_id) + maxlen, NUM_ad_id+1, embed_dim, DATA['ad_id_emb'])(input_ad_id) input_product_id = Input(shape=(None,), name='product_id') x3 = TokenAndPositionEmbedding( - maxlen, NUM_product_id, embed_dim, DATA['product_id_emb'])(input_product_id) + maxlen, NUM_product_id+1, embed_dim, DATA['product_id_emb'])(input_product_id) input_advertiser_id = Input(shape=(None,), name='advertiser_id') x4 = TokenAndPositionEmbedding( - maxlen, NUM_advertiser_id, embed_dim, DATA['advertiser_id_emb'])(input_advertiser_id) + maxlen, NUM_advertiser_id+1, embed_dim, DATA['advertiser_id_emb'])(input_advertiser_id) input_industry = Input(shape=(None,), name='industry') x5 = TokenAndPositionEmbedding( - maxlen, NUM_industry, embed_dim, DATA['industry_emb'])(input_industry) + maxlen, NUM_industry+1, embed_dim, DATA['industry_emb'])(input_industry) input_product_category = Input(shape=(None,), name='product_category') x6 = TokenAndPositionEmbedding( - maxlen, NUM_product_category, embed_dim, DATA['product_category_emb'])(input_product_category) + maxlen, NUM_product_category+1, embed_dim, DATA['product_category_emb'])(input_product_category) # concat # x = x1 + x2 + x3