Skip to content

Commit

Permalink
[CosmosDb] Add mongo collection copy command (#5506)
Browse files Browse the repository at this point in the history
* Remove preview tag

* Add mongo collection copy

* Rename param

* Add extra space

* Add new version

Co-authored-by: Nitesh Vijay <niteshvijay@microsoft.com>
  • Loading branch information
niteshvijay1995 and niteshvijay-ms authored Nov 3, 2022
1 parent b714fbe commit 1279c3b
Show file tree
Hide file tree
Showing 87 changed files with 1,538 additions and 515 deletions.
4 changes: 4 additions & 0 deletions src/cosmosdb-preview/HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
Release History
===============
0.21.0
* Add support for mongo data transfer jobs.

++++++
0.20.0
* Add support for Continuous mode restore with user provided identity.

Expand Down
17 changes: 16 additions & 1 deletion src/cosmosdb-preview/azext_cosmosdb_preview/_help.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,14 +567,29 @@
Usage: --dest-sql-container database=XX container=XX'
database: Database name of CosmosDB Sql.
container: Container name of CosmosDB Sql.
- name: --source-mongo
short-summary: "Source mongo collection"
long-summary: |
Usage: --source-mongo database=XX collection=XX'
database: Database name of CosmosDB Mongo.
collection: Collection name of CosmosDB Mongo.
- name: --dest-mongo
short-summary: "Destination mongo collection"
long-summary: |
Usage: --dest-mongo database=XX collection=XX'
database: Database name of CosmosDB Mongo.
collection: Collection name of CosmosDB Mongo.
examples:
- name: Copy sql container
text: |-
az cosmosdb dts copy -g "rg1" --job-name "j1" --account-name "db1" --source-sql-container database=db1 container=c1 --dest-sql-container database=db2 container=c2
- name: Copy cassandra table
text: |-
az cosmosdb dts copy -g "rg1" --job-name "j1" --account-name "db1" --source-cassandra-table keyspace=k1 table=t1 --dest-cassandra-table keyspace=k1 table=t1
az cosmosdb dts copy -g "rg1" --job-name "j1" --account-name "db1" --source-cassandra-table keyspace=k1 table=t1 --dest-cassandra-table keyspace=k2 table=t2
- name: Copy mongo collection
text: |-
az cosmosdb dts copy -g "rg1" --job-name "j1" --account-name "db1" --source-mongo database=d1 collection=c1 --dest-mongo database=d2 collection=c2
"""

helps['cosmosdb dts'] = """
Expand Down
3 changes: 3 additions & 0 deletions src/cosmosdb-preview/azext_cosmosdb_preview/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CreateGremlinDatabaseRestoreResource,
CreateTableRestoreResource,
AddCassandraTableAction,
AddMongoCollectionAction,
AddSqlContainerAction,
CreateTargetPhysicalPartitionThroughputInfoAction,
CreateSourcePhysicalPartitionThroughputInfoAction,
Expand Down Expand Up @@ -314,8 +315,10 @@ def load_arguments(self, _):
with self.argument_context('cosmosdb dts copy') as c:
c.argument('job_name', job_name_type)
c.argument('source_cassandra_table', nargs='+', action=AddCassandraTableAction, help='Source cassandra table')
c.argument('source_mongo', nargs='+', action=AddMongoCollectionAction, help='Source mongo collection')
c.argument('source_sql_container', nargs='+', action=AddSqlContainerAction, help='Source sql container')
c.argument('dest_cassandra_table', nargs='+', action=AddCassandraTableAction, help='Destination cassandra table')
c.argument('dest_mongo', nargs='+', action=AddMongoCollectionAction, help='Destination mongo collection')
c.argument('dest_sql_container', nargs='+', action=AddSqlContainerAction, help='Destination sql container')
c.argument('worker_count', type=int, help='Worker count')

Expand Down
40 changes: 40 additions & 0 deletions src/cosmosdb-preview/azext_cosmosdb_preview/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DatabaseRestoreResource,
GremlinDatabaseRestoreResource,
CosmosCassandraDataTransferDataSourceSink,
CosmosMongoDataTransferDataSourceSink,
CosmosSqlDataTransferDataSourceSink,
PhysicalPartitionThroughputInfoResource,
PhysicalPartitionId
Expand Down Expand Up @@ -138,6 +139,45 @@ def __call__(self, parser, namespace, values, option_string=None):
namespace.cassandra_table = cassandra_table


class AddMongoCollectionAction(argparse._AppendAction):
def __call__(self, parser, namespace, values, option_string=None):
if not values:
# pylint: disable=line-too-long
raise CLIError(f'usage error: {option_string} [KEY=VALUE ...]')

database_name = None
collection_name = None

for (k, v) in (x.split('=', 1) for x in values):
kl = k.lower()
if kl == 'database':
database_name = v

elif kl == 'collection':
collection_name = v

else:
raise CLIError(
f'Unsupported Key {k} is provided for {option_string} component. All'
' possible keys are: database, collection'
)

if database_name is None:
raise CLIError(f'usage error: missing key database in {option_string} component')

if collection_name is None:
raise CLIError(f'usage error: missing key table in {option_string} component')

mongo_collection = CosmosMongoDataTransferDataSourceSink(database_name=database_name, collection_name=collection_name)

if option_string == "--source-mongo":
namespace.source_mongo = mongo_collection
elif option_string == "--dest-mongo":
namespace.dest_mongo = mongo_collection
else:
namespace.mongo_collection = mongo_collection


class AddSqlContainerAction(argparse._AppendAction):
def __call__(self, parser, namespace, values, option_string=None):
if not values:
Expand Down
50 changes: 34 additions & 16 deletions src/cosmosdb-preview/azext_cosmosdb_preview/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,33 +1114,51 @@ def cosmosdb_data_transfer_copy_job(client,
dest_cassandra_table=None,
source_sql_container=None,
dest_sql_container=None,
source_mongo=None,
dest_mongo=None,
worker_count=0,
job_name=None):
if source_cassandra_table is None and source_sql_container is None:
raise CLIError('source component ismissing')

if source_cassandra_table is not None and source_sql_container is not None:
raise CLIError('Invalid input: multiple source components')

if dest_cassandra_table is None and dest_sql_container is None:
raise CLIError('destination component is missing')

if dest_cassandra_table is not None and dest_sql_container is not None:
raise CLIError('Invalid input: multiple destination components')

job_create_properties = {}

source = None
if source_cassandra_table is not None:
job_create_properties['source'] = source_cassandra_table
if source is not None:
raise CLIError('Invalid input: multiple source components')
source = source_cassandra_table

if source_sql_container is not None:
job_create_properties['source'] = source_sql_container
if source is not None:
raise CLIError('Invalid input: multiple source components')
source = source_sql_container

if source_mongo is not None:
if source is not None:
raise CLIError('Invalid input: multiple source components')
source = source_mongo

if source is None:
raise CLIError('source component is missing')
job_create_properties['source'] = source

destination = None
if dest_cassandra_table is not None:
job_create_properties['destination'] = dest_cassandra_table
if destination is not None:
raise CLIError('Invalid input: multiple destination components')
destination = dest_cassandra_table

if dest_sql_container is not None:
job_create_properties['destination'] = dest_sql_container
if destination is not None:
raise CLIError('Invalid input: multiple destination components')
destination = dest_sql_container

if dest_mongo is not None:
if destination is not None:
raise CLIError('Invalid input: multiple destination components')
destination = dest_mongo

if destination is None:
raise CLIError('destination component is missing')
job_create_properties['destination'] = destination

if worker_count > 0:
job_create_properties['worker_count'] = worker_count
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------
from typing import Any, AsyncIterable, Callable, Dict, IO, Optional, TypeVar, Union, cast, overload
from urllib.parse import parse_qs, urljoin, urlparse
import urllib.parse

from azure.core.async_paging import AsyncItemPaged, AsyncList
from azure.core.exceptions import (
Expand Down Expand Up @@ -106,10 +106,17 @@ def prepare_request(next_link=None):

else:
# make call to next link with the client's api-version
_parsed_next_link = urlparse(next_link)
_next_request_params = case_insensitive_dict(parse_qs(_parsed_next_link.query))
_parsed_next_link = urllib.parse.urlparse(next_link)
_next_request_params = case_insensitive_dict(
{
key: [urllib.parse.quote(v) for v in value]
for key, value in urllib.parse.parse_qs(_parsed_next_link.query).items()
}
)
_next_request_params["api-version"] = self._config.api_version
request = HttpRequest("GET", urljoin(next_link, _parsed_next_link.path), params=_next_request_params)
request = HttpRequest(
"GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params
)
request = _convert_request(request)
request.url = self._client.format_url(request.url) # type: ignore
request.method = "GET"
Expand Down Expand Up @@ -184,10 +191,17 @@ def prepare_request(next_link=None):

else:
# make call to next link with the client's api-version
_parsed_next_link = urlparse(next_link)
_next_request_params = case_insensitive_dict(parse_qs(_parsed_next_link.query))
_parsed_next_link = urllib.parse.urlparse(next_link)
_next_request_params = case_insensitive_dict(
{
key: [urllib.parse.quote(v) for v in value]
for key, value in urllib.parse.parse_qs(_parsed_next_link.query).items()
}
)
_next_request_params["api-version"] = self._config.api_version
request = HttpRequest("GET", urljoin(next_link, _parsed_next_link.path), params=_next_request_params)
request = HttpRequest(
"GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params
)
request = _convert_request(request)
request.url = self._client.format_url(request.url) # type: ignore
request.method = "GET"
Expand Down Expand Up @@ -1067,10 +1081,17 @@ def prepare_request(next_link=None):

else:
# make call to next link with the client's api-version
_parsed_next_link = urlparse(next_link)
_next_request_params = case_insensitive_dict(parse_qs(_parsed_next_link.query))
_parsed_next_link = urllib.parse.urlparse(next_link)
_next_request_params = case_insensitive_dict(
{
key: [urllib.parse.quote(v) for v in value]
for key, value in urllib.parse.parse_qs(_parsed_next_link.query).items()
}
)
_next_request_params["api-version"] = self._config.api_version
request = HttpRequest("GET", urljoin(next_link, _parsed_next_link.path), params=_next_request_params)
request = HttpRequest(
"GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params
)
request = _convert_request(request)
request.url = self._client.format_url(request.url) # type: ignore
request.method = "GET"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------
from typing import Any, AsyncIterable, Callable, Dict, IO, Optional, TypeVar, Union, cast, overload
from urllib.parse import parse_qs, urljoin, urlparse
import urllib.parse

from azure.core.async_paging import AsyncItemPaged, AsyncList
from azure.core.exceptions import (
Expand Down Expand Up @@ -108,10 +108,17 @@ def prepare_request(next_link=None):

else:
# make call to next link with the client's api-version
_parsed_next_link = urlparse(next_link)
_next_request_params = case_insensitive_dict(parse_qs(_parsed_next_link.query))
_parsed_next_link = urllib.parse.urlparse(next_link)
_next_request_params = case_insensitive_dict(
{
key: [urllib.parse.quote(v) for v in value]
for key, value in urllib.parse.parse_qs(_parsed_next_link.query).items()
}
)
_next_request_params["api-version"] = self._config.api_version
request = HttpRequest("GET", urljoin(next_link, _parsed_next_link.path), params=_next_request_params)
request = HttpRequest(
"GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params
)
request = _convert_request(request)
request.url = self._client.format_url(request.url) # type: ignore
request.method = "GET"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------
from typing import Any, AsyncIterable, Callable, Dict, IO, Optional, TypeVar, Union, cast, overload
from urllib.parse import parse_qs, urljoin, urlparse
import urllib.parse

from azure.core.async_paging import AsyncItemPaged, AsyncList
from azure.core.exceptions import (
Expand Down Expand Up @@ -129,10 +129,17 @@ def prepare_request(next_link=None):

else:
# make call to next link with the client's api-version
_parsed_next_link = urlparse(next_link)
_next_request_params = case_insensitive_dict(parse_qs(_parsed_next_link.query))
_parsed_next_link = urllib.parse.urlparse(next_link)
_next_request_params = case_insensitive_dict(
{
key: [urllib.parse.quote(v) for v in value]
for key, value in urllib.parse.parse_qs(_parsed_next_link.query).items()
}
)
_next_request_params["api-version"] = self._config.api_version
request = HttpRequest("GET", urljoin(next_link, _parsed_next_link.path), params=_next_request_params)
request = HttpRequest(
"GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params
)
request = _convert_request(request)
request.url = self._client.format_url(request.url) # type: ignore
request.method = "GET"
Expand Down Expand Up @@ -1177,10 +1184,17 @@ def prepare_request(next_link=None):

else:
# make call to next link with the client's api-version
_parsed_next_link = urlparse(next_link)
_next_request_params = case_insensitive_dict(parse_qs(_parsed_next_link.query))
_parsed_next_link = urllib.parse.urlparse(next_link)
_next_request_params = case_insensitive_dict(
{
key: [urllib.parse.quote(v) for v in value]
for key, value in urllib.parse.parse_qs(_parsed_next_link.query).items()
}
)
_next_request_params["api-version"] = self._config.api_version
request = HttpRequest("GET", urljoin(next_link, _parsed_next_link.path), params=_next_request_params)
request = HttpRequest(
"GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params
)
request = _convert_request(request)
request.url = self._client.format_url(request.url) # type: ignore
request.method = "GET"
Expand Down Expand Up @@ -2266,10 +2280,17 @@ def prepare_request(next_link=None):

else:
# make call to next link with the client's api-version
_parsed_next_link = urlparse(next_link)
_next_request_params = case_insensitive_dict(parse_qs(_parsed_next_link.query))
_parsed_next_link = urllib.parse.urlparse(next_link)
_next_request_params = case_insensitive_dict(
{
key: [urllib.parse.quote(v) for v in value]
for key, value in urllib.parse.parse_qs(_parsed_next_link.query).items()
}
)
_next_request_params["api-version"] = self._config.api_version
request = HttpRequest("GET", urljoin(next_link, _parsed_next_link.path), params=_next_request_params)
request = HttpRequest(
"GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params
)
request = _convert_request(request)
request.url = self._client.format_url(request.url) # type: ignore
request.method = "GET"
Expand Down
Loading

0 comments on commit 1279c3b

Please sign in to comment.