diff --git a/real/train_code/architecture/MST_Plus_Plus.py b/real/train_code/architecture/MST_Plus_Plus.py index 2b93eb1..def71c9 100755 --- a/real/train_code/architecture/MST_Plus_Plus.py +++ b/real/train_code/architecture/MST_Plus_Plus.py @@ -277,27 +277,12 @@ def __init__(self, in_channels=3, out_channels=28, n_feat=28, stage=3): self.body = nn.Sequential(*modules_body) self.conv_out = nn.Conv2d(n_feat, out_channels, kernel_size=3, padding=(3 - 1) // 2,bias=False) - def initial_x(self, y): - """ - :param y: [b,1,256,310] - :param Phi: [b,28,256,310] - :return: z: [b,28,256,310] - """ - nC, step = 28, 2 - bs, row, col = y.shape - x = torch.zeros(bs, nC, row, row).cuda().float() - for i in range(nC): - x[:, i, :, :] = y[:, :, step * i:step * i + col - (nC - 1) * step] - x = self.fution(x) - return x def forward(self, x, input_mask=None): """ x: [b,c,h,w] return out:[b,c,h,w] """ - x = self.initial_x(x) - b, c, h_inp, w_inp = x.shape hb, wb = 8, 8 pad_h = (hb - h_inp % hb) % hb