Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _open(self, timeout_time=None): # pylint:disable=unused-argument # TODO: to
else:
alt_creds = {}
self._create_handler()
self._handler.open(connection=self.client._conn_manager.get_connection(
self._handler.open(connection=self.client._conn_manager.get_connection( # pylint: disable=protected-access
self.client.address.hostname,
self.client.get_auth(**alt_creds)
))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import datetime
import functools
import asyncio

from typing import Any, List, Dict, Union, TYPE_CHECKING

from uamqp import authentication, constants # type: ignore
Expand Down Expand Up @@ -47,6 +48,7 @@ class EventHubClient(EventHubClientAbstract):
def __init__(self, host, event_hub_path, credential, **kwargs):
# type:(str, str, Union[EventHubSharedKeyCredential, EventHubSASTokenCredential, TokenCredential], Any) -> None
super(EventHubClient, self).__init__(host=host, event_hub_path=event_hub_path, credential=credential, **kwargs)
self._lock = asyncio.Lock()
self._conn_manager = get_connection_manager(**kwargs)

async def __aenter__(self):
Expand Down Expand Up @@ -105,10 +107,17 @@ async def _close_connection(self):
await self._conn_manager.reset_connection_if_broken()

async def _management_request(self, mgmt_msg, op_type):
if self._is_iothub and not self._iothub_redirect_info:
await self._iothub_redirect()

alt_creds = {
"username": self._auth_config.get("iot_username"),
"password": self._auth_config.get("iot_password")
}
max_retries = self.config.max_retries
retry_count = 0
while True:
mgmt_auth = self._create_auth()
mgmt_auth = self._create_auth(**alt_creds)
mgmt_client = AMQPClientAsync(self.mgmt_target, auth=mgmt_auth, debug=self.config.network_tracing)
try:
conn = await self._conn_manager.get_connection(self.host, mgmt_auth)
Expand All @@ -126,6 +135,18 @@ async def _management_request(self, mgmt_msg, op_type):
finally:
await mgmt_client.close_async()

async def _iothub_redirect(self):
async with self._lock:
if self._is_iothub and not self._iothub_redirect_info:
if not self._redirect_consumer:
self._redirect_consumer = self.create_consumer(consumer_group='$default',
partition_id='0',
event_position=EventPosition('-1'),
operation='/messages/events')
async with self._redirect_consumer:
await self._redirect_consumer._open_with_retry(timeout=self.config.receive_timeout) # pylint: disable=protected-access
self._redirect_consumer = None

async def get_properties(self):
# type:() -> Dict[str, Any]
"""
Expand All @@ -139,6 +160,8 @@ async def get_properties(self):
:rtype: dict
:raises: ~azure.eventhub.ConnectError
"""
if self._is_iothub and not self._iothub_redirect_info:
await self._iothub_redirect()
mgmt_msg = Message(application_properties={'name': self.eh_name})
response = await self._management_request(mgmt_msg, op_type=b'com.microsoft:eventhub')
output = {}
Expand Down Expand Up @@ -178,6 +201,8 @@ async def get_partition_properties(self, partition):
:rtype: dict
:raises: ~azure.eventhub.ConnectError
"""
if self._is_iothub and not self._iothub_redirect_info:
await self._iothub_redirect()
mgmt_msg = Message(application_properties={'name': self.eh_name,
'partition': partition})
response = await self._management_request(mgmt_msg, op_type=b'com.microsoft:partition')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,10 @@ async def __anext__(self):

def _create_handler(self):
alt_creds = {
"username": self.client._auth_config.get("iot_username"), # pylint:disable=protected-access
"password": self.client._auth_config.get("iot_password")} # pylint:disable=protected-access
"username": self.client._auth_config.get("iot_username") if self.redirected else None, # pylint:disable=protected-access
"password": self.client._auth_config.get("iot_password") if self.redirected else None # pylint:disable=protected-access
}

source = Source(self.source)
if self.offset is not None:
source.set_filter(self.offset._selector()) # pylint:disable=protected-access
Expand All @@ -134,19 +136,25 @@ async def _redirect(self, redirect):
self.messages_iter = None
await super(EventHubConsumer, self)._redirect(redirect)

async def _open(self, timeout_time=None):
async def _open(self, timeout_time=None, **kwargs):
"""
Open the EventHubConsumer using the supplied connection.
If the handler has previously been redirected, the redirect
context will be used to create a new handler before opening it.

"""
# pylint: disable=protected-access
self.redirected = self.redirected or self.client._iothub_redirect_info

if not self.running and self.redirected:
self.client._process_redirect_uri(self.redirected)
self.source = self.redirected.address
await super(EventHubConsumer, self)._open(timeout_time)

@_retry_decorator
async def _open_with_retry(self, timeout_time=None, **kwargs):
return await self._open(timeout_time=timeout_time, **kwargs)

async def _receive(self, timeout_time=None, max_batch_size=None, **kwargs):
last_exception = kwargs.get("last_exception")
data_batch = kwargs.get("data_batch")
Expand Down Expand Up @@ -254,4 +262,5 @@ async def close(self, exception=None):
self.error = EventHubError(str(exception))
else:
self.error = EventHubError("This receive handler is now closed.")
await self._handler.close_async()
if self._handler:
await self._handler.close_async()
26 changes: 25 additions & 1 deletion sdk/eventhub/azure-eventhubs/azure/eventhub/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import logging
import datetime
import functools
import threading

from typing import Any, List, Dict, Union, TYPE_CHECKING

import uamqp # type: ignore
Expand Down Expand Up @@ -46,6 +48,7 @@ class EventHubClient(EventHubClientAbstract):
def __init__(self, host, event_hub_path, credential, **kwargs):
# type:(str, str, Union[EventHubSharedKeyCredential, EventHubSASTokenCredential, TokenCredential], Any) -> None
super(EventHubClient, self).__init__(host=host, event_hub_path=event_hub_path, credential=credential, **kwargs)
self._lock = threading.RLock()
self._conn_manager = get_connection_manager(**kwargs)

def __enter__(self):
Expand Down Expand Up @@ -106,10 +109,15 @@ def _close_connection(self):
self._conn_manager.reset_connection_if_broken()

def _management_request(self, mgmt_msg, op_type):
alt_creds = {
"username": self._auth_config.get("iot_username"),
"password": self._auth_config.get("iot_password")
}

max_retries = self.config.max_retries
retry_count = 0
while retry_count <= self.config.max_retries:
mgmt_auth = self._create_auth()
mgmt_auth = self._create_auth(**alt_creds)
mgmt_client = uamqp.AMQPClient(self.mgmt_target)
try:
conn = self._conn_manager.get_connection(self.host, mgmt_auth) #pylint:disable=assignment-from-none
Expand All @@ -127,6 +135,18 @@ def _management_request(self, mgmt_msg, op_type):
finally:
mgmt_client.close()

def _iothub_redirect(self):
with self._lock:
if self._is_iothub and not self._iothub_redirect_info:
if not self._redirect_consumer:
self._redirect_consumer = self.create_consumer(consumer_group='$default',
partition_id='0',
event_position=EventPosition('-1'),
operation='/messages/events')
with self._redirect_consumer:
self._redirect_consumer._open_with_retry(timeout=self.config.receive_timeout) # pylint: disable=protected-access
self._redirect_consumer = None

def get_properties(self):
# type:() -> Dict[str, Any]
"""
Expand All @@ -140,6 +160,8 @@ def get_properties(self):
:rtype: dict
:raises: ~azure.eventhub.ConnectError
"""
if self._is_iothub and not self._iothub_redirect_info:
self._iothub_redirect()
mgmt_msg = Message(application_properties={'name': self.eh_name})
response = self._management_request(mgmt_msg, op_type=b'com.microsoft:eventhub')
output = {}
Expand Down Expand Up @@ -179,6 +201,8 @@ def get_partition_properties(self, partition):
:rtype: dict
:raises: ~azure.eventhub.ConnectError
"""
if self._is_iothub and not self._iothub_redirect_info:
self._iothub_redirect()
mgmt_msg = Message(application_properties={'name': self.eh_name,
'partition': partition})
response = self._management_request(mgmt_msg, op_type=b'com.microsoft:partition')
Expand Down
12 changes: 11 additions & 1 deletion sdk/eventhub/azure-eventhubs/azure/eventhub/client_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import functools
from abc import abstractmethod
from typing import Dict, Union, Any, TYPE_CHECKING
from azure.eventhub import __version__

from azure.eventhub import __version__, EventPosition
from azure.eventhub.configuration import _Configuration
from .common import EventHubSharedKeyCredential, EventHubSASTokenCredential, _Address

Expand Down Expand Up @@ -153,6 +154,8 @@ def __init__(self, host, event_hub_path, credential, **kwargs):
self.get_auth = functools.partial(self._create_auth)
self.config = _Configuration(**kwargs)
self.debug = self.config.network_tracing
self._is_iothub = False
self._iothub_redirect_info = None

log.info("%r: Created the Event Hub client", self.container_id)

Expand All @@ -173,6 +176,11 @@ def _from_iothub_connection_string(cls, conn_str, **kwargs):
'iot_password': key,
'username': username,
'password': password}
client._is_iothub = True
client._redirect_consumer = client.create_consumer(consumer_group='$default',
partition_id='0',
event_position=EventPosition('-1'),
operation='/messages/events')
return client

@abstractmethod
Expand Down Expand Up @@ -213,6 +221,8 @@ def _process_redirect_uri(self, redirect):
self.auth_uri = "sb://{}{}".format(self.address.hostname, self.address.path)
self.eh_name = self.address.path.lstrip('/')
self.mgmt_target = redirect_uri
if self._is_iothub:
self._iothub_redirect_info = redirect

@classmethod
def from_connection_string(cls, conn_str, **kwargs):
Expand Down
14 changes: 11 additions & 3 deletions sdk/eventhub/azure-eventhubs/azure/eventhub/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ def __next__(self):

def _create_handler(self):
alt_creds = {
"username": self.client._auth_config.get("iot_username"), # pylint:disable=protected-access
"password": self.client._auth_config.get("iot_password")} # pylint:disable=protected-access
"username": self.client._auth_config.get("iot_username") if self.redirected else None, # pylint:disable=protected-access
"password": self.client._auth_config.get("iot_password") if self.redirected else None # pylint:disable=protected-access
}

