Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor dimension issue occurring partway through training #8

Open
olliestanley opened this issue Apr 13, 2022 · 4 comments
Open

Tensor dimension issue occurring partway through training #8

olliestanley opened this issue Apr 13, 2022 · 4 comments

Comments

@olliestanley
Copy link

olliestanley commented Apr 13, 2022

Hi, I have encountered a weird issue when attempting to train a DetCo (ResNet18 backbone) model. In short, the model trained perfectly well for 9 epochs, with loss on a downwards trend. Then part-way through the 10th epoch, an error occurred:

RuntimeError('The expanded size of the tensor (8) must match the existing size (16) at non-singleton dimension 2. Target sizes: [8, 128, 8]. Tensor sizes: [8, 128, 16]')

Running it again with the code wrapped in try-except to validate, confirms that after this first occurs it occurs for every forward call made from that point onwards, on all of the GPUs. The stack trace points to the line

self.queue[:, :, ptr:ptr + batch_size] = keys.permute(1, 2, 0)

in this function, which is called at the end of the forward function:

@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
# gather keys before updating queue
keys = concat_all_gather(keys)
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
assert self.K % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.queue[:, :, ptr:ptr + batch_size] = keys.permute(1,2,0)
ptr = (ptr + batch_size) % self.K # move pointer
self.queue_ptr[0] = ptr

It seems like this cannot be a problem with the data, as the exact same data have been passed through the model several times previously and following the point it occurs, it occurs for every batch. I believe the exact point it occurs to be different each run, as if I remember correctly it occurred after 8 epochs the first time I encountered it but after 9 the second time.

I wondered whether it could be due to using 4 GPUs while the code was tested with 8, but this would still not really explain it only beginning part-way through training. Running continual experiments to test this or other hypotheses would require a lot of GPU-time, so I decided to post here first in case you have any insight. Thanks!

@shuuchen
Copy link
Owner

Hi,

Have you tested it on 4 GPUs ? If you tested it on 8 GPUs, you might change this:

https://github.com/shuuchen/DetCo.pytorch/blob/main/main_detco.py#L136

and then check whether each GPU is working correctly by nvidia-smi.

@olliestanley
Copy link
Author

Thanks for your response! I am using an instance with 4 GPUs and am attempting to train using all 4 of them currently.

I am using slightly different launch code as I only needed the single-node DDP functionality, but when my worker function is called (via PyTorch multiprocessing's spawn function) I use the passed gpu argument as the value passed into set_device(), cuda(), device_ids=[] etc.

Based on that code it looks like you are taking the process ID passed as gpu to your worker function by spawn, and adding 4 to it. I thought this was because you have 8 GPUs but want to train using 4 of them, so are spawning torch.cuda.device_count() // 2 processes, and then using GPUs 4 through 7. Is this understanding not correct? On my end I have been spawning 4 processes and using all GPUs 0 through 3.

I am also not sure if a GPU problem like this would explain the issue, as if the distributed training were improperly configured I imagine we would expect a failure much earlier in the process?

@shuuchen
Copy link
Owner

Yes. I used GPUs 4 through 7, with the 4th as master GPU.
If your machine contains 4 GPUs, just use torch.cuda.device_count() and set the master GPU to 0.

@olliestanley
Copy link
Author

I believe that lines up with what I am doing currently in that case, so shouldn't be related to the problem. I have now tried with another much smaller dataset on the same GPU configuration, and was able to make it past 10 epochs no problem both times. So perhaps it is a data-related issue somehow

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants