@@ -827,7 +827,7 @@ async def copy_to_table(self, table_name, *, source,
827827 delimiter = None , null = None , header = None ,
828828 quote = None , escape = None , force_quote = None ,
829829 force_not_null = None , force_null = None ,
830- encoding = None ):
830+ encoding = None , where = None ):
831831 """Copy data to the specified table.
832832
833833 :param str table_name:
@@ -846,6 +846,15 @@ async def copy_to_table(self, table_name, *, source,
846846 :param str schema_name:
847847 An optional schema name to qualify the table.
848848
849+ :param str where:
850+ An optional condition used to filter rows when copying.
851+
852+ .. note::
853+
854+ Usage of this parameter requires support for the
855+ ``COPY FROM ... WHERE`` syntax, introduced in
856+ PostgreSQL version 12.
857+
849858 :param float timeout:
850859 Optional timeout value in seconds.
851860
@@ -873,6 +882,9 @@ async def copy_to_table(self, table_name, *, source,
873882 https://www.postgresql.org/docs/current/static/sql-copy.html
874883
875884 .. versionadded:: 0.11.0
885+
886+ .. versionadded:: 0.27.0
887+ Added ``where`` parameter.
876888 """
877889 tabname = utils ._quote_ident (table_name )
878890 if schema_name :
@@ -884,21 +896,22 @@ async def copy_to_table(self, table_name, *, source,
884896 else :
885897 cols = ''
886898
899+ cond = self ._format_copy_where (where )
887900 opts = self ._format_copy_opts (
888901 format = format , oids = oids , freeze = freeze , delimiter = delimiter ,
889902 null = null , header = header , quote = quote , escape = escape ,
890903 force_not_null = force_not_null , force_null = force_null ,
891904 encoding = encoding
892905 )
893906
894- copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}' .format (
895- tab = tabname , cols = cols , opts = opts )
907+ copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond} ' .format (
908+ tab = tabname , cols = cols , opts = opts , cond = cond )
896909
897910 return await self ._copy_in (copy_stmt , source , timeout )
898911
899912 async def copy_records_to_table (self , table_name , * , records ,
900913 columns = None , schema_name = None ,
901- timeout = None ):
914+ timeout = None , where = None ):
902915 """Copy a list of records to the specified table using binary COPY.
903916
904917 :param str table_name:
@@ -915,6 +928,16 @@ async def copy_records_to_table(self, table_name, *, records,
915928 :param str schema_name:
916929 An optional schema name to qualify the table.
917930
931+ :param str where:
932+ An optional condition used to filter rows when copying.
933+
934+ .. note::
935+
936+ Usage of this parameter requires support for the
937+ ``COPY FROM ... WHERE`` syntax, introduced in
938+ PostgreSQL version 12.
939+
940+
918941 :param float timeout:
919942 Optional timeout value in seconds.
920943
@@ -959,6 +982,9 @@ async def copy_records_to_table(self, table_name, *, records,
959982
960983 .. versionchanged:: 0.24.0
961984 The ``records`` argument may be an asynchronous iterable.
985+
986+ .. versionadded:: 0.27.0
987+ Added ``where`` parameter.
962988 """
963989 tabname = utils ._quote_ident (table_name )
964990 if schema_name :
@@ -976,14 +1002,27 @@ async def copy_records_to_table(self, table_name, *, records,
9761002
9771003 intro_ps = await self ._prepare (intro_query , use_cache = True )
9781004
1005+ cond = self ._format_copy_where (where )
9791006 opts = '(FORMAT binary)'
9801007
981- copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}' .format (
982- tab = tabname , cols = cols , opts = opts )
1008+ copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond} ' .format (
1009+ tab = tabname , cols = cols , opts = opts , cond = cond )
9831010
9841011 return await self ._protocol .copy_in (
9851012 copy_stmt , None , None , records , intro_ps ._state , timeout )
9861013
1014+ def _format_copy_where (self , where ):
1015+ if where and not self ._server_caps .sql_copy_from_where :
1016+ raise exceptions .UnsupportedServerFeatureError (
1017+ 'the `where` parameter requires PostgreSQL 12 or later' )
1018+
1019+ if where :
1020+ where_clause = 'WHERE ' + where
1021+ else :
1022+ where_clause = ''
1023+
1024+ return where_clause
1025+
9871026 def _format_copy_opts (self , * , format = None , oids = None , freeze = None ,
9881027 delimiter = None , null = None , header = None , quote = None ,
9891028 escape = None , force_quote = None , force_not_null = None ,
@@ -2308,7 +2347,7 @@ class _ConnectionProxy:
23082347ServerCapabilities = collections .namedtuple (
23092348 'ServerCapabilities' ,
23102349 ['advisory_locks' , 'notifications' , 'plpgsql' , 'sql_reset' ,
2311- 'sql_close_all' ])
2350+ 'sql_close_all' , 'sql_copy_from_where' ])
23122351ServerCapabilities .__doc__ = 'PostgreSQL server capabilities.'
23132352
23142353
@@ -2320,27 +2359,31 @@ def _detect_server_capabilities(server_version, connection_settings):
23202359 plpgsql = False
23212360 sql_reset = True
23222361 sql_close_all = False
2362+ sql_copy_from_where = False
23232363 elif hasattr (connection_settings , 'crdb_version' ):
23242364 # CockroachDB detected.
23252365 advisory_locks = False
23262366 notifications = False
23272367 plpgsql = False
23282368 sql_reset = False
23292369 sql_close_all = False
2370+ sql_copy_from_where = False
23302371 elif hasattr (connection_settings , 'crate_version' ):
23312372 # CrateDB detected.
23322373 advisory_locks = False
23332374 notifications = False
23342375 plpgsql = False
23352376 sql_reset = False
23362377 sql_close_all = False
2378+ sql_copy_from_where = False
23372379 else :
23382380 # Standard PostgreSQL server assumed.
23392381 advisory_locks = True
23402382 notifications = True
23412383 plpgsql = True
23422384 sql_reset = True
23432385 sql_close_all = True
2386+ sql_copy_from_where = server_version .major >= 12
23442387
23452388 return ServerCapabilities (
23462389 advisory_locks = advisory_locks ,
0 commit comments