Skip to content

Commit

Permalink
Merge pull request ansible#752 from ryanpetrello/multivault
Browse files Browse the repository at this point in the history
support specifying multiple vault IDs for a playbook run
  • Loading branch information
ryanpetrello authored Dec 1, 2017
2 parents fde5a88 + a1f8f65 commit e7918ad
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 24 deletions.
7 changes: 7 additions & 0 deletions awx/main/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,13 @@ def from_db_value(self, value, expression, connection, context):
return value


@JSONSchemaField.format_checker.checks('vault_id')
def format_vault_id(value):
if '@' in value:
raise jsonschema.exceptions.FormatError('@ is not an allowed character')
return True


@JSONSchemaField.format_checker.checks('ssh_private_key')
def format_ssh_private_key(value):
# Sanity check: GCE, in particular, provides JSON-encoded private
Expand Down
2 changes: 2 additions & 0 deletions awx/main/migrations/0009_v330_multi_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from django.db import migrations, models

from awx.main.migrations import _migration_utils as migration_utils
from awx.main.migrations import _credentialtypes as credentialtypes
from awx.main.migrations._multi_cred import migrate_to_multi_cred


Expand Down Expand Up @@ -50,4 +51,5 @@ class Migration(migrations.Migration):
model_name='jobtemplate',
name='vault_credential',
),
migrations.RunPython(credentialtypes.add_vault_id_field)
]
5 changes: 5 additions & 0 deletions awx/main/migrations/_credentialtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,8 @@ def migrate_job_credentials(apps, schema_editor):
finally:
utils.get_current_apps = orig_current_apps


def add_vault_id_field(apps, schema_editor):
vault_credtype = CredentialType.objects.get(kind='vault')
vault_credtype.inputs = CredentialType.defaults.get('vault')().inputs
vault_credtype.save()
10 changes: 10 additions & 0 deletions awx/main/models/credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,16 @@ def vault(cls):
'type': 'string',
'secret': True,
'ask_at_runtime': True
}, {
'id': 'vault_id',
'label': 'Vault Identifier',
'type': 'string',
'format': 'vault_id',
'help_text': ('Specify an (optional) Vault ID. This is '
'equivalent to specifying the --vault-id '
'Ansible parameter for providing multiple Vault '
'passwords. Note: this feature only works in '
'Ansible 2.4+.')
}],
'required': ['vault_password'],
}
Expand Down
4 changes: 4 additions & 0 deletions awx/main/models/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ def network_credentials(self):
def cloud_credentials(self):
return list(self.credentials.filter(credential_type__kind='cloud'))

@property
def vault_credentials(self):
return list(self.credentials.filter(credential_type__kind='vault'))

@property
def credential(self):
cred = self.get_deprecated_credential('ssh')
Expand Down
70 changes: 46 additions & 24 deletions awx/main/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def get_instance_timeout(self, instance):
job_timeout = 0
return job_timeout

def get_password_prompts(self):
def get_password_prompts(self, **kwargs):
'''
Return a dictionary where keys are strings or regular expressions for
prompts, and values are password lookup keys (keys that are returned
Expand Down Expand Up @@ -833,7 +833,7 @@ def run(self, pk, isolated_host=None, **kwargs):
job_cwd=cwd, job_env=safe_env, result_stdout_file=stdout_handle.name)

expect_passwords = {}
for k, v in self.get_password_prompts().items():
for k, v in self.get_password_prompts(**kwargs).items():
expect_passwords[k] = kwargs['passwords'].get(v, '') or ''
_kw = dict(
expect_passwords=expect_passwords,
Expand Down Expand Up @@ -961,19 +961,30 @@ def build_passwords(self, job, **kwargs):
and ansible-vault.
'''
passwords = super(RunJob, self).build_passwords(job, **kwargs)
for kind, fields in {
'ssh': ('ssh_key_unlock', 'ssh_password', 'become_password'),
'vault': ('vault_password',)
}.items():
cred = job.get_deprecated_credential(kind)
if cred:
for field in fields:
if field == 'ssh_password':
value = kwargs.get(field, decrypt_field(cred, 'password'))
else:
value = kwargs.get(field, decrypt_field(cred, field))
if value not in ('', 'ASK'):
passwords[field] = value
cred = job.get_deprecated_credential('ssh')
if cred:
for field in ('ssh_key_unlock', 'ssh_password', 'become_password'):
value = kwargs.get(
field,
decrypt_field(cred, 'password' if field == 'ssh_password' else field)
)
if value not in ('', 'ASK'):
passwords[field] = value

