Skip to content

Commit

Permalink
[ServiceBus] Enable FQDNs and connection strings to support newlines …
Browse files Browse the repository at this point in the history
…and protocol prefixing (e.g. sb://) (#15212)

* Enable FQDNs and connection strings to support trailing newlines and protocol prefixing (e.g. sb://, http://) and add tests to this effect.
  • Loading branch information
KieranBrantnerMagee authored Nov 18, 2020
1 parent f111ffc commit 10d4675
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 18 deletions.
7 changes: 6 additions & 1 deletion sdk/servicebus/azure-servicebus/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
## 7.0.0b9 (Unreleased)

**Breaking Changes**
* `ServiceBusSender` and `ServiceBusReceiver` are no more reusable and will raise `ValueError` when trying to operate on a closed handler.

* `ServiceBusSender` and `ServiceBusReceiver` are no longer reusable and will raise `ValueError` when trying to operate on a closed handler.

**BugFixes**

* FQDNs and Connection strings are now supported even with strippable whitespace or protocol headers (e.g. 'sb://').

**Bug Fixes**

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
OperationTimeoutError,
_create_servicebus_exception
)
from ._common.utils import create_properties
from ._common.utils import create_properties, strip_protocol_from_uri
from ._common.constants import (
CONTAINER_PREFIX,
MANAGEMENT_PATH_SUFFIX,
Expand All @@ -50,7 +50,7 @@ def _parse_conn_str(conn_str):
entity_path = None # type: Optional[str]
shared_access_signature = None # type: Optional[str]
shared_access_signature_expiry = None # type: Optional[int]
for element in conn_str.split(";"):
for element in conn_str.strip().split(";"):
key, _, value = element.partition("=")
if key.lower() == "endpoint":
endpoint = value.rstrip("/")
Expand Down Expand Up @@ -79,11 +79,7 @@ def _parse_conn_str(conn_str):
"\nWith alternate option of providing SharedAccessSignature instead of SharedAccessKeyName and Key"
)
entity = cast(str, entity_path)
left_slash_pos = cast(str, endpoint).find("//")
if left_slash_pos != -1:
host = cast(str, endpoint)[left_slash_pos + 2:]
else:
host = str(endpoint)
host = cast(str, strip_protocol_from_uri(cast(str, endpoint)))

return (host,
str(shared_access_key_name) if shared_access_key_name else None,
Expand Down Expand Up @@ -163,7 +159,8 @@ def __init__(
**kwargs
):
# type: (str, str, TokenCredential, Any) -> None
self.fully_qualified_namespace = fully_qualified_namespace
# If the user provided http:// or sb://, let's be polite and strip that.
self.fully_qualified_namespace = strip_protocol_from_uri(fully_qualified_namespace.strip())
self._entity_name = entity_name

subscription_name = kwargs.get("subscription_name")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,12 @@ def transform_messages_to_sendable_if_needed(messages):
return messages._to_outgoing_message()
except AttributeError:
return messages


def strip_protocol_from_uri(uri):
# type: (str) -> str
"""Removes the protocol (e.g. http:// or sb://) from a URI, such as the FQDN."""
left_slash_pos = uri.find("//")
if left_slash_pos != -1:
return uri[left_slash_pos + 2:]
return uri
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ._servicebus_sender import ServiceBusSender
from ._servicebus_receiver import ServiceBusReceiver
from ._common._configuration import Configuration
from ._common.utils import create_authentication, generate_dead_letter_entity_name
from ._common.utils import create_authentication, generate_dead_letter_entity_name, strip_protocol_from_uri
from ._common.constants import SubQueue

if TYPE_CHECKING:
Expand Down Expand Up @@ -70,7 +70,9 @@ def __init__(
**kwargs
):
# type: (str, TokenCredential, Any) -> None
self.fully_qualified_namespace = fully_qualified_namespace
# If the user provided http:// or sb://, let's be polite and strip that.
self.fully_qualified_namespace = strip_protocol_from_uri(fully_qualified_namespace.strip())

self._credential = credential
self._config = Configuration(**kwargs)
self._connection = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from .._base_handler import _generate_sas_token, BaseHandler as BaseHandlerSync
from .._common._configuration import Configuration
from .._common.utils import create_properties
from .._common.utils import create_properties, strip_protocol_from_uri
from .._common.constants import (
TOKEN_TYPE_SASTOKEN,
MGMT_REQUEST_OP_TYPE_ENTITY_MGMT,
Expand Down Expand Up @@ -81,7 +81,8 @@ def __init__(
credential: "TokenCredential",
**kwargs: Any
) -> None:
self.fully_qualified_namespace = fully_qualified_namespace
# If the user provided http:// or sb://, let's be polite and strip that.
self.fully_qualified_namespace = strip_protocol_from_uri(fully_qualified_namespace.strip())
self._entity_name = entity_name

subscription_name = kwargs.get("subscription_name")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ._servicebus_sender_async import ServiceBusSender
from ._servicebus_receiver_async import ServiceBusReceiver
from .._common._configuration import Configuration
from .._common.utils import generate_dead_letter_entity_name
from .._common.utils import generate_dead_letter_entity_name, strip_protocol_from_uri
from .._common.constants import SubQueue
from ._async_utils import create_authentication

Expand Down Expand Up @@ -66,7 +66,8 @@ def __init__(
credential: "TokenCredential",
**kwargs: Any
) -> None:
self.fully_qualified_namespace = fully_qualified_namespace
# If the user provided http:// or sb://, let's be polite and strip that.
self.fully_qualified_namespace = strip_protocol_from_uri(fully_qualified_namespace.strip())
self._credential = credential
self._config = Configuration(**kwargs)
self._connection = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,7 @@ def _schedule_interval_logger(self, end_time, description="", interval_seconds=3
def _do_interval_logging():
if end_time > datetime.utcnow() and not self._should_stop:
self._state.populate_process_stats()
_logger.critical("{} RECURRENT STATUS:".format(description))
_logger.critical(self._state)
_logger.critical("{} RECURRENT STATUS: {}".format(description, self._state))
self._schedule_interval_logger(end_time, description, interval_seconds)

t = threading.Timer(interval_seconds, _do_interval_logging)
Expand Down
41 changes: 40 additions & 1 deletion sdk/servicebus/azure-servicebus/tests/test_sb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,4 +224,43 @@ def test_client_sas_credential(self,
#with client:
# assert len(client._handlers) == 0
# with client.get_queue_sender(servicebus_queue.name) as sender:
# sender.send_messages(ServiceBusMessage("foo"))
# sender.send_messages(ServiceBusMessage("foo"))

@pytest.mark.liveTest
@pytest.mark.live_test_only
@CachedResourceGroupPreparer()
@CachedServiceBusNamespacePreparer(name_prefix='servicebustest')
@CachedServiceBusQueuePreparer(name_prefix='servicebustest')
def test_client_credential(self,
servicebus_queue,
servicebus_namespace,
servicebus_namespace_key_name,
servicebus_namespace_primary_key,
servicebus_namespace_connection_string,
**kwargs):
# This should "just work" to validate known-good.
credential = ServiceBusSharedKeyCredential(servicebus_namespace_key_name, servicebus_namespace_primary_key)
hostname = "{}.servicebus.windows.net".format(servicebus_namespace.name)

client = ServiceBusClient(hostname, credential)
with client:
assert len(client._handlers) == 0
with client.get_queue_sender(servicebus_queue.name) as sender:
sender.send_messages(ServiceBusMessage("foo"))

hostname = "sb://{}.servicebus.windows.net".format(servicebus_namespace.name)

client = ServiceBusClient(hostname, credential)
with client:
assert len(client._handlers) == 0
with client.get_queue_sender(servicebus_queue.name) as sender:
sender.send_messages(ServiceBusMessage("foo"))

hostname = "https://{}.servicebus.windows.net \
".format(servicebus_namespace.name)

client = ServiceBusClient(hostname, credential)
with client:
assert len(client._handlers) == 0
with client.get_queue_sender(servicebus_queue.name) as sender:
sender.send_messages(ServiceBusMessage("foo"))

0 comments on commit 10d4675

Please sign in to comment.