imlixinyang / HiSD

Official pytorch implementation of paper "Image-to-image Translation via Hierarchical Style Disentanglement" (CVPR 2021 Oral).

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How many tags can this project train at the same time?

datar001 opened this issue · comments

Hi, thanks for your sharing.
How many tags have you tried to train? What's the relation between the number of tags and that of training iterations?
And How many tags will you recommend at the once training?

I've succeeded to train 6 tags at the same time. In experiment, I found 50k per tag is enough (i.e., 20k for 6 tags).
HiSD supports various numbers of tags but you should increase the training iteration and the model capacity.
Using gradient accumulation and train all tags in one iteration is also important (so you need to change the code a little).

Thanks for your reply.
Is it right about "the gradient accumulation and all tags in one iteration"?
image
image
And '20k for 6 tags' is the typo? The official repo is 200k for 3 tags with 7 attributions.
Then is there a better performance when we train fewer tags?

Sorry for the typo, it should be 200k for 3 tags with 7 attributes.
You get the idea of the gradient accumulation in a right way, and you can clarify the update code like:

    def update(self, x, y, i, j, j_trg, iterations):

        this_model = self.models.module if self.multi_gpus else self.models

        # gen 
        for p in this_model.dis.parameters():
            p.requires_grad = False
        for p in this_model.gen.parameters():
            p.requires_grad = True

        self.loss_gen_adv, self.loss_gen_sty, self.loss_gen_rec, \
        x_trg, x_cyc, s, s_trg = self.models((x, y, i, j, j_trg), mode='gen')

        self.loss_gen_adv = self.loss_gen_adv.mean()
        self.loss_gen_sty = self.loss_gen_sty.mean()
        self.loss_gen_rec = self.loss_gen_rec.mean()
        

        # dis
        for p in this_model.dis.parameters():
            p.requires_grad = True
        for p in this_model.gen.parameters():
            p.requires_grad = False


        self.loss_dis_adv = self.models((x, x_trg, x_cyc, s, s_trg, y, i, j, j_trg), mode='dis')
        self.loss_dis_adv = self.loss_dis_adv.mean()
        
        if (iterations + 1) % self.tag_num == 0:
            nn.utils.clip_grad_norm_(this_model.gen.parameters(), 100)
            nn.utils.clip_grad_norm_(this_model.dis.parameters(), 100)
            self.gen_opt.step()
            self.dis_opt.step()
            self.gen_opt.zero_grad()
            self.dis_opt.zero_grad()

            update_average(this_model.gen_test, this_model.gen)

        return self.loss_gen_adv.item(), \
               self.loss_gen_sty.item(), \
               self.loss_gen_rec.item(), \
               self.loss_dis_adv.item()

And you need to decrease the learning rate before backward (maybe lr/tag_num) since the gradient by 'sum' rather than 'average'.