Skip to content

Commit

Permalink
{ServiceConnector} add param --new (#7889)
Browse files Browse the repository at this point in the history
* add param --new

* deprecate Postgres

* update version
  • Loading branch information
xfz11 authored Aug 20, 2024
1 parent 074a466 commit 72c981a
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 50 deletions.
4 changes: 4 additions & 0 deletions src/serviceconnector-passwordless/HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
Release History
===============
3.0.0
++++++
* Add new param --new to override the existing database user and deprecate Postgres single server

2.0.7
++++++
* Fix argument missing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
# pylint: disable=line-too-long, consider-using-f-string, too-many-statements
# For db(mysqlFlex/psql/psqlFlex/sql) linker with auth type=systemAssignedIdentity, enable Microsoft Entra auth and create db user on data plane
# For other linker, ignore the steps
def get_enable_mi_for_db_linker_func(yes=False):
def get_enable_mi_for_db_linker_func(yes=False, new=False):
def enable_mi_for_db_linker(cmd, source_id, target_id, auth_info, client_type, connection_name):
# return if connection is not for db mi
if auth_info['auth_type'] not in [AUTHTYPES[AUTH_TYPE.SystemIdentity],
Expand All @@ -61,7 +61,7 @@ def enable_mi_for_db_linker(cmd, source_id, target_id, auth_info, client_type, c
if source_handler is None:
return None
target_handler = getTargetHandler(
cmd, target_id, target_type, auth_info, client_type, connection_name, skip_prompt=yes)
cmd, target_id, target_type, auth_info, client_type, connection_name, skip_prompt=yes, new_user=new)
if target_handler is None:
return None
target_handler.check_db_existence()
Expand Down Expand Up @@ -149,21 +149,21 @@ def enable_mi_for_db_linker(cmd, source_id, target_id, auth_info, client_type, c


# pylint: disable=unused-argument, too-many-instance-attributes
def getTargetHandler(cmd, target_id, target_type, auth_info, client_type, connection_name, skip_prompt):
def getTargetHandler(cmd, target_id, target_type, auth_info, client_type, connection_name, skip_prompt, new_user):
if target_type in {RESOURCE.Sql}:
return SqlHandler(cmd, target_id, target_type, auth_info, connection_name, skip_prompt)
return SqlHandler(cmd, target_id, target_type, auth_info, connection_name, skip_prompt, new_user)
if target_type in {RESOURCE.Postgres}:
return PostgresSingleHandler(cmd, target_id, target_type, auth_info, connection_name, skip_prompt)
return PostgresSingleHandler(cmd, target_id, target_type, auth_info, connection_name, skip_prompt, new_user)
if target_type in {RESOURCE.PostgresFlexible}:
return PostgresFlexHandler(cmd, target_id, target_type, auth_info, connection_name, skip_prompt)
return PostgresFlexHandler(cmd, target_id, target_type, auth_info, connection_name, skip_prompt, new_user)
if target_type in {RESOURCE.MysqlFlexible}:
return MysqlFlexibleHandler(cmd, target_id, target_type, auth_info, connection_name, skip_prompt)
return MysqlFlexibleHandler(cmd, target_id, target_type, auth_info, connection_name, skip_prompt, new_user)
return None


class TargetHandler:

def __init__(self, cmd, target_id, target_type, auth_info, connection_name, skip_prompt):
def __init__(self, cmd, target_id, target_type, auth_info, connection_name, skip_prompt, new_user):
self.cmd = cmd
self.target_id = target_id
self.target_type = target_type
Expand All @@ -186,6 +186,7 @@ def __init__(self, cmd, target_id, target_type, auth_info, connection_name, skip
self.aad_username = "aad_" + connection_name
self.connection_name = connection_name
self.skip_prompt = skip_prompt
self.new_user = new_user
self.endpoint = ""
self.user_object_id = ""
self.identity_name = ""
Expand Down Expand Up @@ -250,9 +251,9 @@ def get_auth_config(self, user_object_id):

class MysqlFlexibleHandler(TargetHandler):

def __init__(self, cmd, target_id, target_type, auth_info, connection_name, skip_prompt):
def __init__(self, cmd, target_id, target_type, auth_info, connection_name, skip_prompt, new_user):
super().__init__(cmd, target_id, target_type,
auth_info, connection_name, skip_prompt)
auth_info, connection_name, skip_prompt, new_user)
self.endpoint = cmd.cli_ctx.cloud.suffixes.mysql_server_endpoint
target_segments = parse_resource_id(target_id)
self.server = target_segments.get('name')
Expand Down Expand Up @@ -446,9 +447,9 @@ def get_create_query(self):

class SqlHandler(TargetHandler):

def __init__(self, cmd, target_id, target_type, auth_info, connection_name, skip_prompt):
def __init__(self, cmd, target_id, target_type, auth_info, connection_name, skip_prompt, new_user):
super().__init__(cmd, target_id, target_type,
auth_info, connection_name, skip_prompt)
auth_info, connection_name, skip_prompt, new_user)
self.endpoint = cmd.cli_ctx.cloud.suffixes.sql_server_hostname
target_segments = parse_resource_id(target_id)
self.server = target_segments.get('name')
Expand Down Expand Up @@ -570,7 +571,8 @@ def set_target_firewall(self, is_add, ip_name, start_ip=None, end_ip=None):
"Can't remove firewall rule %s. Please manually delete it to avoid security issue. %s", ip_name, str(e))

def create_aad_user_in_sql(self, connection_args, query_list):

if not self.new_user:
query_list = query_list[1:]
if not is_packaged_installed('pyodbc'):
_run_pip(["install", "pyodbc"])

Expand Down Expand Up @@ -624,19 +626,21 @@ def get_create_query(self):
self.aad_username = self.identity_name
if self.auth_type == AUTHTYPES[AUTH_TYPE.UserAccount]:
self.aad_username = self.login_username
delete_q = "DROP USER IF EXISTS \"{}\";".format(
self.aad_username)
role_q = "CREATE USER \"{}\" FROM EXTERNAL PROVIDER;".format(
self.aad_username)
grant_q = "GRANT CONTROL ON DATABASE::\"{}\" TO \"{}\";".format(
self.dbname, self.aad_username)

return [role_q, grant_q]
return [delete_q, role_q, grant_q]


class PostgresFlexHandler(TargetHandler):

def __init__(self, cmd, target_id, target_type, auth_info, connection_name, skip_prompt):
def __init__(self, cmd, target_id, target_type, auth_info, connection_name, skip_prompt, new_user):
super().__init__(cmd, target_id, target_type,
auth_info, connection_name, skip_prompt)
auth_info, connection_name, skip_prompt, new_user)
self.endpoint = cmd.cli_ctx.cloud.suffixes.postgresql_server_endpoint
target_segments = parse_resource_id(target_id)
self.db_server = target_segments.get('name')
Expand Down Expand Up @@ -707,11 +711,15 @@ def create_aad_user(self):
query_list = self.get_create_query()
connection_string = self.get_connection_string()
ip_name = generate_random_string(prefix='svc_').lower()

