Skip to content

Commit

Permalink
Merge pull request #11 from ImMin5/master
Browse files Browse the repository at this point in the history
Add default option exclude root tenant group true
  • Loading branch information
ImMin5 authored Apr 5, 2024
2 parents ca4d46b + 2f88fcd commit d56e303
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 21 deletions.
9 changes: 5 additions & 4 deletions src/plugin/connector/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os

from azure.identity import ClientSecretCredential
from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.mgmt.resource import ResourceManagementClient, SubscriptionClient
from azure.mgmt.managementgroups import ManagementGroupsAPI
from azure.mgmt.billing import BillingManagementClient
Expand Down Expand Up @@ -41,9 +41,6 @@ def set_connect(self, secret_data: dict, tenant_id: str = None) -> None:
self.resource_client: ResourceManagementClient = ResourceManagementClient(
credential=credential, subscription_id=subscription_id
)
self.subscription_client: SubscriptionClient = SubscriptionClient(
credential=credential
)
self.management_groups_client: ManagementGroupsAPI = ManagementGroupsAPI(
credential=credential
)
Expand All @@ -52,6 +49,10 @@ def set_connect(self, secret_data: dict, tenant_id: str = None) -> None:
credential=credential, subscription_id=subscription_id
)

self.subscription_client: SubscriptionClient = SubscriptionClient(
credential=DefaultAzureCredential()
)

def _make_request_headers(self, secret_data, client_type=None):
access_token = self._get_access_token(secret_data)
headers = {
Expand Down
14 changes: 14 additions & 0 deletions src/plugin/connector/subscription_connector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import asyncio

from azure.core.exceptions import ClientAuthenticationError

from plugin.connector.base import AzureBaseConnector

_LOGGER = logging.getLogger("spaceone")
Expand All @@ -22,3 +24,15 @@ def list_tenants(
def list_subscriptions(self) -> dict:
subscriptions = self.subscription_client.subscriptions.list()
return subscriptions

def get_subscription(
self, secret_data: dict, subscription_id: str, tenant_id: str = None
) -> dict:
try:
if tenant_id:
self.set_connect(secret_data, tenant_id)
subscription = self.subscription_client.subscriptions.get(subscription_id)
return subscription
except ClientAuthenticationError as e:
_LOGGER.error(f"[get_subscription] {e.message} => SKIP")
return {}
2 changes: 1 addition & 1 deletion src/plugin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def account_collector_init(params: dict) -> dict:
"exclude_tenant_root_group": {
"title": "Exclude Tenant Root Group",
"type": "boolean",
"default": False,
"default": True,
},
},
}
Expand Down
65 changes: 49 additions & 16 deletions src/plugin/manager/resource_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from typing import List, Union

from azure.core.exceptions import ResourceNotFoundError
from azure.core.exceptions import ResourceNotFoundError, ClientAuthenticationError

from plugin.manager.base import AzureBaseManager
from plugin.connector.subscription_connector import SubscriptionConnector
Expand Down Expand Up @@ -46,6 +46,7 @@ def sync(
agreement_type = billing_account_info.get("agreement_type", "")

result_subscription_map = {}
subscription_info_map = {}
for subscription in billing_connector.list_subscription(
secret_data, agreement_type, billing_account_id
):
Expand All @@ -58,7 +59,9 @@ def sync(
)

if subscription_status in ["Active"] and subscription_id:
tenant_id = self._get_tenant_id(subscription_info, agreement_type)
tenant_id = self._get_tenant_id(
secret_data, subscription_info, agreement_type
)

name = self._get_subscription_name(
subscription_info, agreement_type
Expand All @@ -68,9 +71,19 @@ def sync(
subscription_info, agreement_type, tenant_id
)

result = self._make_result_without_secret(
tenant_id, subscription_id, name, location
)
if subscription_info_map.get("subscription_id") is None:
subscription_info_map = self._get_subscription_info_map(
subscription_info_map, secret_data, tenant_id
)

if subscription_info_map.get(subscription_id):
result = self._make_result(
tenant_id, subscription_id, name, location
)
else:
result = self._make_result_without_secret(
tenant_id, subscription_id, name, location
)
result_subscription_map[subscription_id] = result

tenants = subscription_connector.list_tenants()
Expand Down Expand Up @@ -103,9 +116,10 @@ def sync(
{
"tags": tags,
"location": location,
"secret_schema_id": "azure-secret-subscription-id",
"secret_schema_id": "azure-secret-multi-tenant",
"secret_data": {
"subscription_id": subscription_id
"subscription_id": subscription_id,
"tenant_id": tenant_id,
},
}
)
Expand Down Expand Up @@ -137,6 +151,27 @@ def sync(

return results

def _get_subscription_info_map(
self, subscription_info_map: dict, secret_data: dict, tenant_id: str
) -> dict:
try:
subscription_connector = SubscriptionConnector(
secret_data=secret_data, tenant_id=tenant_id
)
subscriptions = subscription_connector.list_subscriptions()
for subscription in subscriptions:
print(subscription)
subscription_info = self.convert_nested_dictionary(subscription)
subscription_id = subscription_info.get("subscription_id")
if subscription_id:
subscription_info_map[subscription_id] = subscription_info
except ClientAuthenticationError as e:
_LOGGER.error(f"[_get_subscription_info_map] {e.message}", exc_info=True)
except Exception as e:
_LOGGER.error(f"[_get_subscription_info_map] {e}", exc_info=True)

return subscription_info_map

@staticmethod
def _create_location_from_entity_info(
entity_info: dict, options: dict
Expand Down Expand Up @@ -167,16 +202,13 @@ def _get_subscription_status(subscription_info: dict, agreement_type: str) -> st
return subscription_info.get("subscription_billing_status", "")

@staticmethod
def _get_tenant_id(subscription_info: dict, agreement_type: str) -> str:
if agreement_type == "EnterpriseAgreement":
# EA has different structure, Use enrollment id as tenant id
tenant_id = subscription_info.get("properties", {}).get(
"enrollmentAccountId", ""
)
elif agreement_type == "MicrosoftPartnerAgreement":
def _get_tenant_id(
secret_data: dict, subscription_info: dict, agreement_type: str
) -> Union[str, None]:
if agreement_type == "MicrosoftPartnerAgreement":
tenant_id = subscription_info.get("customer_id").split("/")[-1]
else:
tenant_id = subscription_info.get("tenant_id")
tenant_id = secret_data["tenant_id"]
return tenant_id

@staticmethod
Expand Down Expand Up @@ -236,9 +268,10 @@ def _make_result(
"subscription_id": subscription_id,
"tenant_id": tenant_id,
},
"secret_schema_id": "azure-secret-subscription-id",
"secret_schema_id": "azure-secret-multi-tenant",
"secret_data": {
"subscription_id": subscription_id,
"tenant_id": tenant_id,
},
"resource_id": subscription_id,
"tags": tags,
Expand Down

0 comments on commit d56e303

Please sign in to comment.