Skip to content

Commit

Permalink
Update async (#12978)
Browse files Browse the repository at this point in the history
* updated async code to be in line with sync methods, added connection_string

* removing async from two methods

* added name to codeowners for tables

* removed async from table service client from conn str method

* linting fixes

* fixed weird indentation

* changed parse_connection_str method to remove some repeated code, other fixes from krista/issy
  • Loading branch information
seankane-msft authored Aug 14, 2020
1 parent 87a6e32 commit 9138465
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 38 deletions.
3 changes: 3 additions & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@
# PRLabel: %Cognitive - Form Recognizer
/sdk/formrecognizer/ @kristapratico @iscai-msft @rakshith91

# PRLabel: %Tables
/sdk/tables/ @seankane-msft

# Smoke Tests
/common/smoketest/ @lmazuel @chlowell @annatisch @rakshith91 @shurd @southpolesteve

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def format_shared_key_credential(account, credential):
return credential


def parse_connection_str(conn_str, credential, service):
def parse_connection_str(conn_str, credential, service, keyword_args):
conn_str = conn_str.rstrip(";")
conn_settings = [s.split("=", 1) for s in conn_str.split(";")]
if any(len(tup) != 2 for tup in conn_settings):
Expand Down Expand Up @@ -378,7 +378,11 @@ def parse_connection_str(conn_str, credential, service):
)
except KeyError:
raise ValueError("Connection string missing required connection details.")
return primary, secondary, credential

if 'secondary_hostname' not in keyword_args:
keyword_args['secondary_hostname'] = secondary

return primary, credential