if self.new_user:
user_query = query_list[0:2]
else:
user_query = query_list[1:2]
permission_query = query_list[2:]
try:
logger.warning("Connecting to database...")
self.create_aad_user_in_pg(connection_string, query_list[0:1])
self.create_aad_user_in_pg(self.get_connection_string(self.dbname), query_list[1:])
self.create_aad_user_in_pg(connection_string, user_query)
self.create_aad_user_in_pg(self.get_connection_string(self.dbname), permission_query)
except AzureConnectionError as e:
logger.warning(e)
if 'password authentication failed' in str(e):
Expand All @@ -726,8 +734,8 @@ def create_aad_user(self):
True, ip_name, ip_address, ip_address)
try:
# create again
self.create_aad_user_in_pg(connection_string, query_list[0:1])
self.create_aad_user_in_pg(self.get_connection_string(self.dbname), query_list[1:])
self.create_aad_user_in_pg(connection_string, user_query)
self.create_aad_user_in_pg(self.get_connection_string(self.dbname), permission_query)
except AzureConnectionError as e:
logger.warning(e)
if not ip_address:
Expand All @@ -739,8 +747,8 @@ def create_aad_user(self):
True, ip_name, '0.0.0.0', '255.255.255.255')
# create again
try:
self.create_aad_user_in_pg(connection_string, query_list[0:1])
self.create_aad_user_in_pg(self.get_connection_string(self.dbname), query_list[1:])
self.create_aad_user_in_pg(connection_string, user_query)
self.create_aad_user_in_pg(self.get_connection_string(self.dbname), permission_query)
except AzureConnectionError as e:
telemetry.set_exception(e, "Connect-Db-Fail")
raise e
Expand Down Expand Up @@ -833,7 +841,7 @@ def get_create_query(self):
object_id = self.user_object_id
object_type = 'user'
return [
# 'drop role IF EXISTS "{0}";'.format(self.aad_username),
'drop role IF EXISTS "{0}";'.format(self.aad_username),
"select * from pgaadauth_create_principal_with_oid('{0}', '{1}', '{2}', false, false);".format(
self.aad_username, object_id, object_type),
'GRANT ALL PRIVILEGES ON DATABASE "{0}" TO "{1}";'.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@
help='Do not prompt for confirmation.'
)

