leeesangwon / PyTorch-Image-Retrieval

A PyTorch framework for an image retrieval task including implementation of N-pair Loss (NIPS 2016) and Angular Loss (ICCV 2017).

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

A question about how to test retrieval

gimpong opened this issue · comments

Hi, thank you for building such a useful and concise framework for image retrival task!

I just read the file inference.py and got confused about the testing process.

query_loader = DataLoader(query_img_dataset, batch_size=infer_batch_size, shuffle=False, num_workers=4,
pin_memory=True)
reference_loader = DataLoader(reference_img_dataset, batch_size=infer_batch_size, shuffle=False, num_workers=4,
pin_memory=True)

My question is, why is the database for retrieval getting divided into batches? Shouldn't we keep it as a whole for every query, since we want to rank all the reference images in whole database according to the pair-wise similarity?

In the batch_process function, all feature vectors are concatenated.

query_paths, query_vecs = batch_process(model, query_loader)
reference_paths, reference_vecs = batch_process(model, reference_loader)

def batch_process(model, loader):
feature_vecs = []
img_paths = []
for data in loader:
paths, inputs = data
feature_vec = _get_feature(model, inputs.cuda())
feature_vec = feature_vec.detach().cpu().numpy() # (batch_size, channels)
for i in range(feature_vec.shape[0]):
feature_vecs.append(feature_vec[i])
img_paths = img_paths + paths
return img_paths, np.asarray(feature_vecs)