Skip to content

Commit

Permalink
refactored to add rate limit retry
Browse files Browse the repository at this point in the history
  • Loading branch information
WildDogOne committed Aug 25, 2024
1 parent 08920b7 commit 212c2ee
Showing 1 changed file with 58 additions and 27 deletions.
85 changes: 58 additions & 27 deletions src/droid/platforms/ms_xdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
import msal
import requests
import json
import time
from datetime import datetime, timedelta, timezone
from droid.abstracts import AbstractPlatform
from droid.color import ColorLogger


logger = ColorLogger("droid.platforms.msxdr")


Expand Down Expand Up @@ -90,7 +92,9 @@ def mssp_run_xdr_search(
def run_xdr_search(self, rule_converted, rule_file):
payload = {"Query": rule_converted, "Timespan": "P1D"}
try:
results = self._post(url="/security/runHuntingQuery", payload=payload)
results, status_code = self._post(
url="/security/runHuntingQuery", payload=payload
)
if "error" in results:
self.logger.error(
f"Error while running the query {results['error']['message']}"
Expand All @@ -108,7 +112,9 @@ def get_rule(self, rule_id):
"""
try:
params = {"$filter": f"contains(displayName, '{rule_id}')"}
rule = self._get(url="/security/rules/detectionRules", params=params)
rule, status_code = self._get(
url="/security/rules/detectionRules", params=params
)
if len(rule["value"]) > 0:
return rule["value"][0]
else:
Expand Down Expand Up @@ -350,54 +356,50 @@ def check_rule_changes(self, existing_rule, new_rule):
def push_detection_rule(
self, alert_rule=None, rule_content=None, rule_file=None, rule_converted=None
):
headers = {
"Authorization": f"Bearer {self._token}",
"Content-Type": "application/json",
}
existing_rule = self.get_rule(rule_content["id"])
if existing_rule:
self.logger.info("Rule already exists")
if not self.check_rule_changes(existing_rule, alert_rule):
return True
else:
api_url = f"{self._api_base_url}/security/rules/detectionRules/{existing_rule['id']}"
response = requests.patch(api_url, headers=headers, json=alert_rule)
api_url = f"/security/rules/detectionRules/{existing_rule['id']}"
response, status_code = self._patch(url=api_url, payload=alert_rule)
else:
api_url = self._api_base_url + "/security/rules/detectionRules"
response = requests.post(api_url, headers=headers, json=alert_rule)
api_url = "/security/rules/detectionRules"
response, status_code = self._post(url=api_url, payload=alert_rule)

if response.status_code == 400:
if status_code == 400:
self.logger.error(
f"Could not export the rule {rule_file} due to a bad request. {response.json()['error']['message']}",
f"Could not export the rule {rule_file} due to a bad request. {response['error']['message']}",
extra={
"rule_file": rule_file,
"rule_converted": rule_converted,
"rule_content": rule_content,
"error": response.json(),
"error": response,
},
)
elif response.status_code == 403:
elif status_code == 403:
self.logger.error(
f"Could not export the rule {rule_file} due to insufficient permissions. {response.json()}",
f"Could not export the rule {rule_file} due to insufficient permissions. {response}",
extra={
"rule_file": rule_file,
"rule_converted": rule_converted,
"rule_content": rule_content,
"error": response.json(),
"error": response,
},
)
elif response.status_code == 201 or 200:
if "error" in response.json():
elif status_code == 201 or 200:
if "error" in response:
self.logger.error(
f"Could not export the rule {rule_file}",
extra={
"rule_file": rule_file,
"rule_converted": rule_converted,
"rule_content": rule_content,
"error": response.json(),
"error": response,
},
)
raise Exception(response.json())
raise Exception(response)
else:
self.logger.info(
f"Successfully exported the rule {rule_file}",
Expand All @@ -408,8 +410,8 @@ def push_detection_rule(
},
)
else:
print(response.status_code)
pprint(response.json())
print(status_code)
pprint(response)

def parse_actions(self, actions, rule_file=None):
# This whole function is a mess
Expand Down Expand Up @@ -572,26 +574,55 @@ def parse_impactedAssets(self, impactedAssets, rule_file=None):

def _get(self, url=None, headers=None, params=None):
# Send the JSON payload to Microsoft Graph Security API

api_url = self._api_base_url + url
headers = {
"Authorization": f"Bearer {self._token}",
"Content-Type": "application/json",
}
if headers:
headers.update(headers)
response = requests.get(api_url, headers=headers, params=params)
return response.json()
while True:
response = requests.get(api_url, headers=headers, params=params)
if response.status_code == 429:
logger.debug("Rate limit reached, waiting 60 seconds")
time.sleep(60)
else:
break

return response.json(), response.status_code

def _post(self, url=None, payload=None, headers=None, params=None):
# Send the JSON payload to Microsoft Graph Security API
api_url = self._api_base_url + url
headers = {
"Authorization": f"Bearer {self._token}",
"Content-Type": "application/json",
}
if headers:
headers.update(headers)
while True:
response = requests.post(api_url, headers=headers, json=payload)
if response.status_code == 429:
logger.debug("Rate limit reached, waiting 60 seconds")
time.sleep(60)
else:
break
return response.json(), response.status_code

def _patch(self, url=None, payload=None, headers=None, params=None):
# Send the JSON payload to Microsoft Graph Security API
api_url = self._api_base_url + url
headers = {
"Authorization": f"Bearer {self._token}",
"Content-Type": "application/json",
}
if headers:
headers.update(headers)
response = requests.post(api_url, headers=self._headers, json=payload)
return response.json()
while True:
response = requests.patch(api_url, headers=headers, json=payload)
if response.status_code == 429:
logger.debug("Rate limit reached, waiting 60 seconds")
time.sleep(60)
else:
break
return response.json(), response.status_code

0 comments on commit 212c2ee

Please sign in to comment.