Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to train aot? Where is aot's train.py? #135

Open
22236 opened this issue Mar 12, 2024 · 1 comment
Open

How to train aot? Where is aot's train.py? #135

22236 opened this issue Mar 12, 2024 · 1 comment

Comments

@22236
Copy link

22236 commented Mar 12, 2024

I want to train aot before I think about deaot.

@22236
Copy link
Author

22236 commented Mar 13, 2024

Use GPU 0 for training VOS.
Build VOS model.
Use Frozen BN in Encoder!
Build optimizer.
Total Param: 5.73M
Process dataset...
Video Num: 29 X 1
<class 'numpy.ndarray'>
(512, 512, 3) (512, 512, 3)
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
Video Num: 29 X 1
Done!
Remove ['features.0.1.num_batches_tracked', 'features.1.conv.0.1.num_batches_tracked', 'features.1.conv.2.num_batches_tracked', 'features.2.conv.0.1.num_batches_tracked', 'features.2.conv.1.1.num_batches_tracked', 'features.2.conv.3.num_batches_tracked', 'features.3.conv.0.1.num_batches_tracked', 'features.3.conv.1.1.num_batches_tracked', 'features.3.conv.3.num_batches_tracked', 'features.4.conv.0.1.num_batches_tracked', 'features.4.conv.1.1.num_batches_tracked', 'features.4.conv.3.num_batches_tracked', 'features.5.conv.0.1.num_batches_tracked', 'features.5.conv.1.1.num_batches_tracked', 'features.5.conv.3.num_batches_tracked', 'features.6.conv.0.1.num_batches_tracked', 'features.6.conv.1.1.num_batches_tracked', 'features.6.conv.3.num_batches_tracked', 'features.7.conv.0.1.num_batches_tracked', 'features.7.conv.1.1.num_batches_tracked', 'features.7.conv.3.num_batches_tracked', 'features.8.conv.0.1.num_batches_tracked', 'features.8.conv.1.1.num_batches_tracked', 'features.8.conv.3.num_batches_tracked', 'features.9.conv.0.1.num_batches_tracked', 'features.9.conv.1.1.num_batches_tracked', 'features.9.conv.3.num_batches_tracked', 'features.10.conv.0.1.num_batches_tracked', 'features.10.conv.1.1.num_batches_tracked', 'features.10.conv.3.num_batches_tracked', 'features.11.conv.0.1.num_batches_tracked', 'features.11.conv.1.1.num_batches_tracked', 'features.11.conv.3.num_batches_tracked', 'features.12.conv.0.1.num_batches_tracked', 'features.12.conv.1.1.num_batches_tracked', 'features.12.conv.3.num_batches_tracked', 'features.13.conv.0.1.num_batches_tracked', 'features.13.conv.1.1.num_batches_tracked', 'features.13.conv.3.num_batches_tracked', 'features.14.conv.0.1.num_batches_tracked', 'features.14.conv.1.1.num_batches_tracked', 'features.14.conv.3.num_batches_tracked', 'features.15.conv.0.1.num_batches_tracked', 'features.15.conv.1.1.num_batches_tracked', 'features.15.conv.3.num_batches_tracked', 'features.16.conv.0.1.num_batches_tracked', 'features.16.conv.1.1.num_batches_tracked', 'features.16.conv.3.num_batches_tracked', 'features.17.conv.0.1.num_batches_tracked', 'features.17.conv.1.1.num_batches_tracked', 'features.17.conv.3.num_batches_tracked', 'features.18.1.num_batches_tracked', 'classifier.1.weight', 'classifier.1.bias'] from pretrained model.
Load pretrained backbone model from Segment-and-Track-Anything-main/aot/pretrain_models/mobilenet_v2-b0353104.pth.
Start training:
Traceback (most recent call last):
File "Segment-and-Track-Anything-main/my_code/aot/train.py", line 111, in
main() # 执行主函数
File "Segment-and-Track-Anything-main/my_code/aot/train.py", line 108, in main
main_worker(0, cfg, args.amp) # 单进程训练
File "Segment-and-Track-Anything-main/my_code/aot/train.py", line 52, in main_worker
trainer.sequential_training()
File "Segment-and-Track-Anything-main/./aot/networks/managers/trainer.py", line 419, in sequential_training
for sample in enumerate(train_loader):
File "anaconda3/envs/samtr/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 521, in next
data = self._next_data()
File "anaconda3/envs/samtr/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data
return self._process_data(data)
File "anaconda3/envs/samtr/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data
data.reraise()
File "/anaconda3/envs/samtr/lib/python3.9/site-packages/torch/_utils.py", line 434, in reraise
raise exception
KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/anaconda3/envs/samtr/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "/anaconda3/envs/samtr/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/anaconda3/envs/samtr/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/anaconda3/envs/samtr/lib/python3.9/site-packages/torch/utils/data/dataset.py", line 308, in getitem
return self.datasets[dataset_idx][sample_idx]
KeyError: 4

The above is my error, I don't know how to modify it. When I was training my dataset, I got KeyError, I didn't know what it meant, and when I changed the len or obj parameter in default.py, sometimes the KeyError was 6, sometimes 0, sometimes 2, sometimes 4. I would like to know what the idx, max_obj_n, seq_len etc parameters refer to in the code.

Thanks for the answer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant