peymanbateni / simple-cnaps

Source codes for "Improved Few-Shot Visual Classification" (CVPR 2020), "Enhancing Few-Shot Image Classification with Unlabelled Examples" (WACV 2022), and "Beyond Simple Meta-Learning: Multi-Purpose Models for Multi-Domain, Active and Continual Few-Shot Learning" (SSRN Electronic Journal)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to Prediction

FDwangchao opened this issue · comments

when I trained the simple-cnaps model, how to use the model to predict the new dataset?

You need to create a few-shot task. Basically, you need to provide the inputs to the def forward(self, context_images, context_labels, target_images): function and call the function using a loaded model checkpoint. The context_images and context_labels would be your new dataset and its labels, while target_images are the images which you want the model to produce labels for. All images need to be resized to 84x84 and pixels need to be normalized (i.e. divide by 255, see how Meta Dataset, miniImageNet, and tieredImageNet images are prepared). Labels would simply be integers, indicating which class the context_image belongs to.

Hope this helps! I tried to keep the answer very concise. Let me know if you need any additional clarification.

Thank you, I also want to ask: When I have a small amount of labeled data, do I need to retrain the simple-cnaps model with this small amount of data? Or directly load the trained model (trained by tieredImageNet images), and let these small amounts of data be used as context_images and context_labels, and then predict new data?

The latter. You should load the trained model, then use your small labeled dataset as context_images and context_labels, and have the images you want to predict as target_images to get predicted labels for. Hope that helps :)

Thanks a lot, it works!

I used a small amount of labeled data as content data(a total of 45 pictures in 5 categories), and then predicted the same data, and found that the accuracy rate was only 70%, not 100%. Is there any problem?
step1:load model
model.load_state_dict(torch.load('simple-cnaps/model-checkpoints/meta-dataset-checkpoints/best_simple_cnaps.pt', map_location=torch.device('cpu')))
step2: The same way to convert context and query image numpy array
load_img_rz = Image.open(os.path.join(path, img))
im = load_img_rz.convert('RGB')
im = im.resize((84, 84), resample=Image.LANCZOS)
img_np = 2 * (np.array(im) / 255.0 - 0.5)
context_images_np = img_np.transpose([0, 3, 1, 2])
context_images = torch.from_numpy(context_images_np)
context_images = context_images.to(torch.float32)
step3:prediction
test_logits_sample = model.forward(context_images, context_labels, target_images)
averaged_predictions = torch.logsumexp(test_logits_sample, dim=0)
predict_label.extend(torch.argmax(averaged_predictions, dim=-1))

I want to ask if there is any problem?

Looks reasonable, and keep in mind that even fully supervised data-rich classifiers never reach 100% accuracy. In fact, despite our method being one of the best-performing ones on most datasets, accuracy levels are around 70%. See https://openaccess.thecvf.com/content_CVPR_2020/html/Bateni_Improved_Few-Shot_Visual_Classification_CVPR_2020_paper.html for more details.

Thank you very much for your detailed explanation