Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Camoprompt #5

Open
gggg01hhhh opened this issue Nov 25, 2024 · 1 comment
Open

Camoprompt #5

gggg01hhhh opened this issue Nov 25, 2024 · 1 comment

Comments

@gggg01hhhh
Copy link

I'd like to ask you about calculating the accuracy of CamoPrompts. Could you provide some calculation code? Thank you very much.

@lartpang
Copy link
Owner

lartpang commented Nov 25, 2024

@gggg01hhhh

This process is modified directly from the original CLIP:

class CLIP(nn.Module):
    def __init__(self, template_set="basev3", **kwargs):
        super().__init__()
        warnings.warn(f"kwargs: {kwargs} do not work!")
        self.clip = ConvNeXtCLIP(template_set=template_set)
        self.normalizer = PixelNormalizer(mean=self.clip.mean, std=self.clip.std)
        self.test_class_embs = None

    @torch.no_grad()
    def test_forward(self, data, class_names, *, use_map=True, **kwargs):
        image = data["image"]
        mask = data["mask"]

        image = self.normalizer(image)
        image_feats = self.clip.get_visual_feats(image)
        image_deep = image_feats["clip_vis_dense"]

        if use_map:
            image_deep = resize_to(image_deep, tgt_hw=mask.shape[-2:])
            image_embs = (mask * image_deep).sum((-1, -2), keepdim=True) / mask.sum((-1, -2), keepdim=True)
        else:
            image_embs = image_deep
        image_embs = self.clip.visual_feats_to_embs(image_embs, normalize=True)

        # [N=num_classes, 768]
        if self.test_class_embs is None:
            self.test_class_embs = self.clip.get_text_embs_by_template(class_names)
        class_embs = self.test_class_embs

        class_logits = image_embs @ class_embs.T  # B,N
        cls_id_per_image = torch.argmax(class_logits, dim=-1)
        pred_classes = [class_names[i] for i in cls_id_per_image]
        return pred_classes

    def forward(self, *arg, **kwargs):
        if self.training:
            raise NotImplementedError
        else:
            return self.test_forward(*arg, **kwargs)

It follows a similar process as map_classifier:

def map_classifier(self, logits, image_deep, normed_class_embs):
prob = logits.sigmoid()
image_embs = resize_to(image_deep, tgt_hw=prob.shape[-2:])
# image_embs (B,C)
image_embs = (prob * image_embs).sum((-1, -2)) / prob.sum((-1, -2))
image_embs = image_embs[..., None, None]
# B,C => B,D
normed_image_embs = self.clip.visual_feats_to_embs(image_embs, normalize=True)
class_logits = normed_image_embs @ normed_class_embs.T # B,N
class_logits = self.clip.clip_model.logit_scale.exp() * class_logits
return class_logits

With the output classes pred_classes, you can calculate the accuracy of the classification.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants