GEM-benchmark / NL-Augmenter

NL-Augmenter 🦎 → 🐍 A Collaborative Repository of Natural Language Transformations

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Change batch size and number of visible devices for text-style-transfer

guanqun-yang opened this issue · comments

Hi @Filco306

Thank you for your great work to make the powerful paraphrasing model easily accessible through HuggingFace! Now it is much easier for me to work with it without the hassle of handling complicated dependencies!

But is there any way for us to use a larger batch size and more GPUs to accelerate the paraphrasing process? Now it I could use only one GPU and a small batch size. I read your implementation here but there does not seem to be an easy to do either of them.

Thank you. I am looking forward to your reply.

Hello @guanqun-yang ,

Thank you for your kind words! Yes, I think that could be a good enhancement. Note that there is a WIP issue #298 to fix some things; it is important for the method to really work to perform a two-step process for the paraphrasers (unless you are using a basic one). I think we could integrate these enhancements there; I just have not fixed due to a lack of response, and lack of time on my own side. If you have any suggestions on how to do this we can try and integrate this.

One thing I am wondering however is whether your desired change would change the API design style too much away from the NL-Augmenter's. That way, you can customize your pipeline as you want without any restrictions from the NL-Augmenter package. One thing you therefore might do instead is simply use the actual built-in GPT2-model from the transformers library and not use the NL-Augmenter. Just remember to do the two-step process as in the original paper :) As you can see in #298, you can simply use the code below to generate outputs from a GPT2 transformer model with the weights I have uploaded loaded. Check the example below that @martiansideofthemoon pasted :)

out = gpt2.generate(
    input_ids=gpt2_sentences[:, 0:init_context_size],
    max_length=gpt2_sentences.shape[1],
    return_dict_in_generate=True,
    eos_token_id=eos_token_id,
    output_scores=True,
    do_sample=top_k > 0 or top_p > 0.0,
    top_k=top_k,
    top_p=top_p,
    temperature=temperature,
    num_beams=beam_size,
    token_type_ids=segments[:, 0:init_context_size]
)