Sense-GVT / DeCLIP

Supervision Exists Everywhere: A Data Efficient Contrastive Language-Image Pre-training Paradigm

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

worked (simple) example of loading model and transforms?

ColinConwell opened this issue · comments

Thank you for this exciting repository. Can you provide a simple example of how I might be able to load the models you provide in your model zoo?

Something along the lines of what is provided by the timm (pytorch-image-models) model repository:

import timm
model_name = 'ghostnet_100'
model = timm.create_model(model_name, pretrained=True)
model.eval()

from timm.data.transforms_factory import create_transform
from timm.data import resolve_data_config
    
config = resolve_data_config({}, model = model_name)
transform = create_transform(**config)

Ideally, this would allow us to use the models in a jupyter notebook or other interactive context.

Thanks in advance!

By way of example, here's a little script I worked out. If this looks incorrect, let me know!

import os, sys, torch
from PIL import Image

from torchvision import transforms

if not os.path.exists('DeCLIP'):
    !git clone https://github.com/Sense-GVT/DeCLIP/
    
sys.path.append('DeCLIP')
    
sample_image = Image.open('dog.jpg')

from prototype.utils.misc import parse_config

config_path = 'DeCLIP/experiments/declip_experiments/declip88m/declip88m_r50_declip/config.yaml'
config = parse_config(config_path)

from prototype.model.declip import declip_res50

bpe_path = 'DeCLIP/prototype/text_info/bpe_simple_vocab_16e6.txt.gz'
config['model']['kwargs']['text_encode']['bpe_path'] = bpe_path
config['model']['kwargs']['clip']['text_mask_type'] = None

weights = torch.load('DeCLIP/weights/declip_88m/r50.pth.tar')['model']
weights = {k.replace('module.',''):v for k,v in weights.items()}
weights['logit_scale'] = weights['logit_scale'].unsqueeze(0)

model = declip_res50(**config['model']['kwargs'])
model.load_state_dict(weights, strict = False)

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
preprocess = transforms.Compose([transforms.Resize(256), transforms.ToTensor(), normalize])

inputs = preprocess(sample_image).unsqueeze(0)
model.visual(inputs)