new_arg_type = CLIArgumentType(
options_list=['--new'],
help='Deleting existing users with the same name before creating a new user in database.'
)


def add_auth_block(context, source, target):
support_auth_types = EX_SUPPORTED_AUTH_TYPE.get(
Expand Down Expand Up @@ -68,6 +73,7 @@ def load_arguments(self, _):
add_local_connection_block(c)
add_customized_keys_argument(c)
c.argument('yes', arg_type=yes_arg_type)
c.argument('new', arg_type=new_arg_type)

for source in SOURCE_RESOURCES_PARAMS:
for target in TARGET_RESOURCES_PARAMS:
Expand All @@ -85,3 +91,4 @@ def load_arguments(self, _):
add_customized_keys_argument(c)
add_opt_out_argument(c)
c.argument('yes', arg_type=yes_arg_type)
c.argument('new', arg_type=new_arg_type)
Original file line number Diff line number Diff line change
Expand Up @@ -29,46 +29,46 @@
]

PASSWORDLESS_TARGET_RESOURCES = [
RESOURCE.Postgres,
# RESOURCE.Postgres,
RESOURCE.PostgresFlexible,
RESOURCE.MysqlFlexible,
RESOURCE.Sql
]

# pylint: disable=line-too-long
EX_SUPPORTED_AUTH_TYPE[RESOURCE.Local] = {
RESOURCE.Postgres: [AUTH_TYPE.Secret, AUTH_TYPE.UserAccount, AUTH_TYPE.ServicePrincipalSecret],
# RESOURCE.Postgres: [AUTH_TYPE.Secret, AUTH_TYPE.UserAccount, AUTH_TYPE.ServicePrincipalSecret],
RESOURCE.PostgresFlexible: [AUTH_TYPE.Secret, AUTH_TYPE.UserAccount, AUTH_TYPE.ServicePrincipalSecret],
RESOURCE.MysqlFlexible: [AUTH_TYPE.Secret, AUTH_TYPE.UserAccount, AUTH_TYPE.ServicePrincipalSecret],
RESOURCE.Sql: [AUTH_TYPE.Secret, AUTH_TYPE.UserAccount, AUTH_TYPE.ServicePrincipalSecret],
}

