entrpn / maxdiffusion

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Unit Tests

Overview

WARNING: The training code is purely experimental and is under development.

MaxDiffusion is a Latent Diffusion model written in pure Python/Jax and targeting Google Cloud TPUs. MaxDiffusion aims to be a launching off point for ambitious Diffusion projects both in research and production. We encourage users to start by experimenting with MaxDiffusion out of the box and then fork and modify MaxDiffusion to meet their needs.

MaxDiffusion supports

  • Stable Diffusion 2.1 (training and inference)
  • Stable Diffusion XL (inference).

Table of Contents

Getting Started

We recommend starting with single host first and then moving to multihost.

Getting Started: Local Development for single host

Local development is a convenient way to run MaxDiffusion on a single host.

  1. Create and SSH to the single-host TPU of your choice. We recommend a v4-8.
  2. Clone MaxDiffusion onto that TPUVM.
  3. Within the root directory of that git repo, install dependencies by running:
pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip3 install -r requirements.txt
pip3 install -e .
  1. After installation completes, run training with the command:
python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base.yml run_name="my_run" base_output_directory="gs://your-bucket/"
  1. If you want to generate images, you can do it as follows.
  • Stable Diffusion 2.1

    python -m src.maxdiffusion.generate src/maxdiffusion/configs/base.yml
  • Stable Diffusion XL

    Multi host supported with sharding annotations:

    python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml

    Single host pmap version:

    python -m src.maxdiffusion.generate_sdxl_replicated

Getting Started: Multihost development

Multihost training can be ran as follows.

TPU_NAME=<your-tpu-name>
ZONE=<your-zone>
PROJECT_ID=<your-project-id>
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --project $PROJECT_ID --worker=all --command="
git clone https://github.com/google/maxdiffusion
pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip3 install -r requirements.txt
pip3 install .
python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base.yml run_name=my_run base_output_directory=gs://your-bucket/"

Comparison to Alternatives

MaxDiffusion started as a fork of Diffusers, a Hugging Face diffusion library written in Python, Pytorch and Jax. MaxDiffusion is compatible with Hugging Face Jax models. MaxDiffusion is more complex with the aim to run distributed across TPU Pods.

Development

Whether you are forking MaxDiffusion for your own needs or intending to contribute back to the community, we offer simple testing recipes.

To run unit tests and lint, simply run:

python -m pytest
ruff check --fix .

The full suite of -end-to end tests is in tests and src/maxdiffusion/tests. We run them with a nightly cadance.

About

License:Apache License 2.0


Languages

Language:Python 99.6%Language:Makefile 0.3%Language:Shell 0.1%