Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Where the fine-tuning in the code #55

Open
JianqunZhang opened this issue Oct 15, 2021 · 4 comments
Open

Where the fine-tuning in the code #55

JianqunZhang opened this issue Oct 15, 2021 · 4 comments

Comments

@JianqunZhang
Copy link

Hi Liu,i appreciate your work very much. In the process of my study, there is a difficult problem to understand. Can tou give me some help. I think fine-tuning should be implemented using test data sets. However, there seems to be no tuning step for the test data in the code. Can you tell me where the fine tuning steps have been implemented? In addition, i use your pytorch code.

@yaoyao-liu
Copy link
Owner

Thanks for your interest in our work.

Could you please which test data you refer to? We have meta-train and meta-test stages. In each stage, we will sample episodes. Thus we have episode train and episode test. Do you mean fine-tuning on episode test data?

@JianqunZhang
Copy link
Author

Thank you for your timely reply. Yes,i do mean fine-tuning on episode test data. I use the test data in the mini-imagement dataset under the meta-test stage. The support set in the test data should be fine tuned FTN + 1, is that right? Thank you very much!

@yaoyao-liu
Copy link
Owner

The fine-tuning on the episode test is implemented by the following function:

def meta_forward(self, data_shot, label_shot, data_query):
"""The function to forward meta-train phase.
Args:
data_shot: train images for the task
label_shot: train labels for the task
data_query: test images for the task.
Returns:
logits_q: the predictions for the test samples.
"""
embedding_query = self.encoder(data_query)
embedding_shot = self.encoder(data_shot)
logits = self.base_learner(embedding_shot)
loss = F.cross_entropy(logits, label_shot)
grad = torch.autograd.grad(loss, self.base_learner.parameters())
fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.base_learner.parameters())))
logits_q = self.base_learner(embedding_query, fast_weights)
for _ in range(1, self.update_step):
logits = self.base_learner(embedding_shot, fast_weights)
loss = F.cross_entropy(logits, label_shot)
grad = torch.autograd.grad(loss, fast_weights)
fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights)))
logits_q = self.base_learner(embedding_query, fast_weights)
return logits_q

If you have further questions, feel free to add more comments to this issue.

@JianqunZhang
Copy link
Author

Thank you very much! I got it! Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants