Assert prevents mixed_precision?
mattmisk opened this issue · comments
mattmisk commented
Line 2661 of imagen_pytorch:
assert images.dtype == torch.float, f'images tensor needs to be floats but {images.dtype} dtype found instead'
This assert triggers when using mixed_precision = True. Am I doing something wrong, or is this a bug? Everything works fine if I comment the line out.