Skip to content

Commit

Permalink
Improves error messages on RunnerResults for various scenarios.
Browse files Browse the repository at this point in the history
Fixes handling of `async_` keyword parameter.
  • Loading branch information
Daverball committed Sep 3, 2024
1 parent 7da0e90 commit 17e154f
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 29 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
Changelog
---------

- Improves error messages for various scenarios on `RunnerResults`
[Daverball]

- Only sets `ansible_connection` to `local` when `ansible_port`
is `22`, since anything else is likely a SSH tunnel
[Daverball]
Expand Down
17 changes: 15 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Connect to a server using a username and a password::
remote_pass=password
)

print api.command('whoami').stdout() # prints 'admin'
print(api.command('whoami').stdout()) # prints 'admin'

Run a command on multiple servers and get the output for each::

Expand All @@ -81,7 +81,20 @@ Run a command on multiple servers and get the output for each::
result = api.command('whoami')

for server in servers:
print result.stdout(server)
print(result.stdout(server))

Or alternatively::

api = Api(['a.example.org', 'b.example.org'])
results = api.command('whoami')

for server, result in results['contacted'].items():
if 'stdout' in result:
print(server, result['stdout'])

The latter is more robust for optional result components, since not
every server's result may contain it.


Which Modules are Available?
----------------------------
Expand Down
6 changes: 6 additions & 0 deletions scripts/generate_module_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,12 @@ def write_return_type(returns: dict[str, Any] | None) -> None:
return_type += '[Incomplete]'
elif return_type == 'dict':
return_type += '[str, Incomplete]'
elif return_type == 'complex':
# TODO: This seems to be more or less an alias to dict
# but it contains a schema for the contents. If it
# is always dict, then try to merge this with dict
# and generate a TypedDict using `contains`.
return_type = 'Incomplete'
suffix = ' # type:ignore[override]' if name == 'values' else ''
if len(name) + len(return_type) + len(suffix) > 33:
# signature doesn't fit on one line
Expand Down
26 changes: 13 additions & 13 deletions src/suitable/_module_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,7 +1467,7 @@ class PackageFactsResults(RunnerResults):
"""

def ansible_facts(self, server: str | None = None) -> complex:
def ansible_facts(self, server: str | None = None) -> Incomplete:
"""
Facts to add to ansible_facts.
Expand Down Expand Up @@ -1763,7 +1763,7 @@ class ServiceFactsResults(RunnerResults):
"""

def ansible_facts(self, server: str | None = None) -> complex:
def ansible_facts(self, server: str | None = None) -> Incomplete:
"""
Facts to add to ansible_facts about the services on the system.
Expand Down Expand Up @@ -2072,7 +2072,7 @@ class SysvinitResults(RunnerResults):
"""

def results(self, server: str | None = None) -> complex:
def results(self, server: str | None = None) -> Incomplete:
"""
results from actions taken.
Expand Down Expand Up @@ -3126,7 +3126,7 @@ def stdout_lines(self, server: str | None = None) -> list[Incomplete]:
"""
return self.acquire(server, 'stdout_lines')

def output(self, server: str | None = None) -> complex:
def output(self, server: str | None = None) -> Incomplete:
"""
Based on the value of display option will return either the set of
transformed XML to JSON format from the RPC response with type dict or
Expand Down Expand Up @@ -3177,7 +3177,7 @@ def stdout_lines(self, server: str | None = None) -> list[Incomplete]:
"""
return self.acquire(server, 'stdout_lines')

def output(self, server: str | None = None) -> complex:
def output(self, server: str | None = None) -> Incomplete:
"""
Based on the value of display option will return either the set of
transformed XML to JSON format from the RPC response with type dict or
Expand Down Expand Up @@ -3528,7 +3528,7 @@ def undefined_zones(self, server: str | None = None) -> list[Incomplete]:
"""
return self.acquire(server, 'undefined_zones')

def firewalld_info(self, server: str | None = None) -> complex:
def firewalld_info(self, server: str | None = None) -> Incomplete:
"""
Returns various information about firewalld configuration.
Expand Down Expand Up @@ -3590,7 +3590,7 @@ class RhelFactsResults(RunnerResults):
"""

def ansible_facts(self, server: str | None = None) -> complex:
def ansible_facts(self, server: str | None = None) -> Incomplete:
"""
Relevant Ansible Facts.
Expand Down Expand Up @@ -4340,7 +4340,7 @@ def exitcode(self, server: str | None = None) -> str:
"""
return self.acquire(server, 'exitcode')

def feature_result(self, server: str | None = None) -> complex:
def feature_result(self, server: str | None = None) -> Incomplete:
"""
List of features that were installed or removed.
Expand Down Expand Up @@ -4411,7 +4411,7 @@ def matched(self, server: str | None = None) -> int:
"""
return self.acquire(server, 'matched')

def files(self, server: str | None = None) -> complex:
def files(self, server: str | None = None) -> Incomplete:
"""
Information on the files/folders that match the criteria returned as a
list of dictionary elements for each file matched. The entries are
Expand Down Expand Up @@ -4764,7 +4764,7 @@ class WinPowershellResults(RunnerResults):
"""

def result(self, server: str | None = None) -> complex:
def result(self, server: str | None = None) -> Incomplete:
"""
The values that were set by `$Ansible.Result` in the script.
Expand Down Expand Up @@ -5276,7 +5276,7 @@ def changed(self, server: str | None = None) -> bool:
"""
return self.acquire(server, 'changed')

def stat(self, server: str | None = None) -> complex:
def stat(self, server: str | None = None) -> Incomplete:
"""
dictionary containing all the stat data.
Expand Down Expand Up @@ -5716,7 +5716,7 @@ def privileges(self, server: str | None = None) -> dict[str, Incomplete]:
"""
return self.acquire(server, 'privileges')

def label(self, server: str | None = None) -> complex:
def label(self, server: str | None = None) -> Incomplete:
"""
The mandatory label set to the logon session.
Expand Down Expand Up @@ -5750,7 +5750,7 @@ def groups(self, server: str | None = None) -> list[Incomplete]:
"""
return self.acquire(server, 'groups')

def account(self, server: str | None = None) -> complex:
def account(self, server: str | None = None) -> Incomplete:
"""
The running account SID details.
Expand Down
11 changes: 9 additions & 2 deletions src/suitable/module_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def get_module_args(
args_str = ' '.join(args).replace('=', '\\=')

kwargs_str = ' '.join(
'{}="{}"'.format(k.rstrip('_'), v.replace('"', '\\"'))
'{}="{}"'.format(k, v.replace('"', '\\"'))
for k, v in kwargs.items()
)

Expand All @@ -189,6 +189,13 @@ def execute(self, *args: Any, **kwargs: Any) -> RunnerResults:
if set_global_context:
set_global_context(self.api.options)

# translate parameters that use a reserved keyword
# TODO: For now async is the only one we know about
# but there may be other ones
if 'async_' in kwargs:
# with conflicts prefer the real name
kwargs.setdefault('async', kwargs.pop('async_'))

# legacy key=value pairs shorthand approach
module_args: dict[str, Any] | str
if args:
Expand Down Expand Up @@ -392,4 +399,4 @@ def evaluate_results(
server: result
for server, result in callback.unreachable.items()
}
})
}, dry_run=self.api.options.check)
32 changes: 24 additions & 8 deletions src/suitable/runner_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,39 @@ class RunnerResults(_Base):
"""

def __init__(self, results: _RunnerResults) -> None:
def __init__(self, results: _RunnerResults, dry_run: bool = False) -> None:
self.dry_run = dry_run
self.update(results) # type:ignore[arg-type]

def __getattr__(self, key: str) -> ResultsCallback:
return lambda server=None: self.acquire(server, key)

def acquire(self, server: str | None, key: str) -> Any:
contacted = self['contacted']

# if no server is given and exactly one contacted server exists
# return the value of said server directly
if server is None and len(self['contacted']) == 1:
server = next((k for k in self['contacted'].keys()), None)

if server not in self['contacted']:
if server is None:
if len(contacted) == 1:
server = next((k for k in contacted.keys()), None)
elif contacted:
raise ValueError(
"When contacting multiple servers you need to "
"specify which server's result you want"
)
elif self.dry_run:
raise ValueError('Results are not available in dry run')
elif (unreachable := self['unreachable']):
raise ValueError(
f"{', '.join(unreachable)} could not be contacted"
)

if server not in contacted:
if self.dry_run:
raise ValueError('Results are not available in dry run')
raise KeyError(f"{server} could not be contacted")

if key not in self['contacted'][server]:
raise AttributeError
if key not in (result := contacted[server]):
raise AttributeError(key)

return self['contacted'][server][key]
return result[key]
22 changes: 18 additions & 4 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ def test_results():
result.rc('localhost')


def test_results_dry_run():
result = Api('localhost', dry_run=True).command('whoami')
assert not result['contacted']
with pytest.raises(ValueError, match=r'not available in dry run'):
result.rc()

with pytest.raises(ValueError, match=r'not available in dry run'):
result.rc('localhost')


@pytest.mark.parametrize("server", ('localhost',))
def test_results_single_server(server):
result = Api(server).command('whoami')
Expand All @@ -92,15 +102,17 @@ def test_results_multiple_servers():
result = RunnerResults({
'contacted': {
'web.seantis.dev': {'rc': 0},
'db.seantis.dev': {'rc': 1}
'db.seantis.dev': {'rc': 1},
'buggy.result.dev': {},
}
})

with pytest.raises(KeyError):
result.rc()

assert result.rc('web.seantis.dev') == 0
assert result.rc('db.seantis.dev') == 1
with pytest.raises(AttributeError, match=r'rc'):
result.rc('buggy.result.dev')
with pytest.raises(ValueError, match=r'When contacting multiple'):
result.rc()


@pytest.mark.parametrize("server", (('localhost', 'localhost:22'),))
Expand All @@ -109,6 +121,8 @@ def test_whoami_multiple_servers(server):
results = host.command('whoami')
assert results.rc(server[0]) == 0
assert results.rc(server[1]) == 0
with pytest.raises(ValueError, match=r'When contacting multiple'):
results.rc()


def test_non_scalar_parameter():
Expand Down

0 comments on commit 17e154f

Please sign in to comment.