source = Source(self.source)
if self.offset is not None:
source.set_filter(self.offset._selector()) # pylint:disable=protected-access
Expand All @@ -129,19 +131,25 @@ def _redirect(self, redirect):
self.messages_iter = None
super(EventHubConsumer, self)._redirect(redirect)

def _open(self, timeout_time=None):
def _open(self, timeout_time=None, **kwargs):
"""
Open the EventHubConsumer using the supplied connection.
If the handler has previously been redirected, the redirect
context will be used to create a new handler before opening it.

"""
# pylint: disable=protected-access
self.redirected = self.redirected or self.client._iothub_redirect_info

if not self.running and self.redirected:
self.client._process_redirect_uri(self.redirected)
self.source = self.redirected.address
super(EventHubConsumer, self)._open(timeout_time)

@_retry_decorator
def _open_with_retry(self, timeout_time=None, **kwargs):
return self._open(timeout_time=timeout_time, **kwargs)

def _receive(self, timeout_time=None, max_batch_size=None, **kwargs):
last_exception = kwargs.get("last_exception")
data_batch = kwargs.get("data_batch")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,82 @@
# license information.
#--------------------------------------------------------------------------

import os
import asyncio
import pytest
import time

from azure.eventhub.aio import EventHubClient
from azure.eventhub import EventData, EventPosition, EventHubError
from azure.eventhub import EventPosition


