openai / consistency_models

Official repo for consistency models.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Not able to obtain results like the checkpoint models (CT, imagenet64) when trying to train from scratch

aarontan-git opened this issue · comments

commented

I tried to train a CT-imagenet model from scratch, and have not been able to reproduce the same quality of images as the checkpoint models provided.
Below is the command I ran:

python cm_train.py --training_mode consistency_training --target_ema_mode adaptive --start_ema 0.95 --scale_mode progressive --start_scales 2 --end_scales 200 --total_training_steps 800000 --loss_norm lpips --lr_anneal_steps 0 --attention_resolutions 32,16,8 --class_cond True --use_scale_shift_norm True --dropout 0.0 --teacher_dropout 0.1 --ema_rate 0.999 --global_batch_size 16 --image_size 64 --lr 0.0001 --num_channels 192 --num_head_channels 64 --num_res_blocks 3 --resblock_updown True --schedule_sampler uniform --use_fp16 True --weight_decay 0.0 --weight_schedule uniform --data_dir /home/ILSVRC2012_img_train

My 3090 can only handle a batch size of 16. Has anyone tried training their own model from scratch?

Training a consistency model from scratch can be a challenging task, and it may require careful experimentation and tuning to achieve results comparable to the checkpoint models provided. Here are a few considerations and suggestions that may help you improve the quality of your trained model:

  1. Check hyperparameters: Ensure that you have set the hyperparameters appropriately for your training. The hyperparameters you provided in the command seem reasonable, but you may need to experiment with different values to find the optimal settings for your specific dataset and model architecture.

  2. Increase the training steps: The total number of training steps you specified is 800,000. Consider training the model for a longer duration, as the provided checkpoint models may have been trained for a larger number of steps. Increasing the number of training steps can allow the model to converge to better solutions.

  3. Experiment with different learning rates: The learning rate you specified is 0.0001. Try different learning rates to see if you can find a value that leads to better convergence and image quality. Learning rate schedules, such as warmup or decay, can also be beneficial.

  4. Use a larger batch size if possible: Although you mentioned that your GPU can only handle a batch size of 16, if you have access to resources with larger memory capacity, consider using a larger batch size. Larger batch sizes can sometimes lead to more stable and better-performing models.

  5. Data augmentation: Apply appropriate data augmentation techniques to enhance the diversity of your training data. This can help the model learn robust features and generalize better to unseen examples.

  6. Regularization techniques: Experiment with different regularization techniques such as weight decay or dropout. These techniques can help prevent overfitting and improve the generalization of the model.

  7. Pretrained initialization: Instead of starting from random initialization, you can try initializing your model with pretrained weights from a similar architecture or a different task. This can provide a better starting point and help the model converge faster.

  8. Monitor and visualize training progress: During training, monitor the loss and other relevant metrics to ensure that the model is making progress. Additionally, visualize the generated images at regular intervals to assess the quality and progress of the training process.

  9. Experiment with different architectures: If you are not satisfied with the results obtained with the current architecture, consider experimenting with different model architectures or variations to find one that works well for your specific task and dataset.

Remember that training consistency models from scratch can be a time-consuming process, and it may require multiple iterations of experimentation and fine-tuning to achieve satisfactory results.