U-2-Net multi-gpu training
This is for researchers who are supported with multi-gpus!
** For this particular process you will need pytorch lightning
- Pretrained weights (DUTS-TR):
checkpoints : u2net.ckpt
state_dict: u2net.pth - I have updated
inference.py
code to test u2net
Python 3.8
Pytorch 1.12.0+cu102
PyTorch Lightning 1.8.6
Numpy 1.23.5
Opencv-Python 4.6.0
Albumentations 1.3.0
wandb 0.13.8 (excluded in requirements)
- Conda
conda create -n <env-name> python=3.8
- pip
pip install -r requirements.txt
- arguments
keyword | type |
---|---|
--min_epoch | int |
--max_epoch | int |
--batch_size | int |
--lr | float |
--epsilon | float |
--tr_im_path | string |
--tr_gt_path | string |
--vd_im_path | string |
--vd_gt_path | string |
--pretrained_path | string |
--save_weight_path | string |
- Script
** I made it easier to change arguments by changing in python script
python train_u2net.py