minimaxir / gpt-2-cloud-run

Text-generation API via GPT-2 for Cloud Run

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TypeError: not all arguments converted during string formatting

vaibhavpras opened this issue · comments

Hi, first of all thank you for this amazing tutorial and repo. Awesome work!

So I've been trying to generate a custom model similar to the code you have in /examples/hacker_news.py. I have my app.py shown below. I've set return_as_list to true. I get the list, remove unnecessary prefixes from each string in the list and return the list as JSON. This code throws an error which I have shows below. However, when I generate text locally without calling the API (ie. without the HTTPS requests involved), the code works perfectly without errors. I can't seem to figure out what I'm doing wrong. Highly appreciate any help.

app.py

from starlette.applications import Starlette
from starlette.responses import UJSONResponse
import gpt_2_simple as gpt2
import tensorflow as tf
import uvicorn
import os
import gc

app = Starlette(debug=False)

sess = gpt2.start_tf_sess(threads=1)
gpt2.load_gpt2(sess)


response_header = {
    'Access-Control-Allow-Origin': '*'
}

generate_count = 0


@app.route('/', methods=['GET', 'POST', 'HEAD'])
async def homepage(request):
    global generate_count
    global sess

    if request.method == 'GET':
        params = request.query_params
    elif request.method == 'POST':
        params = await request.json()
    elif request.method == 'HEAD':
        return UJSONResponse({'text': ''},
                             headers=response_header)

    
    text = gpt2.generate(sess,
                         length=55,
                         temperature=1.0,
                         top_k=int(params.get('top_k', 0)),
                         top_p=float(params.get('top_p', 0)),
                         prefix='<|startoftext|>' + params.get('prefix', ''),
                         truncate='<|endoftext|>',
                         include_prefix=True,
                         nsamples=params.get('nsamples', 1),
                         return_as_list=True
                         )

    for x in text:
        x = x.replace('<|startoftext|>', '')
        x = x.replace('<|endoftext|>', '')
        x = x.replace('  ', ' ')


    generate_count += 1
    if generate_count == 8:
        # Reload model to prevent Graph/Session from going OOM
        tf.reset_default_graph()
        sess.close()
        sess = gpt2.start_tf_sess(threads=1)
        gpt2.load_gpt2(sess)
        generate_count = 0

    gc.collect()
    return UJSONResponse({'text_list': text})

if __name__ == '__main__':
    uvicorn.run(app, host='0.0.0.0', port=int(os.environ.get('PORT', 8080)))

Error:

`Traceback (most recent call last): File "/usr/local/lib/python3.7/site-packages/uvicorn/protocols/http/httptools_impl.py", line 385, in run_asgi result = await app(self.scope, self.receive, self.send) File "/usr/local/lib/python3.7/site-packages/uvicorn/middleware/proxy_headers.py", line 45, in __call__ return await self.app(scope, receive, send) File "/usr/local/lib/python3.7/site-packages/starlette/applications.py", line 102, in __call__ await self.middleware_stack(scope, receive, send) File "/usr/local/lib/python3.7/site-packages/starlette/middleware/errors.py", line 181, in __call__ raise exc from None File "/usr/local/lib/python3.7/site-packages/starlette/middleware/errors.py", line 159, in __call__ await self.app(scope, receive, _send) File "/usr/local/lib/python3.7/site-packages/starlette/exceptions.py", line 82, in __call__ raise exc from None File "/usr/local/lib/python3.7/site-packages/starlette/exceptions.py", line 71, in __call__ await self.app(scope, receive, sender) File "/usr/local/lib/python3.7/site-packages/starlette/routing.py", line 550, in __call__ await route.handle(scope, receive, send) File "/usr/local/lib/python3.7/site-packages/starlette/routing.py", line 227, in handle await self.app(scope, receive, send) File "/usr/local/lib/python3.7/site-packages/starlette/routing.py", line 41, in app response = await func(request) File "app.py", line 44, in homepage nsamples=params.get('nsamples', 1), File "/usr/local/lib/python3.7/site-packages/gpt_2_simple/gpt_2.py", line 428, in generate assert nsamples % batch_size == 0 TypeError: not all arguments converted during string formatting