gydpku / OCM

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

OCM

Here is the official implementation of the paper "Online Continual Learning thorough Mutual Information Maximization". This paper is accepted by ICML2022 as a spotlight paper.

Requirements

pytorch<=1.6.0
numpy==1.19.5
scipy==1.4.1
apex==0.1
tensorboardX
diffdist

Usage

To reproduce the results in the CIFAR10 setting (2 classes per task)

            python test_cifar10.py --buffer_size 1000

To reproduce the results in other setting (e.g. CIFAR100):

            python test_<dataset name>.py --buffer_size xxx

Note that the name of dataset is in lowercase. You can check them in the OCM file.

Pseudo code for OCM (The simplest form to understand the method)

Representation learning part

    x,y=x.cuda(),y.cuda() # get the new data input
    
    rotate_x,rotate_y=Rotation(x,y) # Using the rotation operation to create more pseudo classes
    
    hidden,hidden_aug= Basic_model(rotate_x,is_simclr=True), Basic_model(Aug(rotate_x),is_simclr=True) 
    #Aug is the data augmentation
    
    sim_matrix=torch.matmul(normalize(hidden),normalize(hidden_aug).t()) #similarity matrix
    
    InfoNce_loss_new = Supervised_NT_xent_uni(sim_matrix,labels=rotate_y,temperature=0.07) 
    # You can do the same thing for buffer data

Forgetting loss part

    mem_x,mem_y=mem_x.cuda(),mem_y.cuda() 
    # get the buffer data. You can choice the retrieval strategy by yourself.
    
    hidden_mem,hidden_mem_prev=Basic_model(mem_x,is_simclr=True), Previous_model(mem_x,is_simclr=True) 
    
    sim_matrix_prev= torch.matmul(normalize(hidden_mem),normalize(hidden_mem_prev).t())
    
    InfoNce_loss_prev = Supervised_NT_xent_pre(sim_matrix_prev, labels=mem_y, temperature=0.07)

About


Languages

Language:Python 100.0%