1010import io
1111import os
1212import tempfile
13+ import unittest
1314
1415import asyncpg
1516from asyncpg import _testbase as tb
@@ -415,7 +416,6 @@ async def test_copy_to_table_basics(self):
415416 '*a5*|b5' ,
416417 '*!**|*n-u-l-l*' ,
417418 'n-u-l-l|bb' ,
418- '_-_filtered_-_value_-_|never-here'
419419 ]).encode ('utf-8' )
420420 )
421421 f .seek (0 )
@@ -432,7 +432,7 @@ async def test_copy_to_table_basics(self):
432432 schema_name = 'public' , format = 'csv' ,
433433 delimiter = '|' , null = 'n-u-l-l' , header = True ,
434434 quote = '*' , escape = '!' , force_not_null = ('a' ,),
435- force_null = force_null , where = 'a <> \' _-_filtered_-_value_-_ \' ' )
435+ force_null = force_null )
436436
437437 self .assertEqual (res , 'COPY 7' )
438438
@@ -636,16 +636,44 @@ async def test_copy_records_to_table_1(self):
636636 ]
637637
638638 records .append (('a-100' , None , None ))
639- records .append (('b-999' , None , None ))
640639
641640 res = await self .con .copy_records_to_table (
642- 'copytab' , records = records , where = 'a <> \' b-999 \' ' )
641+ 'copytab' , records = records )
643642
644643 self .assertEqual (res , 'COPY 101' )
645644
646645 finally :
647646 await self .con .execute ('DROP TABLE copytab' )
648647
648+ async def test_copy_records_to_table_where (self ):
649+ if not self .con ._server_caps .sql_copy_from_where :
650+ raise unittest .SkipTest (
651+ 'COPY WHERE not supported on server' )
652+
653+ await self .con .execute ('''
654+ CREATE TABLE copytab_where(a text, b int, c timestamptz);
655+ ''' )
656+
657+ try :
658+ date = datetime .datetime .now (tz = datetime .timezone .utc )
659+ delta = datetime .timedelta (days = 1 )
660+
661+ records = [
662+ ('a-{}' .format (i ), i , date + delta )
663+ for i in range (100 )
664+ ]
665+
666+ records .append (('a-100' , None , None ))
667+ records .append (('b-999' , None , None ))
668+
669+ res = await self .con .copy_records_to_table (
670+ 'copytab_where' , records = records , where = 'a <> \' b-999\' ' )
671+
672+ self .assertEqual (res , 'COPY 101' )
673+
674+ finally :
675+ await self .con .execute ('DROP TABLE copytab_where' )
676+
649677 async def test_copy_records_to_table_async (self ):
650678 await self .con .execute ('''
651679 CREATE TABLE copytab_async(a text, b int, c timestamptz);
@@ -660,11 +688,9 @@ async def record_generator():
660688 yield ('a-{}' .format (i ), i , date + delta )
661689
662690 yield ('a-100' , None , None )
663- yield ('b-999' , None , None )
664691
665692 res = await self .con .copy_records_to_table (
666693 'copytab_async' , records = record_generator (),
667- where = 'a <> \' b-999\' '
668694 )
669695
670696 self .assertEqual (res , 'COPY 101' )
0 commit comments