wpeebles / gangealing

Official PyTorch Implementation of "GAN-Supervised Dense Visual Alignment" (CVPR 2022 Oral, Best Paper Finalist)

Home Page:https://www.wpeebles.com/gangealing

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

single color channel

petercmh01 opened this issue · comments

Hi! I've seem many other post mentioned batch size is essential here to keep the training stable. However I do only have one GPU which only allows me to use batch size under 8. I wonder is there any ways that I can reduce the number of RGB channel of the model to 1? The task / dataset I'm working on is relatively simple and can basically ignore the colors.

Thanks in advance!

Hi @petercmh01. I think reducing the number of image channels won't affect memory too much. I believe the vast bulk of memory consumption comes from the intermediate layers of the StyleGAN generator and the Spatial Transformer which aren't affected by image channel size. If you do want to use 1 channel, you would need to modify line 434 of spatial_transformer.py to use 1 input channel instead of 3. You'd also need to change line 360 of networks.py so StyleGAN produces single-channel outputs.

There are a few things you can try to get good performance if you have 1 gpu. Probably the most surefire way is to use gradient accumulation (you can find an untested code snippet in Issue #10 which uses 8x gradient accumulation). Using grad accumulation will increase your effective batch size. The downside is that training will be quite a bit slower, as for every one backward pass you need multiple forward passes. For what it's worth, it may be that a batch size of 16 or 24 is enough, so you could try using only 2x or 3x rounds of accumulation.

Another thing you could try is reducing the --stn_channel_multiplier argument to e.g., 0.25. This argument controls the number of channels in intermediate conv layers of the Spatial Transformer. Using a smaller value should let you use a bigger batch size. My guess is that you can get away with a smaller STN, although it isn't something I've tried before.

Finally, I would recommend using our training scripts which use LPIPS as the perceptual loss. It seems that LPIPS may be more stable at small batch sizes than the SimCLR VGG perceptual loss, although from my experiments it still has degraded performance with a batch size of 5.

I'll go ahead and close this issue now but feel free to reopen or make a new issue if you have other questions.