for resourceType in PASSWORDLESS_SOURCE_RESOURCES:
EX_SUPPORTED_AUTH_TYPE[resourceType] = {
RESOURCE.Postgres: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity, AUTH_TYPE.UserIdentity, AUTH_TYPE.ServicePrincipalSecret],
# RESOURCE.Postgres: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity, AUTH_TYPE.UserIdentity, AUTH_TYPE.ServicePrincipalSecret],
RESOURCE.PostgresFlexible: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity, AUTH_TYPE.UserIdentity, AUTH_TYPE.ServicePrincipalSecret],
RESOURCE.MysqlFlexible: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity, AUTH_TYPE.UserIdentity, AUTH_TYPE.ServicePrincipalSecret],
RESOURCE.Sql: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity, AUTH_TYPE.UserIdentity, AUTH_TYPE.ServicePrincipalSecret],
}

TARGET_RESOURCES_PARAMS = {
RESOURCE.Postgres: {
'target_resource_group': {
'options': ['--target-resource-group', '--tg'],
'help': 'The resource group which contains the postgres service',
'placeholder': 'PostgresRG'
},
'server': {
'options': ['--server'],
'help': 'Name of postgres server',
'placeholder': 'MyServer'
},
'database': {
'options': ['--database'],
'help': 'Name of postgres database',
'placeholder': 'MyDB'
}
},
# RESOURCE.Postgres: {
# 'target_resource_group': {
# 'options': ['--target-resource-group', '--tg'],
# 'help': 'The resource group which contains the postgres service',
# 'placeholder': 'PostgresRG'
# },
# 'server': {
# 'options': ['--server'],
# 'help': 'Name of postgres server',
# 'placeholder': 'MyServer'
# },
# 'database': {
# 'options': ['--database'],
# 'help': 'Name of postgres database',
# 'placeholder': 'MyDB'
# }
# },
RESOURCE.PostgresFlexible: {
'target_resource_group': {
'options': ['--target-resource-group', '--tg'],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
# --------------------------------------------------------------------------------------------


VERSION = '2.0.7'
VERSION = '3.0.0'
NAME = 'serviceconnector-passwordless'
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def connection_create_ext(cmd, client, # pylint: disable=too-many-locals,too-ma
private_endpoint=None,
store_in_connection_string=False,
new_addon=False, no_wait=False,
yes=False,
yes=False, new=False,
# Resource.KubernetesCluster
cluster=None, scope=None, enable_csi=False,
customized_keys=None,
Expand Down Expand Up @@ -46,7 +46,7 @@ def connection_create_ext(cmd, client, # pylint: disable=too-many-locals,too-ma
site, slot,
spring, app, deployment,
server, database,
enable_mi_for_db_linker=get_enable_mi_for_db_linker_func(yes),
enable_mi_for_db_linker=get_enable_mi_for_db_linker_func(yes, new),
customized_keys=customized_keys,
opt_out_list=opt_out_list,
app_config_id=app_config_id,
Expand All @@ -64,7 +64,7 @@ def local_connection_create_ext(cmd, client, # pylint: disable=too-many-locals,
service_principal_auth_info_secret=None,
no_wait=False,
customized_keys=None,
yes=False,
yes=False, new=False,
# Resource.*Postgres, Resource.*Sql*
server=None, database=None,
**kwargs
Expand All @@ -83,6 +83,6 @@ def local_connection_create_ext(cmd, client, # pylint: disable=too-many-locals,
no_wait,
# Resource.*Postgres, Resource.*Sql*
server, database,
enable_mi_for_db_linker=get_enable_mi_for_db_linker_func(yes),
enable_mi_for_db_linker=get_enable_mi_for_db_linker_func(yes, new),
customized_keys=customized_keys,
**kwargs)
2 changes: 1 addition & 1 deletion src/serviceconnector-passwordless/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
logger.warn("Wheel is not available, disabling bdist_wheel hook")


VERSION = '2.0.7'
VERSION = '3.0.0'
try:
from azext_serviceconnector_passwordless.config import VERSION
except ImportError:
Expand Down

0 comments on commit 72c981a

Please sign in to comment.