Zhongdao / gcn_clustering

Code for CVPR'19 paper Linkage-based Face Clustering via GCN

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Extracting features from ArcFace

saedr opened this issue · comments

Hi, I have a question regarding the feature extraction, as I cannot reproduce the results with my own preprocessed files. Given IJB-B-512, your checkpoint for CASIA, and pytorch implementation of ArcFace. I came up with the following code:

import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import datasets

from model import Backbone

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.Resize(size=(112, 112)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

data_path = "../data/IJB-B-512/"
batch_size = 16
num_workers = 16

data = datasets.ImageFolder(data_path, transform=transform)
loader = torch.utils.data.DataLoader(data, 
                                     batch_size=batch_size, 
                                     num_workers=num_workers,
                                     shuffle=True,
                                     pin_memory=True)

model = Backbone(50, 0.6, 'ir_se')
ckpt = torch.load("../pretrained/model_ir_se50.pth")
model.load_state_dict(ckpt)
model.cuda()
model.eval()

features = []
def hook(module, input, output):
    N, C, H, W = output.shape
    output = output.reshape(N, C, -1)
    features.append(output.mean(dim=2).cpu().detach().numpy())

handle = model._modules['body'][23].res_layer[5].fc2.register_forward_hook(hook)
for i_batch, inputs in tqdm(enumerate(loader), total=len(loader)):
    _ = model(inputs[0].cuda())

features = np.concatenate(features)
handle.remove()

Could you please let me know if my approach makes sense or how is it different from yours or could you kindly share your pre-processing module?

Hi Saed (@saedr) , Please how did you make the code run ? Like the code is written for Pytorch 0.4 and Python 2.7 ,and it is difficult to make it run in the new GPU-s. Any suggestion would be helpful.

Hi, @saedr , Could you share the dataset "IJB-B"?