visionhong / SD-LoRA-MLflow

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

mlflow tracking server with http basic authentication

Run the Dockerfile with the http basic authentication added in mlflow tracking server. Check out how to use it in the blog below.

blog - https://visionhong.github.io/aws/AWS-MLflow-SD-LoRA/


Stable Diffusion LoRA with MLflow and Ray

Hyper-parameter Tuning

export MLFLOW_TRACKING_URI="<ECS 태스크 Public IP>"
export MLFLOW_TRACKING_USERNAME="<대시보드 username>"
export MLFLOW_TRACKING_PASSWORD="<대시보드 비밀번호>"
export AWS_ACCESS_KEY_ID="<IAM 사용자 Access Key>"
export AWS_SECRET_ACCESS_KEY="<IAM 사용자 Secret Key>"
python train_tune.py \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1-base" \
  --dataset_name="zoheb/sketch-scene" \
  --dataloader_num_workers=8 \
  --width=256 --height=256 --center_crop --random_flip \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --max_train_steps=1000 \
  --learning_rate=1e-04 \
  --lr_scheduler="cosine" --lr_warmup_steps=0 \
  --experiments_name='sketch_ray_tune' \
  --seed=1337 \
  --mixed_precision='fp16' \
  --enable_xformers_memory_efficient_attention \
  --tune \
  --gpus_per_trial=1

result:

mlflow-res2 mlflow-res1

Hyper-parameter Tuning

accelerate launch --mixed_precision="fp16" train_tune.py \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1-base" \
  --dataset_name="zoheb/sketch-scene" \
  --dataloader_num_workers=8 \
  --width=256 --height=256 --center_crop --random_flip \
  --train_batch_size=2 \
  --gradient_accumulation_steps=4 \
  --num_train_epochs=10 \
  --learning_rate=1e-03 \
  --lr_scheduler="cosine" --lr_warmup_steps=500 \
  --output_dir="LoRA_sketch_output" \
  --experiments_name='sketch' \
  --checkpointing_steps=5000 \
  --validation_prompt="a man swimming in the sea" \
  --validation_epochs=1 \
  --num_validation_images=2 \
  --seed=1337 \
  --enable_xformers_memory_efficient_attention

result:

mlflow-loss mlflow-res3

About


Languages

Language:Python 98.7%Language:Dockerfile 0.7%Language:Shell 0.6%