TMElyralab / MuseTalk

MuseTalk: Real-Time High Quality Lip Synchorization with Latent Space Inpainting

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

写了一版训练代码,可是合成视频上半张脸和下半张脸感觉是分离的,而且不连贯抖动明显

gobigrassland opened this issue · comments

基于MuseTalk项目介绍,实现了一版训练代码。然后与train_codes分支进行对比,大体上是一致。然后也是基于约350条hdtf数据,训练出的模型合成效果出现了如下两个明显问题:
(1)上半张脸和下半张脸,不协调,能看出色调不一致。(有些case还是比较明显的)
(2)视频下半张脸抖动特别明显,视频播放看出是不连贯的。
请问,问题出现在什么地方?有什么建议吗?怎么才能复现当前项目开源的模型效果

原始视频:

out.mp4

合成后视频:
https://github.com/TMElyralab/MuseTalk/assets/23277618/df55aa88-c007-454d-92fc-9c0dc6d51504

Hi,

Thanks for your interest in our work. Based on your problem description, I suggest you check the following issues:

  1. Do you have a validation set separated from the HDTF dataset? This can help you verify if there is overfitting during the training process.
  2. Is the difference between the reference image and the target image too large during the training process? This may cause the reference image information to not be utilized effectively.

Hi,

Thanks for your interest in our work. Based on your problem description, I suggest you check the following issues:

  1. Do you have a validation set separated from the HDTF dataset? This can help you verify if there is overfitting during the training process.
  2. Is the difference between the reference image and the target image too large during the training process? This may cause the reference image information to not be utilized effectively.

(1)我没有单独分离出验证集,但是也测试了早期step较小时保存下来的模型文件,也是类似问题。后续我也加强这一点
(2)这个是随机选择的,与当前train_codes分支设置保持一致。在#65 中也与大家有讨论。
(3)关于latent和image维度的loss,权重是如何设置的?有这块实验对比吗?

Hi,

Thanks for your interest in our work. Based on your problem description, I suggest you check the following issues:

  1. Do you have a validation set separated from the HDTF dataset? This can help you verify if there is overfitting during the training process.
  2. Is the difference between the reference image and the target image too large during the training process? This may cause the reference image information to not be utilized effectively.

Hi,

Thanks for your interest in our work. Based on your problem description, I suggest you check the following issues:

  1. Do you have a validation set separated from the HDTF dataset? This can help you verify if there is overfitting during the training process.
  2. Is the difference between the reference image and the target image too large during the training process? This may cause the reference image information to not be utilized effectively.

你好,我也训练了一个模型,也是差不多的效果,看validation生成的图片清晰度一直都不高,也存在抖的问题,您这边有加额外的数据或者在模型方面做过其他优化吗

请问你们(1)使用多少数据进行训练?(2)使用什么GPU,batch size多少,训练到多少步了呢?

请问你们(1)使用多少数据进行训练?(2)使用什么GPU,batch size多少,训练到多少步了呢?

(1) 我使用了319条 hdtf视频
(2)训练用了8张A800,单卡batchsize=16, 训练step 10万步以上,epoch也有10个。

下面是我写的几处关键代码:预处理、Dataset、训练部分
(1)预处理部分,是参考MuseTalk推理代码修改,提取音频特征,获取人脸框这些代码没有修改。思路是准备好视频每帧裁剪人脸、对应no-mask VAE编码、对应的mask VAE编码、对应的音频特征。这样可以加速训练过程。

    # extract feature from an audio
    whisper_feature = audio_processor.audio2feat(f"{save_dir_full}/{video_audio_name}.mp3")
    whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)

    # face detection
    print("extracting landmarks...time consuming")
    bbox_shift = 0
    coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)

    nomask_latent_list = []
    mask_latent_list = []
    crop_face_list = []
    num = len(frame_list)
    for ind, (bbox, frame) in enumerate(zip(coord_list, frame_list)):
        height, width, _ = frame.shape
        x1, y1, x2, y2 = bbox_check(bbox, height, width)
        # current crop_frame
        crop_frame = frame[y1:y2, x1:x2]
        face = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
        nomask_latent = vae.get_latents_for_nomask(face)
        mask_latent = vae.get_latents_for_mask(face)
        crop_face_list.append(face)
        nomask_latent_list.append(nomask_latent)
        mask_latent_list.append(mask_latent)

    for ind in range(num):
        meta_dict = dict()
        meta_dict['whisper'] = whisper_chunks[ind]
        meta_dict['face'] = crop_face_list[ind]
        meta_dict['nomask_latent'] = nomask_latent_list[ind].detach().cpu().float().numpy()
        meta_dict['mask_latent'] = mask_latent_list[ind].detach().cpu().float().numpy()
        with open(f"{save_dir_full}/{ind}.pkl", "wb") as pickle_file:
            pickle.dump(meta_dict, pickle_file)

(2)Dataset 获取当前帧,然后随机选择其它帧图片作为ref image.

