Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bastion] adding rdp file to temp location and adding auth-type for rdp #7006

Merged
merged 3 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/bastion/HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

Release History
===============
0.2.6
++++++
* Adding auth type aad for RDP to mimic the enable-mfa flag.
* Fixing issue where if powershell is opened in system32 directory, file generation throws error. Files are now dumped in temp folder.

0.2.5
++++++
* Fixing the command `az network bastion rdp` to avoid the `java.lang.NullPointerException` while calling `get_auth_token` function
Expand Down
2 changes: 2 additions & 0 deletions src/bastion/azext_bastion/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def load_arguments(self, _): # pylint: disable=unused-argument
c.argument("configure", help="Flag to configure RDP session.", action="store_true")
c.argument("disable_gateway", help="Flag to disable access through RD gateway.",
arg_type=get_three_state_flag())
c.argument("auth_type", help="Auth type to use for RDP connections.", required=False,
options_list=["--auth-type"])
c.argument('enable_mfa', help='Enable RDS auth for MFA if supported by the target machine.',
arg_type=get_three_state_flag())
with self.argument_context("network bastion tunnel") as c:
Expand Down
68 changes: 44 additions & 24 deletions src/bastion/azext_bastion/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import threading
import time
import json
import uuid

import requests
from azure.cli.core.azclierror import ValidationError, InvalidArgumentValueError, RequiredArgumentMissingError, \
Expand Down Expand Up @@ -148,17 +149,18 @@ def ssh_bastion_host(cmd, auth_type, target_resource_id, target_ip_address, reso
if not resource_port:
resource_port = 22

if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True:
if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and \
bastion['enableTunneling'] is not True:
raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.')

ip_connect = _is_ipconnect_request(cmd, bastion, target_ip_address)
ip_connect = _is_ipconnect_request(bastion, target_ip_address)
if ip_connect:
target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}"

if ip_connect and int(resource_port) not in [22, 3389]:
raise UnrecognizedArgumentError("Custom ports are not allowed. Allowed ports for Tunnel with IP connect is 22, 3389.")
if int(resource_port) not in [22, 3389]:
raise UnrecognizedArgumentError("Custom ports are not allowed. Allowed ports for Tunnel with IP connect is 22, 3389.")
target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}"
f"/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}"

_validate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address)
_validate_resourceid(target_resource_id)
bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id)

tunnel_server = _get_tunnel(cmd, bastion, bastion_endpoint, target_resource_id, resource_port)
Expand Down Expand Up @@ -227,7 +229,7 @@ def _get_rdp_path(rdp_command="mstsc"):


def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_name, bastion_host_name,
resource_port=None, disable_gateway=False, configure=False, enable_mfa=False):
auth_type=None, resource_port=None, disable_gateway=False, configure=False, enable_mfa=False):
import os
from azure.cli.core._profile import Profile
from ._process_helper import launch_and_wait
Expand All @@ -241,17 +243,31 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_
if not resource_port:
resource_port = 3389

if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True:
if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and \
bastion['enableTunneling'] is not True:
raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.')

ip_connect = _is_ipconnect_request(cmd, bastion, target_ip_address)
ip_connect = _is_ipconnect_request(bastion, target_ip_address)

if auth_type is None:
# do nothing
pass
elif auth_type.lower() == "aad":
enable_mfa = True

if disable_gateway or ip_connect:
raise UnrecognizedArgumentError("AAD login is not supported for Disable Gateway & IP Connect scenarios.")
else:
raise UnrecognizedArgumentError("Unknown auth type, support auth-types: aad. For non aad login, you dont need to provide auth-type flag.")

if ip_connect:
target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}"
if int(resource_port) not in [22, 3389]:
raise UnrecognizedArgumentError("Custom ports are not allowed. Allowed ports for Tunnel with IP connect is 22, 3389.")

if ip_connect and int(resource_port) not in [22, 3389]:
raise UnrecognizedArgumentError("Custom ports are not allowed. Allowed ports for Tunnel with IP connect is 22, 3389.")
target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}"
f"/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}"

_validate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address)
_validate_resourceid(target_resource_id)
bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id)

if platform.system() == "Windows":
Expand All @@ -269,7 +285,8 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_
profile = Profile(cli_ctx=cmd.cli_ctx)
access_token = profile.get_raw_token()[0][2].get("accessToken")
logger.debug("Response %s", access_token)
web_address = f"https://{bastion_endpoint}/api/rdpfile?resourceId={target_resource_id}&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}"
web_address = f"https://{bastion_endpoint}/api/rdpfile?resourceId={target_resource_id}&format=rdp"
f"&rdpport={resource_port}&enablerdsaad={enable_mfa}"

