ZeroCodePro / U2Net-Multi-Gpus-Training

U-2-Net mutli-gpus training Pytorch Lightning code

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

U-2-Net multi-gpu training

This is for researchers who are supported with multi-gpus!

** For this particular process you will need pytorch lightning

UPDATES!

  • Pretrained weights (DUTS-TR):
    checkpoints : u2net.ckpt
    state_dict: u2net.pth
  • I have updated inference.py code to test u2net

Required Libraries

Python 3.8
Pytorch 1.12.0+cu102
PyTorch Lightning 1.8.6
Numpy 1.23.5
Opencv-Python 4.6.0
Albumentations 1.3.0
wandb 0.13.8 (excluded in requirements)

Install

  1. Conda
conda create -n <env-name> python=3.8
  1. pip
pip install -r requirements.txt

Training

  1. arguments
keyword type
--min_epoch int
--max_epoch int
--batch_size int
--lr float
--epsilon float
--tr_im_path string
--tr_gt_path string
--vd_im_path string
--vd_gt_path string
--pretrained_path string
--save_weight_path string
  1. Script
    ** I made it easier to change arguments by changing in python script
python train_u2net.py

About

U-2-Net mutli-gpus training Pytorch Lightning code

License:Apache License 2.0


Languages

Language:Python 100.0%