Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bastion] adding rdp file to temp location and adding auth-type for rdp #7006

Merged
merged 3 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixing some pylint issues
  • Loading branch information
aavalang committed Nov 17, 2023
commit 8cd545d6eccb63c9a490431d9d3ccbecebd0ae9b
2 changes: 1 addition & 1 deletion src/bastion/azext_bastion/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def load_arguments(self, _): # pylint: disable=unused-argument
c.argument("disable_gateway", help="Flag to disable access through RD gateway.",
arg_type=get_three_state_flag())
c.argument("auth_type", help="Auth type to use for RDP connections.", required=False,
options_list=["--auth-type"])
options_list=["--auth-type"])
c.argument('enable_mfa', help='Enable RDS auth for MFA if supported by the target machine.',
arg_type=get_three_state_flag())
with self.argument_context("network bastion tunnel") as c:
Expand Down
49 changes: 27 additions & 22 deletions src/bastion/azext_bastion/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,17 +149,18 @@ def ssh_bastion_host(cmd, auth_type, target_resource_id, target_ip_address, reso
if not resource_port:
resource_port = 22

if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True:
if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and \
bastion['enableTunneling'] is not True:
raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.')

ip_connect = _is_ipconnect_request(cmd, bastion, target_ip_address)
ip_connect = _is_ipconnect_request(bastion, target_ip_address)
if ip_connect:
target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}"
if int(resource_port) not in [22, 3389]:
raise UnrecognizedArgumentError("Custom ports are not allowed. Allowed ports for Tunnel with IP connect is 22, 3389.")
target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}"
f"/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}"

if ip_connect and int(resource_port) not in [22, 3389]:
raise UnrecognizedArgumentError("Custom ports are not allowed. Allowed ports for Tunnel with IP connect is 22, 3389.")

_validate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address)
_validate_resourceid(target_resource_id)
bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id)

tunnel_server = _get_tunnel(cmd, bastion, bastion_endpoint, target_resource_id, resource_port)
Expand Down Expand Up @@ -242,7 +243,8 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_
if not resource_port:
resource_port = 3389

if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True:
if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and \
bastion['enableTunneling'] is not True:
raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.')

if auth_type is None:
Expand All @@ -253,18 +255,19 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_
else:
raise UnrecognizedArgumentError("Unknown auth type, support auth-types: aad. For non aad login, you dont need to provide auth-type flag.")

ip_connect = _is_ipconnect_request(cmd, bastion, target_ip_address)

if ip_connect and enable_mfa:
raise UnrecognizedArgumentError("AAD login is not supported for IP Connect scenarios.")
ip_connect = _is_ipconnect_request(bastion, target_ip_address)

if ip_connect:
target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}"
if enable_mfa:
raise UnrecognizedArgumentError("AAD login is not supported for IP Connect scenarios.")

if ip_connect and int(resource_port) not in [22, 3389]:
raise UnrecognizedArgumentError("Custom ports are not allowed. Allowed ports for Tunnel with IP connect is 22, 3389.")
if int(resource_port) not in [22, 3389]:
raise UnrecognizedArgumentError("Custom ports are not allowed. Allowed ports for Tunnel with IP connect is 22, 3389.")

target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}"
f"/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}"

_validate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address)
_validate_resourceid(target_resource_id)
bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id)

if platform.system() == "Windows":
Expand All @@ -282,7 +285,8 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_
profile = Profile(cli_ctx=cmd.cli_ctx)
access_token = profile.get_raw_token()[0][2].get("accessToken")
logger.debug("Response %s", access_token)
web_address = f"https://{bastion_endpoint}/api/rdpfile?resourceId={target_resource_id}&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}"
web_address = f"https://{bastion_endpoint}/api/rdpfile?resourceId={target_resource_id}&format=rdp"
f"&rdpport={resource_port}&enablerdsaad={enable_mfa}"

