Skip to content

binary representation of numpy.complex and postgresql composite #1060

Closed
@GFuhr

Description

@GFuhr
  • asyncpg version: 0.28.0
  • PostgreSQL version: 15 (latest docker container)
  • Do you use a PostgreSQL SaaS? If so, which? Can you reproduce
    the issue with a local PostgreSQL install?
    :
  • Python version: 3.10
  • Platform: windows or debian in WSL
  • Do you use pgbouncer?: no
  • Did you install asyncpg with pip?: yes
  • If you built asyncpg locally, which version of Cython did you use?:
  • Can the issue be reproduced under both asyncio and
    uvloop?
    : yes

Hi,

I need for a project to store complex numbers in a database (numpy.complex64 exactly) and for that I created a postgresql composite datatype.

I was able to make it work with asyncpg in text format for simple INSERT, SELECT... after some works and patches found on previous issues report.

However for the binary format I have weird errors : asyncpg.exceptions.DatatypeMismatchError: wrong number of columns: 1082549862, expected 2

I made a "simple" python script to reproduce the issue :

from __future__ import annotations

import asyncpg
import asyncio
import numpy as np

from ast import literal_eval
import struct

SQL_CREATE = """
DO $$ BEGIN
    IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'complex') THEN
        CREATE TYPE complex AS (
            r float4,
            i float4
        );
    END IF;
    CREATE TABLE dummy_table (dummy_column complex);
    DROP TABLE dummy_table;
END $$;

DROP TABLE IF EXISTS "poc_asyncpg";

CREATE TABLE "poc_asyncpg" (
    "id" SERIAL PRIMARY KEY,
    "float_value" float4 NULL,
    "complex_value" complex NULL,
    "complex_array" complex[] NULL
)
"""


def _cplx_decode(val) -> np.complex64:
    cplx = complex(*literal_eval(val))
    return np.complex64(cplx)


def _cplx_encode(val: np.complex64 | complex) -> str:
    return str((np.float32(val.real), np.float32(val.imag)))


async def set_type_codec(conn):
    """
    had to use this patch since the conn.set_type_codec does not work for scalar variables
    """
    schema = 'public'
    format = 'text'
    conn._check_open()
    typenames = ('complex',)
    for typename in typenames:
        typeinfo = await conn.fetchrow(
            asyncpg.introspection.TYPE_BY_NAME, typename, schema)
        if not typeinfo:
            raise ValueError('unknown type: {}.{}'.format(schema, typename))

        oid = typeinfo['oid']
        conn._protocol.get_settings().add_python_codec(
            oid, typename, schema, 'scalar',
            lambda a: _cplx_encode(a), lambda a: _cplx_decode(a), format)

# if this part is commented, error message is : 
# asyncpg.exceptions._base.InternalClientError: no binary format encoder for type complex    
        conn._protocol.get_settings().add_python_codec(
            oid, typename, schema, 'scalar',
            encoder=lambda x: struct.pack('!2f', x.real, x.imag),
            decoder=lambda x: np.frombuffer(x, dtype=np.complex64)[0],
            format="binary")

    # Statement cache is no longer valid due to codec changes.
    conn._drop_local_statement_cache()


async def init_connection(conn):
    await set_type_codec(conn)
    await conn.set_type_codec(
        'numeric', encoder=str, decoder=np.float32,
        schema='pg_catalog', format='text'
    )
    await conn.set_type_codec(
        'float4', encoder=str, decoder=np.float32,
        schema='pg_catalog', format='text'
    )
    await conn.set_type_codec(
        'float4', encoder=struct.Struct("!f").pack, decoder=struct.Struct("!f").unpack,
        schema='pg_catalog', format='binary'
    )


async def trunc(pool):
    async with pool.acquire() as conn:
        async with conn.transaction():
            query = "TRUNCATE poc_asyncpg"
            await conn.execute(query)


async def worker_copy(pool, column_name, data):
    await trunc(pool)
    async with pool.acquire() as conn:
        async with conn.transaction():
            await conn.copy_records_to_table("poc_asyncpg",
                                             records=[(data,)],
                                             columns=(column_name,)
                                             )


try:
    from common import dbinfo
except ImportError as e:
    class DB:
        user = "user"
        password = "password"
        database = "db"


    dbinfo = DB()


def create_pool(db_info):
    pool = asyncpg.create_pool(
        user=db_info.user,
        password=db_info.password,
        database=dbinfo.database,
        host="127.0.0.1",
        init=init_connection
    )
    return pool


async def main(info):
    pool = await create_pool(info)
    async with pool.acquire() as conn:
        async with conn.transaction():
            await conn.execute(SQL_CREATE)

    await worker_copy(pool, "float_value", np.float32(4.2))
    await worker_copy(pool, "complex_value", np.complex64(4.2 + 1j * 4.2))


if __name__ == '__main__':
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    loop.run_until_complete(main(dbinfo))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions