Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Label Problem #14

Open
savasy opened this issue Jan 10, 2022 · 4 comments
Open

Custom Label Problem #14

savasy opened this issue Jan 10, 2022 · 4 comments

Comments

@savasy
Copy link

savasy commented Jan 10, 2022

When I train the model with custom labels, the training code works well. However, Adapting Inference.py code to my custom trained model does not work.

I change the Inference.ipynb code to adapt my 11 labels as follows:

    LABELS=["Adjective","API","Core","GUI","Hardware","Language","Platform","Standard","User","Verb","O"]
    template_list=[" is a %s entity"%(e) for e in LABELS]
    entity_dict={i:e for i, e in enumerate(LABELS)}

Here is loading checkpoint

tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
model = BartForConditionalGeneration.from_pretrained('./outputs/best_model')

Here is the inference and the error

prediction("As a user I should be able to use the attribute type User in my queries.")

RuntimeError
----> 2 prediction("As a user I should be able to use the attribute type User in my queries.")
/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_bart.py in _shape(self, tensor, seq_len, bsz)
157 def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
--> 158 return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
RuntimeError: shape '[88, -1, 16, 64]' is invalid for input of size 778240

@HuangZhenyang
Copy link

Hi, I met the same problem. And I solved it by replacing the number 5 with len(template_list) in inference.py.

@savasy
Copy link
Author

savasy commented Jan 11, 2022

Hi @HuangZhenyang
You changed every 5s? , what about 4s, they need to be X-1 ?

@HuangZhenyang
Copy link

HuangZhenyang commented Jan 11, 2022

@savasy I changed every 5 to len(template_list) in function def template_entity(). I don't know if this change is right... Maybe we should wait for the author to reply.

@savasy
Copy link
Author

savasy commented Jan 11, 2022

@HuangZhenyang
Aha awesome, it worked, thanks.
Hi @Nealcly, Maybe you can change it accordingly or we can make a pull request

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants