smrfeld / diff-pytorch

Diffusion model in PyTorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Diffusion in PyTorch

Implementation of diffusion in pure PyTorch.

Adapted from: https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition/blob/b91507a769bc40d87f1428f3eabba11dda6ea8c0/notebooks/08_diffusion/01_ddm/ddm.ipynb#L104

Running

  • Data - download from: https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset. Let /path/to/data be the path to the folder that contains the train and test folders.

  • Install:

    conda create -n diffusion python=3.11
    conda activate diffusion
    pip install -r requirements.txt
  • Training:

    python main.py --conf conf.yml --command train --mnt-dir /path/to/data

    See conf.yml for more options.

  • Generate samples:

    python main.py --conf conf.yml --command generate

    See conf.yml for more options.

  • Plot loss:

    python main.py --conf conf.yml --command loss --show

  • Tests:

    cd tests
    pytest

About

Diffusion model in PyTorch

License:MIT License


Languages

Language:Python 100.0%