for cred in job.vault_credentials:
field = 'vault_password'
if cred.inputs.get('vault_id'):
field = 'vault_password.{}'.format(cred.inputs['vault_id'])
if field in passwords:
raise RuntimeError(
'multiple vault credentials were specified with --vault-id {}@prompt'.format(
cred.inputs['vault_id']
)
)
value = kwargs.get(field, decrypt_field(cred, 'vault_password'))
if value not in ('', 'ASK'):
passwords[field] = value

return passwords

def build_env(self, job, **kwargs):
Expand Down Expand Up @@ -1107,9 +1118,16 @@ def build_args(self, job, **kwargs):
args.extend(['--become-user', become_username])
if 'become_password' in kwargs.get('passwords', {}):
args.append('--ask-become-pass')
# Support prompting for a vault password.
if 'vault_password' in kwargs.get('passwords', {}):
args.append('--ask-vault-pass')

# Support prompting for multiple vault passwords
for k, v in kwargs.get('passwords', {}).items():
if k.startswith('vault_password'):
if k == 'vault_password':
args.append('--ask-vault-pass')
else:
vault_id = k.split('.')[1]
args.append('--vault-id')
args.append('{}@prompt'.format(vault_id))

if job.forks: # FIXME: Max limit?
args.append('--forks=%d' % job.forks)
Expand Down Expand Up @@ -1177,8 +1195,8 @@ def build_cwd(self, job, **kwargs):
def get_idle_timeout(self):
return getattr(settings, 'JOB_RUN_IDLE_TIMEOUT', None)

def get_password_prompts(self):
d = super(RunJob, self).get_password_prompts()
def get_password_prompts(self, **kwargs):
d = super(RunJob, self).get_password_prompts(**kwargs)
d[re.compile(r'Enter passphrase for .*:\s*?$', re.M)] = 'ssh_key_unlock'
d[re.compile(r'Bad passphrase, try again for .*:\s*?$', re.M)] = ''
for method in PRIVILEGE_ESCALATION_METHODS:
Expand All @@ -1187,6 +1205,10 @@ def get_password_prompts(self):
d[re.compile(r'SSH password:\s*?$', re.M)] = 'ssh_password'
d[re.compile(r'Password:\s*?$', re.M)] = 'ssh_password'
d[re.compile(r'Vault password:\s*?$', re.M)] = 'vault_password'
for k, v in kwargs.get('passwords', {}).items():
if k.startswith('vault_password.'):
vault_id = k.split('.')[1]
d[re.compile(r'Vault password \({}\):\s*?$'.format(vault_id), re.M)] = k
return d

def get_stdout_handle(self, instance):
Expand Down Expand Up @@ -1442,8 +1464,8 @@ def build_output_replacements(self, project_update, **kwargs):
output_replacements.append((pattern2 % d_before, pattern2 % d_after))
return output_replacements

def get_password_prompts(self):
d = super(RunProjectUpdate, self).get_password_prompts()
def get_password_prompts(self, **kwargs):
d = super(RunProjectUpdate, self).get_password_prompts(**kwargs)
d[re.compile(r'Username for.*:\s*?$', re.M)] = 'scm_username'
d[re.compile(r'Password for.*:\s*?$', re.M)] = 'scm_password'
d[re.compile(r'Password:\s*?$', re.M)] = 'scm_password'
Expand Down Expand Up @@ -2142,8 +2164,8 @@ def build_cwd(self, ad_hoc_command, **kwargs):
def get_idle_timeout(self):
return getattr(settings, 'JOB_RUN_IDLE_TIMEOUT', None)