class MuseTalkDataset(Dataset):
    def __init__(self, root_dir, index_file):
        super(MuseTalkDataset, self).__init__()
        self.root_dir = root_dir
        self.index_file = index_file
        with open(os.path.join(self.root_dir, self.index_file), 'r') as fid:
            self.train_data = fid.readlines()
            self.train_data = [ele.strip() for ele in self.train_data]

    def __getitem__(self, index):
        train_file = self.train_data[index]
        with open(os.path.join(self.root_dir, train_file), "rb") as pickle_file:
            sample = pickle.load(pickle_file)
        whisper, gt_nomask_latent, gt_mask_latent, gt_face = sample['whisper'], sample['nomask_latent'], sample['mask_latent'],  sample['face']

        cur_path = os.path.join(self.root_dir, train_file)
        cur_ind = int(os.path.basename(cur_path).split(".")[0])
        num_frames = np.load(os.path.join(os.path.dirname(cur_path), "number.npy"))
        rand_ind = np.random.randint(0,num_frames, (1))
        while abs(rand_ind[0]-cur_ind) < 5:
            rand_ind = np.random.randint(0,num_frames, (1))
        rand_ind = rand_ind[0]

        rand_path = os.path.join(os.path.dirname(cur_path), f"{rand_ind}.pkl")
        with open(rand_path, "rb") as pickle_file:
            rand_sample = pickle.load(pickle_file)
        ref_nomask_latent = rand_sample['nomask_latent']

        whisper = torch.tensor(whisper)
        input_latent = torch.cat([torch.tensor(gt_mask_latent), torch.tensor(ref_nomask_latent)], dim=1).squeeze(dim=0)
        gt_nomask_latent = torch.tensor(gt_nomask_latent).squeeze(dim=0)
        gt_face = torch.tensor(gt_face).to(torch.float)
        return whisper, input_latent, gt_nomask_latent, gt_face

(3)训练. 代码中生成的图片取值在0-255之间。尝试了不同的loss权重,对结果影响不大

        device = torch.device('cuda', self.local_rank)
        self.unet.model = torch.nn.parallel.DistributedDataParallel(self.unet.model, device_ids=[self.local_rank])
        self.vae.vae.to(device)

        l1_loss_func = torch.nn.L1Loss()
        optimizer = optim.AdamW(params=self.unet.model.parameters(), lr = 0.00001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500000, eta_min=0.0000001  )

        self.unet.model.train()
        self.vae.vae.requires_grad_(False)

        self.epoch = 0
        data_sampler, data_loader, dataset = self.make_input()
        num_samples = dataset.__len__()
        data_iter = iter(data_loader)
        for step in range(500000):
            time1 = time.time()
            try:
                whisper, input_latent, gt_latent, gt_face = next(data_iter)
            except Exception as err:
                # shuffle every epoch
                self.epoch += 1
                data_sampler.set_epoch(self.epoch)
                data_iter = iter(data_loader)
                whisper, input_latent, gt_latent, gt_face = next(data_iter)
                logging_rank0(f"A new iteration begins. Current epoch: {self.epoch}")


            whisper = self.pe(whisper.to(device))
            input_latent = input_latent.to(device)
            gt_latent = gt_latent.to(device)
            gt_face = gt_face.to(device)
            self.timesteps = self.timesteps.to(device)

            latent_pred = self.unet.model(input_latent, self.timesteps, encoder_hidden_states=whisper).sample
            face_pred = self.vae.decode_latents_for_train(latent_pred)

            # using half face to compute loss
            l1 = l1_loss_func(latent_pred, gt_latent)
            l2 = l1_loss_func(face_pred[:,128:,:,:], gt_face[:,128:,:,:])
            alpha = 10
            beta = 1
            loss = alpha * l1 + beta*l2

请问你们(1)使用多少数据进行训练?(2)使用什么GPU,batch size多少,训练到多少步了呢?

(1)我这边用了320条HDTF视频数据,260条训练,60条测试
(2)用的1张A800,batch size=32,训练到50000步结束。其他参数都延用train.sh里默认的参数

请问你们(1)使用多少数据进行训练?(2)使用什么GPU,batch size多少,训练到多少步了呢?

(1)我这边用了320条HDTF视频数据,260条训练,60条测试 (2)用的1张A800,batch size=32,训练到50000步结束。其他参数都延用train.sh里默认的参数

你训练时,用于计算图片级损失这块,与train_codes分支保持一致吗? (我是转换为0-255 图片空间计算的loss,不知这块是否有关影响)
train_codes 分支代码