def create_configuration(**kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from ._entity import EntityProperty, EdmType, TableEntity
from ._common_conversion import _decode_base64_to_bytes
from ._generated.models import TableProperties
from ._error import TableErrorCode

if TYPE_CHECKING:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,9 @@ def from_connection_string(
:returns: A table client.
:rtype: ~azure.data.tables.TableClient
"""
account_url, secondary, credential = parse_connection_str(
conn_str=conn_str, credential=None, service='table')
if 'secondary_hostname' not in kwargs:
kwargs['secondary_hostname'] = secondary
return cls(account_url, table_name=table_name, credential=credential, **kwargs) # type: ignore
account_url, credential = parse_connection_str(
conn_str=conn_str, credential=None, service='table', keyword_args=kwargs)
return cls(account_url, table_name=table_name, credential=credential, **kwargs)

@classmethod
def from_table_url(cls, table_url, credential=None, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,8 @@ def from_connection_string(
:returns: A Table service client.
:rtype: ~azure.data.tables.TableServiceClient
"""
account_url, secondary, credential = parse_connection_str(
conn_str=conn_str, credential=None, service='table')
if 'secondary_hostname' not in kwargs:
kwargs['secondary_hostname'] = secondary
account_url, credential = parse_connection_str(
conn_str=conn_str, credential=None, service='table', keyword_args=kwargs)
return cls(account_url, credential=credential, **kwargs)

@distributed_trace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,19 @@
Any,
)

try:
from urllib.parse import urlparse, unquote
except ImportError:
from urlparse import urlparse # type: ignore
from urllib2 import unquote # type: ignore

from azure.core.async_paging import AsyncItemPaged
from azure.core.exceptions import ResourceNotFoundError, HttpResponseError
from azure.core.tracing.decorator import distributed_trace
from azure.core.tracing.decorator_async import distributed_trace_async

from .. import VERSION
from .._base_client import parse_connection_str
from .._entity import TableEntity
from .._generated.aio import AzureTable
from .._generated.models import SignedIdentifier, TableProperties, QueryOptions
Expand Down Expand Up @@ -67,6 +74,66 @@ def __init__(
self._client._config.version = kwargs.get('api_version', VERSION) # pylint: disable = W0212
self._loop = loop

@classmethod
def from_connection_string(
cls, conn_str, # type: str
table_name, # type: str
**kwargs # type: Any
):
# type: (...) -> TableClient
"""Create TableClient from a Connection String.
:param conn_str:
A connection string to an Azure Storage or Cosmos account.
:type conn_str: str
:param table_name: The table name.
:type table_name: str
:returns: A table client.
:rtype: ~azure.data.tables.TableClient
"""
account_url, credential = parse_connection_str(
conn_str=conn_str, credential=None, service='table', keyword_args=kwargs)
return cls(account_url, table_name=table_name, credential=credential, **kwargs)

@classmethod
def from_table_url(cls, table_url, credential=None, **kwargs):
# type: (str, Optional[Any], Any) -> TableClient
"""A client to interact with a specific Table.
:param table_url: The full URI to the table, including SAS token if used.
:type table_url: str
:param credential:
The credentials with which to authenticate. This is optional if the
account URL already has a SAS token. The value can be a SAS token string, an account
shared access key.
:type credential: str
:returns: A table client.
:rtype: ~azure.data.tables.TableClient
"""
try:
if not table_url.lower().startswith('http'):
table_url = "https://" + table_url
except AttributeError:
raise ValueError("Table URL must be a string.")
parsed_url = urlparse(table_url.rstrip('/'))

if not parsed_url.netloc:
raise ValueError("Invalid URL: {}".format(table_url))

table_path = parsed_url.path.lstrip('/').split('/')
account_path = ""
if len(table_path) > 1:
account_path = "/" + "/".join(table_path[:-1])
account_url = "{}://{}{}?{}".format(
parsed_url.scheme,
parsed_url.netloc.rstrip('/'),
account_path,
parsed_url.query)
table_name = unquote(table_path[-1])
if not table_name:
raise ValueError("Invalid URL. Please provide a URL with a valid table name")
return cls(account_url, table_name=table_name, credential=credential, **kwargs)

@distributed_trace_async
async def get_table_access_policy(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
from azure.core.tracing.decorator_async import distributed_trace_async

from .. import VERSION, LocationMode
from .._base_client import parse_connection_str
from .._generated.aio._azure_table_async import AzureTable
from .._generated.models import TableServiceProperties, TableProperties, QueryOptions
from .._models import service_stats_deserialize, service_properties_deserialize
from .._error import _validate_table_name, _process_table_error
from .._error import _process_table_error
from .._table_service_client_base import TableServiceClientBase
from .._models import Table
from ._policies_async import ExponentialRetry
Expand Down Expand Up @@ -86,6 +87,23 @@ def __init__(
self._client._config.version = kwargs.get('api_version', VERSION) # pylint: disable=protected-access
self._loop = loop

@classmethod
def from_connection_string(
cls, conn_str, # type: any
**kwargs # type: Any
): # type: (...) -> TableServiceClient
"""Create TableServiceClient from a Connection String.
:param conn_str:
A connection string to an Azure Storage or Cosmos account.
:type conn_str: str
:returns: A Table service client.
:rtype: ~azure.data.tables.TableServiceClient
"""
account_url, credential = parse_connection_str(
conn_str=conn_str, credential=None, service='table', keyword_args=kwargs)
return cls(account_url, credential=credential, **kwargs)

@distributed_trace_async
async def get_service_stats(self, **kwargs):
# type: (...) -> dict[str,object]
Expand Down Expand Up @@ -175,11 +193,8 @@ async def create_table(
:rtype: ~azure.data.tables.TableClient or None
:raises: ~azure.core.exceptions.HttpResponseError
"""
_validate_table_name(table_name)

table_properties = TableProperties(table_name=table_name, **kwargs)
await self._client.table.create(table_properties=table_properties, **kwargs)
table = self.get_table_client(table=table_name)
table = self.get_table_client(table_name=table_name)
await table.create_table(**kwargs)
return table

@distributed_trace_async
Expand All @@ -196,9 +211,8 @@ async def delete_table(
:return: None
:rtype: ~None
"""
_validate_table_name(table_name)

await self._client.table.delete(table=table_name, **kwargs)
table = self.get_table_client(table_name=table_name)
await table.delete_table(**kwargs)

@distributed_trace
def list_tables(
Expand Down Expand Up @@ -231,8 +245,7 @@ def list_tables(

@distributed_trace
def query_tables(
self,
filter, # pylint: disable=W0622
self, filter, # type: str pylint: disable=W0622
**kwargs # type: Any
):
# type: (...) -> AsyncItemPaged[Table]
Expand Down Expand Up @@ -262,24 +275,23 @@ def query_tables(
page_iterator_class=TablePropertiesPaged
)

def get_table_client(self, table, **kwargs):
# type: (Union[TableProperties, str], Optional[Any]) -> TableClient
def get_table_client(
self, table_name, # type: str
**kwargs # type: Optional[Any]
):
# type: (...) -> TableClient
"""Get a client to interact with the specified table.
The table need not already exist.
The table need not already exist.
:param table:
The queue. This can either be the name of the queue,
or an instance of QueueProperties.
:type table: str or ~azure.storage.table.TableProperties
:returns: A :class:`~azure.data.tables.TableClient` object.
:rtype: ~azure.data.tables.TableClient
:param table:
The queue. This can either be the name of the queue,
or an instance of QueueProperties.
:type table: str or ~azure.storage.table.TableProperties
:returns: A :class:`~azure.data.tables.TableClient` object.
:rtype: ~azure.data.tables.TableClient
"""
try:
table_name = table.name
except AttributeError:
table_name = table
"""

_pipeline = AsyncPipeline(
transport=AsyncTransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
Expand Down
2 changes: 1 addition & 1 deletion sdk/tables/azure-data-tables/tests/test_table_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ async def test_set_table_acl_with_signed_identifiers(self, resource_group, locat
pytest.skip("Cosmos endpoint does not support this")
ts = TableServiceClient(url, storage_account_key)
table = await self._create_table(ts)
client = ts.get_table_client(table=table.table_name)
client = ts.get_table_client(table_name=table.table_name)

# Act
identifiers = dict()
Expand Down

0 comments on commit 9138465

Please sign in to comment.