Zasder3 / train-CLIP

A PyTorch Lightning solution to training OpenAI's CLIP from scratch.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Loading checkpoint

turicumoperarius opened this issue · comments

Hi there. Can you give some advice in how to load a checkpoint from a trained model with your pytorch lightning wrapper for inference? I used the common pytroch lightning method "load_from_checkpoint" but did not have any luck so far. Thanks

Sure thing! Would you mind sharing your code to load in the model?

In the meantime i did play around with the code a little and did found a working solution for me. May be you had some additional ideas/tweaks to make my current code better.

First of all i used the standard "train.py" training code and pytorch dataset class provided in the repository. To make things easier i added a "torch.save(model.model, )" after the "trainer.fit" part for testing. Than i tried to load the model with "torch.load" for inference purposes. The rest than is similar to the code underneath using the checkpoint method.

After the "torch.load" variant worked i tried to use the default checkpoint created by pytorch lightning.

CHECKPOINT = '/ightning_logs/version_1/checkpoints/epoch=31-step=3487.ckpt'
TEST_IMG = 'test.jpg'
MODEL_NAME = 'ViT-B/16'
DEVICE = 'cuda'

config_dir = 'models/configs/ViT.yaml' if 'ViT' in MODEL_NAME else 'models/configs/RN.yaml'
with open(config_dir) as fin:
    config = yaml.safe_load(fin)[MODEL_NAME ]

model = CLIPWrapper.load_from_checkpoint(CHECKPOINT, model_name=MODEL_NAME, config=config, minibatch_size=1).model.to(DEVICE)

def fix_img(img):
    return img.convert('RGB') if img.mode != 'RGB' else img
    image_transform = T.Compose([
        T.Lambda(fix_img),
        T.RandomResizedCrop(224,
            scale=(0.75, 1.),
            ratio=(1., 1.)),
            T.ToTensor(),
           T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])

with torch.no_grad():

    img_emb = image_transform(Image.open(TEST_IMG)).unsqueeze(0).to(DEVICE)
    img_enc = model.encode_image(img_emb)

    text_emb= clip.tokenize(["test query"], truncate=True).to(DEVICE)
    text_enc = model.encode_text(text_emb)

    ... do things ...

So this code works for me, at least with the CLIPWrapper. I did not test the CustomCLIPWrapper yet, but i think it should work pretty similar. May be you have some advice for getting the image transform and or the tokenizer a little easier into the inference code.

commented

In the meantime i did play around with the code a little and did found a working solution for me. May be you had some additional ideas/tweaks to make my current code better.

Hey, I also trained and developed a ckpt file. Looking now for inferencing code, would you be so kind to share here yours, as it's working? Thank you

In the meantime i did play around with the code a little and did found a working solution for me. May be you had some additional ideas/tweaks to make my current code better.

Hey, I also trained and developed a ckpt file. Looking now for inferencing code, would you be so kind to share here yours, as it's working? Thank you

Actually the inference code is the shared code above. Is there anything in particular you are looking for or struggle with?

commented

In the meantime i did play around with the code a little and did found a working solution for me. May be you had some additional ideas/tweaks to make my current code better.

Hey, I also trained and developed a ckpt file. Looking now for inferencing code, would you be so kind to share here yours, as it's working? Thank you

Actually the inference code is the shared code above. Is there anything in particular you are looking for or struggle with?

I was wondering how to complete the with torch.no_grad(): part to have an image caption as output. So basically what to insert instead of the ... do things ... after instantiating img_enc and text_enc

In the meantime i did play around with the code a little and did found a working solution for me. May be you had some additional ideas/tweaks to make my current code better.

Hey, I also trained and developed a ckpt file. Looking now for inferencing code, would you be so kind to share here yours, as it's working? Thank you

Actually the inference code is the shared code above. Is there anything in particular you are looking for or struggle with?

I was wondering how to complete the with torch.no_grad(): part to have an image caption as output. So basically what to insert instead of the ... do things ... after instantiating img_enc and text_enc

The variables img_enc and text_enc contain everything that you would get form a CLIP Model. They contain the features for the image and for the text. What you now like to do with them is up to you.

For example you could do a simple zero shot classification like in this example taken from the official CLIP github repo.

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")

Or you could simple calculate the cosine simularty between the vectors and do what every you like with it. What i mostly do is put the features in some kind of vectore store with an index like faiss, spotify annoy or lancedb to make it searchable.