headers = {
"Authorization": f"Bearer {access_token}",
Expand All @@ -285,8 +302,10 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_
raise ClientRequestError("Request failed with error: " + errorMessage)
raise ClientRequestError("Request to EncodingReservedUnitTypes v2 API endpoint failed.")

_write_to_file(response)
rdpfilepath = os.getcwd() + "/conn.rdp"
tempdir = os.path.realpath(tempfile.gettempdir())
rdpfilepath = os.path.join(tempdir, 'conn_{}.rdp'.format(uuid.uuid4().hex))
_write_to_file(response, rdpfilepath)

command = [_get_rdp_path()]
if configure:
command.append("/edit")
Expand All @@ -296,14 +315,14 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_
raise UnrecognizedArgumentError("Platform is not supported for this command. Supported platforms: Windows")


def _is_ipconnect_request(cmd, bastion, target_ip_address):
def _is_ipconnect_request(bastion, target_ip_address):
if 'enableIpConnect' in bastion and bastion['enableIpConnect'] is True and target_ip_address:
return True

return False


def _validate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address):
def _validate_resourceid(target_resource_id):
if not is_valid_resource_id(target_resource_id):
err_msg = "Please enter a valid resource ID. If this is not working, " \
"try opening the JSON view of your resource (in the Overview tab), and copying the full resource ID."
Expand All @@ -319,8 +338,8 @@ def _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id):
return bastion['dnsName']


def _write_to_file(response):
with open("conn.rdp", "w", encoding="utf-8") as f:
def _write_to_file(response, file_path):
with open(file_path, "w", encoding="utf-8") as f:
for line in response.text.splitlines():
f.write(line + "\n")

Expand Down Expand Up @@ -358,14 +377,15 @@ def create_bastion_tunnel(cmd, target_resource_id, target_ip_address, resource_g
if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True:
raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.')

ip_connect = _is_ipconnect_request(cmd, bastion, target_ip_address)
ip_connect = _is_ipconnect_request(bastion, target_ip_address)
if ip_connect:
target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}"
target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/"
f"{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}"

if ip_connect and int(resource_port) not in [22, 3389]:
raise UnrecognizedArgumentError("Custom ports are not allowed. Allowed ports for Tunnel with IP connect is 22, 3389.")

_validate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address)
_validate_resourceid(target_resource_id)
bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id)

tunnel_server = _get_tunnel(cmd, bastion, bastion_endpoint, target_resource_id, resource_port, port)
Expand Down
2 changes: 1 addition & 1 deletion src/bastion/azext_bastion/developer_sku_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ def _get_data_pod(cmd, resource_port, target_resource_id, bastion):

web_address = f"https://{bastion['dnsName']}/api/connection"
response = requests.post(web_address, json=content, headers=headers,
verify=(not should_disable_connection_verify()))
verify=not should_disable_connection_verify())

return response.content.decode("utf-8")
17 changes: 9 additions & 8 deletions src/bastion/azext_bastion/tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@
from contextlib import closing
from datetime import datetime
from threading import Thread
import requests
import urllib3

import websocket
from websocket import create_connection, WebSocket

from msrestazure.azure_exceptions import CloudError
from .BastionServiceConstants import BastionSku

from azure.cli.core._profile import Profile
from azure.cli.core.util import should_disable_connection_verify

import requests
import urllib3

from knack.util import CLIError
from knack.log import get_logger

from .BastionServiceConstants import BastionSku

logger = get_logger(__name__)


Expand Down Expand Up @@ -96,7 +96,7 @@ def _get_auth_token(self):
logger.debug("Content: %s", str(content))
web_address = f"https://{self.bastion_endpoint}/api/tokens"
response = requests.post(web_address, data=content, headers=custom_header,
verify=(not should_disable_connection_verify()))
verify=not should_disable_connection_verify())
response_json = None

if response.content is not None:
Expand All @@ -121,7 +121,8 @@ def _listen(self):
self.client, _address = self.sock.accept()

auth_token = self._get_auth_token()
if self.bastion['sku']['name'] == BastionSku.QuickConnect.name or self.bastion['sku']['name'] == BastionSku.Developer.name:
if self.bastion['sku']['name'] == BastionSku.QuickConnect.name or \
self.bastion['sku']['name'] == BastionSku.Developer.name:
host = f"wss://{self.bastion_endpoint}/omni/webtunnel/{auth_token}"
else:
host = f"wss://{self.bastion_endpoint}/webtunnelv2/{auth_token}?X-Node-Id={self.node_id}"
Expand Down Expand Up @@ -204,7 +205,7 @@ def cleanup(self):

web_address = f"https://{self.bastion_endpoint}/api/tokens/{self.last_token}"
response = requests.delete(web_address, headers=custom_header,
verify=(not should_disable_connection_verify()))
verify=not should_disable_connection_verify())
if response.status_code == 404:
logger.info('Session already deleted')
elif response.status_code not in [200, 204]:
Expand Down
2 changes: 1 addition & 1 deletion src/bastion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


# HISTORY.rst entry.
VERSION = '0.2.5'
VERSION = '0.2.6'

# The full list of classifiers is available at
# https://pypi.python.org/pypi?%3Aaction=list_classifiers
Expand Down