Skip to content

Commit 5303764

Browse files
committed
FIX: make load_balance work again
1 parent 3daff30 commit 5303764

File tree

1 file changed

+16
-19
lines changed

1 file changed

+16
-19
lines changed

intelmq/lib/pipeline.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,22 @@
2929
class PipelineFactory(object):
3030

3131
@staticmethod
32-
def create(logger, broker=None, direction=None, queues=None, pipeline_args={}, load_balance=False, is_multithreaded=False):
32+
def create(logger, broker=None, direction=None, queues=None, pipeline_args=None, load_balance=False, is_multithreaded=False):
3333
"""
3434
direction: "source" or "destination", optional, needed for queues
3535
queues: needs direction to be set, calls set_queues
3636
bot: Bot instance
3737
"""
38+
if pipeline_args is None:
39+
pipeline_args = {}
40+
3841
if direction not in [None, "source", "destination"]:
3942
raise exceptions.InvalidArgument("direction", got=direction,
4043
expected=["destination", "source"])
4144

45+
if 'load_balance' not in pipeline_args:
46+
pipeline_args['load_balance'] = load_balance
47+
4248
if direction == 'source' and 'source_pipeline_broker' in pipeline_args:
4349
broker = pipeline_args['source_pipeline_broker'].title()
4450
if direction == 'destination' and 'destination_pipeline_broker' in pipeline_args:
@@ -96,10 +102,7 @@ def set_queues(self, queues: Optional[str], queues_type: str):
96102
"""
97103
if queues_type == "source":
98104
self.source_queue = queues
99-
if queues is not None:
100-
self.internal_queue = queues + "-internal"
101-
else:
102-
self.internal_queue = None
105+
self.internal_queue = None if queues is None else f'{queues}-internal'
103106

104107
elif queues_type == "destination":
105108
type_ = type(queues)
@@ -109,8 +112,7 @@ def set_queues(self, queues: Optional[str], queues_type: str):
109112
q = {"_default": queues.split()}
110113
elif type_ is dict:
111114
q = queues
112-
for key, val in queues.items():
113-
q[key] = val if type(val) is list else val.split()
115+
q.update({key: (val if isinstance(val, list) else val.split()) for key, val in queues.items()})
114116
else:
115117
raise exceptions.InvalidArgument(
116118
'queues', got=queues,
@@ -187,15 +189,12 @@ class Redis(Pipeline):
187189
destination_pipeline_password = None
188190

189191
def load_configurations(self, queues_type):
190-
self.host = self.pipeline_args.get("{}_pipeline_host".format(queues_type),
191-
"127.0.0.1")
192-
self.port = self.pipeline_args.get("{}_pipeline_port".format(queues_type), "6379")
193-
self.db = self.pipeline_args.get("{}_pipeline_db".format(queues_type), 2)
194-
self.password = self.pipeline_args.get("{}_pipeline_password".format(queues_type),
195-
None)
192+
self.host = self.pipeline_args.get(f"{queues_type}_pipeline_host", "127.0.0.1")
193+
self.port = self.pipeline_args.get(f"{queues_type}_pipeline_port", "6379")
194+
self.db = self.pipeline_args.get(f"{queues_type}_pipeline_db", 2)
195+
self.password = self.pipeline_args.get(f"{queues_type}_pipeline_password", None)
196196
# socket_timeout is None by default, which means no timeout
197-
self.socket_timeout = self.pipeline_args.get("{}_pipeline_socket_timeout".format(queues_type),
198-
None)
197+
self.socket_timeout = self.pipeline_args.get(f"{queues_type}_pipeline_socket_timeout", None)
199198
self.load_balance = self.pipeline_args.get("load_balance", False)
200199
self.load_balance_iterator = 0
201200

@@ -241,8 +240,7 @@ def send(self, message: str, path: str = "_default",
241240
if self.load_balance:
242241
queues = [queues[self.load_balance_iterator]]
243242
self.load_balance_iterator += 1
244-
if self.load_balance_iterator == len(self.destination_queues[path]):
245-
self.load_balance_iterator = 0
243+
self.load_balance_iterator %= len(self.destination_queues[path])
246244

247245
for destination_queue in queues:
248246
try:
@@ -559,8 +557,7 @@ def send(self, message: str, path: str = "_default",
559557
if self.load_balance:
560558
queues = [queues[self.load_balance_iterator]]
561559
self.load_balance_iterator += 1
562-
if self.load_balance_iterator == len(self.destination_queues[path]):
563-
self.load_balance_iterator = 0
560+
self.load_balance_iterator %= len(self.destination_queues[path])
564561

565562
for destination_queue in queues:
566563
self._send(destination_queue, message)

0 commit comments

Comments
 (0)