def get_password_prompts(self):
d = super(RunAdHocCommand, self).get_password_prompts()
def get_password_prompts(self, **kwargs):
d = super(RunAdHocCommand, self).get_password_prompts(**kwargs)
d[re.compile(r'Enter passphrase for .*:\s*?$', re.M)] = 'ssh_key_unlock'
d[re.compile(r'Bad passphrase, try again for .*:\s*?$', re.M)] = ''
for method in PRIVILEGE_ESCALATION_METHODS:
Expand Down
25 changes: 25 additions & 0 deletions awx/main/tests/functional/test_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,31 @@ def test_ssh_key_data_validation(organization, kind, ssh_key_data, ssh_key_unloc
assert e.type in (ValidationError, serializers.ValidationError)


@pytest.mark.django_db
@pytest.mark.parametrize('inputs, valid', [
({'vault_password': 'some-pass'}, True),
({}, False),
({'vault_password': 'dev-pass', 'vault_id': 'dev'}, True),
({'vault_password': 'dev-pass', 'vault_id': 'dev@prompt'}, False), # @ not allowed
])
def test_vault_validation(organization, inputs, valid):
cred_type = CredentialType.defaults['vault']()
cred_type.save()
cred = Credential(
credential_type=cred_type,
name="Best credential ever",
inputs=inputs,
organization=organization
)
cred.save()
if valid:
cred.full_clean()
else:
with pytest.raises(Exception) as e:
cred.full_clean()
assert e.type in (ValidationError, serializers.ValidationError)


@pytest.mark.django_db
@pytest.mark.parametrize('become_method, valid', zip(
dict(V1Credential.FIELDS['become_method'].choices).keys(),
Expand Down
91 changes: 91 additions & 0 deletions awx/main/tests/unit/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,97 @@ def test_vault_password(self):
] == 'vault-me'
assert '--ask-vault-pass' in ' '.join(args)

def test_vault_password_ask(self):
vault = CredentialType.defaults['vault']()
credential = Credential(
pk=1,
credential_type=vault,
inputs={'vault_password': 'ASK'}
)
credential.inputs['vault_password'] = encrypt_field(credential, 'vault_password')
self.instance.credentials.add(credential)
self.task.run(self.pk, vault_password='provided-at-launch')

assert self.run_pexpect.call_count == 1
call_args, call_kwargs = self.run_pexpect.call_args_list[0]
args, cwd, env, stdout = call_args

assert call_kwargs.get('expect_passwords')[
re.compile(r'Vault password:\s*?$', re.M)
] == 'provided-at-launch'
assert '--ask-vault-pass' in ' '.join(args)

def test_multi_vault_password(self):
vault = CredentialType.defaults['vault']()
for i, label in enumerate(['dev', 'prod']):
credential = Credential(
pk=i,
credential_type=vault,
inputs={'vault_password': 'pass@{}'.format(label), 'vault_id': label}
)
credential.inputs['vault_password'] = encrypt_field(credential, 'vault_password')
self.instance.credentials.add(credential)
self.task.run(self.pk)

assert self.run_pexpect.call_count == 1
call_args, call_kwargs = self.run_pexpect.call_args_list[0]
args, cwd, env, stdout = call_args

vault_passwords = dict(
(k.pattern, v) for k, v in call_kwargs['expect_passwords'].items()
if 'Vault' in k.pattern
)
assert vault_passwords['Vault password \(prod\):\\s*?$'] == 'pass@prod'
assert vault_passwords['Vault password \(dev\):\\s*?$'] == 'pass@dev'
assert vault_passwords['Vault password:\\s*?$'] == ''
assert '--ask-vault-pass' not in ' '.join(args)
assert '--vault-id dev@prompt' in ' '.join(args)
assert '--vault-id prod@prompt' in ' '.join(args)

def test_multi_vault_id_conflict(self):
vault = CredentialType.defaults['vault']()
for i in range(2):
credential = Credential(
pk=i,
credential_type=vault,
inputs={'vault_password': 'some-pass', 'vault_id': 'conflict'}
)
credential.inputs['vault_password'] = encrypt_field(credential, 'vault_password')
self.instance.credentials.add(credential)

with pytest.raises(Exception):
self.task.run(self.pk)

def test_multi_vault_password_ask(self):
vault = CredentialType.defaults['vault']()
for i, label in enumerate(['dev', 'prod']):
credential = Credential(
pk=i,
credential_type=vault,
inputs={'vault_password': 'ASK', 'vault_id': label}
)
credential.inputs['vault_password'] = encrypt_field(credential, 'vault_password')
self.instance.credentials.add(credential)
self.task.run(self.pk, **{
'vault_password.dev': 'provided-at-launch@dev',
'vault_password.prod': 'provided-at-launch@prod'
})

assert self.run_pexpect.call_count == 1
call_args, call_kwargs = self.run_pexpect.call_args_list[0]
args, cwd, env, stdout = call_args

vault_passwords = dict(
(k.pattern, v) for k, v in call_kwargs['expect_passwords'].items()
if 'Vault' in k.pattern
)
assert vault_passwords['Vault password \(prod\):\\s*?$'] == 'provided-at-launch@prod'
assert vault_passwords['Vault password \(dev\):\\s*?$'] == 'provided-at-launch@dev'
assert vault_passwords['Vault password:\\s*?$'] == ''
assert '--ask-vault-pass' not in ' '.join(args)
assert '--vault-id dev@prompt' in ' '.join(args)
assert '--vault-id prod@prompt' in ' '.join(args)

def test_ssh_key_with_agent(self):
ssh = CredentialType.defaults['ssh']()
credential = Credential(
Expand Down

0 comments on commit e7918ad

Please sign in to comment.