kohya-ss / sd-scripts

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

I have reason to believe "scale v-loss like epsilon loss" and Min-SNR-Gamma are implemented wrong.

drhead opened this issue · comments

commented

I've been training a model using Kohya's implementation of Min-SNR-Gamma and the more recent option for scaling v-prediction like epsilon loss. I am also training it on v-prediction and zero terminal SNR, which is important.

I first found that the v-loss rescaling actually prevents a zero terminal SNR model from becoming able to produce fully black images even after about 5 million training samples, but it immediately learned it once I turned that setting off. However, others still noticed that it nevertheless improved quality in other areas, suggesting that there was likely a proper way to correct the flaw.

Looking further into the paper, it seems that the authors for Min-SNR-Gamma stated that the formula should be modified for V-loss, but may have been somewhat unclear in their wording:
image

Kohya implements the simplified formula on the right hand side.

I have implemented and tested this alternative function for min_snr_gamma, based on the middle formula -- it is the same as the middle formula except the denominator is replaced with SNR(t) + 1. My implementation is in JAX since that is what my current training script uses, but converting it to Pytorch should pretty much just be removing the expand_dims line and replacing jnp with torch:

    def apply_snr_weight_alt(loss, timesteps, noise_scheduler, gamma):
        snr = jnp.stack([noise_scheduler.all_snr[t] for t in timesteps])
        min_snr_gamma = jnp.minimum(snr, gamma)
        snr_weight = jnp.divide(min_snr_gamma, snr + 1).astype(jnp.float32)
        snr_weight = jnp.expand_dims(snr_weight, axis=(1, 2, 3)) # likely unnecessary for pytorch
        loss = loss * snr_weight
        return loss

This is, as far as I am aware, the correct function for min_snr_gamma for V-loss. It has performed well in my tests and has improved quality of my outputs without compromising on contrast range. It serves the same purpose as the "scale v-loss like epsilon loss" option and results in loss metrics that are in the same range as epsilon loss. It should be how Min-SNR-Gamma behaves under v-prediction and should fully replace the "scale v-loss like epsilon loss" option.

Others who I worked on this problem with have tested this function and found that it improves performance compared to using the current implementation of the aforementioned options. If you need a model to test it on, I can release one of my prototypes (an SD 1.5 model trained on V-loss and zero terminal SNR) for testing purposes. I would imagine SD 2.1 768-v would work as well.

There's some discussion about this in the original pull request for Min-SNR-gamma and AI-Casanova (the PR's author) also thinks it should probably be min(SNR,gamma)/(SNR+1) for vpred

Thank you very much for this!

I am not a math person and my understanding may be incorrect, but does this mean we can modify the following?

def apply_snr_weight_noise_pred(loss, timesteps, noise_scheduler, gamma):
    snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
    gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
    snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device)  # from paper
    loss = loss * snr_weight
    return loss


def apply_snr_weight_alt(v_prediction, loss, timesteps, noise_scheduler, gamma):
    if not v_prediction:
        return apply_snr_weight_noise_pred(loss, timesteps, noise_scheduler, gamma)

    snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
    min_snr_gamma = torch.minimum(snr, gamma)
    snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
    loss = loss * snr_weight
    return loss


# we can remove this function
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
    snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])  # batch_size
    snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000)  # if timestep is 0, snr_t is inf, so limit it to 1000
    scale = snr_t / (snr_t + 1)

    loss = loss * scale
    return loss
commented

That should work, but I think the cleanest way to implement it would be to change the denominator based on v-prediction. If it is epsilon-prediction, it should be snr, if v-prediction it should be snr + 1. The epsilon prediction case for that should be equivalent to the original implementation.

def apply_snr_weight_alt(v_prediction, loss, timesteps, noise_scheduler, gamma):
    snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
    min_snr_gamma = torch.minimum(snr, gamma)
    if v_prediction:
        snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
    else:
        snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
    loss = loss * snr_weight
    return loss

Thank you for clarification!

The formulas seem to say that if we apply the current apply_snr_weight (apply_snr_weight_noise_pred above) and scale_v_prediction_loss_like_noise_prediction at the same time, we should be fine for the v-prediction case. Is this correct?

Both options can be specified at the same time.

commented

No, the scale_v_prediction_loss_like_noise_prediction doesn't function the same as the v prediction path of my apply_snr_weight_alt, in part due to the clipping at timestep 0 which is the likely cause of the interference with zero terminal SNR I mentioned. apply_snr_weight_alt ensures that infinite SNR at timestep 0 doesn't cause problems.

edit: I'd also like to emphasize that the formula used in the v-prediction code path should be outright the correct implementation of min-SNR-gamma, as in there shouldn't be a separate loss rescale function for v-prediction that is optional. min-SNR-gamma used on v-prediction should always behave like this. Compatibility is a possible concern, but at least from the testing I've seen so far this implementation gives better results than the two loss rescales.

By the way, is clipping necessary?
The snr when timestep is 0 is the snr when x_1, not x_0. Therefore, it is not inf.

scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")

def get_snr(
    scheduler, 
    timesteps: torch.IntTensor,
) -> torch.FloatTensor:

    sqrt_alpha_prod = scheduler.alphas_cumprod[timesteps] ** 0.5
    sqrt_alpha_prod = sqrt_alpha_prod.flatten()

    sqrt_one_minus_alpha_prod = (1 - scheduler.alphas_cumprod[timesteps]) ** 0.5
    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()

    return (sqrt_alpha_prod / sqrt_one_minus_alpha_prod) ** 2

get_snr(scheduler, torch.tensor(0))

# tensor([1175.4406])

Any updates?

As laksjdjf wrote, I believe it is OK when we specify both apply_snr_weight and scale_v_prediction_loss_like_noise_prediction options.

Unfortunately not. Combining both leads to loss being scaled twice. scale_v_prediction_loss_like_noise_prediction is applying the v-pred version of the formula from the paper but with a hardcoded gamma=1000.
They really should be one function like @drhead's example.

commented

I should reiterate that apply_snr_weight as it exists currently is an incorrect implementation of the paper when training using v-prediction. Using both that and scale_v_prediction_loss_like_noise_prediction is not mathematically equivalent to the implementation I have provided, and our testing has shown that the corrected version performs better.

i can confirm this after having discussed it with Tian, one of the original paper authors.

additionally, i've implemented the fix in SimpleTuner, as a non-conditional fix for v_prediction type models when min-snr gamma is in use.

May be my lack of knowledge but I had always wondered why for my dataset setting SNR seemed to yield worse results in a way I couldnt really explain. Glad to know I wasnt imagining things. Hope this can be solved soon

Fixed with merge of #934