omnilib / aiosqlite

asyncio bridge to the standard sqlite3 module

Home Page:https://aiosqlite.omnilib.dev

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Cannot set `row_factory` if I use the connection as a context manager

decorator-factory opened this issue · comments

Description

It seems that I cannot set the row_factory attribute if I want to use the connection with async with later.

  1. If I don't await the connection, I can't set row_factory on it, because it hasn't created an sqlite3.Connection object yet.
  2. If I do await the connection, I can set row_factory on it, but if I then use the connection in an async with statement, Connection.__aenter__ will await on the aiosqlite.Connection object, which will start() it, which will start the thread twice, which leads to a nasty error.

As a workaround, I did this:

def connect(database: Union[str, Path], iter_chunk_size: int = 64, **kwargs: Any) -> aiosqlite.Connection:
    def connector() -> sqlite3.Connection:
        conn = sqlite3.connect(str(database), **kwargs)
        conn.row_factory = sqlite3.Row
        return conn
    return aiosqlite.Connection(connector, iter_chunk_size)

Maybe aiosqlite.connect should accept some flag or parameter to configure the row_factory, or maybe even a callback to do something when the connection is made?

Details

(don't think these matter)

  • OS: Manjaro, 21.0.7
  • Python version: 3.9.2
  • aiosqlite version: 0.17.0
  • Can you repro on 'main' branch? Yes
  • Can you repro in a clean virtualenv? Yes

Hi @decorator-factory,

I have also come across this issue and decided to resolve it this way:

import sqlite3    
import typing as t    
    
import aiosqlite    
    
    
def connect(    
    database: str,    
    iter_chunk_size: int = 64,    
    on_connect: t.Callable[[sqlite3.Connection], None] = lambda _: None,    
    **kwargs: t.Any,    
) -> aiosqlite.Connection:    
    def connector() -> sqlite3.Connection:    
        connection = sqlite3.connect(database, **kwargs)    
        on_connect(connection)    
        return connection    
    
    return aiosqlite.Connection(connector, iter_chunk_size)

This way, you can not only set the row_factory, but also have access to all the methods of sqlite3.Connection .

So in a complete example, following is now possible:

import asyncio      
import sqlite3      
import typing as t      
import uuid      
      
import aiosqlite      
      
      
def connect(      
    database: str,      
    iter_chunk_size: int = 64,      
    on_connect: t.Callable[[sqlite3.Connection], None] = lambda _: None,      
    **kwargs: t.Any,      
) -> aiosqlite.Connection:      
    def connector() -> sqlite3.Connection:      
        connection = sqlite3.connect(database, **kwargs)      
        on_connect(connection)      
        return connection      
      
    return aiosqlite.Connection(connector, iter_chunk_size)      
      
      
def uuid_str() -> str:      
    return str(uuid.uuid4())      
      
      
def on_connect(conn: sqlite3.Connection) -> None:      
    conn.row_factory = sqlite3.Row      
    conn.create_function("uuid4", 0, uuid_str)      
    conn.set_trace_callback(lambda tb: print(tb))      
      
      
async def main() -> None:      
    async with connect(":memory:", on_connect=on_connect) as db:      
        cur = await db.execute("SELECT uuid4() AS my_uuid;")      
        row = await cur.fetchone()      
        print(type(row))      
        print(row[0])                                                                                                                                                                         
        print(row["my_uuid"])                                                                                                                                                                 
        await cur.close()  
  
asyncio.run(main())

and now, in console you can see something like:

SELECT uuid4() AS my_uuid;
<class 'sqlite3.Row'>
4af29e4b-1d71-4ef0-b7c3-4ebbc93812c0
4af29e4b-1d71-4ef0-b7c3-4ebbc93812c0

What do you think?

@jreese would you be open for me to submit a PR with this enhancement?

Or, would the prefered solution be having a custom @asynccontextmanager that does all the initialization work we want?

Since aiosqlite.Connection has all the methods mentioned above, but executes them async

Like this:

import asyncio
import typing as t
import uuid
from contextlib import asynccontextmanager

import aiosqlite

def uuid_str() -> str:
    return str(uuid.uuid4())

async def aon_conn(conn: aiosqlite.Connection) -> None:
    conn.row_factory = sqlite3.Row
    await conn.create_function("uuid4", 0, uuid_str)
    await conn.set_trace_callback(lambda tb: print(tb))


@asynccontextmanager
async def custom_connect(
    database: str,
    iter_chunk_size: int = 64,
    on_connect: t.Optional[t.Callable[[aiosqlite.Connection], t.Awaitable[None]]] = None,
    **kwargs: t.Any,
) -> t.AsyncGenerator[aiosqlite.Connection, None]:
    async with aiosqlite.connect(database, iter_chunk_size=iter_chunk_size, **kwargs) as connection:
        if on_connect:
            await on_connect(connection)
        yield connection

async def main() -> None:
    async with custom_connect(":memory:", on_connect=aon_conn) as db:
        cur = await db.execute("SELECT uuid4() AS my_uuid;")
        row = await cur.fetchone()
        print(type(row))
        print(row[0])
        print(row["my_uuid"])
        await cur.close()

asyncio.run(main())

Output:

SELECT uuid4() AS my_uuid;
<class 'sqlite3.Row'>
4ad03421-6dd8-4137-8a2c-9b4176b9edba
4ad03421-6dd8-4137-8a2c-9b4176b9edba