dome272 / Diffusion-Models-pytorch

Pytorch implementation of Diffusion Models (https://arxiv.org/pdf/2006.11239.pdf)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

When i change the image_size to 128, the error occurs

lion-ops opened this issue · comments

image

How can i solve it?

I also want to ask the problem.

Take a look here: #7
Let me know if that helps

thank you.

modify modules.py code:
class UNet(nn.Module): #modified by hx add param imgsize
def init(self, c_in=3, c_out=3, imgsize=64,time_dim=256, device="cuda"):
super().init()
self.device = device
self.time_dim = time_dim
self.inc = DoubleConv(c_in, 64)
self.down1 = Down(64, 128)
self.sa1 = SelfAttention(128, imgsize//2)
self.down2 = Down(128, 256)
self.sa2 = SelfAttention(256, imgsize//4)
self.down3 = Down(256, 256)
self.sa3 = SelfAttention(256, imgsize//8)

    self.bot1 = DoubleConv(256, 512)
    self.bot2 = DoubleConv(512, 512)
    self.bot3 = DoubleConv(512, 256)

    self.up1 = Up(512, 128)
    self.sa4 = SelfAttention(128, imgsize//4)
    self.up2 = Up(256, 64)
    self.sa5 = SelfAttention(64, imgsize//2)
    self.up3 = Up(128, 64)
    self.sa6 = SelfAttention(64, imgsize)
    self.outc = nn.Conv2d(64, c_out, kernel_size=1)

and

class UNet_conditional(nn.Module): #modified by hx add param imgsize
def init(self, c_in=3, c_out=3, imgsize=64,time_dim=256, num_classes=None, device="cuda"):
super().init()
self.device = device
self.time_dim = time_dim
self.inc = DoubleConv(c_in, 64)
self.down1 = Down(64, 128)
self.sa1 = SelfAttention(128, imgsize//2)
self.down2 = Down(128, 256)
self.sa2 = SelfAttention(256, imgsize//4)
self.down3 = Down(256, 256)
self.sa3 = SelfAttention(256, imgsize//8)

    self.bot1 = DoubleConv(256, 512)
    self.bot2 = DoubleConv(512, 512)
    self.bot3 = DoubleConv(512, 256)

    self.up1 = Up(512, 128)
    self.sa4 = SelfAttention(128, imgsize//4)
    self.up2 = Up(256, 64)
    self.sa5 = SelfAttention(64, imgsize//2)
    self.up3 = Up(128, 64)
    self.sa6 = SelfAttention(64, imgsize)
    self.outc = nn.Conv2d(64, c_out, kernel_size=1)