14
14
# pylint: disable=unnecessary-dunder-call
15
15
16
16
from logging import getLogger
17
- from typing import Any , Collection , Dict , Optional
17
+ from typing import Any , Collection , Dict , Optional , Union
18
18
19
19
import pika
20
20
import wrapt
24
24
BlockingChannel ,
25
25
_QueueConsumerGeneratorInfo ,
26
26
)
27
+ from pika .channel import Channel
28
+ from pika .connection import Connection
27
29
28
30
from opentelemetry import trace
29
31
from opentelemetry .instrumentation .instrumentor import BaseInstrumentor
@@ -53,12 +55,16 @@ class PikaInstrumentor(BaseInstrumentor): # type: ignore
53
55
54
56
# pylint: disable=attribute-defined-outside-init
55
57
@staticmethod
56
- def _instrument_blocking_channel_consumers (
57
- channel : BlockingChannel ,
58
+ def _instrument_channel_consumers (
59
+ channel : Union [ BlockingChannel , Channel ] ,
58
60
tracer : Tracer ,
59
61
consume_hook : utils .HookT = utils .dummy_callback ,
60
62
) -> Any :
61
- for consumer_tag , consumer_info in channel ._consumer_infos .items ():
63
+ if isinstance (channel , BlockingChannel ):
64
+ consumer_infos = channel ._consumer_infos
65
+ elif isinstance (channel , Channel ):
66
+ consumer_infos = channel ._consumers
67
+ for consumer_tag , consumer_info in consumer_infos .items ():
62
68
callback_attr = PikaInstrumentor .CONSUMER_CALLBACK_ATTR
63
69
consumer_callback = getattr (consumer_info , callback_attr , None )
64
70
if consumer_callback is None :
@@ -79,7 +85,7 @@ def _instrument_blocking_channel_consumers(
79
85
80
86
@staticmethod
81
87
def _instrument_basic_publish (
82
- channel : BlockingChannel ,
88
+ channel : Union [ BlockingChannel , Channel ] ,
83
89
tracer : Tracer ,
84
90
publish_hook : utils .HookT = utils .dummy_callback ,
85
91
) -> None :
@@ -93,7 +99,7 @@ def _instrument_basic_publish(
93
99
94
100
@staticmethod
95
101
def _instrument_channel_functions (
96
- channel : BlockingChannel ,
102
+ channel : Union [ BlockingChannel , Channel ] ,
97
103
tracer : Tracer ,
98
104
publish_hook : utils .HookT = utils .dummy_callback ,
99
105
) -> None :
@@ -103,7 +109,9 @@ def _instrument_channel_functions(
103
109
)
104
110
105
111
@staticmethod
106
- def _uninstrument_channel_functions (channel : BlockingChannel ) -> None :
112
+ def _uninstrument_channel_functions (
113
+ channel : Union [BlockingChannel , Channel ],
114
+ ) -> None :
107
115
for function_name in _FUNCTIONS_TO_UNINSTRUMENT :
108
116
if not hasattr (channel , function_name ):
109
117
continue
@@ -115,7 +123,7 @@ def _uninstrument_channel_functions(channel: BlockingChannel) -> None:
115
123
@staticmethod
116
124
# Make sure that the spans are created inside hash them set as parent and not as brothers
117
125
def instrument_channel (
118
- channel : BlockingChannel ,
126
+ channel : Union [ BlockingChannel , Channel ] ,
119
127
tracer_provider : Optional [TracerProvider ] = None ,
120
128
publish_hook : utils .HookT = utils .dummy_callback ,
121
129
consume_hook : utils .HookT = utils .dummy_callback ,
@@ -133,7 +141,7 @@ def instrument_channel(
133
141
tracer_provider ,
134
142
schema_url = "https://opentelemetry.io/schemas/1.11.0" ,
135
143
)
136
- PikaInstrumentor ._instrument_blocking_channel_consumers (
144
+ PikaInstrumentor ._instrument_channel_consumers (
137
145
channel , tracer , consume_hook
138
146
)
139
147
PikaInstrumentor ._decorate_basic_consume (channel , tracer , consume_hook )
@@ -178,16 +186,17 @@ def wrapper(wrapped, instance, args, kwargs):
178
186
return channel
179
187
180
188
wrapt .wrap_function_wrapper (BlockingConnection , "channel" , wrapper )
189
+ wrapt .wrap_function_wrapper (Connection , "channel" , wrapper )
181
190
182
191
@staticmethod
183
192
def _decorate_basic_consume (
184
- channel : BlockingChannel ,
193
+ channel : Union [ BlockingChannel , Channel ] ,
185
194
tracer : Optional [Tracer ],
186
195
consume_hook : utils .HookT = utils .dummy_callback ,
187
196
) -> None :
188
197
def wrapper (wrapped , instance , args , kwargs ):
189
198
return_value = wrapped (* args , ** kwargs )
190
- PikaInstrumentor ._instrument_blocking_channel_consumers (
199
+ PikaInstrumentor ._instrument_channel_consumers (
191
200
channel , tracer , consume_hook
192
201
)
193
202
return return_value
@@ -236,6 +245,7 @@ def _uninstrument(self, **kwargs: Dict[str, Any]) -> None:
236
245
if hasattr (self , "__opentelemetry_tracer_provider" ):
237
246
delattr (self , "__opentelemetry_tracer_provider" )
238
247
unwrap (BlockingConnection , "channel" )
248
+ unwrap (Connection , "channel" )
239
249
unwrap (_QueueConsumerGeneratorInfo , "__init__" )
240
250
241
251
def instrumentation_dependencies (self ) -> Collection [str ]:
0 commit comments