StableDiffusion example doesn't currently work without adjustment
mdvthu opened this issue · comments
Issue Type
Documentation Bug
Source
binary
Keras Version
keras_cv 0.9.0
Custom Code
No
OS Platform and Distribution
macOS, Windows, Ubuntu
Python version
3.12.4
GPU model and memory
CPU, Nvidia 3080, and M1 (METAL)
Current Behavior?
Following the example verbatim from Keras documentation pages including https://keras.io/guides/keras_cv/generate_images_with_stable_diffusion/ produces shape mismatch errors:
ValueError: Exception encountered when calling DiffusionModelV2.call().
Invalid input shape for input Tensor("data_2:0", shape=(1, 77, 1024), dtype=float32). Expected shape (None, 96, 96, 4), but input has incompatible shape (1, 77, 1024)
Arguments received by DiffusionModelV2.call():
• inputs={'latent': 'tf.Tensor(shape=(1, 96, 96, 4), dtype=float32)', 'timestep_embedding': 'tf.Tensor(shape=(1, 320), dtype=float32)', 'context': 'tf.Tensor(shape=(1, 77, 1024), dtype=float32)'}
• training=False
• mask={'latent': 'None', 'timestep_embedding': 'None', 'context': 'None'}
Standalone code to reproduce the issue or tutorial link
Follow the documentation on https://keras.io/guides/keras_cv/generate_images_with_stable_diffusion/
python3 -m venv venv
. ./venv/bin/activate
python3 -m pip install tensorflow keras_cv IPython
python3 -m IPython
```python
import keras_cv
model = keras_cv.models.StableDiffusion(
img_width=512, img_height=512, jit_compile=False
)
images = model.text_to_image("photograph of an astronaut riding a horse", batch_size=3)
### Relevant log output
```shell
ValueError: Exception encountered when calling DiffusionModelV2.call().
Invalid input shape for input Tensor("data_2:0", shape=(1, 77, 1024), dtype=float32). Expected shape (None, 96, 96, 4), but input has incompatible shape (1, 77, 1024)
Arguments received by DiffusionModelV2.call():
• inputs={'latent': 'tf.Tensor(shape=(1, 96, 96, 4), dtype=float32)', 'timestep_embedding': 'tf.Tensor(shape=(1, 320), dtype=float32)', 'context': 'tf.Tensor(shape=(1, 77, 1024), dtype=float32)'}
• training=False
• mask={'latent': 'None', 'timestep_embedding': 'None', 'context': 'None'}