PrithivirajDamodaran / Alt-ZSC

Alternate Implementation for Zero Shot Text Classification: Instead of reframing NLI/XNLI, this reframes the text backbone of CLIP models to do ZSC. Hence, can be lightweight + supports more languages without trading-off accuracy. (Super simple, a 10th-grader could totally write this but since no 10th-grader did, I did) - Prithivi Da

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Unable to inference for longer texts

smiraldr opened this issue · comments

Issue:
As the library uses clip from open ai as its backbone , it is unable to tokenize and use larger strings due to the constraint in original model. Having the same issue when inferencing for other languages than en.
Is there a way to use longer texts for zeroshot inference without replacing the underlying model and trunctating text?

Reference : in original clip model source code https://github.com/openai/CLIP/blob/67fc250eb6aa84ef9ad19a020e3f8eb4e698feb4/clip/clip.py

def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
    """
    Returns the tokenized representation of given input string(s)
    Parameters
    ----------
    texts : Union[str, List[str]]
        An input string or a list of input strings to tokenize
    context_length : int
        The context length to use; all CLIP models use 77 as the context length
    truncate: bool
        Whether to truncate the text in case its encoding is longer than the context length
    Returns
    -------
    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]

Stacktrace

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_12712/3648834823.py in <module>
----> 1 english_df[['predicted_cat', 'prob', 'res']] = english_df.apply(lambda x: evaluate(x['clean_title'],categories), axis=1, result_type="expand")

~/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/pandas/core/frame.py in apply(self, func, axis, raw, result_type, args, **kwargs)
   8738             kwargs=kwargs,
   8739         )
-> 8740         return op.apply()
   8741 
   8742     def applymap(

~/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/pandas/core/apply.py in apply(self)
    686             return self.apply_raw()
    687 
--> 688         return self.apply_standard()
    689 
    690     def agg(self):

~/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/pandas/core/apply.py in apply_standard(self)
    810 
    811     def apply_standard(self):
--> 812         results, res_index = self.apply_series_generator()
    813 
    814         # wrap results

~/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/pandas/core/apply.py in apply_series_generator(self)
    826             for i, v in enumerate(series_gen):
    827                 # ignore SettingWithCopy here in case the user mutates
--> 828                 results[i] = self.f(v)
    829                 if isinstance(results[i], ABCSeries):
    830                     # If we have a view on v, we need to make a copy because

/tmp/ipykernel_12712/3648834823.py in <lambda>(x)
----> 1 english_df[['predicted_cat', 'prob', 'res']] = english_df.apply(lambda x: evaluate(x['clean_title'],categories), axis=1, result_type="expand")

/tmp/ipykernel_12712/32427261.py in evaluate(clean_title, labels)
      7 
      8     preds = zstc(text=clean_title,
----> 9                 candidate_labels=labels,
     10                 )
     11     d = dict(zip(preds['labels'],preds['scores']))

~/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/AltZSC/ZeroShotTextClassification.py in __call__(self, text, candidate_labels, *args, **kwargs)
    102 
    103         if str(type(self.model)) == "<class 'clip.model.CLIP'>":
--> 104             text_tokens = clip.tokenize(text).to(device)
    105             label_tokens = clip.tokenize(labels).to(device)
    106             text_features = self.model.encode_text(text_tokens)

~/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/clip/clip.py in tokenize(texts, context_length, truncate)
    226                 tokens[-1] = eot_token
    227             else:
--> 228                 raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
    229         result[i, :len(tokens)] = torch.tensor(tokens)
    230 

RuntimeError: Input PM NARENDRA MODI  TU KON  RAHUL GANDHI  MAIN FIROZ KHAN KA POTA PM NARENDRA MODI  KYA KAAM AATA HAI  RAHUL GANDHI  SCAM KARNANEPOTISM KARNA COMMUNAL KARNA  ANTI NATIONAL BANNA MUJHE AATA HAI PM NARENDRA MODI  PUNJAB MAIN  RAHUL GANDHI  SAB PUNJABI HAINNAFRAT KARKE EK DUSRE KO JEET LENGE PM NARENDRA MODI  PAKISTAN NIKAL LE IAS IFTIKHARUDDIN IS THE EXAMPLE OF MADARSA ISLA is too long for context length 77

split & pass the text as a list and average the scores.