async def pump(receiver, sleep=None):
messages = 0
if sleep:
await asyncio.sleep(sleep)
async with receiver:
batch = await receiver.receive(timeout=1)
batch = await receiver.receive(timeout=3)
messages += len(batch)
return messages


async def get_partitions(iot_connection_str):
client = EventHubClient.from_connection_string(iot_connection_str, network_tracing=False)
receiver = client.create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition("-1"), prefetch=1000, operation='/messages/events')
async with receiver:
partitions = await client.get_properties()
return partitions["partition_ids"]


@pytest.mark.liveTest
@pytest.mark.asyncio
async def test_iothub_receive_multiple_async(iot_connection_str):
pytest.skip("This will get AuthenticationError. We're investigating...")
partitions = await get_partitions(iot_connection_str)
client = EventHubClient.from_connection_string(iot_connection_str, network_tracing=False)
partitions = await client.get_partition_ids()
receivers = []
for p in partitions:
receivers.append(client.create_consumer(consumer_group="$default", partition_id=p, event_position=EventPosition("-1"), prefetch=10, operation='/messages/events'))
outputs = await asyncio.gather(*[pump(r) for r in receivers])

assert isinstance(outputs[0], int) and outputs[0] <= 10
assert isinstance(outputs[1], int) and outputs[1] <= 10


@pytest.mark.liveTest
@pytest.mark.asyncio
async def test_iothub_get_properties_async(iot_connection_str, device_id):
client = EventHubClient.from_connection_string(iot_connection_str, network_tracing=False)
properties = await client.get_properties()
assert properties["partition_ids"] == ["0", "1", "2", "3"]


@pytest.mark.liveTest
@pytest.mark.asyncio
async def test_iothub_get_partition_ids_async(iot_connection_str, device_id):
client = EventHubClient.from_connection_string(iot_connection_str, network_tracing=False)
partitions = await client.get_partition_ids()
assert partitions == ["0", "1", "2", "3"]


@pytest.mark.liveTest
@pytest.mark.asyncio
async def test_iothub_get_partition_properties_async(iot_connection_str, device_id):
client = EventHubClient.from_connection_string(iot_connection_str, network_tracing=False)
partition_properties = await client.get_partition_properties("0")
assert partition_properties["id"] == "0"


@pytest.mark.liveTest
@pytest.mark.asyncio
async def test_iothub_receive_after_mgmt_ops_async(iot_connection_str, device_id):
client = EventHubClient.from_connection_string(iot_connection_str, network_tracing=False)
partitions = await client.get_partition_ids()
assert partitions == ["0", "1", "2", "3"]
receiver = client.create_consumer(consumer_group="$default", partition_id=partitions[0], event_position=EventPosition("-1"), operation='/messages/events')
async with receiver:
received = await receiver.receive(timeout=5)
assert len(received) == 0


@pytest.mark.liveTest
@pytest.mark.asyncio
async def test_iothub_mgmt_ops_after_receive_async(iot_connection_str, device_id):
client = EventHubClient.from_connection_string(iot_connection_str, network_tracing=False)
receiver = client.create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition("-1"), operation='/messages/events')
async with receiver:
received = await receiver.receive(timeout=5)
assert len(received) == 0

partitions = await client.get_partition_ids()
assert partitions == ["0", "1", "2", "3"]

Loading