image_pred_img = (1 / vae_fp32.config.scaling_factor) * image_pred
image_pred_img = vae_fp32.decode(image_pred_img).sample
# Mask the top half of the image and calculate the loss only for the lower half of the image.
image_pred_img = image_pred_img[:, :, image_pred_img.shape[2]//2:, :]
image = image[:, :, image.shape[2]//2:, :]    
loss_lip = F.l1_loss(image_pred_img.float(), image.float(), reduction="mean") # the loss of the decoded images

请问你们(1)使用多少数据进行训练?(2)使用什么GPU,batch size多少,训练到多少步了呢?

(1)我这边用了320条HDTF视频数据,260条训练,60条测试 (2)用的1张A800,batch size=32,训练到50000步结束。其他参数都延用train.sh里默认的参数

你训练时,用于计算图片级损失这块,与train_codes分支保持一致吗? (我是转换为0-255 图片空间计算的loss,不知这块是否有关影响) train_codes 分支代码

image_pred_img = (1 / vae_fp32.config.scaling_factor) * image_pred
image_pred_img = vae_fp32.decode(image_pred_img).sample
# Mask the top half of the image and calculate the loss only for the lower half of the image.
image_pred_img = image_pred_img[:, :, image_pred_img.shape[2]//2:, :]
image = image[:, :, image.shape[2]//2:, :]    
loss_lip = F.l1_loss(image_pred_img.float(), image.float(), reduction="mean") # the loss of the decoded images

我是和train_codes保持一致的,但是效果跟你的一样,也是比较抖,清晰度也不是很好

请问你们(1)使用多少数据进行训练?(2)使用什么GPU,batch size多少,训练到多少步了呢?

(1)我这边用了320条HDTF视频数据,260条训练,60条测试 (2)用的1张A800,batch size=32,训练到50000步结束。其他参数都延用train.sh里默认的参数

你训练时,用于计算图片级损失这块,与train_codes分支保持一致吗? (我是转换为0-255 图片空间计算的loss,不知这块是否有关影响) train_codes 分支代码

image_pred_img = (1 / vae_fp32.config.scaling_factor) * image_pred
image_pred_img = vae_fp32.decode(image_pred_img).sample
# Mask the top half of the image and calculate the loss only for the lower half of the image.
image_pred_img = image_pred_img[:, :, image_pred_img.shape[2]//2:, :]
image = image[:, :, image.shape[2]//2:, :]    
loss_lip = F.l1_loss(image_pred_img.float(), image.float(), reduction="mean") # the loss of the decoded images

loss方面是和train_codes保持一致的

loss_lip

我们的训练经验来看,需要的步数比较多,至少15w步以上。可以尝试训练更多步数

loss_lip

我们的训练经验来看,需要的步数比较多,至少15w步以上。可以尝试训练更多步数

增加batch size同时成比例的减少训练步数,会对结果造成影响吗?

基于MuseTalk项目介绍,实现了一版训练代码。然后与train_codes分支进行对比,大体上是一致。然后也是基于约350条hdtf数据,训练出的模型合成效果出现了如下两个明显问题: (1)上半张脸和下半张脸,不协调,能看出色调不一致。(有些case还是比较明显的) (2)视频下半张脸抖动特别明显,视频播放看出是不连贯的。 请问,问题出现在什么地方?有什么建议吗?怎么才能复现当前项目开源的模型效果

原始视频:

out.mp4
合成后视频: https://github.com/TMElyralab/MuseTalk/assets/23277618/df55aa88-c007-454d-92fc-9c0dc6d51504

想确认一下,是只用了HDTF数据集训练的吗?如果掺杂一些质量不好的数据集,可能会导致问题。

可以参照这里的代码下载HDTF。
https://github.com/universome/HDTF/blob/main/download.py
以及训练之前要将视频重采样到25FPS。

基于MuseTalk项目介绍,实现了一版训练代码。然后与train_codes分支进行对比,大体上是一致。然后也是基于约350条hdtf数据,训练出的模型合成效果出现了如下两个明显问题: (1)上半张脸和下半张脸,不协调,能看出色调不一致。(有些case还是比较明显的) (2)视频下半张脸抖动特别明显,视频播放看出是不连贯的。 请问,问题出现在什么地方?有什么建议吗?怎么才能复现当前项目开源的模型效果
原始视频:
out.mp4
合成后视频: https://github.com/TMElyralab/MuseTalk/assets/23277618/df55aa88-c007-454d-92fc-9c0dc6d51504

想确认一下,是只用了HDTF数据集训练的吗?如果掺杂一些质量不好的数据集,可能会导致问题。

可以参照这里的代码下载HDTF。 https://github.com/universome/HDTF/blob/main/download.py 以及训练之前要将视频重采样到25FPS。

仅使用了HDTF数据集训练,由于下载失效等原因,仅保留了其中约320条训练数据。而且都是重采样到25fps。
我目前还在排除原因,不知道哪个环节出问题。您这边是能复现与现在开源的模型相一致的效果是吗?

loss_lip

我们的训练经验来看,需要的步数比较多,至少15w步以上。可以尝试训练更多步数

你们训练出当前模型,latent loss和image loss大概在什么范围。我这训练epoch=22(316个视频)的latent loss 约0.09-0.1,image loss约0.025-0.03; epoch=10时,latent loss约0.10-0.11,image loss约0.03-0.035。

请问你们(1)使用多少数据进行训练?(2)使用什么GPU,batch size多少,训练到多少步了呢?

(1) 我使用了319条 hdtf视频 (2)训练用了8张A800,单卡batchsize=16, 训练step 10万步以上,epoch也有10个。

下面是我写的几处关键代码:预处理、Dataset、训练部分 (1)预处理部分,是参考MuseTalk推理代码修改,提取音频特征,获取人脸框这些代码没有修改。思路是准备好视频每帧裁剪人脸、对应no-mask VAE编码、对应的mask VAE编码、对应的音频特征。这样可以加速训练过程。

    # extract feature from an audio
    whisper_feature = audio_processor.audio2feat(f"{save_dir_full}/{video_audio_name}.mp3")
    whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)

    # face detection
    print("extracting landmarks...time consuming")
    bbox_shift = 0
    coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)

    nomask_latent_list = []
    mask_latent_list = []
    crop_face_list = []
    num = len(frame_list)
    for ind, (bbox, frame) in enumerate(zip(coord_list, frame_list)):
        height, width, _ = frame.shape
        x1, y1, x2, y2 = bbox_check(bbox, height, width)
        # current crop_frame
        crop_frame = frame[y1:y2, x1:x2]
        face = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
        nomask_latent = vae.get_latents_for_nomask(face)
        mask_latent = vae.get_latents_for_mask(face)
        crop_face_list.append(face)
        nomask_latent_list.append(nomask_latent)
        mask_latent_list.append(mask_latent)

    for ind in range(num):
        meta_dict = dict()
        meta_dict['whisper'] = whisper_chunks[ind]
        meta_dict['face'] = crop_face_list[ind]
        meta_dict['nomask_latent'] = nomask_latent_list[ind].detach().cpu().float().numpy()
        meta_dict['mask_latent'] = mask_latent_list[ind].detach().cpu().float().numpy()
        with open(f"{save_dir_full}/{ind}.pkl", "wb") as pickle_file:
            pickle.dump(meta_dict, pickle_file)

(2)Dataset 获取当前帧,然后随机选择其它帧图片作为ref image.

class MuseTalkDataset(Dataset):
    def __init__(self, root_dir, index_file):
        super(MuseTalkDataset, self).__init__()
        self.root_dir = root_dir
        self.index_file = index_file
        with open(os.path.join(self.root_dir, self.index_file), 'r') as fid:
            self.train_data = fid.readlines()
            self.train_data = [ele.strip() for ele in self.train_data]

    def __getitem__(self, index):
        train_file = self.train_data[index]
        with open(os.path.join(self.root_dir, train_file), "rb") as pickle_file:
            sample = pickle.load(pickle_file)
        whisper, gt_nomask_latent, gt_mask_latent, gt_face = sample['whisper'], sample['nomask_latent'], sample['mask_latent'],  sample['face']

        cur_path = os.path.join(self.root_dir, train_file)
        cur_ind = int(os.path.basename(cur_path).split(".")[0])
        num_frames = np.load(os.path.join(os.path.dirname(cur_path), "number.npy"))
        rand_ind = np.random.randint(0,num_frames, (1))
        while abs(rand_ind[0]-cur_ind) < 5:
            rand_ind = np.random.randint(0,num_frames, (1))
        rand_ind = rand_ind[0]

        rand_path = os.path.join(os.path.dirname(cur_path), f"{rand_ind}.pkl")
        with open(rand_path, "rb") as pickle_file:
            rand_sample = pickle.load(pickle_file)
        ref_nomask_latent = rand_sample['nomask_latent']

        whisper = torch.tensor(whisper)
        input_latent = torch.cat([torch.tensor(gt_mask_latent), torch.tensor(ref_nomask_latent)], dim=1).squeeze(dim=0)
        gt_nomask_latent = torch.tensor(gt_nomask_latent).squeeze(dim=0)
        gt_face = torch.tensor(gt_face).to(torch.float)
        return whisper, input_latent, gt_nomask_latent, gt_face

(3)训练. 代码中生成的图片取值在0-255之间。尝试了不同的loss权重,对结果影响不大

        device = torch.device('cuda', self.local_rank)
        self.unet.model = torch.nn.parallel.DistributedDataParallel(self.unet.model, device_ids=[self.local_rank])
        self.vae.vae.to(device)

        l1_loss_func = torch.nn.L1Loss()
        optimizer = optim.AdamW(params=self.unet.model.parameters(), lr = 0.00001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500000, eta_min=0.0000001  )

        self.unet.model.train()
        self.vae.vae.requires_grad_(False)

        self.epoch = 0
        data_sampler, data_loader, dataset = self.make_input()
        num_samples = dataset.__len__()
        data_iter = iter(data_loader)
        for step in range(500000):
            time1 = time.time()
            try:
                whisper, input_latent, gt_latent, gt_face = next(data_iter)
            except Exception as err:
                # shuffle every epoch
                self.epoch += 1
                data_sampler.set_epoch(self.epoch)
                data_iter = iter(data_loader)
                whisper, input_latent, gt_latent, gt_face = next(data_iter)
                logging_rank0(f"A new iteration begins. Current epoch: {self.epoch}")


            whisper = self.pe(whisper.to(device))
            input_latent = input_latent.to(device)
            gt_latent = gt_latent.to(device)
            gt_face = gt_face.to(device)
            self.timesteps = self.timesteps.to(device)

            latent_pred = self.unet.model(input_latent, self.timesteps, encoder_hidden_states=whisper).sample
            face_pred = self.vae.decode_latents_for_train(latent_pred)

            # using half face to compute loss
            l1 = l1_loss_func(latent_pred, gt_latent)
            l2 = l1_loss_func(face_pred[:,128:,:,:], gt_face[:,128:,:,:])
            alpha = 10
            beta = 1
            loss = alpha * l1 + beta*l2

您好,感谢您的分享。有尝试使用您的主代码进行训练,但是loss_lip达到了100多,想咨询一下您是否遇到过类似的情况。

@liuzysy 没有遇到类似情况。从作者的原模型出发,在hdtf数据集上训练,step=0时,lip_loss按我的写法,也就是9附近,而且很快就下降。你看看是否图片通道那里出错了。你可以按照作者代码的写法,进行预处理后再计算损失值。

@liuzysy 没有遇到类似情况。从作者的原模型出发,在hdtf数据集上训练,step=0时,lip_loss按我的写法,也就是9附近,而且很快就下降。你看看是否图片通道那里出错了。你可以按照作者代码的写法,进行预处理后再计算损失值。

请问目前训练效果如何呢?我们重新训练了一次,发现效果与当前开源权重是基本相当的。

@czk32611 关于训练代码还有些问题跟您请教一下
(1)当前开源权重训练的数据就是hdtf吧,具体参与训练的有多少条视频
(2)你们的训练代码数据是否进行了增强
(3)预处理代码生成的人脸区域有没有进行适当外扩,看到推理代码得到的人脸区域实际上没有包括完整的人脸,会漏掉一些边缘。如下图所示,右侧脸颊部位
image

我最近仅仅基于320条hdtf训练数据,虽然合成的视频效果相比2周前有提升,但是还是无法复现开源权重效果,合成质量还是不佳。不知问题何在。

通过适当增加一些从国内视频网站收集的视频和数据增强,上述问题得到很大改善。但是目前和开源权重还有一些差距,其中牙齿部分就明显一些。

附上我基于作者开源权重,冻结所有模型参数。在hdtf数据集上,latent loss 与 half-face loss 的情况。基本上latent loss在0.16-0.18区间,half-face loss 在0.05-0.07之间。
我训练过程loss会比这个要低,不管是否使用hdtf还去额外收集的数据集。

initialize model's weights from models/musetalk/pytorch_model.bin
initialize model's weights from models/musetalk/pytorch_model.bin
initialize model's weights from models/musetalk/pytorch_model.bin
initialize model's weights from models/musetalk/pytorch_model.bin
initialize model's weights from models/musetalk/pytorch_model.bin
initialize model's weights from models/musetalk/pytorch_model.bin
initialize model's weights from models/musetalk/pytorch_model.bin
initialize model's weights from models/musetalk/pytorch_model.bin
2024-05-27 12:39:54,402: step: 0/(1122759/16/8), epoch: 0, latent loss: 0.17210, alpha: 2, latent loss(*2): 0.34420, face loss: 0.06612, face loss(*1): 0.06612, total loss: 0.41032, lr: 0.00001000, during time: 31.19, samples/sec: 8.21
2024-05-27 12:39:56,030: step: 2/(1122759/16/8), epoch: 0, latent loss: 0.16610, alpha: 2, latent loss(*2): 0.33219, face loss: 0.06022, face loss(*1): 0.06022, total loss: 0.39242, lr: 0.00001000, during time: 0.53, samples/sec: 484.85
2024-05-27 12:39:57,087: step: 4/(1122759/16/8), epoch: 0, latent loss: 0.17786, alpha: 2, latent loss(*2): 0.35571, face loss: 0.06771, face loss(*1): 0.06771, total loss: 0.42342, lr: 0.00001000, during time: 0.53, samples/sec: 485.07
2024-05-27 12:39:58,144: step: 6/(1122759/16/8), epoch: 0, latent loss: 0.17310, alpha: 2, latent loss(*2): 0.34619, face loss: 0.06454, face loss(*1): 0.06454, total loss: 0.41073, lr: 0.00001000, during time: 0.53, samples/sec: 483.89
2024-05-27 12:39:59,202: step: 8/(1122759/16/8), epoch: 0, latent loss: 0.18312, alpha: 2, latent loss(*2): 0.36625, face loss: 0.07405, face loss(*1): 0.07405, total loss: 0.44029, lr: 0.00001000, during time: 0.53, samples/sec: 483.67
2024-05-27 12:40:00,261: step: 10/(1122759/16/8), epoch: 0, latent loss: 0.18585, alpha: 2, latent loss(*2): 0.37169, face loss: 0.07470, face loss(*1): 0.07470, total loss: 0.44639, lr: 0.00001000, during time: 0.53, samples/sec: 484.30
2024-05-27 12:40:01,319: step: 12/(1122759/16/8), epoch: 0, latent loss: 0.16785, alpha: 2, latent loss(*2): 0.33571, face loss: 0.06366, face loss(*1): 0.06366, total loss: 0.39937, lr: 0.00001000, during time: 0.53, samples/sec: 484.03
2024-05-27 12:40:02,377: step: 14/(1122759/16/8), epoch: 0, latent loss: 0.16573, alpha: 2, latent loss(*2): 0.33146, face loss: 0.06019, face loss(*1): 0.06019, total loss: 0.39165, lr: 0.00001000, during time: 0.53, samples/sec: 483.45
2024-05-27 12:40:03,436: step: 16/(1122759/16/8), epoch: 0, latent loss: 0.16840, alpha: 2, latent loss(*2): 0.33679, face loss: 0.06269, face loss(*1): 0.06269, total loss: 0.39949, lr: 0.00001000, during time: 0.53, samples/sec: 484.85
2024-05-27 12:40:04,495: step: 18/(1122759/16/8), epoch: 0, latent loss: 0.18388, alpha: 2, latent loss(*2): 0.36776, face loss: 0.07651, face loss(*1): 0.07651, total loss: 0.44427, lr: 0.00001000, during time: 0.53, samples/sec: 484.19
2024-05-27 12:40:05,555: step: 20/(1122759/16/8), epoch: 0, latent loss: 0.17007, alpha: 2, latent loss(*2): 0.34014, face loss: 0.05818, face loss(*1): 0.05818, total loss: 0.39832, lr: 0.00001000, during time: 0.53, samples/sec: 483.39
2024-05-27 12:40:06,612: step: 22/(1122759/16/8), epoch: 0, latent loss: 0.17128, alpha: 2, latent loss(*2): 0.34257, face loss: 0.06878, face loss(*1): 0.06878, total loss: 0.41135, lr: 0.00001000, during time: 0.53, samples/sec: 484.69
2024-05-27 12:40:07,669: step: 24/(1122759/16/8), epoch: 0, latent loss: 0.16634, alpha: 2, latent loss(*2): 0.33267, face loss: 0.06333, face loss(*1): 0.06333, total loss: 0.39600, lr: 0.00001000, during time: 0.53, samples/sec: 484.31
2024-05-27 12:40:08,728: step: 26/(1122759/16/8), epoch: 0, latent loss: 0.18350, alpha: 2, latent loss(*2): 0.36701, face loss: 0.07385, face loss(*1): 0.07385, total loss: 0.44085, lr: 0.00001000, during time: 0.53, samples/sec: 483.39
2024-05-27 12:40:09,789: step: 28/(1122759/16/8), epoch: 0, latent loss: 0.17434, alpha: 2, latent loss(*2): 0.34868, face loss: 0.06881, face loss(*1): 0.06881, total loss: 0.41749, lr: 0.00001000, during time: 0.53, samples/sec: 483.25
2024-05-27 12:40:10,849: step: 30/(1122759/16/8), epoch: 0, latent loss: 0.16171, alpha: 2, latent loss(*2): 0.32342, face loss: 0.05599, face loss(*1): 0.05599, total loss: 0.37941, lr: 0.00001000, during time: 0.53, samples/sec: 483.49
2024-05-27 12:40:11,909: step: 32/(1122759/16/8), epoch: 0, latent loss: 0.18266, alpha: 2, latent loss(*2): 0.36533, face loss: 0.07115, face loss(*1): 0.07115, total loss: 0.43648, lr: 0.00001000, during time: 0.53, samples/sec: 482.99
2024-05-27 12:40:12,968: step: 34/(1122759/16/8), epoch: 0, latent loss: 0.18183, alpha: 2, latent loss(*2): 0.36365, face loss: 0.07388, face loss(*1): 0.07388, total loss: 0.43753, lr: 0.00001000, during time: 0.53, samples/sec: 483.58

Attached is my result based on the author's open source weights, freezing all model parameters. On the hdtf dataset, the latent loss and half-face loss are basically in the range of 0.16-0.18, and the half-face loss is between 0.05-0.07. My training loss will be lower than this, regardless of whether I use hdtf or collect additional datasets.

initialize model's weights from models/musetalk/pytorch_model.bin
initialize model's weights from models/musetalk/pytorch_model.bin
initialize model's weights from models/musetalk/pytorch_model.bin
initialize model's weights from models/musetalk/pytorch_model.bin
initialize model's weights from models/musetalk/pytorch_model.bin
initialize model's weights from models/musetalk/pytorch_model.bin
initialize model's weights from models/musetalk/pytorch_model.bin
initialize model's weights from models/musetalk/pytorch_model.bin
2024-05-27 12:39:54,402: step: 0/(1122759/16/8), epoch: 0, latent loss: 0.17210, alpha: 2, latent loss(*2): 0.34420, face loss: 0.06612, face loss(*1): 0.06612, total loss: 0.41032, lr: 0.00001000, during time: 31.19, samples/sec: 8.21
2024-05-27 12:39:56,030: step: 2/(1122759/16/8), epoch: 0, latent loss: 0.16610, alpha: 2, latent loss(*2): 0.33219, face loss: 0.06022, face loss(*1): 0.06022, total loss: 0.39242, lr: 0.00001000, during time: 0.53, samples/sec: 484.85
2024-05-27 12:39:57,087: step: 4/(1122759/16/8), epoch: 0, latent loss: 0.17786, alpha: 2, latent loss(*2): 0.35571, face loss: 0.06771, face loss(*1): 0.06771, total loss: 0.42342, lr: 0.00001000, during time: 0.53, samples/sec: 485.07
2024-05-27 12:39:58,144: step: 6/(1122759/16/8), epoch: 0, latent loss: 0.17310, alpha: 2, latent loss(*2): 0.34619, face loss: 0.06454, face loss(*1): 0.06454, total loss: 0.41073, lr: 0.00001000, during time: 0.53, samples/sec: 483.89
2024-05-27 12:39:59,202: step: 8/(1122759/16/8), epoch: 0, latent loss: 0.18312, alpha: 2, latent loss(*2): 0.36625, face loss: 0.07405, face loss(*1): 0.07405, total loss: 0.44029, lr: 0.00001000, during time: 0.53, samples/sec: 483.67
2024-05-27 12:40:00,261: step: 10/(1122759/16/8), epoch: 0, latent loss: 0.18585, alpha: 2, latent loss(*2): 0.37169, face loss: 0.07470, face loss(*1): 0.07470, total loss: 0.44639, lr: 0.00001000, during time: 0.53, samples/sec: 484.30
2024-05-27 12:40:01,319: step: 12/(1122759/16/8), epoch: 0, latent loss: 0.16785, alpha: 2, latent loss(*2): 0.33571, face loss: 0.06366, face loss(*1): 0.06366, total loss: 0.39937, lr: 0.00001000, during time: 0.53, samples/sec: 484.03
2024-05-27 12:40:02,377: step: 14/(1122759/16/8), epoch: 0, latent loss: 0.16573, alpha: 2, latent loss(*2): 0.33146, face loss: 0.06019, face loss(*1): 0.06019, total loss: 0.39165, lr: 0.00001000, during time: 0.53, samples/sec: 483.45
2024-05-27 12:40:03,436: step: 16/(1122759/16/8), epoch: 0, latent loss: 0.16840, alpha: 2, latent loss(*2): 0.33679, face loss: 0.06269, face loss(*1): 0.06269, total loss: 0.39949, lr: 0.00001000, during time: 0.53, samples/sec: 484.85
2024-05-27 12:40:04,495: step: 18/(1122759/16/8), epoch: 0, latent loss: 0.18388, alpha: 2, latent loss(*2): 0.36776, face loss: 0.07651, face loss(*1): 0.07651, total loss: 0.44427, lr: 0.00001000, during time: 0.53, samples/sec: 484.19
2024-05-27 12:40:05,555: step: 20/(1122759/16/8), epoch: 0, latent loss: 0.17007, alpha: 2, latent loss(*2): 0.34014, face loss: 0.05818, face loss(*1): 0.05818, total loss: 0.39832, lr: 0.00001000, during time: 0.53, samples/sec: 483.39
2024-05-27 12:40:06,612: step: 22/(1122759/16/8), epoch: 0, latent loss: 0.17128, alpha: 2, latent loss(*2): 0.34257, face loss: 0.06878, face loss(*1): 0.06878, total loss: 0.41135, lr: 0.00001000, during time: 0.53, samples/sec: 484.69
2024-05-27 12:40:07,669: step: 24/(1122759/16/8), epoch: 0, latent loss: 0.16634, alpha: 2, latent loss(*2): 0.33267, face loss: 0.06333, face loss(*1): 0.06333, total loss: 0.39600, lr: 0.00001000, during time: 0.53, samples/sec: 484.31
2024-05-27 12:40:08,728: step: 26/(1122759/16/8), epoch: 0, latent loss: 0.18350, alpha: 2, latent loss(*2): 0.36701, face loss: 0.07385, face loss(*1): 0.07385, total loss: 0.44085, lr: 0.00001000, during time: 0.53, samples/sec: 483.39
2024-05-27 12:40:09,789: step: 28/(1122759/16/8), epoch: 0, latent loss: 0.17434, alpha: 2, latent loss(*2): 0.34868, face loss: 0.06881, face loss(*1): 0.06881, total loss: 0.41749, lr: 0.00001000, during time: 0.53, samples/sec: 483.25
2024-05-27 12:40:10,849: step: 30/(1122759/16/8), epoch: 0, latent loss: 0.16171, alpha: 2, latent loss(*2): 0.32342, face loss: 0.05599, face loss(*1): 0.05599, total loss: 0.37941, lr: 0.00001000, during time: 0.53, samples/sec: 483.49
2024-05-27 12:40:11,909: step: 32/(1122759/16/8), epoch: 0, latent loss: 0.18266, alpha: 2, latent loss(*2): 0.36533, face loss: 0.07115, face loss(*1): 0.07115, total loss: 0.43648, lr: 0.00001000, during time: 0.53, samples/sec: 482.99
2024-05-27 12:40:12,968: step: 34/(1122759/16/8), epoch: 0, latent loss: 0.18183, alpha: 2, latent loss(*2): 0.36365, face loss: 0.07388, face loss(*1): 0.07388, total loss: 0.43753, lr: 0.00001000, during time: 0.53, samples/sec: 483.58

Are you using the code in the train_code branch? Or you created your own training code?

Attached is my result based on the author's open source weights, freezing all model parameters. On the hdtf dataset, the latent loss and half-face loss are basically in the range of 0.16-0.18, and the half-face loss is between 0.05-0.07. My training loss will be lower than this, regardless of whether I use hdtf or collect additional datasets.

initialize model's weights from models/musetalk/pytorch_model.bin
initialize model's weights from models/musetalk/pytorch_model.bin
initialize model's weights from models/musetalk/pytorch_model.bin
initialize model's weights from models/musetalk/pytorch_model.bin
initialize model's weights from models/musetalk/pytorch_model.bin
initialize model's weights from models/musetalk/pytorch_model.bin
initialize model's weights from models/musetalk/pytorch_model.bin
initialize model's weights from models/musetalk/pytorch_model.bin
2024-05-27 12:39:54,402: step: 0/(1122759/16/8), epoch: 0, latent loss: 0.17210, alpha: 2, latent loss(*2): 0.34420, face loss: 0.06612, face loss(*1): 0.06612, total loss: 0.41032, lr: 0.00001000, during time: 31.19, samples/sec: 8.21
2024-05-27 12:39:56,030: step: 2/(1122759/16/8), epoch: 0, latent loss: 0.16610, alpha: 2, latent loss(*2): 0.33219, face loss: 0.06022, face loss(*1): 0.06022, total loss: 0.39242, lr: 0.00001000, during time: 0.53, samples/sec: 484.85
2024-05-27 12:39:57,087: step: 4/(1122759/16/8), epoch: 0, latent loss: 0.17786, alpha: 2, latent loss(*2): 0.35571, face loss: 0.06771, face loss(*1): 0.06771, total loss: 0.42342, lr: 0.00001000, during time: 0.53, samples/sec: 485.07
2024-05-27 12:39:58,144: step: 6/(1122759/16/8), epoch: 0, latent loss: 0.17310, alpha: 2, latent loss(*2): 0.34619, face loss: 0.06454, face loss(*1): 0.06454, total loss: 0.41073, lr: 0.00001000, during time: 0.53, samples/sec: 483.89
2024-05-27 12:39:59,202: step: 8/(1122759/16/8), epoch: 0, latent loss: 0.18312, alpha: 2, latent loss(*2): 0.36625, face loss: 0.07405, face loss(*1): 0.07405, total loss: 0.44029, lr: 0.00001000, during time: 0.53, samples/sec: 483.67
2024-05-27 12:40:00,261: step: 10/(1122759/16/8), epoch: 0, latent loss: 0.18585, alpha: 2, latent loss(*2): 0.37169, face loss: 0.07470, face loss(*1): 0.07470, total loss: 0.44639, lr: 0.00001000, during time: 0.53, samples/sec: 484.30
2024-05-27 12:40:01,319: step: 12/(1122759/16/8), epoch: 0, latent loss: 0.16785, alpha: 2, latent loss(*2): 0.33571, face loss: 0.06366, face loss(*1): 0.06366, total loss: 0.39937, lr: 0.00001000, during time: 0.53, samples/sec: 484.03
2024-05-27 12:40:02,377: step: 14/(1122759/16/8), epoch: 0, latent loss: 0.16573, alpha: 2, latent loss(*2): 0.33146, face loss: 0.06019, face loss(*1): 0.06019, total loss: 0.39165, lr: 0.00001000, during time: 0.53, samples/sec: 483.45
2024-05-27 12:40:03,436: step: 16/(1122759/16/8), epoch: 0, latent loss: 0.16840, alpha: 2, latent loss(*2): 0.33679, face loss: 0.06269, face loss(*1): 0.06269, total loss: 0.39949, lr: 0.00001000, during time: 0.53, samples/sec: 484.85
2024-05-27 12:40:04,495: step: 18/(1122759/16/8), epoch: 0, latent loss: 0.18388, alpha: 2, latent loss(*2): 0.36776, face loss: 0.07651, face loss(*1): 0.07651, total loss: 0.44427, lr: 0.00001000, during time: 0.53, samples/sec: 484.19
2024-05-27 12:40:05,555: step: 20/(1122759/16/8), epoch: 0, latent loss: 0.17007, alpha: 2, latent loss(*2): 0.34014, face loss: 0.05818, face loss(*1): 0.05818, total loss: 0.39832, lr: 0.00001000, during time: 0.53, samples/sec: 483.39
2024-05-27 12:40:06,612: step: 22/(1122759/16/8), epoch: 0, latent loss: 0.17128, alpha: 2, latent loss(*2): 0.34257, face loss: 0.06878, face loss(*1): 0.06878, total loss: 0.41135, lr: 0.00001000, during time: 0.53, samples/sec: 484.69
2024-05-27 12:40:07,669: step: 24/(1122759/16/8), epoch: 0, latent loss: 0.16634, alpha: 2, latent loss(*2): 0.33267, face loss: 0.06333, face loss(*1): 0.06333, total loss: 0.39600, lr: 0.00001000, during time: 0.53, samples/sec: 484.31
2024-05-27 12:40:08,728: step: 26/(1122759/16/8), epoch: 0, latent loss: 0.18350, alpha: 2, latent loss(*2): 0.36701, face loss: 0.07385, face loss(*1): 0.07385, total loss: 0.44085, lr: 0.00001000, during time: 0.53, samples/sec: 483.39
2024-05-27 12:40:09,789: step: 28/(1122759/16/8), epoch: 0, latent loss: 0.17434, alpha: 2, latent loss(*2): 0.34868, face loss: 0.06881, face loss(*1): 0.06881, total loss: 0.41749, lr: 0.00001000, during time: 0.53, samples/sec: 483.25
2024-05-27 12:40:10,849: step: 30/(1122759/16/8), epoch: 0, latent loss: 0.16171, alpha: 2, latent loss(*2): 0.32342, face loss: 0.05599, face loss(*1): 0.05599, total loss: 0.37941, lr: 0.00001000, during time: 0.53, samples/sec: 483.49
2024-05-27 12:40:11,909: step: 32/(1122759/16/8), epoch: 0, latent loss: 0.18266, alpha: 2, latent loss(*2): 0.36533, face loss: 0.07115, face loss(*1): 0.07115, total loss: 0.43648, lr: 0.00001000, during time: 0.53, samples/sec: 482.99
2024-05-27 12:40:12,968: step: 34/(1122759/16/8), epoch: 0, latent loss: 0.18183, alpha: 2, latent loss(*2): 0.36365, face loss: 0.07388, face loss(*1): 0.07388, total loss: 0.43753, lr: 0.00001000, during time: 0.53, samples/sec: 483.58

Are you using the code in the train_code branch? Or you created your own training code?

I did not use the train_codes branch code; I wrote my own code based on the inference code. I have verified that the output from the dataset layer of the training code is consistent with the output from the inference code.

What is the range of your loss output?

commented

基于MuseTalk项目介绍,实现了一版训练代码。然后与train_codes分支进行对比,大体上是一致。然后也是基于约350条hdtf数据,训练出的模型合成效果出现了如下两个明显问题: (1)上半张脸和下半张脸,不协调,能看出色调不一致。(有些case还是比较明显的) (2)视频下半张脸抖动特别明显,视频播放看出是不连贯的。 请问,问题出现在什么地方?有什么建议吗?怎么才能复现当前项目开源的模型效果

原始视频:

out.mp4
合成后视频: https://github.com/TMElyralab/MuseTalk/assets/23277618/df55aa88-c007-454d-92fc-9c0dc6d51504

你好,可以分享一下HDTF的数据集吗,我下载完只有150多个视频。如果方便的话可以wvinzh@qq.com

commented

请问你们(1)使用多少数据进行训练?(2)使用什么GPU,batch size多少,训练到多少步了呢?

(1)我这边用了320条HDTF视频数据,260条训练,60条测试 (2)用的1张A800,batch size=32,训练到50000步结束。其他参数都延用train.sh里默认的参数

你好,可以分享一下HDTF的数据集吗,我下载完只有150多个视频。如果方便的话可以wvinzh@qq.com