Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistake in the code? #18

Open
Twice22 opened this issue Jul 2, 2020 · 3 comments
Open

Mistake in the code? #18

Twice22 opened this issue Jul 2, 2020 · 3 comments

Comments

@Twice22
Copy link

Twice22 commented Jul 2, 2020

Hello!

Thank you for releasing your implementation. Yet, it looks like the fba_fusion doesn't do what you want to do. Or am I missing something?

Indeed, before calling the fba_fusion function, you've defined, alpha, fg, bg as follow:

        alpha = torch.clamp(output[:, 0][:, None], 0, 1)

        F = torch.sigmoid(output[:, 1:4])
        B = torch.sigmoid(output[:, 4:7])

        alpha, F, B = fba_fusion(alpha, img, F, B)

So, you are broadcasting alpha so that it is of size (B, 1, H, W)
Moreover F, and B are respectively of sizes (B, 3, H, W)

Now, if we look at how you compute alpha in the fba_fusion module, we have:

def fba_fusion(alpha, img, F, B):
    F = ((alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B))
    B = ((1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * (1 - alpha) * F)

    F = torch.clamp(F, 0, 1)
    B = torch.clamp(B, 0, 1)
    la = 0.1
    alpha = (alpha * la + torch.sum((img - B) * (F - B), 1)) / (torch.sum((F - B) * (F - B), 1) + la)
    alpha = torch.clamp(alpha, 0, 1)
    return alpha, F, B

So, we have (by using the broadcasting rules)

size = ((B, 1, H, W) * scalar + sum((B, 3, H, W), 1)) / (sum((B, 3, H, W), 1) + scalar)
size = (B, 1, H, W) + (B, H, W)) / (B, H, W)
size = (B, 1, H, W) + (1, B, H, W) / (B, H, W)
size = (B, B, H, W) / (B, B, H, W)
size = (B, B, H, W)

So, in the end, alpha is of size (B, B, H, W)

Wheren't you supposed to add keepdim=True in torch.sum?
Your final pth model used this flawed operation?

Hope you can reply my enquiries.
Thank you

@raphychek
Copy link

Well there actually is a keepdim = True in the torch.sum. In networks/models.py, the code is as follows, on line 256:

def fba_fusion(alpha, img, F, B):
    F = ((alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B))
    B = ((1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * (1 - alpha) * F)

    F = torch.clamp(F, 0, 1)
    B = torch.clamp(B, 0, 1)
    la = 0.1
    alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (torch.sum((F - B) * (F - B), 1, keepdim=True) + la)
    alpha = torch.clamp(alpha, 0, 1)
    return alpha, F, B

@Twice22
Copy link
Author

Twice22 commented Jul 2, 2020

Oh ok. I haven't seen this because I was working on the implementation before your the last commit

@MarcoForte
Copy link
Owner

Hi thanks for your interest and taking time to inform me of this issue. As raphychek pointed out it has been corrected already, see #7

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants