a problem about the code,thanks
aosong01 opened this issue · comments
it seems that you change all the basictransformerblock in both down_blocks, mid_blocks and up_blocks. why still change the up_blocks in the unet again?
def register_extended_attention(model):
for _, module in model.unet.named_modules():
if isinstance_str(module, "BasicTransformerBlock"):
module.attn1.forward = sa_forward(module.attn1)
res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
# we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
for res in res_dict:
for block in res_dict[res]:
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
module.forward = sa_forward(module)
I have the same query
The first for loop modifies the following blocks:
down_blocks.0.attentions.0.transformer_blocks.0
down_blocks.0.attentions.1.transformer_blocks.0
down_blocks.1.attentions.0.transformer_blocks.0
down_blocks.1.attentions.1.transformer_blocks.0
down_blocks.2.attentions.0.transformer_blocks.0
down_blocks.2.attentions.1.transformer_blocks.0
up_blocks.1.attentions.0.transformer_blocks.0
up_blocks.1.attentions.1.transformer_blocks.0
up_blocks.1.attentions.2.transformer_blocks.0
up_blocks.2.attentions.0.transformer_blocks.0
up_blocks.2.attentions.1.transformer_blocks.0
up_blocks.2.attentions.2.transformer_blocks.0
up_blocks.3.attentions.0.transformer_blocks.0
up_blocks.3.attentions.1.transformer_blocks.0
up_blocks.3.attentions.2.transformer_blocks.0
mid_block.attentions.0.transformer_blocks.0
The second for loop modifies:
up_blocks.1.attentions.1.transformer_blocks.0.attn1
up_blocks.1.attentions.2.transformer_blocks.0.attn1
up_blocks.2.attentions.0.transformer_blocks.0.attn1
up_blocks.2.attentions.1.transformer_blocks.0.attn1
up_blocks.2.attentions.2.transformer_blocks.0.attn1
up_blocks.3.attentions.0.transformer_blocks.0.attn1
up_blocks.3.attentions.1.transformer_blocks.0.attn1
up_blocks.3.attentions.2.transformer_blocks.0.attn1
Which is a subset of the first for loop.
according to the comment, the first block of the lowest resolution shouldn't have extended attention registered. the first for loop registers extended attention for that block as well.
同问
I think the valid function should be register_extended_attention_pnp
where a list injection_schedule
is defined.
Lines 203 to 214 in 8ae24e9
The injection is activated according to injection_schedule
.
Lines 124 to 130 in 8ae24e9
Lines 86 to 91 in 8ae24e9
BTW, I tried removing the first loop in L203-L206 and found the result was not changed. However, when removing the second loop in L208-L214, the result would get worse.