salesforce / UniControl

Unified Controllable Visual Generation Model

Home Page:https://canqin001.github.io/UniControl-Page/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

mult condition mult task input same time

zdxpan opened this issue · comments

how to perform multi task condition in one infer?

for example:

# condition 1 canny
        img = resize_image(HWC3(input_image), image_resolution)
        H, W, C = img.shape
        if condition_mode == True:
            detected_map = apply_canny(img, low_threshold, high_threshold)
            detected_map = HWC3(detected_map)
        else:
            detected_map = 255 - img

        control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
        control = torch.stack([control for _ in range(num_samples)], dim=0)
        control = einops.rearrange(control, 'b h w c -> b c h w').clone()
      
       control1 = einops.rearrange(control, 'b h w c -> b c h w').clone()

# condition 2  depth 
        img = resize_image(input_image, image_resolution)
        H, W, C = img.shape
        if condition_mode == True:
            detected_map = apply_hed(resize_image(input_image, detect_resolution))
            detected_map = HWC3(detected_map)
        else:
            detected_map = img
            
        detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)

        control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
        control = torch.stack([control for _ in range(num_samples)], dim=0)
        control2 = einops.rearrange(control, 'b h w c -> b c h w').clone()

      cond = {"c_concat": [control1, control2], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)], "task": task_dic}

DDIM Sampler: 0%| | 0/31 [00:00<?, ?it/s]
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in :53 │
│ │
│ 50 │ if config.save_memory: │
│ 51 │ │ model.low_vram_shift(is_diffusing=True) │
│ 52 │ │
│ ❱ 53 │ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, │
│ 54 │ │ │ │ │ │ │ │ │ │ │ │ shape, cond, verbose=False, eta=eta, │
│ 55 │ │ │ │ │ │ │ │ │ │ │ │ unconditional_guidance_scale=scale, │
│ 56 │ │ │ │ │ │ │ │ │ │ │ │ unconditional_conditioning=un_cond) │
│ │
│ /home/dell/.conda/envs/sd/lib/python3.8/site-packages/torch/utils/_contextlib.py:115 in │
│ decorate_context │
│ │
│ 112 │ @functools.wraps(func) │
│ 113 │ def decorate_context(*args, **kwargs): │
│ 114 │ │ with ctx_factory(): │
│ ❱ 115 │ │ │ return func(*args, **kwargs) │
│ 116 │ │
│ 117 │ return decorate_context │
│ 118 │
│ │
│ /home/dell/workspace/UniControl/cldm/ddim_unicontrol_hacked.py:113 in sample │
│ │
│ 110 │ │ size = (batch_size, C, H, W) │
│ 111 │ │ print(f'Data shape for DDIM sampling is {size}, eta {eta}') │
│ 112 │ │ │
│ ❱ 113 │ │ samples, intermediates = self.ddim_sampling(conditioning, size, │
│ 114 │ │ │ │ │ │ │ │ │ │ │ │ │ callback=callback, │
│ 115 │ │ │ │ │ │ │ │ │ │ │ │ │ img_callback=img_callback, │
│ 116 │ │ │ │ │ │ │ │ │ │ │ │ │ quantize_denoised=quantize_x0, │
│ │
│ /home/dell/.conda/envs/sd/lib/python3.8/site-packages/torch/utils/_contextlib.py:115 in │
│ decorate_context │
│ │
│ 112 │ @functools.wraps(func) │
│ 113 │ def decorate_context(*args, **kwargs): │
│ 114 │ │ with ctx_factory(): │
│ ❱ 115 │ │ │ return func(*args, **kwargs) │
│ 116 │ │
│ 117 │ return decorate_context │
│ 118 │
│ │
│ /home/dell/workspace/UniControl/cldm/ddim_unicontrol_hacked.py:173 in ddim_sampling │
│ │
│ 170 │ │ │ │ assert len(ucg_schedule) == len(time_range) │
│ 171 │ │ │ │ unconditional_guidance_scale = ucg_schedule[i] │
│ 172 │ │ │ │
│ ❱ 173 │ │ │ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddi │
│ 174 │ │ │ │ │ │ │ │ │ quantize_denoised=quantize_denoised, temperature=t │
│ 175 │ │ │ │ │ │ │ │ │ noise_dropout=noise_dropout, score_corrector=score │
│ 176 │ │ │ │ │ │ │ │ │ corrector_kwargs=corrector_kwargs, │
│ │
│ /home/dell/.conda/envs/sd/lib/python3.8/site-packages/torch/utils/_contextlib.py:115 in │
│ decorate_context │
│ │
│ 112 │ @functools.wraps(func) │
│ 113 │ def decorate_context(*args, **kwargs): │
│ 114 │ │ with ctx_factory(): │
│ ❱ 115 │ │ │ return func(*args, **kwargs) │
│ 116 │ │
│ 117 │ return decorate_context │
│ 118 │
│ │
│ /home/dell/workspace/UniControl/cldm/ddim_unicontrol_hacked.py:211 in p_sample_ddim │
│ │
│ 208 │ │ │ │ │ if k == 'task': │
│ 209 │ │ │ │ │ │ continue │
│ 210 │ │ │ │ │ if isinstance(c[k], list): │
│ ❱ 211 │ │ │ │ │ │ c_in[k] = [torch.cat([ │
│ 212 │ │ │ │ │ │ │ unconditional_conditioning[k][i], │
│ 213 │ │ │ │ │ │ │ c[k][i]]) for i in range(len(c[k]))] │
│ 214 │ │ │ │ │ else: │
│ │
│ /home/dell/workspace/UniControl/cldm/ddim_unicontrol_hacked.py:212 in │
│ │
│ 209 │ │ │ │ │ │ continue │
│ 210 │ │ │ │ │ if isinstance(c[k], list): │
│ 211 │ │ │ │ │ │ c_in[k] = [torch.cat([ │
│ ❱ 212 │ │ │ │ │ │ │ unconditional_conditioning[k][i], │
│ 213 │ │ │ │ │ │ │ c[k][i]]) for i in range(len(c[k]))] │
│ 214 │ │ │ │ │ else: │
│ 215 │ │ │ │ │ │ c_in[k] = torch.cat([ │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: 'NoneType' object is not subscriptable

For multi-condition control, please add their MoE's features with the weighted average and concatenate their task prompts.