headers = {
"Authorization": f"Bearer {access_token}",
Expand Down Expand Up @@ -311,14 +315,14 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_
raise UnrecognizedArgumentError("Platform is not supported for this command. Supported platforms: Windows")


def _is_ipconnect_request(cmd, bastion, target_ip_address):
def _is_ipconnect_request(bastion, target_ip_address):
if 'enableIpConnect' in bastion and bastion['enableIpConnect'] is True and target_ip_address:
return True

return False


def _validate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address):
def _validate_resourceid(target_resource_id):
if not is_valid_resource_id(target_resource_id):
err_msg = "Please enter a valid resource ID. If this is not working, " \
"try opening the JSON view of your resource (in the Overview tab), and copying the full resource ID."
Expand Down Expand Up @@ -373,14 +377,15 @@ def create_bastion_tunnel(cmd, target_resource_id, target_ip_address, resource_g
if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True:
raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.')

ip_connect = _is_ipconnect_request(cmd, bastion, target_ip_address)
ip_connect = _is_ipconnect_request(bastion, target_ip_address)
if ip_connect:
target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}"
target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/"
f"{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}"

if ip_connect and int(resource_port) not in [22, 3389]:
raise UnrecognizedArgumentError("Custom ports are not allowed. Allowed ports for Tunnel with IP connect is 22, 3389.")

_validate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address)
_validate_resourceid(target_resource_id)
bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id)

tunnel_server = _get_tunnel(cmd, bastion, bastion_endpoint, target_resource_id, resource_port, port)
Expand Down
2 changes: 1 addition & 1 deletion src/bastion/azext_bastion/developer_sku_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ def _get_data_pod(cmd, resource_port, target_resource_id, bastion):

web_address = f"https://{bastion['dnsName']}/api/connection"
response = requests.post(web_address, json=content, headers=headers,
verify=(not should_disable_connection_verify()))
verify=not should_disable_connection_verify())

return response.content.decode("utf-8")
17 changes: 9 additions & 8 deletions src/bastion/azext_bastion/tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@
from contextlib import closing
from datetime import datetime
from threading import Thread
import requests
import urllib3

import websocket
from websocket import create_connection, WebSocket

from msrestazure.azure_exceptions import CloudError
from .BastionServiceConstants import BastionSku

from azure.cli.core._profile import Profile
from azure.cli.core.util import should_disable_connection_verify

import requests
import urllib3

from knack.util import CLIError
from knack.log import get_logger

from .BastionServiceConstants import BastionSku

logger = get_logger(__name__)


Expand Down Expand Up @@ -96,7 +96,7 @@ def _get_auth_token(self):
logger.debug("Content: %s", str(content))
web_address = f"https://{self.bastion_endpoint}/api/tokens"
response = requests.post(web_address, data=content, headers=custom_header,
verify=(not should_disable_connection_verify()))
verify=not should_disable_connection_verify())
response_json = None

if response.content is not None:
Expand All @@ -121,7 +121,8 @@ def _listen(self):
self.client, _address = self.sock.accept()

auth_token = self._get_auth_token()
if self.bastion['sku']['name'] == BastionSku.QuickConnect.name or self.bastion['sku']['name'] == BastionSku.Developer.name:
if self.bastion['sku']['name'] == BastionSku.QuickConnect.name or \
self.bastion['sku']['name'] == BastionSku.Developer.name:
host = f"wss://{self.bastion_endpoint}/omni/webtunnel/{auth_token}"
else:
host = f"wss://{self.bastion_endpoint}/webtunnelv2/{auth_token}?X-Node-Id={self.node_id}"
Expand Down Expand Up @@ -204,7 +205,7 @@ def cleanup(self):

web_address = f"https://{self.bastion_endpoint}/api/tokens/{self.last_token}"
response = requests.delete(web_address, headers=custom_header,
verify=(not should_disable_connection_verify()))
verify=not should_disable_connection_verify())
if response.status_code == 404:
logger.info('Session already deleted')
elif response.status_code not in [200, 204]:
Expand Down