bdqnghi / infercode

[ICSE 2021] - InferCode: Self-Supervised Learning of Code Representations by Predicting Subtrees

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question: Is there a reason that the batch size is capped at 5?

Peter-Devine opened this issue · comments

Hey, first off, love the repo, thanks for providing it.

I just have a question about the batch size that the model can handle.

When I put in a list of more than 5 pieces of code like so:

from infercode.client.infercode_client import InferCodeClient
import os
import logging
logging.basicConfig(level=logging.INFO)

# Change from -1 to 0 to enable GPU
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

infercode = InferCodeClient(language="java")
infercode.init_from_config()

# Here we put in 6 identical i initiailizations
vectors = infercode.encode(["int i = 0;"] * 6)

I get the following error:

AssertionError                            Traceback (most recent call last)
Input In [20], in <cell line: 1>()
----> 1 vectors = infercode.encode(["int i = 0;"] * 6)

File ~\Anaconda3\envs\infercode_new_env\lib\site-packages\infercode\client\infercode_client.py:76, in InferCodeClient.encode(self, batch_code_snippets)
     75 def encode(self, batch_code_snippets):
---> 76     tensors = self.snippets_to_tensors(batch_code_snippets)
     77     embeddings = self.sess.run(
     78         [self.infercode_model.code_vector],
     79         feed_dict={
   (...)
     87         }
     88     )
     89     return embeddings[0]

File ~\Anaconda3\envs\infercode_new_env\lib\site-packages\infercode\client\infercode_client.py:62, in InferCodeClient.snippets_to_tensors(self, batch_code_snippets)
     60 def snippets_to_tensors(self, batch_code_snippets):
     61     batch_tree_indexes = []
---> 62     assert len(batch_code_snippets) <= 5
     63     for code_snippet in batch_code_snippets:
     64         # tree-sitter parser requires bytes as the input, not string
     65         code_snippet_to_byte = str.encode(code_snippet)

AssertionError: 

This stems from the code here having an assert that the number of inputs is <=5:

assert len(batch_code_snippets) <= 5

Is there a reason this is hard-coded? Or would it make sense to make a batch_size parameter that maybe defaults to 5 but is adjustable depending on computational capacity?