Everypixel / arshadowgan-like

ARShadowGAN-like realization. PyTorch.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

GAN training on shadow generation task example

Alt text

Colab Notebook

PyTorch Colab notebook: ARShadowGAN-like

Prerequisites

  • Python 3
  • CPU or NVIDIA GPU + CUDA CuDNN

Getting Started

Installation

  • Clone this repo:
git clone https://github.com/Everypixel/arshadowgan-like.git
cd arshadowgan
  • Install dependencies (e.g., segmentation_models_pytorch, ...)
pip install -r requirements.txt

Dataset preparation

ARShadow-dataset

We will use the shadow-ar dataset for training and testing our model. We have already splitted it to train and test parts. Download and extract it please .

Your own dataset

Your own dataset has to have the structure such ShadowAR-dataset has. Each folder contains images.

dataset
├── train
│   ├── noshadow ── example1.png, ...
│   ├── shadow ──── example1.png, ...
│   ├── mask ────── example1.png, ...
│   ├── robject ─── example1.png, ...
│   └── rshadow ─── example1.png, ...
└── test
    ├── noshadow ── example2.png, ...
    ├── shadow ──── example2.png, ...
    ├── mask ────── example2.png, ...
    ├── robject ─── example2.png, ...
    └── rshadow ─── example2.png, ...
  • noshadow - no shadow images
  • shadow - images with shadow
  • mask - inserted object masks
  • robject - occluders masks
  • rshadow - occluders shadows

Training

Training attention module

Set the parameters:

  • dataset_path - path to dataset
  • model_path - path for attention model saving
  • batch_size - amount of images in batch
    (reduce it if "CUDA: out of memory" error)
  • seed - seed for random functions
  • img_size - image width or image height (is divisible by 32)
  • lr - learning rate
  • n_epoch - amount of epochs

For example:

python3 scripts/train_attention.py \
       --dataset_path '/content/arshadowgan/dataset/' \
       --model_path '/content/drive/MyDrive/attention128.pth' \
       --batch_size 200 \
       --seed 42 \
       --img_size 256 \
       --lr 1e-4 \
       --n_epoch 100

Training shadow-generation module

  • dataset_path - path to dataset
  • Gmodel_path - path for generator model saving
  • Dmodel_path - path for discriminator model saving
  • batch_size - amount of images in batch
    (reduce it if "CUDA: out of memory" error)
  • seed - seed for random functions
  • img_size - image width or image height (is divisible by 32)
  • lr_G - generator learning rate
  • lr_D - discriminator learning rate
  • n_epoch - amount of epochs
  • betta1,2,3 - loss function coefficients, see ARShadowGAN paper

For example:

python3 scripts/train_SG.py \
       --dataset_path '/content/arshadowgan/dataset/' \
       --Gmodel_path '/content/drive/MyDrive/SG_generator.pth' \
       --Dmodel_path '/content/drive/MyDrive/SG_discriminator.pth' \
       --batch_size 64 \
       --seed 42 \
       --img_size 256 \
       --lr_G 1e-4 \
       --lr_D 1e-6 \
       --n_epoch 600 \
       --betta1 10 \
       --betta2 1 \
       --betta3 1e-2 \
       --patience 10 \
       --encoder 'resnet18'

Run

Start inference with results saving

For example:

python3 scripts/test.py \
       --batch_size 1 \
       --img_size 256 \
       --dataset_path '/content/arshadowgan/dataset/test' \
       --result_path '/content/arshadowgan/results' \
       --path_att '/content/drive/MyDrive/ARShadowGAN-like/attention.pth' \
       --path_SG '/content/drive/MyDrive/ARShadowGAN-like/SG_generator.pth'

Acknowledgements

We thank ARShadowGAN authors for their amazing work.
We also thank segmentation_models.pytorch for network architecture, albumentations for augmentations, PyTorch-GAN for discriminator architecture and piq for Content loss.

About

ARShadowGAN-like realization. PyTorch.

License:MIT License


Languages

Language:Python 97.3%Language:Shell 2.7%